diff --git a/Makefile b/Makefile index 13be5b8..0fa7ddc 100644 --- a/Makefile +++ b/Makefile @@ -1,8 +1,11 @@ all: install data train eval -city_distances.csv: check generate_data.py +city_distances.csv: generate_data.py @echo "Generating distance data..." @bash -c 'time python generate_data.py --country US --workers 8 --chunk-size 4200' + @echo "Calculating range of generated data..." + @cat city_distances.csv | tail -n +2 | sort -t',' -k3n | head -n1 + @cat city_distances.csv | tail -n +2 | sort -t',' -k3nr | head -n1 data: city_distances.csv @@ -29,10 +32,10 @@ clean: @rm -rf output/ @rm -rf checkpoints/ -compress: plots/progress_136013_sm.png +compress: plots/progress_12474_sm.png -plots/progress_136013_sm.png: plots/progress_136013.png - @convert -resize 33% plots/progress_136013.png plots/progress_136013_sm.png +plots/progress_12474_sm.png: plots/progress_12474.png + @convert -resize 33% plots/progress_12474.png progress_sample.png install: .requirements_installed diff --git a/README.md b/README.md index 80116bd..d877868 100644 --- a/README.md +++ b/README.md @@ -59,9 +59,9 @@ The approach demonstrated can be extended to other metrics or features beyond ge After training, the model should be able to understand the similarity between cities based on their geodesic distances. You can inspect the evaluation plots generated by the `eval.py` script to see the improvement in similarity scores before and after training. -After one epoch, we can see the model has learned to correlate our desired quantities: +After even just one epoch, we can see the model has learned to correlate our desired quantities: -![Evaluation plot](./plots/progress_136013_sm.png) +![Evaluation plot](./progress_sample.png) *The above plot is an example showing the relationship between geodesic distance and the similarity between the embedded vectors (1 = more similar), for 10,000 randomly selected pairs of US cities (re-sampled for each image).* @@ -82,7 +82,7 @@ There are several potential improvements and extensions to the current model: # Notes -- Generating the data took about 10-15 minutes (for 3269 US cities, of which there were 2826 unique names), in parallel on 8-cores (Intel 9700K), yielding 3,991,725 (combinations of cities) with size 150MB. +- Generating the data took about 10 minutes (for 3269 US cities, of which there were 2826 unique names), in parallel on 8-cores (Intel 9700K), yielding 3,991,725 (combinations of cities) with size 150MB. - For cities with the same name, the one with the larger population is selected (had to make some sort of choice...). - Training on an Nvidia 3090 FE takes about an hour per epoch with an 80/20 test/train split and batch size 16. At batch size 16 times larger, each epoch took about 5-6 minutes. - Evaluation (generating plots) on the above hardware took about 15 minutes for 20 epochs at 10k samples each. diff --git a/generate_data.py b/generate_data.py index a85e8b2..349f997 100644 --- a/generate_data.py +++ b/generate_data.py @@ -11,44 +11,39 @@ from geopy.distance import geodesic from tqdm import tqdm MAX_DISTANCE = 20_037.5 +CACHE = geonamescache.GeonamesCache() + # Add argparse -parser = argparse.ArgumentParser() -parser.add_argument( - "-c", "--country", help="Specify the country code", type=str, default="US" -) -parser.add_argument( - "-w", "--workers", help="Specify the number of workers", type=int, default=1 -) -parser.add_argument( - "-s", - "--chunk-size", - help="Specify chunk size for batching calculations", - type=int, - default=1000, -) -parser.add_argument( - "-o", - "--output-file", - help="Specify the name of the output file (file.csv)", - type=str, - default="city_distances.csv", -) -parser.add_argument( - "--shuffle", - action="store_true", - help="Option to shuffle combinations list before iterating over it", -) -args = parser.parse_args() - - -gc = geonamescache.GeonamesCache() -cities = gc.get_cities() -us_cities = { - k: c - for k, c in cities.items() - if (c.get("countrycode") == args.country) # & (c.get("population", 0) > 5e4) -} +def parse_args(): + parser = argparse.ArgumentParser() + parser.add_argument( + "-c", "--country", help="Specify the country code", type=str, default="US" + ) + parser.add_argument( + "-w", "--workers", help="Specify the number of workers", type=int, default=1 + ) + parser.add_argument( + "-s", + "--chunk-size", + help="Specify chunk size for batching calculations", + type=int, + default=1000, + ) + parser.add_argument( + "-o", + "--output-file", + help="Specify the name of the output file (file.csv)", + type=str, + default="distances.csv", + ) + parser.add_argument( + "--shuffle", + action="store_true", + help="Option to shuffle combinations list before iterating over it", + ) + args = parser.parse_args() + return args @lru_cache(maxsize=None) @@ -69,16 +64,52 @@ def get_coordinates(city_name, country_code="US"): A tuple containing the latitude and longitude of the city, or None if the city is not found. """ - search_results = gc.search_cities(city_name, case_sensitive=True) + city = find_city(city_name, country_code) + if city is None: + return None + return city.get("latitude"), city.get("longitude") + + +@lru_cache(maxsize=None) +def find_city(city_name, country_code="US"): + """ + Finds the matching city. + + Parameters + ---------- + city_name : str + The name of the city. + country_code : str, optional + The country code of the city, by default 'US'. + + Returns + ------- + city + A dict containing the raw data about the city. + + """ + search_results = CACHE.get_cities_by_name(city_name) + # search_results = [ + # list(c.values())[0] for c in search_results + # ] + search_results = [inner_dict for d in search_results for inner_dict in d.values()] + if not search_results: # if not found by name, search alternatenames + search_results = CACHE.search_cities( + city_name, attribute="alternatenames", case_sensitive=True + ) + # filter search results to match requested country + # and avoid wasted computation if coordinates missing search_results = [ - d for d in search_results if (d.get("countrycode") == country_code) + d + for d in search_results + if (d.get("countrycode") == country_code) & (d.get("longitude") is not None) ] if not search_results: - return None, None + return None populations = [city.get("population") for city in search_results] city = search_results[np.argmax(populations)] - return city.get("latitude"), city.get("longitude") + return city def get_distance(city1, city2, country1="US", country2="US"): @@ -117,25 +148,39 @@ def calculate_distance(pair): return city1, city2, distance -def main(): +def main(args): + output_file = args.output_file + shuffle = args.shuffle + country_code = args.country + chunk_size = args.chunk_size + max_workers = args.workers + + cities = CACHE.get_cities() + us_cities = { + k: c + for k, c in cities.items() + if (c.get("countrycode") == country_code) & (c.get("longitude") is not None) + } + # & (c.get("population", 0) > 5e4) + cities = list(us_cities.values()) unique_names = set([c.get("name") for c in cities]) + unique_names = sorted(list(unique_names)) # unique_cities = [c for c in cities if c.get("name") in unique_names] print(f"Num cities: {len(cities)}, unique names: {len(unique_names)}") city_combinations = list(itertools.combinations(unique_names, 2)) - if args.shuffle: + if shuffle: np.random.shuffle(city_combinations) - chunk_size = args.chunk_size - num_chunks = len(city_combinations) // chunk_size + 1 - output_file = args.output_file + # chunk size, city_combinations, max_workers, output_file + num_chunks = len(city_combinations) // chunk_size + 1 with open(output_file, "w", newline="") as csvfile: fieldnames = ["city_from", "city_to", "distance"] writer = csv.DictWriter(csvfile, fieldnames=fieldnames) writer.writeheader() try: - executor = concurrent.futures.ProcessPoolExecutor(max_workers=args.workers) + executor = concurrent.futures.ProcessPoolExecutor(max_workers=max_workers) for i in tqdm( range(num_chunks), total=num_chunks, @@ -163,6 +208,20 @@ def main(): executor.shutdown(wait=False) raise SystemExit("Execution terminated by user.") + print(f"Wrote {output_file}") + if __name__ == "__main__": - main() + # preliminary check + assert find_city("New York City") is not None + assert find_city("NYC") is not None + assert round(get_distance("NYC", "Jamaica"), 2) == 17.11 + args = parse_args() + main(args) + # perform check + print("Performing a quick validation...") + import pandas as pd + + df = pd.read_csv(args.output_file) + assert df["distance"].min() > 0 + assert df["distance"].max() < MAX_DISTANCE diff --git a/plots/progress_136013_sm.png b/plots/progress_136013_sm.png deleted file mode 100644 index b70a651..0000000 Binary files a/plots/progress_136013_sm.png and /dev/null differ diff --git a/progress_sample.png b/progress_sample.png new file mode 100644 index 0000000..22de9c4 Binary files /dev/null and b/progress_sample.png differ