Compare commits
2 Commits
5c76ddc778
...
e03740732b
Author | SHA1 | Date | |
---|---|---|---|
e03740732b | |||
ab26735c82 |
11
Makefile
11
Makefile
@ -1,8 +1,11 @@
|
|||||||
all: install data train eval
|
all: install data train eval
|
||||||
|
|
||||||
city_distances.csv: check generate_data.py
|
city_distances.csv: generate_data.py
|
||||||
@echo "Generating distance data..."
|
@echo "Generating distance data..."
|
||||||
@bash -c 'time python generate_data.py --country US --workers 8 --chunk-size 4200'
|
@bash -c 'time python generate_data.py --country US --workers 8 --chunk-size 4200'
|
||||||
|
@echo "Calculating range of generated data..."
|
||||||
|
@cat city_distances.csv | tail -n +2 | sort -t',' -k3n | head -n1
|
||||||
|
@cat city_distances.csv | tail -n +2 | sort -t',' -k3nr | head -n1
|
||||||
|
|
||||||
data: city_distances.csv
|
data: city_distances.csv
|
||||||
|
|
||||||
@ -29,10 +32,10 @@ clean:
|
|||||||
@rm -rf output/
|
@rm -rf output/
|
||||||
@rm -rf checkpoints/
|
@rm -rf checkpoints/
|
||||||
|
|
||||||
compress: plots/progress_136013_sm.png
|
compress: plots/progress_12474_sm.png
|
||||||
|
|
||||||
plots/progress_136013_sm.png: plots/progress_136013.png
|
plots/progress_12474_sm.png: plots/progress_12474.png
|
||||||
@convert -resize 33% plots/progress_136013.png plots/progress_136013_sm.png
|
@convert -resize 33% plots/progress_12474.png plots/progress_12474_sm.png
|
||||||
|
|
||||||
install: .requirements_installed
|
install: .requirements_installed
|
||||||
|
|
||||||
|
@ -82,7 +82,7 @@ There are several potential improvements and extensions to the current model:
|
|||||||
|
|
||||||
|
|
||||||
# Notes
|
# Notes
|
||||||
- Generating the data took about 10-15 minutes (for 3269 US cities, of which there were 2826 unique names), in parallel on 8-cores (Intel 9700K), yielding 3,991,725 (combinations of cities) with size 150MB.
|
- Generating the data took about 10 minutes (for 3269 US cities, of which there were 2826 unique names), in parallel on 8-cores (Intel 9700K), yielding 3,991,725 (combinations of cities) with size 150MB.
|
||||||
- For cities with the same name, the one with the larger population is selected (had to make some sort of choice...).
|
- For cities with the same name, the one with the larger population is selected (had to make some sort of choice...).
|
||||||
- Training on an Nvidia 3090 FE takes about an hour per epoch with an 80/20 test/train split and batch size 16. At batch size 16 times larger, each epoch took about 5-6 minutes.
|
- Training on an Nvidia 3090 FE takes about an hour per epoch with an 80/20 test/train split and batch size 16. At batch size 16 times larger, each epoch took about 5-6 minutes.
|
||||||
- Evaluation (generating plots) on the above hardware took about 15 minutes for 20 epochs at 10k samples each.
|
- Evaluation (generating plots) on the above hardware took about 15 minutes for 20 epochs at 10k samples each.
|
||||||
|
2
eval.py
2
eval.py
@ -68,7 +68,7 @@ if __name__ == "__main__":
|
|||||||
|
|
||||||
data = pd.read_csv("city_distances.csv")
|
data = pd.read_csv("city_distances.csv")
|
||||||
# data_sample = data.sample(1_000)
|
# data_sample = data.sample(1_000)
|
||||||
checkpoint_dir = "checkpoints_absmax_split" # no slash
|
checkpoint_dir = "checkpoints" # no slash
|
||||||
for checkpoint in sorted(glob.glob(f"{checkpoint_dir}/*")):
|
for checkpoint in sorted(glob.glob(f"{checkpoint_dir}/*")):
|
||||||
print(f"Evaluating {checkpoint}")
|
print(f"Evaluating {checkpoint}")
|
||||||
data_sample = data.sample(1_000)
|
data_sample = data.sample(1_000)
|
||||||
|
@ -11,8 +11,11 @@ from geopy.distance import geodesic
|
|||||||
from tqdm import tqdm
|
from tqdm import tqdm
|
||||||
|
|
||||||
MAX_DISTANCE = 20_037.5
|
MAX_DISTANCE = 20_037.5
|
||||||
|
CACHE = geonamescache.GeonamesCache()
|
||||||
|
|
||||||
|
|
||||||
# Add argparse
|
# Add argparse
|
||||||
|
def parse_args():
|
||||||
parser = argparse.ArgumentParser()
|
parser = argparse.ArgumentParser()
|
||||||
parser.add_argument(
|
parser.add_argument(
|
||||||
"-c", "--country", help="Specify the country code", type=str, default="US"
|
"-c", "--country", help="Specify the country code", type=str, default="US"
|
||||||
@ -40,15 +43,7 @@ parser.add_argument(
|
|||||||
help="Option to shuffle combinations list before iterating over it",
|
help="Option to shuffle combinations list before iterating over it",
|
||||||
)
|
)
|
||||||
args = parser.parse_args()
|
args = parser.parse_args()
|
||||||
|
return args
|
||||||
|
|
||||||
gc = geonamescache.GeonamesCache()
|
|
||||||
cities = gc.get_cities()
|
|
||||||
us_cities = {
|
|
||||||
k: c
|
|
||||||
for k, c in cities.items()
|
|
||||||
if (c.get("countrycode") == args.country) # & (c.get("population", 0) > 5e4)
|
|
||||||
}
|
|
||||||
|
|
||||||
|
|
||||||
@lru_cache(maxsize=None)
|
@lru_cache(maxsize=None)
|
||||||
@ -69,16 +64,57 @@ def get_coordinates(city_name, country_code="US"):
|
|||||||
A tuple containing the latitude and longitude of the city,
|
A tuple containing the latitude and longitude of the city,
|
||||||
or None if the city is not found.
|
or None if the city is not found.
|
||||||
"""
|
"""
|
||||||
search_results = gc.search_cities(city_name, case_sensitive=True)
|
city = find_city(city_name, country_code)
|
||||||
|
if city is None:
|
||||||
|
return None
|
||||||
|
return city.get("latitude"), city.get("longitude")
|
||||||
|
|
||||||
|
|
||||||
|
@lru_cache(maxsize=None)
|
||||||
|
def find_city(city_name, country_code="US"):
|
||||||
|
"""
|
||||||
|
Finds the matching city.
|
||||||
|
|
||||||
|
Parameters
|
||||||
|
----------
|
||||||
|
city_name : str
|
||||||
|
The name of the city.
|
||||||
|
country_code : str, optional
|
||||||
|
The country code of the city, by default 'US'.
|
||||||
|
|
||||||
|
Returns
|
||||||
|
-------
|
||||||
|
city
|
||||||
|
A dict containing the raw data about the city.
|
||||||
|
|
||||||
|
"""
|
||||||
|
search_results = CACHE.get_cities_by_name(city_name)
|
||||||
|
# search_results = [
|
||||||
|
# list(c.values())[0] for c in search_results
|
||||||
|
# ]
|
||||||
search_results = [
|
search_results = [
|
||||||
d for d in search_results if (d.get("countrycode") == country_code)
|
{k: v}
|
||||||
|
for d in search_results
|
||||||
|
for inner_dict in d.values()
|
||||||
|
for k, v in inner_dict.items()
|
||||||
|
]
|
||||||
|
if not search_results: # if not found by name, search alternatenames
|
||||||
|
search_results = CACHE.search_cities(
|
||||||
|
city_name, attribute="alternatenames", case_sensitive=True
|
||||||
|
)
|
||||||
|
# filter search results to match requested country
|
||||||
|
# and avoid wasted computation if coordinates missing
|
||||||
|
search_results = [
|
||||||
|
d
|
||||||
|
for d in search_results
|
||||||
|
if (d.get("countrycode") == country_code) & (d.get("longitude") is not None)
|
||||||
]
|
]
|
||||||
|
|
||||||
if not search_results:
|
if not search_results:
|
||||||
return None, None
|
return None
|
||||||
populations = [city.get("population") for city in search_results]
|
populations = [city.get("population") for city in search_results]
|
||||||
city = search_results[np.argmax(populations)]
|
city = search_results[np.argmax(populations)]
|
||||||
return city.get("latitude"), city.get("longitude")
|
return city
|
||||||
|
|
||||||
|
|
||||||
def get_distance(city1, city2, country1="US", country2="US"):
|
def get_distance(city1, city2, country1="US", country2="US"):
|
||||||
@ -117,25 +153,38 @@ def calculate_distance(pair):
|
|||||||
return city1, city2, distance
|
return city1, city2, distance
|
||||||
|
|
||||||
|
|
||||||
def main():
|
def main(args):
|
||||||
|
output_file = args.output_file
|
||||||
|
shuffle = args.shuffle
|
||||||
|
country_code = args.country
|
||||||
|
chunk_size = args.chunk_size
|
||||||
|
max_workers = args.workers
|
||||||
|
|
||||||
|
cities = CACHE.get_cities()
|
||||||
|
us_cities = {
|
||||||
|
k: c
|
||||||
|
for k, c in cities.items()
|
||||||
|
if (c.get("countrycode") == country_code) & (c.get("longitude") is not None)
|
||||||
|
}
|
||||||
|
# & (c.get("population", 0) > 5e4)
|
||||||
|
|
||||||
cities = list(us_cities.values())
|
cities = list(us_cities.values())
|
||||||
unique_names = set([c.get("name") for c in cities])
|
unique_names = set([c.get("name") for c in cities])
|
||||||
# unique_cities = [c for c in cities if c.get("name") in unique_names]
|
# unique_cities = [c for c in cities if c.get("name") in unique_names]
|
||||||
print(f"Num cities: {len(cities)}, unique names: {len(unique_names)}")
|
print(f"Num cities: {len(cities)}, unique names: {len(unique_names)}")
|
||||||
city_combinations = list(itertools.combinations(unique_names, 2))
|
city_combinations = list(itertools.combinations(unique_names, 2))
|
||||||
if args.shuffle:
|
if shuffle:
|
||||||
np.random.shuffle(city_combinations)
|
np.random.shuffle(city_combinations)
|
||||||
chunk_size = args.chunk_size
|
|
||||||
num_chunks = len(city_combinations) // chunk_size + 1
|
|
||||||
output_file = args.output_file
|
|
||||||
|
|
||||||
|
# chunk size, city_combinations, max_workers, output_file
|
||||||
|
num_chunks = len(city_combinations) // chunk_size + 1
|
||||||
with open(output_file, "w", newline="") as csvfile:
|
with open(output_file, "w", newline="") as csvfile:
|
||||||
fieldnames = ["city_from", "city_to", "distance"]
|
fieldnames = ["city_from", "city_to", "distance"]
|
||||||
writer = csv.DictWriter(csvfile, fieldnames=fieldnames)
|
writer = csv.DictWriter(csvfile, fieldnames=fieldnames)
|
||||||
writer.writeheader()
|
writer.writeheader()
|
||||||
|
|
||||||
try:
|
try:
|
||||||
executor = concurrent.futures.ProcessPoolExecutor(max_workers=args.workers)
|
executor = concurrent.futures.ProcessPoolExecutor(max_workers=max_workers)
|
||||||
for i in tqdm(
|
for i in tqdm(
|
||||||
range(num_chunks),
|
range(num_chunks),
|
||||||
total=num_chunks,
|
total=num_chunks,
|
||||||
@ -163,6 +212,15 @@ def main():
|
|||||||
executor.shutdown(wait=False)
|
executor.shutdown(wait=False)
|
||||||
raise SystemExit("Execution terminated by user.")
|
raise SystemExit("Execution terminated by user.")
|
||||||
|
|
||||||
|
print(f"Wrote {output_file}")
|
||||||
|
|
||||||
|
|
||||||
if __name__ == "__main__":
|
if __name__ == "__main__":
|
||||||
main()
|
args = parse_args()
|
||||||
|
main(args)
|
||||||
|
# perform check
|
||||||
|
import pandas as pd
|
||||||
|
|
||||||
|
df = pd.read_csv(args.output_file)
|
||||||
|
assert df["distance"].min() > 0
|
||||||
|
assert df["distance"].max() < MAX_DISTANCE
|
||||||
|
Binary file not shown.
Before Width: | Height: | Size: 230 KiB |
Loading…
Reference in New Issue
Block a user