From 294d4bb1cd68bb6f91f7d2a7c0f4b8aca18655ad Mon Sep 17 00:00:00 2001 From: mm Date: Fri, 5 May 2023 06:42:14 +0000 Subject: [PATCH] shuffle option --- Makefile | 2 +- generate_data.py | 8 +++++++- 2 files changed, 8 insertions(+), 2 deletions(-) diff --git a/Makefile b/Makefile index e1d361f..4d410e0 100644 --- a/Makefile +++ b/Makefile @@ -2,7 +2,7 @@ all: install data train eval city_distances_full.csv: check generate_data.py @echo "Generating distance data..." - @bash -c 'time python generate_data.py --country US --workers 8 --chunk-size 8000' + @bash -c 'time python generate_data.py --country US --workers 8 --chunk-size 4200' data: city_distances_full.csv diff --git a/generate_data.py b/generate_data.py index 2fdb3da..d255a8e 100644 --- a/generate_data.py +++ b/generate_data.py @@ -34,6 +34,11 @@ parser.add_argument( type=str, default="city_distances_full.csv", ) +parser.add_argument( + "--shuffle", + action="store_true", + help="Option to shuffle combinations list before iterating over it", +) args = parser.parse_args() @@ -116,7 +121,8 @@ def main(): cities = list(us_cities.values()) print(f"Num cities: {len(cities)}") city_combinations = list(itertools.combinations(cities, 2)) - # np.random.shuffle(city_combinations) # will this help or hurt caching? 1.03it/s + if args.shuffle: + np.random.shuffle(city_combinations) chunk_size = args.chunk_size num_chunks = len(city_combinations) // chunk_size + 1 output_file = args.output_file