diff --git a/Makefile b/Makefile index e1d361f..4d410e0 100644 --- a/Makefile +++ b/Makefile @@ -2,7 +2,7 @@ all: install data train eval city_distances_full.csv: check generate_data.py @echo "Generating distance data..." - @bash -c 'time python generate_data.py --country US --workers 8 --chunk-size 8000' + @bash -c 'time python generate_data.py --country US --workers 8 --chunk-size 4200' data: city_distances_full.csv diff --git a/generate_data.py b/generate_data.py index 2fdb3da..d255a8e 100644 --- a/generate_data.py +++ b/generate_data.py @@ -34,6 +34,11 @@ parser.add_argument( type=str, default="city_distances_full.csv", ) +parser.add_argument( + "--shuffle", + action="store_true", + help="Option to shuffle combinations list before iterating over it", +) args = parser.parse_args() @@ -116,7 +121,8 @@ def main(): cities = list(us_cities.values()) print(f"Num cities: {len(cities)}") city_combinations = list(itertools.combinations(cities, 2)) - # np.random.shuffle(city_combinations) # will this help or hurt caching? 1.03it/s + if args.shuffle: + np.random.shuffle(city_combinations) chunk_size = args.chunk_size num_chunks = len(city_combinations) // chunk_size + 1 output_file = args.output_file