initial commit, working code
This commit is contained in:
commit
b14a33c984
4
.gitignore
vendored
Normal file
4
.gitignore
vendored
Normal file
@ -0,0 +1,4 @@
|
||||
checkpoints*
|
||||
plots*
|
||||
*.csv
|
||||
output/
|
17
Makefile
Normal file
17
Makefile
Normal file
@ -0,0 +1,17 @@
|
||||
city_distances.csv: lint generate_data.py
|
||||
bash -c 'time python generate_data.py'
|
||||
|
||||
lint:
|
||||
isort --profile=black .
|
||||
black .
|
||||
flake8 --max-line-length=88 .
|
||||
|
||||
train: lint train.py
|
||||
bash -c 'time python train.py'
|
||||
|
||||
eval: lint eval.py
|
||||
bash -c 'time python eval.py'
|
||||
|
||||
clean:
|
||||
rm -rf output/
|
||||
rm -rf checkpoints/
|
45
debug_distance.py
Normal file
45
debug_distance.py
Normal file
@ -0,0 +1,45 @@
|
||||
import geonamescache
|
||||
from geopy.distance import geodesic
|
||||
|
||||
gc = geonamescache.GeonamesCache()
|
||||
cities = gc.get_cities()
|
||||
us_cities = {k: c for k, c in cities.items() if c.get("countrycode") == "US"}
|
||||
|
||||
print(gc.search_cities("Jamaica"), "\n")
|
||||
print(gc.search_cities("Manhattan"), "\n")
|
||||
print("lengths:", len(cities), len(us_cities))
|
||||
|
||||
|
||||
def get_coordinates(city_name, country_code="US"):
|
||||
search_results = gc.search_cities(city_name, case_sensitive=True)
|
||||
for city in search_results:
|
||||
print(f"searching {city}")
|
||||
possible_matches = city.get("alternatenames") + [city_name]
|
||||
if city_name in possible_matches and city.get("countrycode") == country_code:
|
||||
return city.get("latitude"), city.get("longitude")
|
||||
return None
|
||||
|
||||
|
||||
def get_distance(city1, city2, country1="US", country2="US"):
|
||||
city1_coords = get_coordinates(city1, country1)
|
||||
city2_coords = get_coordinates(city2, country2)
|
||||
|
||||
if city1_coords is None or city2_coords is None:
|
||||
return None
|
||||
|
||||
return geodesic(city1_coords, city2_coords).km
|
||||
|
||||
|
||||
MAX_DISTANCE = 20_037.5
|
||||
|
||||
city1 = "New York"
|
||||
city2 = "Jamaica"
|
||||
country1 = "US"
|
||||
country2 = "US"
|
||||
|
||||
distance = get_distance(city1, city2, country1, country2)
|
||||
|
||||
if distance is not None:
|
||||
print(f"Distance between {city1} and {city2} is {distance:.2f} km.")
|
||||
else:
|
||||
print("One or both city names were not found.")
|
78
eval.py
Normal file
78
eval.py
Normal file
@ -0,0 +1,78 @@
|
||||
import glob
|
||||
import logging
|
||||
import os
|
||||
|
||||
import numpy as np
|
||||
import pandas as pd
|
||||
from matplotlib import pyplot as plt
|
||||
from sentence_transformers import LoggingHandler, SentenceTransformer
|
||||
|
||||
# from sentence_transformers.evaluation import EmbeddingSimilarityEvaluator
|
||||
# from sklearn.model_selection import train_test_split
|
||||
|
||||
if not os.path.exists("./plots"):
|
||||
os.mkdir("./plots")
|
||||
|
||||
# Configure logging
|
||||
logging.basicConfig(
|
||||
format="%(asctime)s - %(message)s",
|
||||
datefmt="%Y-%m-%d %H:%M:%S",
|
||||
level=logging.INFO,
|
||||
handlers=[LoggingHandler()],
|
||||
)
|
||||
|
||||
|
||||
def evaluate(model, city_from, city_to):
|
||||
city_to = model.encode(city_to)
|
||||
city_from = model.encode(city_from)
|
||||
return np.dot(city_to, city_from) / (
|
||||
np.linalg.norm(city_to) * np.linalg.norm(city_from)
|
||||
)
|
||||
|
||||
|
||||
def calculate_similarity(data, base_model, trained_model):
|
||||
# MAX_DISTANCE = 20_037.5
|
||||
# data["distance"] /= MAX_DISTANCE
|
||||
data["similarity_before"] = data.apply(
|
||||
lambda x: evaluate(base_model, x["city_from"], x["city_to"]), axis=1
|
||||
)
|
||||
|
||||
data["similarity_after"] = data.apply(
|
||||
lambda x: evaluate(trained_model, x["city_from"], x["city_to"]), axis=1
|
||||
)
|
||||
return data
|
||||
|
||||
|
||||
def make_plot(data):
|
||||
fig, ax = plt.subplots()
|
||||
|
||||
ax.scatter(
|
||||
data["distance"],
|
||||
data["similarity_before"],
|
||||
color="r",
|
||||
alpha=0.1,
|
||||
label="before",
|
||||
)
|
||||
ax.scatter(
|
||||
data["distance"], data["similarity_after"], color="b", alpha=0.1, label="after"
|
||||
)
|
||||
ax.set_xlabel("distance between cities (km)")
|
||||
ax.set_ylabel("similarity between vectors\n(cosine)")
|
||||
fig.legend(loc="upper right")
|
||||
return fig
|
||||
|
||||
|
||||
if __name__ == "__main__":
|
||||
model_name = "sentence-transformers/all-MiniLM-L6-v2"
|
||||
base_model = SentenceTransformer(model_name, device="cuda")
|
||||
|
||||
data = pd.read_csv("city_distances_sample.csv")
|
||||
# data_sample = data.sample(1_000)
|
||||
checkpoint_dir = "checkpoints_absmax_split" # no slash
|
||||
for checkpoint in sorted(glob.glob(f"{checkpoint_dir}/*")):
|
||||
data_sample = data.sample(1_000)
|
||||
trained_model = SentenceTransformer(checkpoint, device="cuda")
|
||||
|
||||
data_sample = calculate_similarity(data_sample, base_model, trained_model)
|
||||
fig = make_plot(data_sample)
|
||||
fig.savefig(f"./plots/progress_{checkpoint.split('/')[1]}.png", dpi=600)
|
118
generate_data.py
Normal file
118
generate_data.py
Normal file
@ -0,0 +1,118 @@
|
||||
import concurrent.futures
|
||||
import csv
|
||||
import itertools
|
||||
from concurrent.futures import as_completed
|
||||
from functools import lru_cache
|
||||
|
||||
import geonamescache
|
||||
import numpy as np
|
||||
from geopy.distance import geodesic
|
||||
|
||||
MAX_DISTANCE = 20_037.5
|
||||
|
||||
gc = geonamescache.GeonamesCache()
|
||||
cities = gc.get_cities()
|
||||
us_cities = {
|
||||
k: c
|
||||
for k, c in cities.items()
|
||||
if (c.get("countrycode") == "US") # & (c.get("population", 0) > 5e4)
|
||||
}
|
||||
|
||||
|
||||
@lru_cache(maxsize=None)
|
||||
def get_coordinates(city_name, country_code="US"):
|
||||
"""
|
||||
Get the coordinates of a city.
|
||||
|
||||
Parameters
|
||||
----------
|
||||
city_name : str
|
||||
The name of the city.
|
||||
country_code : str, optional
|
||||
The country code of the city, by default 'US'.
|
||||
|
||||
Returns
|
||||
-------
|
||||
tuple
|
||||
A tuple containing the latitude and longitude of the city,
|
||||
or None if the city is not found.
|
||||
"""
|
||||
search_results = gc.search_cities(city_name, case_sensitive=True)
|
||||
if not search_results:
|
||||
return None
|
||||
populations = [city.get("population") for city in search_results]
|
||||
city = search_results[np.argmax(populations)]
|
||||
return city.get("latitude"), city.get("longitude")
|
||||
|
||||
|
||||
def get_distance(city1, city2, country1="US", country2="US"):
|
||||
"""
|
||||
Get the distance between two cities in kilometers.
|
||||
|
||||
Parameters
|
||||
----------
|
||||
city1 : str
|
||||
The name of the first city.
|
||||
city2 : str
|
||||
The name of the second city.
|
||||
country1 : str, optional
|
||||
The country code of the first city, by default 'US'.
|
||||
country2 : str, optional
|
||||
The country code of the second city, by default 'US'.
|
||||
|
||||
Returns
|
||||
-------
|
||||
float
|
||||
The distance between the two cities in kilometers,
|
||||
or None if one or both city names were not found.
|
||||
"""
|
||||
city1_coords = get_coordinates(city1, country1)
|
||||
city2_coords = get_coordinates(city2, country2)
|
||||
|
||||
if city1_coords is None or city2_coords is None:
|
||||
return None
|
||||
|
||||
return geodesic(city1_coords, city2_coords).km
|
||||
|
||||
|
||||
def calculate_distance(pair):
|
||||
city1, city2 = pair
|
||||
distance = get_distance(city1["name"], city2["name"])
|
||||
return city1["name"], city2["name"], distance
|
||||
|
||||
|
||||
def main():
|
||||
cities = list(us_cities.values())
|
||||
print(f"Num cities: {len(cities)}")
|
||||
city_combinations = list(itertools.combinations(cities, 2))
|
||||
|
||||
with open("city_distances_full.csv", "w", newline="") as csvfile:
|
||||
fieldnames = ["city_from", "city_to", "distance"]
|
||||
writer = csv.DictWriter(csvfile, fieldnames=fieldnames)
|
||||
writer.writeheader()
|
||||
|
||||
try:
|
||||
executor = concurrent.futures.ProcessPoolExecutor(max_workers=8)
|
||||
# results = executor.map(calculate_distance, city_combinations)
|
||||
futures = {
|
||||
executor.submit(calculate_distance, pair): pair
|
||||
for pair in city_combinations
|
||||
}
|
||||
for future in as_completed(futures):
|
||||
city_from, city_to, distance = future.result()
|
||||
if distance is not None:
|
||||
writer.writerow(
|
||||
{
|
||||
"city_from": city_from,
|
||||
"city_to": city_to,
|
||||
"distance": distance,
|
||||
}
|
||||
)
|
||||
except KeyboardInterrupt:
|
||||
print("Interrupted. Terminating processes...")
|
||||
executor.shutdown(wait=False)
|
||||
raise SystemExit("Execution terminated by user.")
|
||||
|
||||
|
||||
if __name__ == "__main__":
|
||||
main()
|
96
train.py
Normal file
96
train.py
Normal file
@ -0,0 +1,96 @@
|
||||
import logging
|
||||
|
||||
import numpy as np
|
||||
import pandas as pd
|
||||
from sentence_transformers import (
|
||||
InputExample,
|
||||
LoggingHandler,
|
||||
SentenceTransformer,
|
||||
losses,
|
||||
)
|
||||
from sentence_transformers.evaluation import EmbeddingSimilarityEvaluator
|
||||
from sklearn.model_selection import train_test_split
|
||||
from torch.utils.data import DataLoader
|
||||
|
||||
# Configure logging
|
||||
logging.basicConfig(
|
||||
format="%(asctime)s - %(message)s",
|
||||
datefmt="%Y-%m-%d %H:%M:%S",
|
||||
level=logging.INFO,
|
||||
handlers=[LoggingHandler()],
|
||||
)
|
||||
|
||||
model_name = "sentence-transformers/all-MiniLM-L6-v2"
|
||||
model = SentenceTransformer(model_name, device="cuda")
|
||||
# num_examples = 10_000
|
||||
|
||||
# Perform train-test split
|
||||
# Example fake data with right types (for testing)
|
||||
# import faker
|
||||
# fake = Faker()
|
||||
# train_data = [
|
||||
# (fake.city(), fake.city(), np.random.rand())
|
||||
# for _ in range(num_examples)
|
||||
# ]
|
||||
data = pd.read_csv("city_distances_sample.csv")
|
||||
MAX_DISTANCE = 20_037.5 # global max distance
|
||||
# MAX_DISTANCE = data["distance"].max() # about 5k
|
||||
|
||||
print(f"{MAX_DISTANCE=}")
|
||||
train_data = [
|
||||
(row["city_from"], row["city_to"], 1 - row["distance"] / MAX_DISTANCE)
|
||||
for _, row in data.iterrows()
|
||||
]
|
||||
|
||||
np.random.seed(1992)
|
||||
np.random.shuffle(train_data)
|
||||
train_examples = examples = [
|
||||
InputExample(texts=[city_from, city_to], label=dist)
|
||||
for city_from, city_to, dist in train_data
|
||||
]
|
||||
|
||||
train_examples, val_examples = train_test_split(
|
||||
examples, test_size=0.2, random_state=21
|
||||
)
|
||||
# validation examples can be something like templated sentences
|
||||
# that maintain the same distance as the cities (same context)
|
||||
# should probably add training examples like that too if needed
|
||||
batch_size = 16
|
||||
num_examples = len(train_examples)
|
||||
steps_per_epoch = num_examples // batch_size
|
||||
|
||||
print(f"\nHead of training data (size: {num_examples}):")
|
||||
print(train_data[:10], "\n")
|
||||
|
||||
# Create DataLoaders for train and validation datasets
|
||||
train_dataloader = DataLoader(train_examples, shuffle=True, batch_size=16)
|
||||
|
||||
print("TRAINING")
|
||||
# Configure the training arguments
|
||||
training_args = {
|
||||
"output_path": "./output",
|
||||
# "evaluation_steps": steps_per_epoch, # already evaluates at the end of each epoch
|
||||
"epochs": 5,
|
||||
"warmup_steps": 500,
|
||||
"optimizer_params": {"lr": 2e-5},
|
||||
# "weight_decay": 0, # not sure if this helps but works fine without setting it.
|
||||
"scheduler": "WarmupLinear",
|
||||
"save_best_model": True,
|
||||
"checkpoint_path": "./checkpoints_absmax_split",
|
||||
"checkpoint_save_steps": steps_per_epoch,
|
||||
"checkpoint_save_total_limit": 20,
|
||||
}
|
||||
print(f"TRAINING ARGUMENTS:\n {training_args}")
|
||||
|
||||
train_loss = losses.CosineSimilarityLoss(model)
|
||||
|
||||
# Create an evaluator for validation dataset
|
||||
evaluator = EmbeddingSimilarityEvaluator.from_input_examples(
|
||||
val_examples, write_csv=True
|
||||
)
|
||||
|
||||
model.fit(
|
||||
train_objectives=[(train_dataloader, train_loss)],
|
||||
evaluator=evaluator,
|
||||
**training_args,
|
||||
)
|
Loading…
Reference in New Issue
Block a user