Compare commits
No commits in common. "6bb4a64624a5c7c17675c899003826cd3172dc66" and "57ef4c06df05ed4a7d18784ec8b0d3250515b51d" have entirely different histories.
6bb4a64624
...
57ef4c06df
1
.gitignore
vendored
1
.gitignore
vendored
@ -2,4 +2,3 @@ checkpoints*
|
||||
plots*
|
||||
*.csv
|
||||
output/
|
||||
.requirements_installed
|
||||
|
9
Makefile
9
Makefile
@ -2,7 +2,7 @@ all: install data train eval
|
||||
|
||||
city_distances_full.csv: check generate_data.py
|
||||
@echo "Generating distance data..."
|
||||
@bash -c 'time python generate_data.py -w 8 -c US -s 10000'
|
||||
@bash -c 'time python generate_data.py'
|
||||
|
||||
data: city_distances_full.csv
|
||||
|
||||
@ -18,7 +18,7 @@ lint:
|
||||
@echo "Auto-linting files and performing final style checks..."
|
||||
@isort --profile=black .
|
||||
@black .
|
||||
@flake8 --max-line-length=88 --ignore E203 .
|
||||
@flake8 --max-line-length=88 .
|
||||
|
||||
check: lint
|
||||
@echo "Checking for unstaged or untracked changes..."
|
||||
@ -34,10 +34,7 @@ compress: plots/progress_136013_sm.png
|
||||
plots/progress_136013_sm.png: plots/progress_136013.png
|
||||
@convert -resize 33% plots/progress_136013.png plots/progress_136013_sm.png
|
||||
|
||||
install: .requirements_installed
|
||||
|
||||
.requirements_installed: requirements.txt
|
||||
install:
|
||||
pip install -r requirements.txt
|
||||
@echo "installed requirements" > .requirements_installed
|
||||
|
||||
.PHONY: data train eval lint check clean all
|
@ -1,4 +1,3 @@
|
||||
import argparse
|
||||
import concurrent.futures
|
||||
import csv
|
||||
import itertools
|
||||
@ -8,41 +7,15 @@ from functools import lru_cache
|
||||
import geonamescache
|
||||
import numpy as np
|
||||
from geopy.distance import geodesic
|
||||
from tqdm import tqdm
|
||||
|
||||
MAX_DISTANCE = 20_037.5
|
||||
|
||||
# Add argparse
|
||||
parser = argparse.ArgumentParser()
|
||||
parser.add_argument(
|
||||
"-c", "--country", help="Specify the country code", type=str, default="US"
|
||||
)
|
||||
parser.add_argument(
|
||||
"-w", "--workers", help="Specify the number of workers", type=int, default=1
|
||||
)
|
||||
parser.add_argument(
|
||||
"-s",
|
||||
"--chunk-size",
|
||||
help="Specify chunk size for batching calculations",
|
||||
type=int,
|
||||
default=1000,
|
||||
)
|
||||
parser.add_argument(
|
||||
"-o",
|
||||
"--output-file",
|
||||
help="Specify the name of the output file (file.csv)",
|
||||
type=str,
|
||||
default="city_distances_full.csv",
|
||||
)
|
||||
args = parser.parse_args()
|
||||
|
||||
|
||||
gc = geonamescache.GeonamesCache()
|
||||
cities = gc.get_cities()
|
||||
us_cities = {
|
||||
k: c
|
||||
for k, c in cities.items()
|
||||
if (c.get("countrycode") == args.country) # & (c.get("population", 0) > 5e4)
|
||||
if (c.get("countrycode") == "US") # & (c.get("population", 0) > 5e4)
|
||||
}
|
||||
|
||||
|
||||
@ -65,9 +38,11 @@ def get_coordinates(city_name, country_code="US"):
|
||||
or None if the city is not found.
|
||||
"""
|
||||
search_results = gc.search_cities(city_name, case_sensitive=True)
|
||||
search_results = [
|
||||
d for d in search_results if (d.get("countrycode") == country_code)
|
||||
]
|
||||
search_results = {
|
||||
k: c
|
||||
for k, c in search_results.items()
|
||||
if (c.get("countrycode") == country_code)
|
||||
}
|
||||
|
||||
if not search_results:
|
||||
return None
|
||||
@ -116,27 +91,18 @@ def main():
|
||||
cities = list(us_cities.values())
|
||||
print(f"Num cities: {len(cities)}")
|
||||
city_combinations = list(itertools.combinations(cities, 2))
|
||||
chunk_size = args.chunk_size
|
||||
num_chunks = len(city_combinations) // chunk_size + 1
|
||||
output_file = args.output_file
|
||||
|
||||
with open(output_file, "w", newline="") as csvfile:
|
||||
with open("city_distances_full.csv", "w", newline="") as csvfile:
|
||||
fieldnames = ["city_from", "city_to", "distance"]
|
||||
writer = csv.DictWriter(csvfile, fieldnames=fieldnames)
|
||||
writer.writeheader()
|
||||
|
||||
try:
|
||||
executor = concurrent.futures.ProcessPoolExecutor(max_workers=args.workers)
|
||||
for i in tqdm(
|
||||
range(num_chunks),
|
||||
total=num_chunks,
|
||||
desc="Processing chunks",
|
||||
ncols=100,
|
||||
bar_format="{l_bar}{bar:30}{r_bar}",
|
||||
):
|
||||
chunk = city_combinations[(i * chunk_size) : (i + 1) * chunk_size]
|
||||
executor = concurrent.futures.ProcessPoolExecutor(max_workers=8)
|
||||
# results = executor.map(calculate_distance, city_combinations)
|
||||
futures = {
|
||||
executor.submit(calculate_distance, pair): pair for pair in chunk
|
||||
executor.submit(calculate_distance, pair): pair
|
||||
for pair in city_combinations
|
||||
}
|
||||
for future in as_completed(futures):
|
||||
city_from, city_to, distance = future.result()
|
||||
@ -148,7 +114,6 @@ def main():
|
||||
"distance": distance,
|
||||
}
|
||||
)
|
||||
csvfile.flush() # write to disk immediately
|
||||
except KeyboardInterrupt:
|
||||
print("Interrupted. Terminating processes...")
|
||||
executor.shutdown(wait=False)
|
||||
|
Loading…
Reference in New Issue
Block a user