Compare commits
	
		
			8 Commits
		
	
	
		
			57ef4c06df
			...
			6bb4a64624
		
	
	| Author | SHA1 | Date | |
|---|---|---|---|
| 6bb4a64624 | |||
| 4500c1b483 | |||
| c111678bb8 | |||
| 4c36e07085 | |||
| 4e2e160072 | |||
| 9083e9d6e1 | |||
| 1e38ce04c9 | |||
| 96f0c3a313 | 
							
								
								
									
										3
									
								
								.gitignore
									
									
									
									
										vendored
									
									
								
							
							
						
						
									
										3
									
								
								.gitignore
									
									
									
									
										vendored
									
									
								
							| @ -1,4 +1,5 @@ | |||||||
| checkpoints* | checkpoints* | ||||||
| plots* | plots* | ||||||
| *.csv | *.csv | ||||||
| output/ | output/ | ||||||
|  | .requirements_installed | ||||||
|  | |||||||
							
								
								
									
										9
									
								
								Makefile
									
									
									
									
									
								
							
							
						
						
									
										9
									
								
								Makefile
									
									
									
									
									
								
							| @ -2,7 +2,7 @@ all: install data train eval | |||||||
| 
 | 
 | ||||||
| city_distances_full.csv: check generate_data.py | city_distances_full.csv: check generate_data.py | ||||||
| 	@echo "Generating distance data..." | 	@echo "Generating distance data..." | ||||||
| 	@bash -c 'time python generate_data.py' | 	@bash -c 'time python generate_data.py -w 8 -c US -s 10000' | ||||||
| 
 | 
 | ||||||
| data: city_distances_full.csv | data: city_distances_full.csv | ||||||
| 
 | 
 | ||||||
| @ -18,7 +18,7 @@ lint: | |||||||
| 	@echo "Auto-linting files and performing final style checks..." | 	@echo "Auto-linting files and performing final style checks..." | ||||||
| 	@isort --profile=black . | 	@isort --profile=black . | ||||||
| 	@black . | 	@black . | ||||||
| 	@flake8 --max-line-length=88 . | 	@flake8 --max-line-length=88 --ignore E203 . | ||||||
| 
 | 
 | ||||||
| check: lint | check: lint | ||||||
| 	@echo "Checking for unstaged or untracked changes..." | 	@echo "Checking for unstaged or untracked changes..." | ||||||
| @ -34,7 +34,10 @@ compress: plots/progress_136013_sm.png | |||||||
| plots/progress_136013_sm.png: plots/progress_136013.png | plots/progress_136013_sm.png: plots/progress_136013.png | ||||||
| 	@convert -resize 33% plots/progress_136013.png plots/progress_136013_sm.png | 	@convert -resize 33% plots/progress_136013.png plots/progress_136013_sm.png | ||||||
| 
 | 
 | ||||||
| install: | install: .requirements_installed | ||||||
|  | 
 | ||||||
|  | .requirements_installed: requirements.txt | ||||||
| 	pip install -r requirements.txt | 	pip install -r requirements.txt | ||||||
|  | 	@echo "installed requirements" > .requirements_installed | ||||||
| 
 | 
 | ||||||
| .PHONY: data train eval lint check clean all | .PHONY: data train eval lint check clean all | ||||||
| @ -1,3 +1,4 @@ | |||||||
|  | import argparse | ||||||
| import concurrent.futures | import concurrent.futures | ||||||
| import csv | import csv | ||||||
| import itertools | import itertools | ||||||
| @ -7,15 +8,41 @@ from functools import lru_cache | |||||||
| import geonamescache | import geonamescache | ||||||
| import numpy as np | import numpy as np | ||||||
| from geopy.distance import geodesic | from geopy.distance import geodesic | ||||||
|  | from tqdm import tqdm | ||||||
| 
 | 
 | ||||||
| MAX_DISTANCE = 20_037.5 | MAX_DISTANCE = 20_037.5 | ||||||
| 
 | 
 | ||||||
|  | # Add argparse | ||||||
|  | parser = argparse.ArgumentParser() | ||||||
|  | parser.add_argument( | ||||||
|  |     "-c", "--country", help="Specify the country code", type=str, default="US" | ||||||
|  | ) | ||||||
|  | parser.add_argument( | ||||||
|  |     "-w", "--workers", help="Specify the number of workers", type=int, default=1 | ||||||
|  | ) | ||||||
|  | parser.add_argument( | ||||||
|  |     "-s", | ||||||
|  |     "--chunk-size", | ||||||
|  |     help="Specify chunk size for batching calculations", | ||||||
|  |     type=int, | ||||||
|  |     default=1000, | ||||||
|  | ) | ||||||
|  | parser.add_argument( | ||||||
|  |     "-o", | ||||||
|  |     "--output-file", | ||||||
|  |     help="Specify the name of the output file (file.csv)", | ||||||
|  |     type=str, | ||||||
|  |     default="city_distances_full.csv", | ||||||
|  | ) | ||||||
|  | args = parser.parse_args() | ||||||
|  | 
 | ||||||
