Browse Source

add argparse

main
mm 2 years ago
parent
commit
96f0c3a313
  1. 2
      Makefile
  2. 15
      generate_data.py

2
Makefile

@ -2,7 +2,7 @@ all: install data train eval
city_distances_full.csv: check generate_data.py city_distances_full.csv: check generate_data.py
@echo "Generating distance data..." @echo "Generating distance data..."
@bash -c 'time python generate_data.py' @bash -c 'time python generate_data.py -w 8 -c US'
data: city_distances_full.csv data: city_distances_full.csv

15
generate_data.py

@ -1,3 +1,4 @@
import argparse
import concurrent.futures import concurrent.futures
import csv import csv
import itertools import itertools
@ -10,12 +11,20 @@ from geopy.distance import geodesic
MAX_DISTANCE = 20_037.5 MAX_DISTANCE = 20_037.5
# Add argparse
parser = argparse.ArgumentParser()
parser.add_argument("-c", "--country", help="Specify the country code", type=str, default="US")
parser.add_argument("-w", "--workers", help="Specify the number of workers", type=int, default=1)
parser.add_argument("-o", "--output-file", help="Specify the name of the output file (file.csv)", type=str, default="city_distances_full.csv")
args = parser.parse_args()
gc = geonamescache.GeonamesCache() gc = geonamescache.GeonamesCache()
cities = gc.get_cities() cities = gc.get_cities()
us_cities = { us_cities = {
k: c k: c
for k, c in cities.items() for k, c in cities.items()
if (c.get("countrycode") == "US") # & (c.get("population", 0) > 5e4) if (c.get("countrycode") == args.country) # & (c.get("population", 0) > 5e4)
} }
@ -92,13 +101,13 @@ def main():
print(f"Num cities: {len(cities)}") print(f"Num cities: {len(cities)}")
city_combinations = list(itertools.combinations(cities, 2)) city_combinations = list(itertools.combinations(cities, 2))
with open("city_distances_full.csv", "w", newline="") as csvfile: with open(args.output_file, "w", newline="") as csvfile:
fieldnames = ["city_from", "city_to", "distance"] fieldnames = ["city_from", "city_to", "distance"]
writer = csv.DictWriter(csvfile, fieldnames=fieldnames) writer = csv.DictWriter(csvfile, fieldnames=fieldnames)
writer.writeheader() writer.writeheader()
try: try:
executor = concurrent.futures.ProcessPoolExecutor(max_workers=8) executor = concurrent.futures.ProcessPoolExecutor(max_workers=args.workers)
# results = executor.map(calculate_distance, city_combinations) # results = executor.map(calculate_distance, city_combinations)
futures = { futures = {
executor.submit(calculate_distance, pair): pair executor.submit(calculate_distance, pair): pair

Loading…
Cancel
Save