diff --git a/README.md b/README.md index d99b5a9..80116bd 100644 --- a/README.md +++ b/README.md @@ -82,7 +82,8 @@ There are several potential improvements and extensions to the current model: # Notes -- Generating the data took about 13 minutes (for 3269 US cities) on 8-cores (Intel 9700K), yielding 2,720,278 records (combinations of cities). -- Training on an Nvidia 3090 FE takes about an hour per epoch with an 80/20 test/train split and batch size 16, so there were 136,014 steps per epoch. At batch size 16 times larger, each epoch took about 14 minutes. +- Generating the data took about 10-15 minutes (for 3269 US cities, of which there were 2826 unique names), in parallel on 8-cores (Intel 9700K), yielding 3,991,725 (combinations of cities) with size 150MB. +- For cities with the same name, the one with the larger population is selected (had to make some sort of choice...). +- Training on an Nvidia 3090 FE takes about an hour per epoch with an 80/20 test/train split and batch size 16. At batch size 16 times larger, each epoch took about 5-6 minutes. - Evaluation (generating plots) on the above hardware took about 15 minutes for 20 epochs at 10k samples each. - **WARNING**: _It is unclear how the model performs on sentences, as it was trained and evaluated only on word-pairs._ See improvement (5) above. diff --git a/eval.py b/eval.py index 771df64..d7824ce 100644 --- a/eval.py +++ b/eval.py @@ -68,7 +68,7 @@ if __name__ == "__main__": data = pd.read_csv("city_distances.csv") # data_sample = data.sample(1_000) - checkpoint_dir = "checkpoints_absmax_split" # no slash + checkpoint_dir = "checkpoints" # no slash for checkpoint in sorted(glob.glob(f"{checkpoint_dir}/*")): print(f"Evaluating {checkpoint}") data_sample = data.sample(1_000) diff --git a/generate_data.py b/generate_data.py index 5a0e0f1..a85e8b2 100644 --- a/generate_data.py +++ b/generate_data.py @@ -75,7 +75,7 @@ def get_coordinates(city_name, country_code="US"): ] if not search_results: - return None + return None, None populations = [city.get("population") for city in search_results] city = search_results[np.argmax(populations)] return city.get("latitude"), city.get("longitude") @@ -113,14 +113,16 @@ def get_distance(city1, city2, country1="US", country2="US"): def calculate_distance(pair): city1, city2 = pair - distance = get_distance(city1["name"], city2["name"]) - return city1["name"], city2["name"], distance + distance = get_distance(city1, city2) + return city1, city2, distance def main(): cities = list(us_cities.values()) - print(f"Num cities: {len(cities)}") - city_combinations = list(itertools.combinations(cities, 2)) + unique_names = set([c.get("name") for c in cities]) + # unique_cities = [c for c in cities if c.get("name") in unique_names] + print(f"Num cities: {len(cities)}, unique names: {len(unique_names)}") + city_combinations = list(itertools.combinations(unique_names, 2)) if args.shuffle: np.random.shuffle(city_combinations) chunk_size = args.chunk_size