79 lines
2.3 KiB
Python
79 lines
2.3 KiB
Python
|
import glob
|
||
|
import logging
|
||
|
import os
|
||
|
|
||
|
import numpy as np
|
||
|
import pandas as pd
|
||
|
from matplotlib import pyplot as plt
|
||
|
from sentence_transformers import LoggingHandler, SentenceTransformer
|
||
|
|
||
|
# from sentence_transformers.evaluation import EmbeddingSimilarityEvaluator
|
||
|
# from sklearn.model_selection import train_test_split
|
||
|
|
||
|
if not os.path.exists("./plots"):
|
||
|
os.mkdir("./plots")
|
||
|
|
||
|
# Configure logging
|
||
|
logging.basicConfig(
|
||
|
format="%(asctime)s - %(message)s",
|
||
|
datefmt="%Y-%m-%d %H:%M:%S",
|
||
|
level=logging.INFO,
|
||
|
handlers=[LoggingHandler()],
|
||
|
)
|
||
|
|
||
|
|
||
|
def evaluate(model, city_from, city_to):
|
||
|
city_to = model.encode(city_to)
|
||
|
city_from = model.encode(city_from)
|
||
|
return np.dot(city_to, city_from) / (
|
||
|
np.linalg.norm(city_to) * np.linalg.norm(city_from)
|
||
|
)
|
||
|
|
||
|
|
||
|
def calculate_similarity(data, base_model, trained_model):
|
||
|
# MAX_DISTANCE = 20_037.5
|
||
|
# data["distance"] /= MAX_DISTANCE
|
||
|
data["similarity_before"] = data.apply(
|
||
|
lambda x: evaluate(base_model, x["city_from"], x["city_to"]), axis=1
|
||
|
)
|
||
|
|
||
|
data["similarity_after"] = data.apply(
|
||
|
lambda x: evaluate(trained_model, x["city_from"], x["city_to"]), axis=1
|
||
|
)
|
||
|
return data
|
||
|
|
||
|
|
||
|
def make_plot(data):
|
||
|
fig, ax = plt.subplots()
|
||
|
|
||
|
ax.scatter(
|
||
|
data["distance"],
|
||
|
data["similarity_before"],
|
||
|
color="r",
|
||
|
alpha=0.1,
|
||
|
label="before",
|
||
|
)
|
||
|
ax.scatter(
|
||
|
data["distance"], data["similarity_after"], color="b", alpha=0.1, label="after"
|
||
|
)
|
||
|
ax.set_xlabel("distance between cities (km)")
|
||
|
ax.set_ylabel("similarity between vectors\n(cosine)")
|
||
|
fig.legend(loc="upper right")
|
||
|
return fig
|
||
|
|
||
|
|
||
|
if __name__ == "__main__":
|
||
|
model_name = "sentence-transformers/all-MiniLM-L6-v2"
|
||
|
base_model = SentenceTransformer(model_name, device="cuda")
|
||
|
|
||
|
data = pd.read_csv("city_distances_sample.csv")
|
||
|
# data_sample = data.sample(1_000)
|
||
|
checkpoint_dir = "checkpoints_absmax_split" # no slash
|
||
|
for checkpoint in sorted(glob.glob(f"{checkpoint_dir}/*")):
|
||
|
data_sample = data.sample(1_000)
|
||
|
trained_model = SentenceTransformer(checkpoint, device="cuda")
|
||
|
|
||
|
data_sample = calculate_similarity(data_sample, base_model, trained_model)
|
||
|
fig = make_plot(data_sample)
|
||
|
fig.savefig(f"./plots/progress_{checkpoint.split('/')[1]}.png", dpi=600)
|