add argparse
This commit is contained in:
parent
57ef4c06df
commit
96f0c3a313
2
Makefile
2
Makefile
@ -2,7 +2,7 @@ all: install data train eval
|
|||||||
|
|
||||||
city_distances_full.csv: check generate_data.py
|
city_distances_full.csv: check generate_data.py
|
||||||
@echo "Generating distance data..."
|
@echo "Generating distance data..."
|
||||||
@bash -c 'time python generate_data.py'
|
@bash -c 'time python generate_data.py -w 8 -c US'
|
||||||
|
|
||||||
data: city_distances_full.csv
|
data: city_distances_full.csv
|
||||||
|
|
||||||
|
@ -1,3 +1,4 @@
|
|||||||
|
import argparse
|
||||||
import concurrent.futures
|
import concurrent.futures
|
||||||
import csv
|
import csv
|
||||||
import itertools
|
import itertools
|
||||||
@ -10,12 +11,20 @@ from geopy.distance import geodesic
|
|||||||
|
|
||||||
MAX_DISTANCE = 20_037.5
|
MAX_DISTANCE = 20_037.5
|
||||||
|
|
||||||
|
# Add argparse
|
||||||
|
parser = argparse.ArgumentParser()
|
||||||
|
parser.add_argument("-c", "--country", help="Specify the country code", type=str, default="US")
|
||||||
|
parser.add_argument("-w", "--workers", help="Specify the number of workers", type=int, default=1)
|
||||||
|
parser.add_argument("-o", "--output-file", help="Specify the name of the output file (file.csv)", type=str, default="city_distances_full.csv")
|
||||||
|
args = parser.parse_args()
|
||||||
|
|
||||||
|
|
||||||
gc = geonamescache.GeonamesCache()
|
gc = geonamescache.GeonamesCache()
|
||||||
cities = gc.get_cities()
|
cities = gc.get_cities()
|
||||||
us_cities = {
|
us_cities = {
|
||||||
k: c
|
k: c
|
||||||
for k, c in cities.items()
|
for k, c in cities.items()
|
||||||
if (c.get("countrycode") == "US") # & (c.get("population", 0) > 5e4)
|
if (c.get("countrycode") == args.country) # & (c.get("population", 0) > 5e4)
|
||||||
}
|
}
|
||||||
|
|
||||||
|
|
||||||
@ -92,13 +101,13 @@ def main():
|
|||||||
print(f"Num cities: {len(cities)}")
|
print(f"Num cities: {len(cities)}")
|
||||||
city_combinations = list(itertools.combinations(cities, 2))
|
city_combinations = list(itertools.combinations(cities, 2))
|
||||||
|
|
||||||
with open("city_distances_full.csv", "w", newline="") as csvfile:
|
with open(args.output_file, "w", newline="") as csvfile:
|
||||||
fieldnames = ["city_from", "city_to", "distance"]
|
fieldnames = ["city_from", "city_to", "distance"]
|
||||||
writer = csv.DictWriter(csvfile, fieldnames=fieldnames)
|
writer = csv.DictWriter(csvfile, fieldnames=fieldnames)
|
||||||
writer.writeheader()
|
writer.writeheader()
|
||||||
|
|
||||||
try:
|
try:
|
||||||
executor = concurrent.futures.ProcessPoolExecutor(max_workers=8)
|
executor = concurrent.futures.ProcessPoolExecutor(max_workers=args.workers)
|
||||||
# results = executor.map(calculate_distance, city_combinations)
|
# results = executor.map(calculate_distance, city_combinations)
|
||||||
futures = {
|
futures = {
|
||||||
executor.submit(calculate_distance, pair): pair
|
executor.submit(calculate_distance, pair): pair
|
||||||
|
Loading…
Reference in New Issue
Block a user