teaching a transformer to understand how far apart (common) cities are.
You can not select more than 25 topics Topics must start with a letter or number, can include dashes ('-') and can be up to 35 characters long.

154 lines
4.4 KiB

2 years ago
import argparse
import concurrent.futures
import csv
import itertools
from concurrent.futures import as_completed
from functools import lru_cache
import geonamescache
import numpy as np
from geopy.distance import geodesic
MAX_DISTANCE = 20_037.5
2 years ago
# Add argparse
parser = argparse.ArgumentParser()
2 years ago
parser.add_argument(
"-c", "--country", help="Specify the country code", type=str, default="US"
)
parser.add_argument(
"-w", "--workers", help="Specify the number of workers", type=int, default=1
)
2 years ago
parser.add_argument(
"-s",
"--chunk-size",
help="Specify chunk size for batching calculations",
type=int,
default=1000,
)
2 years ago
parser.add_argument(
"-o",
"--output-file",
help="Specify the name of the output file (file.csv)",
type=str,
default="city_distances_full.csv",
)
2 years ago
args = parser.parse_args()
gc = geonamescache.GeonamesCache()
cities = gc.get_cities()
us_cities = {
k: c
for k, c in cities.items()
2 years ago
if (c.get("countrycode") == args.country) # & (c.get("population", 0) > 5e4)
}
@lru_cache(maxsize=None)
def get_coordinates(city_name, country_code="US"):
"""
Get the coordinates of a city.
Parameters
----------
city_name : str
The name of the city.
country_code : str, optional
The country code of the city, by default 'US'.
Returns
-------
tuple
A tuple containing the latitude and longitude of the city,
or None if the city is not found.
"""
search_results = gc.search_cities(city_name, case_sensitive=True)
2 years ago
search_results = [
d for d in search_results if (d.get("countrycode") == country_code)
]
2 years ago
if not search_results:
return None
populations = [city.get("population") for city in search_results]
city = search_results[np.argmax(populations)]
return city.get("latitude"), city.get("longitude")
def get_distance(city1, city2, country1="US", country2="US"):
"""
Get the distance between two cities in kilometers.
Parameters
----------
city1 : str
The name of the first city.
city2 : str
The name of the second city.
country1 : str, optional
The country code of the first city, by default 'US'.
country2 : str, optional
The country code of the second city, by default 'US'.
Returns
-------
float
The distance between the two cities in kilometers,
or None if one or both city names were not found.
"""
city1_coords = get_coordinates(city1, country1)
city2_coords = get_coordinates(city2, country2)
2 years ago
if (city1_coords is None) or (city2_coords is None):
return None
return geodesic(city1_coords, city2_coords).km
def calculate_distance(pair):
city1, city2 = pair
distance = get_distance(city1["name"], city2["name"])
return city1["name"], city2["name"], distance
2 years ago
def main():
cities = list(us_cities.values())
print(f"Num cities: {len(cities)}")
city_combinations = list(itertools.combinations(cities, 2))
2 years ago
chunk_size = args.chunk_size
2 years ago
num_chunks = len(city_combinations) // chunk_size + 1
output_file = args.output_file
2 years ago
with open(output_file, "w", newline="") as csvfile:
fieldnames = ["city_from", "city_to", "distance"]
writer = csv.DictWriter(csvfile, fieldnames=fieldnames)
writer.writeheader()
try:
2 years ago
executor = concurrent.futures.ProcessPoolExecutor(max_workers=args.workers)
2 years ago
for i in range(num_chunks):
2 years ago
print(f"Processing chunk {i}...")
2 years ago
chunk = city_combinations[(i * chunk_size) : (i + 1) * chunk_size]
2 years ago
futures = {
executor.submit(calculate_distance, pair): pair for pair in chunk
}
for future in as_completed(futures):
city_from, city_to, distance = future.result()
if distance is not None:
writer.writerow(
{
"city_from": city_from,
"city_to": city_to,
"distance": distance,
}
)
csvfile.flush() # write to disk immediately
except KeyboardInterrupt:
print("Interrupted. Terminating processes...")
executor.shutdown(wait=False)
raise SystemExit("Execution terminated by user.")
if __name__ == "__main__":
main()