Compare commits
	
		
			8 Commits
		
	
	
		
			57ef4c06df
			...
			6bb4a64624
		
	
	| Author | SHA1 | Date | |
|---|---|---|---|
| 6bb4a64624 | |||
| 4500c1b483 | |||
| c111678bb8 | |||
| 4c36e07085 | |||
| 4e2e160072 | |||
| 9083e9d6e1 | |||
| 1e38ce04c9 | |||
| 96f0c3a313 | 
							
								
								
									
										1
									
								
								.gitignore
									
									
									
									
										vendored
									
									
								
							
							
						
						
									
										1
									
								
								.gitignore
									
									
									
									
										vendored
									
									
								
							| @ -2,3 +2,4 @@ checkpoints* | ||||
| plots* | ||||
| *.csv | ||||
| output/ | ||||
| .requirements_installed | ||||
|  | ||||
							
								
								
									
										9
									
								
								Makefile
									
									
									
									
									
								
							
							
						
						
									
										9
									
								
								Makefile
									
									
									
									
									
								
							| @ -2,7 +2,7 @@ all: install data train eval | ||||
| 
 | ||||
| city_distances_full.csv: check generate_data.py | ||||
| 	@echo "Generating distance data..." | ||||
| 	@bash -c 'time python generate_data.py' | ||||
| 	@bash -c 'time python generate_data.py -w 8 -c US -s 10000' | ||||
| 
 | ||||
| data: city_distances_full.csv | ||||
| 
 | ||||
| @ -18,7 +18,7 @@ lint: | ||||
| 	@echo "Auto-linting files and performing final style checks..." | ||||
| 	@isort --profile=black . | ||||
| 	@black . | ||||
| 	@flake8 --max-line-length=88 . | ||||
| 	@flake8 --max-line-length=88 --ignore E203 . | ||||
| 
 | ||||
| check: lint | ||||
| 	@echo "Checking for unstaged or untracked changes..." | ||||
| @ -34,7 +34,10 @@ compress: plots/progress_136013_sm.png | ||||
| plots/progress_136013_sm.png: plots/progress_136013.png | ||||
| 	@convert -resize 33% plots/progress_136013.png plots/progress_136013_sm.png | ||||
| 
 | ||||
| install: | ||||
| install: .requirements_installed | ||||
| 
 | ||||
| .requirements_installed: requirements.txt | ||||
| 	pip install -r requirements.txt | ||||
| 	@echo "installed requirements" > .requirements_installed | ||||
| 
 | ||||
| .PHONY: data train eval lint check clean all | ||||
| @ -1,3 +1,4 @@ | ||||
| import argparse | ||||
| import concurrent.futures | ||||
| import csv | ||||
| import itertools | ||||
| @ -7,15 +8,41 @@ from functools import lru_cache | ||||
| import geonamescache | ||||
| import numpy as np | ||||
| from geopy.distance import geodesic | ||||
| from tqdm import tqdm | ||||
| 
 | ||||
| MAX_DISTANCE = 20_037.5 | ||||
| 
 | ||||
| # Add argparse | ||||
| parser = argparse.ArgumentParser() | ||||
| parser.add_argument( | ||||
|     "-c", "--country", help="Specify the country code", type=str, default="US" | ||||
| ) | ||||
| parser.add_argument( | ||||
|     "-w", "--workers", help="Specify the number of workers", type=int, default=1 | ||||
| ) | ||||
| parser.add_argument( | ||||
|     "-s", | ||||
|     "--chunk-size", | ||||
|     help="Specify chunk size for batching calculations", | ||||
|     type=int, | ||||
|     default=1000, | ||||
| ) | ||||
| parser.add_argument( | ||||
|     "-o", | ||||
|     "--output-file", | ||||
|     help="Specify the name of the output file (file.csv)", | ||||
|     type=str, | ||||
|     default="city_distances_full.csv", | ||||
| ) | ||||
| args = parser.parse_args() | ||||
| 
 | ||||
| 
 | ||||
