Compare commits

...

8 Commits

Author SHA1 Message Date
mm
6bb4a64624 progress bar 2023-05-05 06:21:28 +00:00
mm
4500c1b483 chunking 2023-05-05 06:17:16 +00:00
mm
c111678bb8 install efficiency 2023-05-05 05:50:51 +00:00
mm
4c36e07085 bugfix 2023-05-05 05:45:01 +00:00
mm
4e2e160072 linting 2023-05-05 05:43:15 +00:00
mm
9083e9d6e1 chunk data gen 2023-05-05 05:41:52 +00:00
mm
1e38ce04c9 linting 2023-05-05 03:41:59 +00:00
mm
96f0c3a313 add argparse 2023-05-05 03:41:51 +00:00
3 changed files with 66 additions and 27 deletions

1
.gitignore vendored
View File

@ -2,3 +2,4 @@ checkpoints*
plots* plots*
*.csv *.csv
output/ output/
.requirements_installed

View File

@ -2,7 +2,7 @@ all: install data train eval
city_distances_full.csv: check generate_data.py city_distances_full.csv: check generate_data.py
@echo "Generating distance data..." @echo "Generating distance data..."
@bash -c 'time python generate_data.py' @bash -c 'time python generate_data.py -w 8 -c US -s 10000'
data: city_distances_full.csv data: city_distances_full.csv
@ -18,7 +18,7 @@ lint:
@echo "Auto-linting files and performing final style checks..." @echo "Auto-linting files and performing final style checks..."
@isort --profile=black . @isort --profile=black .
@black . @black .
@flake8 --max-line-length=88 . @flake8 --max-line-length=88 --ignore E203 .
check: lint check: lint
@echo "Checking for unstaged or untracked changes..." @echo "Checking for unstaged or untracked changes..."
@ -34,7 +34,10 @@ compress: plots/progress_136013_sm.png
plots/progress_136013_sm.png: plots/progress_136013.png plots/progress_136013_sm.png: plots/progress_136013.png
@convert -resize 33% plots/progress_136013.png plots/progress_136013_sm.png @convert -resize 33% plots/progress_136013.png plots/progress_136013_sm.png
install: install: .requirements_installed
.requirements_installed: requirements.txt
pip install -r requirements.txt pip install -r requirements.txt
@echo "installed requirements" > .requirements_installed
.PHONY: data train eval lint check clean all .PHONY: data train eval lint check clean all

View File

@ -1,3 +1,4 @@
import argparse
import concurrent.futures import concurrent.futures
import csv import csv
import itertools import itertools
@ -7,15 +8,41 @@ from functools import lru_cache
import geonamescache import geonamescache
import numpy as np import numpy as np
from geopy.distance import geodesic from geopy.distance import geodesic
from tqdm import tqdm
MAX_DISTANCE = 20_037.5 MAX_DISTANCE = 20_037.5
# Add argparse
parser = argparse.ArgumentParser()
parser.add_argument(
"-c", "--country", help="Specify the country code", type=str, default="US"
)
parser.add_argument(
"-w", "--workers", help="Specify the number of workers", type=int, default=1
)
parser.add_argument(
"-s",
"--chunk-size",
help="Specify chunk size for batching calculations",
type=int,
default=1000,
)
parser.add_argument(
"-o",
"--output-file",
help="Specify the name of the output file (file.csv)",
type=str,
default="city_distances_full.csv",
)
args = parser.parse_args()
gc = geonamescache.GeonamesCache() gc = geonamescache.GeonamesCache()
cities = gc.get_cities() cities = gc.get_cities()
us_cities = { us_cities = {
k: c k: c
for k, c in cities.items() for k, c in cities.items()
if (c.get("countrycode") == "US") # & (c.get("population", 0) > 5e4) if (c.get("countrycode") == args.country) # & (c.get("population", 0) > 5e4)
} }
@ -38,11 +65,9 @@ def get_coordinates(city_name, country_code="US"):
or None if the city is not found. or None if the city is not found.
""" """
search_results = gc.search_cities(city_name, case_sensitive=True) search_results = gc.search_cities(city_name, case_sensitive=True)
search_results = { search_results = [
k: c d for d in search_results if (d.get("countrycode") == country_code)
for k, c in search_results.items() ]
if (c.get("countrycode") == country_code)
}
if not search_results: if not search_results:
return None return None
@ -91,18 +116,27 @@ def main():
cities = list(us_cities.values()) cities = list(us_cities.values())
print(f"Num cities: {len(cities)}") print(f"Num cities: {len(cities)}")
city_combinations = list(itertools.combinations(cities, 2)) city_combinations = list(itertools.combinations(cities, 2))
chunk_size = args.chunk_size
num_chunks = len(city_combinations) // chunk_size + 1
output_file = args.output_file
with open("city_distances_full.csv", "w", newline="") as csvfile: with open(output_file, "w", newline="") as csvfile:
fieldnames = ["city_from", "city_to", "distance"] fieldnames = ["city_from", "city_to", "distance"]
writer = csv.DictWriter(csvfile, fieldnames=fieldnames) writer = csv.DictWriter(csvfile, fieldnames=fieldnames)
writer.writeheader() writer.writeheader()
try: try:
executor = concurrent.futures.ProcessPoolExecutor(max_workers=8) executor = concurrent.futures.ProcessPoolExecutor(max_workers=args.workers)
# results = executor.map(calculate_distance, city_combinations) for i in tqdm(
range(num_chunks),
total=num_chunks,
desc="Processing chunks",
ncols=100,
bar_format="{l_bar}{bar:30}{r_bar}",
):
chunk = city_combinations[(i * chunk_size) : (i + 1) * chunk_size]
futures = { futures = {
executor.submit(calculate_distance, pair): pair executor.submit(calculate_distance, pair): pair for pair in chunk
for pair in city_combinations
} }
for future in as_completed(futures): for future in as_completed(futures):
city_from, city_to, distance = future.result() city_from, city_to, distance = future.result()
@ -114,6 +148,7 @@ def main():
"distance": distance, "distance": distance,
} }
) )
csvfile.flush() # write to disk immediately
except KeyboardInterrupt: except KeyboardInterrupt:
print("Interrupted. Terminating processes...") print("Interrupted. Terminating processes...")
executor.shutdown(wait=False) executor.shutdown(wait=False)