fix combinatorial problem, cleanup

This commit is contained in:
mm 2023-05-05 07:50:22 +00:00
parent 9c439bb6c8
commit ab26735c82
3 changed files with 11 additions and 8 deletions

View File

@ -82,7 +82,8 @@ There are several potential improvements and extensions to the current model:
# Notes
- Generating the data took about 13 minutes (for 3269 US cities) on 8-cores (Intel 9700K), yielding 2,720,278 records (combinations of cities).
- Training on an Nvidia 3090 FE takes about an hour per epoch with an 80/20 test/train split and batch size 16, so there were 136,014 steps per epoch. At batch size 16 times larger, each epoch took about 14 minutes.
- Generating the data took about 10-15 minutes (for 3269 US cities, of which there were 2826 unique names), in parallel on 8-cores (Intel 9700K), yielding 3,991,725 (combinations of cities) with size 150MB.
- For cities with the same name, the one with the larger population is selected (had to make some sort of choice...).
- Training on an Nvidia 3090 FE takes about an hour per epoch with an 80/20 test/train split and batch size 16. At batch size 16 times larger, each epoch took about 5-6 minutes.
- Evaluation (generating plots) on the above hardware took about 15 minutes for 20 epochs at 10k samples each.
- **WARNING**: _It is unclear how the model performs on sentences, as it was trained and evaluated only on word-pairs._ See improvement (5) above.

View File

@ -68,7 +68,7 @@ if __name__ == "__main__":
data = pd.read_csv("city_distances.csv")
# data_sample = data.sample(1_000)
checkpoint_dir = "checkpoints_absmax_split" # no slash
checkpoint_dir = "checkpoints" # no slash
for checkpoint in sorted(glob.glob(f"{checkpoint_dir}/*")):
print(f"Evaluating {checkpoint}")
data_sample = data.sample(1_000)

View File

@ -75,7 +75,7 @@ def get_coordinates(city_name, country_code="US"):
]
if not search_results:
return None
return None, None
populations = [city.get("population") for city in search_results]
city = search_results[np.argmax(populations)]
return city.get("latitude"), city.get("longitude")
@ -113,14 +113,16 @@ def get_distance(city1, city2, country1="US", country2="US"):
def calculate_distance(pair):
city1, city2 = pair
distance = get_distance(city1["name"], city2["name"])
return city1["name"], city2["name"], distance
distance = get_distance(city1, city2)
return city1, city2, distance
def main():
cities = list(us_cities.values())
print(f"Num cities: {len(cities)}")
city_combinations = list(itertools.combinations(cities, 2))
unique_names = set([c.get("name") for c in cities])
# unique_cities = [c for c in cities if c.get("name") in unique_names]
print(f"Num cities: {len(cities)}, unique names: {len(unique_names)}")
city_combinations = list(itertools.combinations(unique_names, 2))
if args.shuffle:
np.random.shuffle(city_combinations)
chunk_size = args.chunk_size