Browse Source

fix bug in city lookups

main
mm 2 years ago
parent
commit
6c5d71e2d9
  1. 11
      Makefile
  2. 6
      README.md
  3. 153
      generate_data.py
  4. BIN
      plots/progress_136013_sm.png
  5. BIN
      progress_sample.png

11
Makefile

@ -1,8 +1,11 @@
all: install data train eval
city_distances.csv: check generate_data.py
city_distances.csv: generate_data.py
@echo "Generating distance data..."
@bash -c 'time python generate_data.py --country US --workers 8 --chunk-size 4200'
@echo "Calculating range of generated data..."
@cat city_distances.csv | tail -n +2 | sort -t',' -k3n | head -n1
@cat city_distances.csv | tail -n +2 | sort -t',' -k3nr | head -n1
data: city_distances.csv
@ -29,10 +32,10 @@ clean:
@rm -rf output/
@rm -rf checkpoints/
compress: plots/progress_136013_sm.png
compress: plots/progress_12474_sm.png
plots/progress_136013_sm.png: plots/progress_136013.png
@convert -resize 33% plots/progress_136013.png plots/progress_136013_sm.png
plots/progress_12474_sm.png: plots/progress_12474.png
@convert -resize 33% plots/progress_12474.png progress_sample.png
install: .requirements_installed

6
README.md

@ -59,9 +59,9 @@ The approach demonstrated can be extended to other metrics or features beyond ge
After training, the model should be able to understand the similarity between cities based on their geodesic distances.
You can inspect the evaluation plots generated by the `eval.py` script to see the improvement in similarity scores before and after training.
After one epoch, we can see the model has learned to correlate our desired quantities:
After even just one epoch, we can see the model has learned to correlate our desired quantities:
![Evaluation plot](./plots/progress_136013_sm.png)
![Evaluation plot](./progress_sample.png)
*The above plot is an example showing the relationship between geodesic distance and the similarity between the embedded vectors (1 = more similar), for 10,000 randomly selected pairs of US cities (re-sampled for each image).*
@ -82,7 +82,7 @@ There are several potential improvements and extensions to the current model:
# Notes
- Generating the data took about 10-15 minutes (for 3269 US cities, of which there were 2826 unique names), in parallel on 8-cores (Intel 9700K), yielding 3,991,725 (combinations of cities) with size 150MB.
- Generating the data took about 10 minutes (for 3269 US cities, of which there were 2826 unique names), in parallel on 8-cores (Intel 9700K), yielding 3,991,725 (combinations of cities) with size 150MB.
- For cities with the same name, the one with the larger population is selected (had to make some sort of choice...).
- Training on an Nvidia 3090 FE takes about an hour per epoch with an 80/20 test/train split and batch size 16. At batch size 16 times larger, each epoch took about 5-6 minutes.
- Evaluation (generating plots) on the above hardware took about 15 minutes for 20 epochs at 10k samples each.

153
generate_data.py

