teaching a transformer to understand how far apart (common) cities are.
You can not select more than 25 topics Topics must start with a letter or number, can include dashes ('-') and can be up to 35 characters long.
 
 

79 lines
2.3 KiB

import glob
import logging
import os
import numpy as np
import pandas as pd
from matplotlib import pyplot as plt
from sentence_transformers import LoggingHandler, SentenceTransformer
# from sentence_transformers.evaluation import EmbeddingSimilarityEvaluator
# from sklearn.model_selection import train_test_split
if not os.path.exists("./plots"):
os.mkdir("./plots")
# Configure logging
logging.basicConfig(
format="%(asctime)s - %(message)s",
datefmt="%Y-%m-%d %H:%M:%S",
level=logging.INFO,
handlers=[LoggingHandler()],
)
def evaluate(model, city_from, city_to):
city_to = model.encode(city_to)
city_from = model.encode(city_from)
return np.dot(city_to, city_from) / (
np.linalg.norm(city_to) * np.linalg.norm(city_from)
)
def calculate_similarity(data, base_model, trained_model):
# MAX_DISTANCE = 20_037.5
# data["distance"] /= MAX_DISTANCE
data["similarity_before"] = data.apply(
lambda x: evaluate(base_model, x["city_from"], x["city_to"]), axis=1
)
data["similarity_after"] = data.apply(
lambda x: evaluate(trained_model, x["city_from"], x["city_to"]), axis=1
)
return data
def make_plot(data):
fig, ax = plt.subplots()
ax.scatter(
data["distance"],
data["similarity_before"],
color="r",
alpha=0.1,
label="before",
)
ax.scatter(
data["distance"], data["similarity_after"], color="b", alpha=0.1, label="after"
)
ax.set_xlabel("distance between cities (km)")
ax.set_ylabel("similarity between vectors\n(cosine)")
ax.legend(loc="center right")
return fig
if __name__ == "__main__":
model_name = "sentence-transformers/all-MiniLM-L6-v2"
base_model = SentenceTransformer(model_name, device="cuda")
data = pd.read_csv("city_distances_full.csv")
# data_sample = data.sample(1_000)
checkpoint_dir = "checkpoints_absmax_split" # no slash
for checkpoint in sorted(glob.glob(f"{checkpoint_dir}/*")):
print(f"Evaluating {checkpoint}")
data_sample = data.sample(1_000)
trained_model = SentenceTransformer(checkpoint, device="cuda")
data_sample = calculate_similarity(data_sample, base_model, trained_model)
fig = make_plot(data_sample)
fig.savefig(f"./plots/progress_{checkpoint.split('/')[1]}.png", dpi=600)