|  | 
 | ||||||
| gc = geonamescache.GeonamesCache() | gc = geonamescache.GeonamesCache() | ||||||
| cities = gc.get_cities() | cities = gc.get_cities() | ||||||
| us_cities = { | us_cities = { | ||||||
|     k: c |     k: c | ||||||
|     for k, c in cities.items() |     for k, c in cities.items() | ||||||
|     if (c.get("countrycode") == "US")  # & (c.get("population", 0) > 5e4) |     if (c.get("countrycode") == args.country)  # & (c.get("population", 0) > 5e4) | ||||||
| } | } | ||||||
| 
 | 
 | ||||||
| 
 | 
 | ||||||
| @ -38,11 +65,9 @@ def get_coordinates(city_name, country_code="US"): | |||||||
|         or None if the city is not found. |         or None if the city is not found. | ||||||
|     """ |     """ | ||||||
|     search_results = gc.search_cities(city_name, case_sensitive=True) |     search_results = gc.search_cities(city_name, case_sensitive=True) | ||||||
|     search_results = { |     search_results = [ | ||||||
|         k: c |         d for d in search_results if (d.get("countrycode") == country_code) | ||||||
|         for k, c in search_results.items() |     ] | ||||||
|         if (c.get("countrycode") == country_code) |  | ||||||
|     } |  | ||||||
| 
 | 
 | ||||||
|     if not search_results: |     if not search_results: | ||||||
|         return None |         return None | ||||||
| @ -91,29 +116,39 @@ def main(): | |||||||
|     cities = list(us_cities.values()) |     cities = list(us_cities.values()) | ||||||
|     print(f"Num cities: {len(cities)}") |     print(f"Num cities: {len(cities)}") | ||||||
|     city_combinations = list(itertools.combinations(cities, 2)) |     city_combinations = list(itertools.combinations(cities, 2)) | ||||||
|  |     chunk_size = args.chunk_size | ||||||
|  |     num_chunks = len(city_combinations) // chunk_size + 1 | ||||||
|  |     output_file = args.output_file | ||||||
| 
 | 
 | ||||||
|     with open("city_distances_full.csv", "w", newline="") as csvfile: |     with open(output_file, "w", newline="") as csvfile: | ||||||
|         fieldnames = ["city_from", "city_to", "distance"] |         fieldnames = ["city_from", "city_to", "distance"] | ||||||
|         writer = csv.DictWriter(csvfile, fieldnames=fieldnames) |         writer = csv.DictWriter(csvfile, fieldnames=fieldnames) | ||||||
|         writer.writeheader() |         writer.writeheader() | ||||||
| 
 | 
 | ||||||
|         try: |         try: | ||||||
|             executor = concurrent.futures.ProcessPoolExecutor(max_workers=8) |             executor = concurrent.futures.ProcessPoolExecutor(max_workers=args.workers) | ||||||
|             # results = executor.map(calculate_distance, city_combinations) |             for i in tqdm( | ||||||
|             futures = { |                 range(num_chunks), | ||||||
|                 executor.submit(calculate_distance, pair): pair |                 total=num_chunks, | ||||||
|                 for pair in city_combinations |                 desc="Processing chunks", | ||||||
|             } |                 ncols=100, | ||||||
|             for future in as_completed(futures): |                 bar_format="{l_bar}{bar:30}{r_bar}", | ||||||
|                 city_from, city_to, distance = future.result() |             ): | ||||||
|                 if distance is not None: |                 chunk = city_combinations[(i * chunk_size) : (i + 1) * chunk_size] | ||||||
|                     writer.writerow( |                 futures = { | ||||||
|                         { |                     executor.submit(calculate_distance, pair): pair for pair in chunk | ||||||
|                             "city_from": city_from, |                 } | ||||||
|                             "city_to": city_to, |                 for future in as_completed(futures): | ||||||
|                             "distance": distance, |                     city_from, city_to, distance = future.result() | ||||||
|                         } |                     if distance is not None: | ||||||
|                     ) |                         writer.writerow( | ||||||
|  |                             { | ||||||
|  |                                 "city_from": city_from, | ||||||
|  |                                 "city_to": city_to, | ||||||
|  |                                 "distance": distance, | ||||||
|  |                             } | ||||||
|  |                         ) | ||||||
|  |                         csvfile.flush()  # write to disk immediately | ||||||
|         except KeyboardInterrupt: |         except KeyboardInterrupt: | ||||||
|             print("Interrupted. Terminating processes...") |             print("Interrupted. Terminating processes...") | ||||||
|             executor.shutdown(wait=False) |             executor.shutdown(wait=False) | ||||||
|  | |||||||
		Loading…
	
		Reference in New Issue
	
	Block a user