citybert/eval.py
2023-05-04 10:03:15 +00:00

79 lines
2.3 KiB
Python

import glob
import logging
import os
import numpy as np
import pandas as pd
from matplotlib import pyplot as plt
from sentence_transformers import LoggingHandler, SentenceTransformer
# from sentence_transformers.evaluation import EmbeddingSimilarityEvaluator
# from sklearn.model_selection import train_test_split
if not os.path.exists("./plots"):
os.mkdir("./plots")
# Configure logging
logging.basicConfig(
format="%(asctime)s - %(message)s",
datefmt="%Y-%m-%d %H:%M:%S",
level=logging.INFO,
handlers=[LoggingHandler()],
)
def evaluate(model, city_from, city_to):
city_to = model.encode(city_to)
city_from = model.encode(city_from)
return np.dot(city_to, city_from) / (
np.linalg.norm(city_to) * np.linalg.norm(city_from)
)
def calculate_similarity(data, base_model, trained_model):
# MAX_DISTANCE = 20_037.5
# data["distance"] /= MAX_DISTANCE
data["similarity_before"] = data.apply(
lambda x: evaluate(base_model, x["city_from"], x["city_to"]), axis=1
)
data["similarity_after"] = data.apply(
lambda x: evaluate(trained_model, x["city_from"], x["city_to"]), axis=1
)
return data
def make_plot(data):
fig, ax = plt.subplots()
ax.scatter(
data["distance"],
data["similarity_before"],
color="r",
alpha=0.1,
label="before",
)
ax.scatter(
data["distance"], data["similarity_after"], color="b", alpha=0.1, label="after"
)
ax.set_xlabel("distance between cities (km)")
ax.set_ylabel("similarity between vectors\n(cosine)")
fig.legend(loc="upper right")
return fig
if __name__ == "__main__":
model_name = "sentence-transformers/all-MiniLM-L6-v2"
base_model = SentenceTransformer(model_name, device="cuda")
data = pd.read_csv("city_distances_sample.csv")
# data_sample = data.sample(1_000)
checkpoint_dir = "checkpoints_absmax_split" # no slash
for checkpoint in sorted(glob.glob(f"{checkpoint_dir}/*")):
data_sample = data.sample(1_000)
trained_model = SentenceTransformer(checkpoint, device="cuda")
data_sample = calculate_similarity(data_sample, base_model, trained_model)
fig = make_plot(data_sample)
fig.savefig(f"./plots/progress_{checkpoint.split('/')[1]}.png", dpi=600)