diff --git a/Makefile b/Makefile index 1b2d059..cb5d8a0 100644 --- a/Makefile +++ b/Makefile @@ -2,7 +2,7 @@ all: install data train eval city_distances_full.csv: check generate_data.py @echo "Generating distance data..." - @bash -c 'time python generate_data.py -w 8 -c US' + @bash -c 'time python generate_data.py -w 8 -c US -s 10000' data: city_distances_full.csv diff --git a/generate_data.py b/generate_data.py index c0d7899..f580a6e 100644 --- a/generate_data.py +++ b/generate_data.py @@ -19,6 +19,13 @@ parser.add_argument( parser.add_argument( "-w", "--workers", help="Specify the number of workers", type=int, default=1 ) +parser.add_argument( + "-s", + "--chunk-size", + help="Specify chunk size for batching calculations", + type=int, + default=1000, +) parser.add_argument( "-o", "--output-file", @@ -108,7 +115,7 @@ def main(): cities = list(us_cities.values()) print(f"Num cities: {len(cities)}") city_combinations = list(itertools.combinations(cities, 2)) - chunk_size = 800 # adjust this as needed + chunk_size = args.chunk_size num_chunks = len(city_combinations) // chunk_size + 1 output_file = args.output_file @@ -120,6 +127,7 @@ def main(): try: executor = concurrent.futures.ProcessPoolExecutor(max_workers=args.workers) for i in range(num_chunks): + print(f"Processing chunk {i}...") chunk = city_combinations[(i * chunk_size) : (i + 1) * chunk_size] futures = { executor.submit(calculate_distance, pair): pair for pair in chunk