@ -11,44 +11,39 @@ from geopy.distance import geodesic
from tqdm import tqdm
MAX_DISTANCE = 20_037.5
CACHE = geonamescache.GeonamesCache()
# Add argparse
parser = argparse.ArgumentParser()
parser.add_argument(
"-c", "--country", help="Specify the country code", type=str, default="US"
)
parser.add_argument(
"-w", "--workers", help="Specify the number of workers", type=int, default=1
)
parser.add_argument(
"-s",
"--chunk-size",
help="Specify chunk size for batching calculations",
type=int,
default=1000,
)
parser.add_argument(
"-o",
"--output-file",
help="Specify the name of the output file (file.csv)",
type=str,
default="city_distances.csv",
)
parser.add_argument(
"--shuffle",
action="store_true",
help="Option to shuffle combinations list before iterating over it",
)
args = parser.parse_args()
gc = geonamescache.GeonamesCache()
cities = gc.get_cities()
us_cities = {
k: c
for k, c in cities.items()
if (c.get("countrycode") == args.country) # & (c.get("population", 0) > 5e4)
}
def parse_args():
parser = argparse.ArgumentParser()
parser.add_argument(
"-c", "--country", help="Specify the country code", type=str, default="US"
)
parser.add_argument(
"-w", "--workers", help="Specify the number of workers", type=int, default=1
)
parser.add_argument(
"-s",
"--chunk-size",
help="Specify chunk size for batching calculations",
type=int,
default=1000,
)
parser.add_argument(
"-o",
"--output-file",
help="Specify the name of the output file (file.csv)",
type=str,
default="distances.csv",
)
parser.add_argument(
"--shuffle",
action="store_true",
help="Option to shuffle combinations list before iterating over it",
)
args = parser.parse_args()
return args
@lru_cache(maxsize=None)
@ -69,16 +64,52 @@ def get_coordinates(city_name, country_code="US"):
A tuple containing the latitude and longitude of the city,
or None if the city is not found.
"""
search_results = gc.search_cities(city_name, case_sensitive=True)
city = find_city(city_name, country_code)
if city is None:
return None
return city.get("latitude"), city.get("longitude")
@lru_cache(maxsize=None)
def find_city(city_name, country_code="US"):
"""
Finds the matching city.
Parameters
----------
city_name : str
The name of the city.
country_code : str, optional
The country code of the city, by default 'US'.
Returns
-------
city
A dict containing the raw data about the city.
"""
search_results = CACHE.get_cities_by_name(city_name)
# search_results = [
# list(c.values())[0] for c in search_results
# ]
search_results = [inner_dict for d in search_results for inner_dict in d.values()]
if not search_results: # if not found by name, search alternatenames
search_results = CACHE.search_cities(
city_name, attribute="alternatenames", case_sensitive=True
)
# filter search results to match requested country
# and avoid wasted computation if coordinates missing
search_results = [
d for d in search_results if (d.get("countrycode") == country_code)
d
for d in search_results
if (d.get("countrycode") == country_code) & (d.get("longitude") is not None)
]
if not search_results:
return None, None
return None
populations = [city.get("population") for city in search_results]
city = search_results[np.argmax(populations)]
return city.get("latitude"), city.get("longitude")
return city
def get_distance(city1, city2, country1="US", country2="US"):
@ -117,25 +148,39 @@ def calculate_distance(pair):
return city1, city2, distance
def main():
def main(args):
output_file = args.output_file
shuffle = args.shuffle
country_code = args.country
chunk_size = args.chunk_size
max_workers = args.workers
cities = CACHE.get_cities()
us_cities = {
k: c
for k, c in cities.items()
if (c.get("countrycode") == country_code) & (c.get("longitude") is not None)
}
# & (c.get("population", 0) > 5e4)
cities = list(us_cities.values())
unique_names = set([c.get("name") for c in cities])
unique_names = sorted(list(unique_names))
# unique_cities = [c for c in cities if c.get("name") in unique_names]
print(f"Num cities: {len(cities)}, unique names: {len(unique_names)}")
city_combinations = list(itertools.combinations(unique_names, 2))
if args.shuffle:
if shuffle:
np.random.shuffle(city_combinations)
chunk_size = args.chunk_size
num_chunks = len(city_combinations) // chunk_size + 1
output_file = args.output_file
# chunk size, city_combinations, max_workers, output_file
num_chunks = len(city_combinations) // chunk_size + 1
with open(output_file, "w", newline="") as csvfile:
fieldnames = ["city_from", "city_to", "distance"]
writer = csv.DictWriter(csvfile, fieldnames=fieldnames)
writer.writeheader()
try:
executor = concurrent.futures.ProcessPoolExecutor(max_workers=args.workers)
executor = concurrent.futures.ProcessPoolExecutor(max_workers=max_workers)
for i in tqdm(
range(num_chunks),
total=num_chunks,
@ -163,6 +208,20 @@ def main():
executor.shutdown(wait=False)
raise SystemExit("Execution terminated by user.")
print(f"Wrote {output_file}")
if __name__ == "__main__":
main()
# preliminary check
assert find_city("New York City") is not None
assert find_city("NYC") is not None
assert round(get_distance("NYC", "Jamaica"), 2) == 17.11
args = parse_args()
main(args)
# perform check
print("Performing a quick validation...")
import pandas as pd
df = pd.read_csv(args.output_file)
assert df["distance"].min() > 0
assert df["distance"].max() < MAX_DISTANCE

BIN
plots/progress_136013_sm.png

Binary file not shown.

Before

Width:  |  Height:  |  Size: 230 KiB

BIN
progress_sample.png

Binary file not shown.

After

Width:  |  Height:  |  Size: 237 KiB

Loading…
Cancel
Save