fix combinatorial problem, cleanup
This commit is contained in:
parent
9c439bb6c8
commit
ab26735c82
@ -82,7 +82,8 @@ There are several potential improvements and extensions to the current model:
|
||||
|
||||
|
||||
# Notes
|
||||
- Generating the data took about 13 minutes (for 3269 US cities) on 8-cores (Intel 9700K), yielding 2,720,278 records (combinations of cities).
|
||||
- Training on an Nvidia 3090 FE takes about an hour per epoch with an 80/20 test/train split and batch size 16, so there were 136,014 steps per epoch. At batch size 16 times larger, each epoch took about 14 minutes.
|
||||
- Generating the data took about 10-15 minutes (for 3269 US cities, of which there were 2826 unique names), in parallel on 8-cores (Intel 9700K), yielding 3,991,725 (combinations of cities) with size 150MB.
|
||||
- For cities with the same name, the one with the larger population is selected (had to make some sort of choice...).
|
||||
- Training on an Nvidia 3090 FE takes about an hour per epoch with an 80/20 test/train split and batch size 16. At batch size 16 times larger, each epoch took about 5-6 minutes.
|
||||
- Evaluation (generating plots) on the above hardware took about 15 minutes for 20 epochs at 10k samples each.
|
||||
- **WARNING**: _It is unclear how the model performs on sentences, as it was trained and evaluated only on word-pairs._ See improvement (5) above.
|
||||
|
2
eval.py
2
eval.py
@ -68,7 +68,7 @@ if __name__ == "__main__":
|
||||
|
||||
data = pd.read_csv("city_distances.csv")
|
||||
# data_sample = data.sample(1_000)
|
||||
checkpoint_dir = "checkpoints_absmax_split" # no slash
|
||||
checkpoint_dir = "checkpoints" # no slash
|
||||
for checkpoint in sorted(glob.glob(f"{checkpoint_dir}/*")):
|
||||
print(f"Evaluating {checkpoint}")
|
||||
data_sample = data.sample(1_000)
|
||||
|
@ -75,7 +75,7 @@ def get_coordinates(city_name, country_code="US"):
|
||||
]
|
||||
|
||||
if not search_results:
|
||||
return None
|
||||
return None, None
|
||||
populations = [city.get("population") for city in search_results]
|
||||
city = search_results[np.argmax(populations)]
|
||||
return city.get("latitude"), city.get("longitude")
|
||||
@ -113,14 +113,16 @@ def get_distance(city1, city2, country1="US", country2="US"):
|
||||
|
||||
def calculate_distance(pair):
|
||||
city1, city2 = pair
|
||||
distance = get_distance(city1["name"], city2["name"])
|
||||
return city1["name"], city2["name"], distance
|
||||
distance = get_distance(city1, city2)
|
||||
return city1, city2, distance
|
||||
|
||||
|
||||
def main():
|
||||
cities = list(us_cities.values())
|
||||
print(f"Num cities: {len(cities)}")
|
||||
city_combinations = list(itertools.combinations(cities, 2))
|
||||
unique_names = set([c.get("name") for c in cities])
|
||||
# unique_cities = [c for c in cities if c.get("name") in unique_names]
|
||||
print(f"Num cities: {len(cities)}, unique names: {len(unique_names)}")
|
||||
city_combinations = list(itertools.combinations(unique_names, 2))
|
||||
if args.shuffle:
|
||||
np.random.shuffle(city_combinations)
|
||||
chunk_size = args.chunk_size
|
||||
|
Loading…
Reference in New Issue
Block a user