You can not select more than 25 topics
Topics must start with a letter or number, can include dashes ('-') and can be up to 35 characters long.
227 lines
6.7 KiB
227 lines
6.7 KiB
import argparse
|
|
import concurrent.futures
|
|
import csv
|
|
import itertools
|
|
from concurrent.futures import as_completed
|
|
from functools import lru_cache
|
|
|
|
import geonamescache
|
|
import numpy as np
|
|
from geopy.distance import geodesic
|
|
from tqdm import tqdm
|
|
|
|
MAX_DISTANCE = 20_037.5
|
|
CACHE = geonamescache.GeonamesCache()
|
|
|
|
|
|
# Add argparse
|
|
def parse_args():
|
|
parser = argparse.ArgumentParser()
|
|
parser.add_argument(
|
|
"-c", "--country", help="Specify the country code", type=str, default="US"
|
|
)
|
|
parser.add_argument(
|
|
"-w", "--workers", help="Specify the number of workers", type=int, default=1
|
|
)
|
|
parser.add_argument(
|
|
"-s",
|
|
"--chunk-size",
|
|
help="Specify chunk size for batching calculations",
|
|
type=int,
|
|
default=1000,
|
|
)
|
|
parser.add_argument(
|
|
"-o",
|
|
"--output-file",
|
|
help="Specify the name of the output file (file.csv)",
|
|
type=str,
|
|
default="distances.csv",
|
|
)
|
|
parser.add_argument(
|
|
"--shuffle",
|
|
action="store_true",
|
|
help="Option to shuffle combinations list before iterating over it",
|
|
)
|
|
args = parser.parse_args()
|
|
return args
|
|
|
|
|
|
@lru_cache(maxsize=None)
|
|
def get_coordinates(city_name, country_code="US"):
|
|
"""
|
|
Get the coordinates of a city.
|
|
|
|
Parameters
|
|
----------
|
|
city_name : str
|
|
The name of the city.
|
|
country_code : str, optional
|
|
The country code of the city, by default 'US'.
|
|
|
|
Returns
|
|
-------
|
|
tuple
|
|
A tuple containing the latitude and longitude of the city,
|
|
or None if the city is not found.
|
|
"""
|
|
city = find_city(city_name, country_code)
|
|
if city is None:
|
|
return None
|
|
return city.get("latitude"), city.get("longitude")
|
|
|
|
|
|
@lru_cache(maxsize=None)
|
|
def find_city(city_name, country_code="US"):
|
|
"""
|
|
Finds the matching city.
|
|
|
|
Parameters
|
|
----------
|
|
city_name : str
|
|
The name of the city.
|
|
country_code : str, optional
|
|
The country code of the city, by default 'US'.
|
|
|
|
Returns
|
|
-------
|
|
city
|
|
A dict containing the raw data about the city.
|
|
|
|
"""
|
|
search_results = CACHE.get_cities_by_name(city_name)
|
|
# search_results = [
|
|
# list(c.values())[0] for c in search_results
|
|
# ]
|
|
search_results = [inner_dict for d in search_results for inner_dict in d.values()]
|
|
if not search_results: # if not found by name, search alternatenames
|
|
search_results = CACHE.search_cities(
|
|
city_name, attribute="alternatenames", case_sensitive=True
|
|
)
|
|
# filter search results to match requested country
|
|
# and avoid wasted computation if coordinates missing
|
|
search_results = [
|
|
d
|
|
for d in search_results
|
|
if (d.get("countrycode") == country_code) & (d.get("longitude") is not None)
|
|
]
|
|
|
|
if not search_results:
|
|
return None
|
|
populations = [city.get("population") for city in search_results]
|
|
city = search_results[np.argmax(populations)]
|
|
return city
|
|
|
|
|
|
def get_distance(city1, city2, country1="US", country2="US"):
|
|
"""
|
|
Get the distance between two cities in kilometers.
|
|
|
|
Parameters
|
|
----------
|
|
city1 : str
|
|
The name of the first city.
|
|
city2 : str
|
|
The name of the second city.
|
|
country1 : str, optional
|
|
The country code of the first city, by default 'US'.
|
|
country2 : str, optional
|
|
The country code of the second city, by default 'US'.
|
|
|
|
Returns
|
|
-------
|
|
float
|
|
The distance between the two cities in kilometers,
|
|
or None if one or both city names were not found.
|
|
"""
|
|
city1_coords = get_coordinates(city1, country1)
|
|
city2_coords = get_coordinates(city2, country2)
|
|
|
|
if (city1_coords is None) or (city2_coords is None):
|
|
return None
|
|
|
|
return geodesic(city1_coords, city2_coords).km
|
|
|
|
|
|
def calculate_distance(pair):
|
|
city1, city2 = pair
|
|
distance = get_distance(city1, city2)
|
|
return city1, city2, distance
|
|
|
|
|
|
def main(args):
|
|
output_file = args.output_file
|
|
shuffle = args.shuffle
|
|
country_code = args.country
|
|
chunk_size = args.chunk_size
|
|
max_workers = args.workers
|
|
|
|
cities = CACHE.get_cities()
|
|
us_cities = {
|
|
k: c
|
|
for k, c in cities.items()
|
|
if (c.get("countrycode") == country_code) & (c.get("longitude") is not None)
|
|
}
|
|
# & (c.get("population", 0) > 5e4)
|
|
|
|
cities = list(us_cities.values())
|
|
unique_names = set([c.get("name") for c in cities])
|
|
unique_names = sorted(list(unique_names))
|
|
# unique_cities = [c for c in cities if c.get("name") in unique_names]
|
|
print(f"Num cities: {len(cities)}, unique names: {len(unique_names)}")
|
|
city_combinations = list(itertools.combinations(unique_names, 2))
|
|
if shuffle:
|
|
np.random.shuffle(city_combinations)
|
|
|
|
# chunk size, city_combinations, max_workers, output_file
|
|
num_chunks = len(city_combinations) // chunk_size + 1
|
|
with open(output_file, "w", newline="") as csvfile:
|
|
fieldnames = ["city_from", "city_to", "distance"]
|
|
writer = csv.DictWriter(csvfile, fieldnames=fieldnames)
|
|
writer.writeheader()
|
|
|
|
try:
|
|
executor = concurrent.futures.ProcessPoolExecutor(max_workers=max_workers)
|
|
for i in tqdm(
|
|
range(num_chunks),
|
|
total=num_chunks,
|
|
desc="Processing chunks",
|
|
ncols=100,
|
|
bar_format="{l_bar}{bar}{r_bar}",
|
|
):
|
|
chunk = city_combinations[(i * chunk_size) : (i + 1) * chunk_size]
|
|
futures = {
|
|
executor.submit(calculate_distance, pair): pair for pair in chunk
|
|
}
|
|
for future in as_completed(futures):
|
|
city_from, city_to, distance = future.result()
|
|
if distance is not None:
|
|
writer.writerow(
|
|
{
|
|
"city_from": city_from,
|
|
"city_to": city_to,
|
|
"distance": distance,
|
|
}
|
|
)
|
|
csvfile.flush() # write to disk immediately
|
|
except KeyboardInterrupt:
|
|
print("Interrupted. Terminating processes...")
|
|
executor.shutdown(wait=False)
|
|
raise SystemExit("Execution terminated by user.")
|
|
|
|
print(f"Wrote {output_file}")
|
|
|
|
|
|
if __name__ == "__main__":
|
|
# preliminary check
|
|
assert find_city("New York City") is not None
|
|
assert find_city("NYC") is not None
|
|
assert round(get_distance("NYC", "Jamaica"), 2) == 17.11
|
|
args = parse_args()
|
|
main(args)
|
|
# perform check
|
|
print("Performing a quick validation...")
|
|
import pandas as pd
|
|
|
|
df = pd.read_csv(args.output_file)
|
|
assert df["distance"].min() > 0
|
|
assert df["distance"].max() < MAX_DISTANCE
|
|
|