import glob import logging import os import numpy as np import pandas as pd from matplotlib import pyplot as plt from sentence_transformers import LoggingHandler, SentenceTransformer # from sentence_transformers.evaluation import EmbeddingSimilarityEvaluator # from sklearn.model_selection import train_test_split if not os.path.exists("./plots"): os.mkdir("./plots") # Configure logging logging.basicConfig( format="%(asctime)s - %(message)s", datefmt="%Y-%m-%d %H:%M:%S", level=logging.INFO, handlers=[LoggingHandler()], ) def evaluate(model, city_from, city_to): city_to = model.encode(city_to) city_from = model.encode(city_from) return np.dot(city_to, city_from) / ( np.linalg.norm(city_to) * np.linalg.norm(city_from) ) def calculate_similarity(data, base_model, trained_model): # MAX_DISTANCE = 20_037.5 # data["distance"] /= MAX_DISTANCE data["similarity_before"] = data.apply( lambda x: evaluate(base_model, x["city_from"], x["city_to"]), axis=1 ) data["similarity_after"] = data.apply( lambda x: evaluate(trained_model, x["city_from"], x["city_to"]), axis=1 ) return data def make_plot(data): fig, ax = plt.subplots() ax.scatter( data["distance"], data["similarity_before"], color="r", alpha=0.1, label="before", ) ax.scatter( data["distance"], data["similarity_after"], color="b", alpha=0.1, label="after" ) ax.set_xlabel("distance between cities (km)") ax.set_ylabel("similarity between vectors\n(cosine)") fig.legend(loc="upper right") return fig if __name__ == "__main__": model_name = "sentence-transformers/all-MiniLM-L6-v2" base_model = SentenceTransformer(model_name, device="cuda") data = pd.read_csv("city_distances_full.csv") # data_sample = data.sample(1_000) checkpoint_dir = "checkpoints_absmax_split" # no slash for checkpoint in sorted(glob.glob(f"{checkpoint_dir}/*")): data_sample = data.sample(1_000) trained_model = SentenceTransformer(checkpoint, device="cuda") data_sample = calculate_similarity(data_sample, base_model, trained_model) fig = make_plot(data_sample) fig.savefig(f"./plots/progress_{checkpoint.split('/')[1]}.png", dpi=600)