diff --git a/Makefile b/Makefile index 0afdc16..d646f65 100644 --- a/Makefile +++ b/Makefile @@ -2,7 +2,7 @@ all: install data train eval city_distances_full.csv: check generate_data.py @echo "Generating distance data..." - @bash -c 'time python generate_data.py' + @bash -c 'time python generate_data.py -w 8 -c US' data: city_distances_full.csv diff --git a/generate_data.py b/generate_data.py index b16f3e4..865d06d 100644 --- a/generate_data.py +++ b/generate_data.py @@ -1,3 +1,4 @@ +import argparse import concurrent.futures import csv import itertools @@ -10,12 +11,20 @@ from geopy.distance import geodesic MAX_DISTANCE = 20_037.5 +# Add argparse +parser = argparse.ArgumentParser() +parser.add_argument("-c", "--country", help="Specify the country code", type=str, default="US") +parser.add_argument("-w", "--workers", help="Specify the number of workers", type=int, default=1) +parser.add_argument("-o", "--output-file", help="Specify the name of the output file (file.csv)", type=str, default="city_distances_full.csv") +args = parser.parse_args() + + gc = geonamescache.GeonamesCache() cities = gc.get_cities() us_cities = { k: c for k, c in cities.items() - if (c.get("countrycode") == "US") # & (c.get("population", 0) > 5e4) + if (c.get("countrycode") == args.country) # & (c.get("population", 0) > 5e4) } @@ -92,13 +101,13 @@ def main(): print(f"Num cities: {len(cities)}") city_combinations = list(itertools.combinations(cities, 2)) - with open("city_distances_full.csv", "w", newline="") as csvfile: + with open(args.output_file, "w", newline="") as csvfile: fieldnames = ["city_from", "city_to", "distance"] writer = csv.DictWriter(csvfile, fieldnames=fieldnames) writer.writeheader() try: - executor = concurrent.futures.ProcessPoolExecutor(max_workers=8) + executor = concurrent.futures.ProcessPoolExecutor(max_workers=args.workers) # results = executor.map(calculate_distance, city_combinations) futures = { executor.submit(calculate_distance, pair): pair