| gc = geonamescache.GeonamesCache() | ||||
| cities = gc.get_cities() | ||||
| us_cities = { | ||||
|     k: c | ||||
|     for k, c in cities.items() | ||||
|     if (c.get("countrycode") == "US")  # & (c.get("population", 0) > 5e4) | ||||
|     if (c.get("countrycode") == args.country)  # & (c.get("population", 0) > 5e4) | ||||
| } | ||||
| 
 | ||||
| 
 | ||||
| @ -38,11 +65,9 @@ def get_coordinates(city_name, country_code="US"): | ||||
|         or None if the city is not found. | ||||
|     """ | ||||
|     search_results = gc.search_cities(city_name, case_sensitive=True) | ||||
|     search_results = { | ||||
|         k: c | ||||
|         for k, c in search_results.items() | ||||
|         if (c.get("countrycode") == country_code) | ||||
|     } | ||||
|     search_results = [ | ||||
|         d for d in search_results if (d.get("countrycode") == country_code) | ||||
|     ] | ||||
| 
 | ||||
|     if not search_results: | ||||
|         return None | ||||
| @ -91,29 +116,39 @@ def main(): | ||||
|     cities = list(us_cities.values()) | ||||
|     print(f"Num cities: {len(cities)}") | ||||
|     city_combinations = list(itertools.combinations(cities, 2)) | ||||
|     chunk_size = args.chunk_size | ||||
|     num_chunks = len(city_combinations) // chunk_size + 1 | ||||
|     output_file = args.output_file | ||||
| 
 | ||||
|     with open("city_distances_full.csv", "w", newline="") as csvfile: | ||||
|     with open(output_file, "w", newline="") as csvfile: | ||||
|         fieldnames = ["city_from", "city_to", "distance"] | ||||
|         writer = csv.DictWriter(csvfile, fieldnames=fieldnames) | ||||
|         writer.writeheader() | ||||
| 
 | ||||
|         try: | ||||
|             executor = concurrent.futures.ProcessPoolExecutor(max_workers=8) | ||||
|             # results = executor.map(calculate_distance, city_combinations) | ||||
|             futures = { | ||||
|                 executor.submit(calculate_distance, pair): pair | ||||
|                 for pair in city_combinations | ||||
|             } | ||||
|             for future in as_completed(futures): | ||||
|                 city_from, city_to, distance = future.result() | ||||
|                 if distance is not None: | ||||
|                     writer.writerow( | ||||
|                         { | ||||
|                             "city_from": city_from, | ||||
|                             "city_to": city_to, | ||||
|                             "distance": distance, | ||||
|                         } | ||||
|                     ) | ||||
|             executor = concurrent.futures.ProcessPoolExecutor(max_workers=args.workers) | ||||
|             for i in tqdm( | ||||
|                 range(num_chunks), | ||||
|                 total=num_chunks, | ||||
|                 desc="Processing chunks", | ||||
|                 ncols=100, | ||||
|                 bar_format="{l_bar}{bar:30}{r_bar}", | ||||
|             ): | ||||
|                 chunk = city_combinations[(i * chunk_size) : (i + 1) * chunk_size] | ||||
|                 futures = { | ||||
|                     executor.submit(calculate_distance, pair): pair for pair in chunk | ||||
|                 } | ||||
|                 for future in as_completed(futures): | ||||
|                     city_from, city_to, distance = future.result() | ||||
|                     if distance is not None: | ||||
|                         writer.writerow( | ||||
|                             { | ||||
|                                 "city_from": city_from, | ||||
|                                 "city_to": city_to, | ||||
|                                 "distance": distance, | ||||
|                             } | ||||
|                         ) | ||||
|                         csvfile.flush()  # write to disk immediately | ||||
|         except KeyboardInterrupt: | ||||
|             print("Interrupted. Terminating processes...") | ||||
|             executor.shutdown(wait=False) | ||||
|  | ||||
		Loading…
	
		Reference in New Issue
	
	Block a user