You can not select more than 25 topics
Topics must start with a letter or number, can include dashes ('-') and can be up to 35 characters long.
78 lines
2.3 KiB
78 lines
2.3 KiB
import glob
|
|
import logging
|
|
import os
|
|
|
|
import numpy as np
|
|
import pandas as pd
|
|
from matplotlib import pyplot as plt
|
|
from sentence_transformers import LoggingHandler, SentenceTransformer
|
|
|
|
# from sentence_transformers.evaluation import EmbeddingSimilarityEvaluator
|
|
# from sklearn.model_selection import train_test_split
|
|
|
|
if not os.path.exists("./plots"):
|
|
os.mkdir("./plots")
|
|
|
|
# Configure logging
|
|
logging.basicConfig(
|
|
format="%(asctime)s - %(message)s",
|
|
datefmt="%Y-%m-%d %H:%M:%S",
|
|
level=logging.INFO,
|
|
handlers=[LoggingHandler()],
|
|
)
|
|
|
|
|
|
def evaluate(model, city_from, city_to):
|
|
city_to = model.encode(city_to)
|
|
city_from = model.encode(city_from)
|
|
return np.dot(city_to, city_from) / (
|
|
np.linalg.norm(city_to) * np.linalg.norm(city_from)
|
|
)
|
|
|
|
|
|
def calculate_similarity(data, base_model, trained_model):
|
|
# MAX_DISTANCE = 20_037.5
|
|
# data["distance"] /= MAX_DISTANCE
|
|
data["similarity_before"] = data.apply(
|
|
lambda x: evaluate(base_model, x["city_from"], x["city_to"]), axis=1
|
|
)
|
|
|
|
data["similarity_after"] = data.apply(
|
|
lambda x: evaluate(trained_model, x["city_from"], x["city_to"]), axis=1
|
|
)
|
|
return data
|
|
|
|
|
|
def make_plot(data):
|
|
fig, ax = plt.subplots()
|
|
|
|
ax.scatter(
|
|
data["distance"],
|
|
data["similarity_before"],
|
|
color="r",
|
|
alpha=0.1,
|
|
label="before",
|
|
)
|
|
ax.scatter(
|
|
data["distance"], data["similarity_after"], color="b", alpha=0.1, label="after"
|
|
)
|
|
ax.set_xlabel("distance between cities (km)")
|
|
ax.set_ylabel("similarity between vectors\n(cosine)")
|
|
fig.legend(loc="upper right")
|
|
return fig
|
|
|
|
|
|
if __name__ == "__main__":
|
|
model_name = "sentence-transformers/all-MiniLM-L6-v2"
|
|
base_model = SentenceTransformer(model_name, device="cuda")
|
|
|
|
data = pd.read_csv("city_distances_sample.csv")
|
|
# data_sample = data.sample(1_000)
|
|
checkpoint_dir = "checkpoints_absmax_split" # no slash
|
|
for checkpoint in sorted(glob.glob(f"{checkpoint_dir}/*")):
|
|
data_sample = data.sample(1_000)
|
|
trained_model = SentenceTransformer(checkpoint, device="cuda")
|
|
|
|
data_sample = calculate_similarity(data_sample, base_model, trained_model)
|
|
fig = make_plot(data_sample)
|
|
fig.savefig(f"./plots/progress_{checkpoint.split('/')[1]}.png", dpi=600)
|
|
|