Compare commits

...

2 Commits

Author SHA1 Message Date
mm
294d4bb1cd shuffle option 2023-05-05 06:42:32 +00:00
mm
b8ac59d942 progress bar 2023-05-05 06:37:35 +00:00
2 changed files with 16 additions and 3 deletions

View File

@ -2,7 +2,7 @@ all: install data train eval
city_distances_full.csv: check generate_data.py
@echo "Generating distance data..."
@bash -c 'time python generate_data.py -w 8 -c US -s 10000'
@bash -c 'time python generate_data.py --country US --workers 8 --chunk-size 4200'
data: city_distances_full.csv

View File

@ -8,6 +8,7 @@ from functools import lru_cache
import geonamescache
import numpy as np
from geopy.distance import geodesic
from tqdm import tqdm
MAX_DISTANCE = 20_037.5
@ -33,6 +34,11 @@ parser.add_argument(
type=str,
default="city_distances_full.csv",
)
parser.add_argument(
"--shuffle",
action="store_true",
help="Option to shuffle combinations list before iterating over it",
)
args = parser.parse_args()
@ -115,6 +121,8 @@ def main():
cities = list(us_cities.values())
print(f"Num cities: {len(cities)}")
city_combinations = list(itertools.combinations(cities, 2))
if args.shuffle:
np.random.shuffle(city_combinations)
chunk_size = args.chunk_size
num_chunks = len(city_combinations) // chunk_size + 1
output_file = args.output_file
@ -126,8 +134,13 @@ def main():
try:
executor = concurrent.futures.ProcessPoolExecutor(max_workers=args.workers)
for i in range(num_chunks):
print(f"Processing chunk {i}...")
for i in tqdm(
range(num_chunks),
total=num_chunks,
desc="Processing chunks",
ncols=100,
bar_format="{l_bar}{bar}{r_bar}",
):
chunk = city_combinations[(i * chunk_size) : (i + 1) * chunk_size]
futures = {
executor.submit(calculate_distance, pair): pair for pair in chunk