diff --git a/Makefile b/Makefile index 13be5b8..0fa7ddc 100644 --- a/Makefile +++ b/Makefile @@ -1,8 +1,11 @@ all: install data train eval -city_distances.csv: check generate_data.py +city_distances.csv: generate_data.py @echo "Generating distance data..." @bash -c 'time python generate_data.py --country US --workers 8 --chunk-size 4200' + @echo "Calculating range of generated data..." + @cat city_distances.csv | tail -n +2 | sort -t',' -k3n | head -n1 + @cat city_distances.csv | tail -n +2 | sort -t',' -k3nr | head -n1 data: city_distances.csv @@ -29,10 +32,10 @@ clean: @rm -rf output/ @rm -rf checkpoints/ -compress: plots/progress_136013_sm.png +compress: plots/progress_12474_sm.png -plots/progress_136013_sm.png: plots/progress_136013.png - @convert -resize 33% plots/progress_136013.png plots/progress_136013_sm.png +plots/progress_12474_sm.png: plots/progress_12474.png + @convert -resize 33% plots/progress_12474.png progress_sample.png install: .requirements_installed diff --git a/README.md b/README.md index 80116bd..ab56641 100644 --- a/README.md +++ b/README.md @@ -61,7 +61,7 @@ You can inspect the evaluation plots generated by the `eval.py` script to see th After one epoch, we can see the model has learned to correlate our desired quantities: -![Evaluation plot](./plots/progress_136013_sm.png) +![Evaluation plot](./plots/progress_sample.png) *The above plot is an example showing the relationship between geodesic distance and the similarity between the embedded vectors (1 = more similar), for 10,000 randomly selected pairs of US cities (re-sampled for each image).* @@ -82,7 +82,7 @@ There are several potential improvements and extensions to the current model: # Notes -- Generating the data took about 10-15 minutes (for 3269 US cities, of which there were 2826 unique names), in parallel on 8-cores (Intel 9700K), yielding 3,991,725 (combinations of cities) with size 150MB. +- Generating the data took about 10 minutes (for 3269 US cities, of which there were 2826 unique names), in parallel on 8-cores (Intel 9700K), yielding 3,991,725 (combinations of cities) with size 150MB. - For cities with the same name, the one with the larger population is selected (had to make some sort of choice...). - Training on an Nvidia 3090 FE takes about an hour per epoch with an 80/20 test/train split and batch size 16. At batch size 16 times larger, each epoch took about 5-6 minutes. - Evaluation (generating plots) on the above hardware took about 15 minutes for 20 epochs at 10k samples each. diff --git a/generate_data.py b/generate_data.py index a85e8b2..ebef959 100644 --- a/generate_data.py +++ b/generate_data.py @@ -11,44 +11,39 @@ from geopy.distance import geodesic from tqdm import tqdm MAX_DISTANCE = 20_037.5 +CACHE = geonamescache.GeonamesCache() + # Add argparse -parser = argparse.ArgumentParser() -parser.add_argument( - "-c", "--country", help="Specify the country code", type=str, default="US" -) -parser.add_argument( - "-w", "--workers", help="Specify the number of workers", type=int, default=1 -) -parser.add_argument( - "-s", - "--chunk-size", - help="Specify chunk size for batching calculations", - type=int, - default=1000, -) -parser.add_argument( - "-o", - "--output-file", - help="Specify the name of the output file (file.csv)", - type=str, - default="city_distances.csv", -) -parser.add_argument( - "--shuffle", - action="store_true", - help="Option to shuffle combinations list before iterating over it", -) -args = parser.parse_args() - - -gc = geonamescache.GeonamesCache() -cities = gc.get_cities() -us_cities = { - k: c - for k, c in cities.items() - if (c.get("countrycode") == args.country) # & (c.get("population", 0) > 5e4) -} +def parse_args(): + parser = argparse.ArgumentParser() + parser.add_argument( + "-c", "--country", help="Specify the country code", type=str, default="US" + ) + parser.add_argument( + "-w", "--workers", help="Specify the number of workers", type=int, default=1 + ) + parser.add_argument( + "-s", + "--chunk-size", + help="Specify chunk size for batching calculations", + type=int, + default=1000, + ) + parser.add_argument( + "-o", + "--output-file", + help="Specify the name of the output file (file.csv)", + type=str, + default="city_distances.csv", + ) + parser.add_argument( + "--shuffle", + action="store_true", + help="Option to shuffle combinations list before iterating over it", + ) + args = parser.parse_args() + return args @lru_cache(maxsize=None) @@ -69,16 +64,52 @@ def get_coordinates(city_name, country_code="US"): A tuple containing the latitude and longitude of the city, or None if the city is not found. """ - search_results = gc.search_cities(city_name, case_sensitive=True) + city = find_city(city_name, country_code) + if city is None: + return None + return city.get("latitude"), city.get("longitude") + + +@lru_cache(maxsize=None) +def find_city(city_name, country_code="US"): + """ + Finds the matching city. + + Parameters + ---------- + city_name : str + The name of the city. + country_code : str, optional + The country code of the city, by default 'US'. + + Returns + ------- + city + A dict containing the raw data about the city. + + """ + search_results = CACHE.get_cities_by_name(city_name) + # search_results = [ + # list(c.values())[0] for c in search_results + # ] + search_results = [inner_dict for d in search_results for inner_dict in d.values()] + if not search_results: # if not found by name, search alternatenames + search_results = CACHE.search_cities( + city_name, attribute="alternatenames", case_sensitive=True + ) + # filter search results to match requested country + # and avoid wasted computation if coordinates missing search_results = [ - d for d in search_results if (d.get("countrycode") == country_code) + d + for d in search_results + if (d.get("countrycode") == country_code) & (d.get("longitude") is not None) ] if not search_results: - return None, None + return None populations = [city.get("population") for city in search_results] city = search_results[np.argmax(populations)] - return city.get("latitude"), city.get("longitude") + return city def get_distance(city1, city2, country1="US", country2="US"): @@ -117,25 +148,39 @@ def calculate_distance(pair): return city1, city2, distance -def main(): +def main(args): + output_file = args.output_file + shuffle = args.shuffle + country_code = args.country + chunk_size = args.chunk_size + max_workers = args.workers + + cities = CACHE.get_cities() + us_cities = { + k: c + for k, c in cities.items() + if (c.get("countrycode") == country_code) & (c.get("longitude") is not None) + } + # & (c.get("population", 0) > 5e4) + cities = list(us_cities.values()) unique_names = set([c.get("name") for c in cities]) + unique_names = sorted(list(unique_names)) # unique_cities = [c for c in cities if c.get("name") in unique_names] print(f"Num cities: {len(cities)}, unique names: {len(unique_names)}") city_combinations = list(itertools.combinations(unique_names, 2)) - if args.shuffle: + if shuffle: np.random.shuffle(city_combinations) - chunk_size = args.chunk_size - num_chunks = len(city_combinations) // chunk_size + 1 - output_file = args.output_file + # chunk size, city_combinations, max_workers, output_file + num_chunks = len(city_combinations) // chunk_size + 1 with open(output_file, "w", newline="") as csvfile: fieldnames = ["city_from", "city_to", "distance"] writer = csv.DictWriter(csvfile, fieldnames=fieldnames) writer.writeheader() try: - executor = concurrent.futures.ProcessPoolExecutor(max_workers=args.workers) + executor = concurrent.futures.ProcessPoolExecutor(max_workers=max_workers) for i in tqdm( range(num_chunks), total=num_chunks, @@ -163,6 +208,20 @@ def main(): executor.shutdown(wait=False) raise SystemExit("Execution terminated by user.") + print(f"Wrote {output_file}") + if __name__ == "__main__": - main() + # preliminary check + assert find_city("New York City") is not None + assert find_city("NYC") is not None + assert round(get_distance("NYC", "Jamaica"), 2) == 17.11 + args = parse_args() + main(args) + # perform check + print("Performing a quick validation...") + import pandas as pd + + df = pd.read_csv(args.output_file) + assert df["distance"].min() > 0 + assert df["distance"].max() < MAX_DISTANCE diff --git a/plots/progress_136013_sm.png b/plots/progress_136013_sm.png deleted file mode 100644 index b70a651..0000000 Binary files a/plots/progress_136013_sm.png and /dev/null differ