Compare commits

...

11 Commits
main ... ll

Author SHA1 Message Date
Michael Pilosov
fbe4503801 workers 2026-05-25 23:36:06 +00:00
Michael Pilosov
cd9da60893 10 epochs 2026-05-25 22:51:43 +00:00
Michael Pilosov
e9d3d72499 Improve training signal and add honest eval metrics
- prepare_training_data: bag_size 5→8, street signs fill slots first so
  every sample contains the most geographically discriminative texts
- train: HuberLoss replaces MSE (robust to outlier intersections),
  ReduceLROnPlateau scheduler added, split logic extracted to data_utils
- eval: reproduce train/val split to report honest per-bag and
  per-intersection-aggregated metrics separately for train and val sets
- data_utils: shared split_indices() so train and eval use identical splits

Co-authored-by: Cursor <cursoragent@cursor.com>
2026-05-25 22:41:14 +00:00
Michael Pilosov
b72cd1b917 device fix (cuda) 2026-05-25 22:10:06 +00:00
Michael Pilosov
40967a303c drop pena 2026-05-25 22:09:57 +00:00
Michael Pilosov
8af279771a frozen layers more epochs 2026-05-25 21:45:10 +00:00
Michael Pilosov
8f4d4c1057 optimizations? 2026-05-25 21:45:08 +00:00
Michael Pilosov
44c7753856 cuda 2026-05-25 21:25:24 +00:00
Michael Pilosov
b5810dd282 more training data + frozen layers options 2026-05-25 15:16:19 -06:00
Michael Pilosov
bfad4547ce no mps 2026-05-25 14:26:37 -06:00
Michael Pilosov
e8419354f4 first train/eval 2026-05-25 14:11:05 -06:00
9 changed files with 1173 additions and 230 deletions

3
.gitignore vendored
View File

@ -3,3 +3,6 @@ plots*
*.csv
output/
.requirements_installed
.DS_Store
__pycache__/
output*

1
.python-version Normal file
View File

@ -0,0 +1 @@
3.10

105
Makefile
View File

@ -1,46 +1,81 @@
all: install data train eval
all: data train eval
city_distances.csv: generate_data.py
@echo "Generating distance data..."
@bash -c 'time python generate_data.py --country US --workers 8 --chunk-size 4200'
@echo "Calculating range of generated data..."
@cat city_distances.csv | tail -n +2 | sort -t',' -k3n | head -n1
@cat city_distances.csv | tail -n +2 | sort -t',' -k3nr | head -n1
training.csv: prepare_training_data.py training_data_raw.csv
@echo "Preparing bootstrapped sign text bags..."
@bash -c 'source .venv/bin/activate && python prepare_training_data.py'
data: city_distances.csv
data: training.csv
train: check train.py
@echo "Training embeddings..."
@bash -c 'time python train.py'
train: train.py training.csv
@echo "Training coordinate regressor..."
@bash -c 'source .venv/bin/activate && python train.py --epochs 50'
eval: check eval.py
@echo "Evaluating trained model..."
@bash -c 'time python eval.py'
eval: eval.py training.csv
@echo "Evaluating coordinate regressor..."
@bash -c 'source .venv/bin/activate && python eval.py'
train_frozen_encoder: train.py training.csv
@echo "Training coordinate head with frozen encoder..."
@bash -c 'source .venv/bin/activate && python train.py --output-path output_frozen_encoder --freeze-encoder --epochs 10'
eval_frozen_encoder: eval.py training.csv
@echo "Evaluating frozen-encoder coordinate regressor..."
@bash -c 'source .venv/bin/activate && python eval.py --model-path output_frozen_encoder --output-file predictions_frozen_encoder.csv --plot-file plots/prediction_map_frozen_encoder.png --scatter-plot-file plots/predicted_vs_actual_frozen_encoder.png'
train_frozen_layers: train.py training.csv
@echo "Training coordinate regressor with first transformer layers frozen..."
@bash -c 'source .venv/bin/activate && python train.py --output-path output_frozen_layers --freeze-transformer-layers 4 --epochs 50 --num-workers 0'
eval_frozen_layers: eval.py training.csv
@echo "Evaluating frozen-layer coordinate regressor..."
@bash -c 'source .venv/bin/activate && python eval.py --model-path output_frozen_layers --output-file predictions_frozen_layers.csv --plot-file plots/prediction_map_frozen_layers.png --scatter-plot-file plots/predicted_vs_actual_frozen_layers.png'
train_mpnet: train.py training.csv
@echo "Training with all-mpnet-base-v2 (encoder fully frozen, head only)..."
@bash -c 'source .venv/bin/activate && python train.py \
--model-name sentence-transformers/all-mpnet-base-v2 \
--output-path output_mpnet \
--freeze-encoder \
--hidden-dim 512 \
--head-learning-rate 1e-2 \
--epochs 50'
train_mpnet_finetune: train.py training.csv
@echo "Training with all-mpnet-base-v2 (frozen first 10 of 12 layers)..."
@bash -c 'source .venv/bin/activate && python train.py \
--model-name sentence-transformers/all-mpnet-base-v2 \
--output-path output_mpnet_ft \
--freeze-transformer-layers 10 \
--hidden-dim 512 \
--head-learning-rate 1e-2 \
--batch-size 32 \
--epochs 50'
eval_mpnet_finetune: eval.py training.csv
@echo "Evaluating all-mpnet-base-v2 fine-tuned coordinate regressor..."
@bash -c 'source .venv/bin/activate && python eval.py \
--model-path output_mpnet_ft \
--output-file predictions_mpnet_ft.csv \
--plot-file plots/prediction_map_mpnet_ft.png \
--scatter-plot-file plots/predicted_vs_actual_mpnet_ft.png'
eval_mpnet: eval.py training.csv
@echo "Evaluating all-mpnet-base-v2 coordinate regressor..."
@bash -c 'source .venv/bin/activate && python eval.py \
--model-path output_mpnet \
--output-file predictions_mpnet.csv \
--plot-file plots/prediction_map_mpnet.png \
--scatter-plot-file plots/predicted_vs_actual_mpnet.png'
lint:
@echo "Auto-linting files and performing final style checks..."
@isort --profile=black .
@black .
@flake8 --max-line-length=88 --ignore E203 .
check: lint
@echo "Checking for unstaged or untracked changes..."
@git diff-index --quiet HEAD -- || { echo "Unstaged or untracked changes detected!"; exit 1; }
@bash -c 'source .venv/bin/activate && isort --profile=black *.py'
@bash -c 'source .venv/bin/activate && black *.py'
@bash -c 'source .venv/bin/activate && flake8 --max-line-length=88 --ignore E203 *.py'
clean:
@echo "Removing outputs/ and checkpoints/ directories"
@echo "Removing generated outputs"
@rm -rf output/
@rm -rf checkpoints/
@rm -f training.csv predictions.csv
compress: plots/progress_12474_sm.png
plots/progress_12474_sm.png: plots/progress_12474.png
@convert -resize 33% plots/progress_12474.png progress_sample.png
install: .requirements_installed
.requirements_installed: requirements.txt
pip install -r requirements.txt
@echo "installed requirements" > .requirements_installed
.PHONY: data train eval lint check clean all
.PHONY: data train eval train_frozen_encoder eval_frozen_encoder train_frozen_layers eval_frozen_layers train_mpnet eval_mpnet train_mpnet_finetune eval_mpnet_finetune lint clean all

120
README.md
View File

@ -1,89 +1,83 @@
# CityBert
# Sign Coordinate Regressor
CityBert is a machine learning project that fine-tunes a neural network model to understand the similarity between cities based on their geodesic distances.
This project fine-tunes a sentence-transformer model to predict a latitude and
longitude pair from text observed on signs at an intersection.
The project generates a dataset of US cities and their pair-wise geodesic distances, which are then used to train the model.
The raw dataset is expected at `training_data_raw.csv`. It may include the
pandas-exported index column; the preprocessing script only uses:
The project can be extended to include other distance metrics or additional data, such as airport codes, city aliases, or time zones.
- `intersection`
- `text_on_sign_exact`
- `latitude`
- `longitude`
> Note that this model only considers geographic distances and does not take into account other factors such as political borders or transportation infrastructure.
These factors contribute to a sense of "distance as it pertains to travel difficulty," which is not directly reflected by this model.
## Workflow
## But Why?
Prepare bootstrapped training rows:
### Demonstrate Flexibility
This project showcases how pre-trained language models can be fine-tuned to understand geographic relationships between cities.
```bash
source .venv/bin/activate
python prepare_training_data.py --seed 1992 --bag-size 5 --samples-per-intersection 50
```
### Contribute to the Community
Out-of-the-box neural network models struggle to grasp the spatial relationships, cultural connections, and underlying patterns between cities that humans intuitively understand. Training a specialized model bridges this gap, allowing the network to capture complex relationships and better comprehend geographic data.
This writes `training.csv` with rows shaped like:
By training a model on pairs of city names and geodesic distances, we enhance its ability to infer city similarity based on names alone. This is beneficial in applications like search engines, recommendation systems, or other natural language processing tasks involving geographic context.
```text
intersection,sample_id,text,latitude,longitude,unique_sign_count,raw_sign_count
```
This model will be published for public use, and the code can be adapted for other specialized use-cases.
Each row is a deterministic bootstrap sample of sign texts from one
intersection, joined into a single text field. This trains on "some signs seen
at this coordinate" instead of one sign or every sign at that coordinate.
### Explore Tradeoffs
Using a neural network to understand geographic relationships provides a more robust and flexible representation compared to traditional latitude/longitude lookups. It can capture complex patterns and relationships in natural language tasks that may be difficult to model using traditional methods.
Train the model:
Although there's an initial computational overhead in training the model, benefits include handling various location-based queries and better handling of aliases and alternate city names.
```bash
python train.py
```
This tradeoff allows more efficient and context-aware processing of location-based information, making it valuable in specific applications. In scenarios requiring high precision or quick solutions, traditional methods may still be more suitable.
Evaluate and write predictions:
Ultimately, the efficiency of a neural network compared to traditional methods depends on the specific problem and desired trade-offs between accuracy, efficiency, and flexibility.
```bash
python eval.py
```
### Applicability to Other Tasks
The approach demonstrated can be extended to other metrics or features beyond geographic distance. By adapting dataset generation and fine-tuning processes, models can be trained to learn various relationships and similarities between different entities.
This writes `predictions.csv` and a map-style diagnostic plot:
![Prediction map](./plots/prediction_map.png)
## Overview of Project Files
It also writes a coordinate calibration plot:
- `generate_data.py`: Generates a dataset of US cities and their pairwise geodesic distances.
- `train.py`: Trains the neural network model using the generated dataset.
- `eval.py`: Evaluates the trained model by comparing the similarity between city vectors before and after training.
- `Makefile`: Automates the execution of various tasks, such as generating data, training, and evaluation.
- `README.md`: Provides a description of the project, instructions on how to use it, and expected results.
- `requirements.txt`: Defines requirements used for creating the results.
![Predicted vs actual coordinates](./plots/predicted_vs_actual.png)
Or run the full pipeline:
## How to Use
```bash
make
```
1. Install the required dependencies by running `pip install -r requirements.txt`.
2. Run `make data` to generate the dataset of city distances. This will create `city_distances.csv` by default only with US cities (for now).
3. Run `make train` to train the neural network model.
4. Run `make eval` to evaluate the trained model and generate evaluation plots.
## Useful Options
**You can also just run `make` (i.e., `make all`) which will run through all of those steps.**
`prepare_training_data.py`:
- `--seed`: deterministic bootstrap seed.
- `--bag-size`: number of sign texts per sampled bag.
- `--samples-per-intersection`: number of bags generated per intersection.
## What to Expect
`train.py`:
After training, the model should be able to understand the similarity between cities based on their geodesic distances.
You can inspect the evaluation plots generated by the `eval.py` script to see the improvement in similarity scores before and after training.
- `--model-name`: sentence-transformers base model.
- `--epochs`: training epochs.
- `--batch-size`: batch size.
- `--device`: device override such as `cpu`, `cuda`, or `mps`. Defaults to `cuda`.
After even just one epoch, we can see the model has learned to correlate our desired quantities:
## Outputs
![Evaluation plot](./progress_sample.png)
*The above plot is an example showing the relationship between geodesic distance and the similarity between the embedded vectors (1 = more similar), for 10,000 randomly selected pairs of US cities (re-sampled for each image).*
*Note the (vertical) "gap" we see in the image, corresponding to the size of the continental United States (~5,000 km)*
---
## Future Improvements
There are several potential improvements and extensions to the current model:
1. **Incorporate airport codes**: Train the model to understand the unique codes of airports, which could be useful for search engines and other applications.
2. **Add city aliases**: Enhance the dataset with city aliases, so the model can recognize different names for the same city. The `geonamescache` package already includes these.
3. **Include time zones**: Train the model to understand time zone differences between cities, which could be helpful for various time-sensitive use cases. The `geonamescache` package already includes this data, but how to calculate the hours between them is an open question.
4. **Expand to other distance metrics**: Adapt the model to consider other measures of distance, such as transportation infrastructure or travel time.
5. **Train on sentences**: Improve the model's performance on sentences by adding training and validation examples that involve city names in the context of sentences. Can use generative AI to create template sentences (mad-libs style) to create random and diverse training examples.
6. **Global city support**: Extend the model to support cities outside the US and cover a broader range of geographic locations.
# Notes
- Generating the data took about 10 minutes (for 3269 US cities, of which there were 2826 unique names), in parallel on 8-cores (Intel 9700K), yielding 3,991,725 (combinations of cities) with size 150MB.
- For cities with the same name, the one with the larger population is selected (had to make some sort of choice...).
- Training on an Nvidia 3090 FE takes about an hour per epoch with an 80/20 test/train split and batch size 16. At batch size 16 times larger, each epoch took about 5-6 minutes.
- Evaluation (generating plots) on the above hardware took about 15 minutes for 20 epochs at 10k samples each.
- **WARNING**: _It is unclear how the model performs on sentences, as it was trained and evaluated only on word-pairs._ See improvement (5) above.
- `training.csv`: prepared bootstrapped dataset.
- `output/`: saved sentence-transformer encoder, coordinate head, and coordinate
normalization metadata.
- `predictions.csv`: evaluation rows with predicted coordinates and `error_km`.
- `plots/prediction_map.png`: actual vs predicted coordinates with line segments
showing the prediction error.
- `plots/predicted_vs_actual.png`: predicted vs actual latitude and longitude
scatter plots.

27
data_utils.py Normal file
View File

@ -0,0 +1,27 @@
import numpy as np
from sklearn.model_selection import GroupShuffleSplit, train_test_split
DEFAULT_TEST_SIZE = 0.2
DEFAULT_SEED = 1992
def split_indices(data, test_size=DEFAULT_TEST_SIZE, seed=DEFAULT_SEED):
"""Return (train_indices, val_indices) arrays.
Uses a group-aware split on the ``intersection`` column when present so
that every row from the same intersection lands in the same partition.
Falls back to a plain random split otherwise.
"""
indices = np.arange(len(data))
if "intersection" in data.columns:
splitter = GroupShuffleSplit(
n_splits=1, test_size=test_size, random_state=seed
)
train_idx, val_idx = next(
splitter.split(indices, groups=data["intersection"])
)
else:
train_idx, val_idx = train_test_split(
indices, test_size=test_size, random_state=seed
)
return train_idx, val_idx

365
eval.py
View File

@ -1,79 +1,328 @@
import glob
import logging
import argparse
import json
import os
import numpy as np
import pandas as pd
from matplotlib import pyplot as plt
from sentence_transformers import LoggingHandler, SentenceTransformer
os.environ.setdefault("MPLCONFIGDIR", "/private/tmp/matplotlib")
os.environ.setdefault("XDG_CACHE_HOME", "/private/tmp")
# from sentence_transformers.evaluation import EmbeddingSimilarityEvaluator
# from sklearn.model_selection import train_test_split
import matplotlib # noqa: E402
import numpy as np # noqa: E402
import pandas as pd # noqa: E402
import torch # noqa: E402
from sentence_transformers import SentenceTransformer # noqa: E402
if not os.path.exists("./plots"):
os.mkdir("./plots")
from data_utils import split_indices # noqa: E402
from train import CoordinateRegressor, haversine_km # noqa: E402
# Configure logging
logging.basicConfig(
format="%(asctime)s - %(message)s",
datefmt="%Y-%m-%d %H:%M:%S",
level=logging.INFO,
handlers=[LoggingHandler()],
)
matplotlib.use("Agg")
from matplotlib import pyplot as plt # noqa: E402
def evaluate(model, city_from, city_to):
city_to = model.encode(city_to)
city_from = model.encode(city_from)
return np.dot(city_to, city_from) / (
np.linalg.norm(city_to) * np.linalg.norm(city_from)
def parse_args():
parser = argparse.ArgumentParser(description="Evaluate sign coordinate model.")
parser.add_argument("--data-file", default="training.csv")
parser.add_argument("--model-path", default="output")
parser.add_argument("--output-file", default="predictions.csv")
parser.add_argument("--plot-file", default="plots/prediction_map.png")
parser.add_argument(
"--scatter-plot-file",
default="plots/predicted_vs_actual.png",
help="2x1 scatter plot of predicted vs actual latitude and longitude.",
)
parser.add_argument(
"--device",
default="cuda",
help="Device to use for evaluation. Defaults to `cuda`.",
)
parser.add_argument("--batch-size", type=int, default=64)
parser.add_argument(
"--seed",
type=int,
default=1992,
help="Must match the seed used during training for a consistent split.",
)
parser.add_argument(
"--test-size",
type=float,
default=0.2,
help="Must match the test-size used during training.",
)
return parser.parse_args()
def get_device(requested_device):
if requested_device:
return torch.device(requested_device)
if torch.cuda.is_available():
return torch.device("cuda")
return torch.device("cpu")
def load_head(model_path, filename, encoder, config, device):
head = CoordinateRegressor(
embedding_dim=encoder.get_sentence_embedding_dimension(),
hidden_dim=config["hidden_dim"],
dropout=config["dropout"],
).to(device)
head.load_state_dict(
torch.load(
os.path.join(model_path, filename),
map_location=device,
)
)
head.eval()
return head
def load_model(model_path, device):
with open(os.path.join(model_path, "coordinate_config.json")) as f:
config = json.load(f)
trained_encoder = SentenceTransformer(model_path, device=str(device))
trained_encoder.to(device)
trained_head = load_head(
model_path, "coordinate_head.pt", trained_encoder, config, device
)
trained_encoder.eval()
initial_encoder_path = os.path.join(model_path, "initial_encoder")
if os.path.exists(initial_encoder_path):
base_encoder = SentenceTransformer(initial_encoder_path, device=str(device))
else:
base_encoder = SentenceTransformer(config["model_name"], device=str(device))
base_encoder.to(device)
initial_head_path = os.path.join(model_path, "initial_coordinate_head.pt")
initial_head = None
if os.path.exists(initial_head_path):
initial_head = load_head(
model_path, "initial_coordinate_head.pt", base_encoder, config, device
)
base_encoder.eval()
return (
base_encoder,
initial_head,
trained_encoder,
trained_head,
np.array(config["coord_mean"]),
np.array(config["coord_std"]),
)
def calculate_similarity(data, base_model, trained_model):
# MAX_DISTANCE = 20_037.5
# data["distance"] /= MAX_DISTANCE
data["similarity_before"] = data.apply(
lambda x: evaluate(base_model, x["city_from"], x["city_to"]), axis=1
)
@torch.no_grad()
def predict(encoder, head, texts, coord_mean, coord_std, device, batch_size):
predictions = []
for start in range(0, len(texts), batch_size):
batch = texts[start : start + batch_size]
features = encoder.tokenize(batch)
features = {key: value.to(device) for key, value in features.items()}
embeddings = encoder(features)["sentence_embedding"]
batch_predictions = head(embeddings).cpu().numpy()
predictions.append(batch_predictions)
data["similarity_after"] = data.apply(
lambda x: evaluate(trained_model, x["city_from"], x["city_to"]), axis=1
)
return data
normalized = np.vstack(predictions)
return normalized * coord_std + coord_mean
def make_plot(data):
fig, ax = plt.subplots()
def make_prediction_plot(results, plot_file):
os.makedirs(os.path.dirname(plot_file), exist_ok=True)
fig, ax = plt.subplots(figsize=(8, 8))
for _, row in results.iterrows():
ax.plot(
[row["longitude"], row["predicted_longitude"]],
[row["latitude"], row["predicted_latitude"]],
color="0.75",
linewidth=0.4,
alpha=0.35,
zorder=1,
)
ax.scatter(
data["distance"],
data["similarity_before"],
color="r",
alpha=0.1,
label="before",
results["predicted_longitude"],
results["predicted_latitude"],
s=12,
color="#d62728",
alpha=0.45,
label="predicted",
zorder=2,
)
ax.scatter(
data["distance"], data["similarity_after"], color="b", alpha=0.1, label="after"
results["longitude"],
results["latitude"],
s=18,
color="#1f77b4",
alpha=0.75,
label="actual",
zorder=3,
)
ax.set_xlabel("distance between cities (km)")
ax.set_ylabel("similarity between vectors\n(cosine)")
ax.legend(loc="center right")
return fig
ax.set_xlabel("longitude")
ax.set_ylabel("latitude")
ax.set_title("Sign Text Coordinate Predictions")
ax.legend()
ax.set_aspect("equal", adjustable="box")
fig.tight_layout()
fig.savefig(plot_file, dpi=200)
plt.close(fig)
def make_predicted_vs_actual_plot(results, plot_file):
os.makedirs(os.path.dirname(plot_file), exist_ok=True)
fig, axes = plt.subplots(nrows=2, ncols=1, figsize=(7, 9))
plot_specs = [
("latitude", "initial_predicted_latitude", "predicted_latitude", "Latitude"),
(
"longitude",
"initial_predicted_longitude",
"predicted_longitude",
"Longitude",
),
]
for ax, (actual_col, initial_col, predicted_col, title) in zip(axes, plot_specs):
actual = results[actual_col]
initial = results[initial_col] if initial_col in results else None
predicted = results[predicted_col]
values = [actual, predicted]
if initial is not None:
values.append(initial)
lower = min(series.min() for series in values)
upper = max(series.max() for series in values)
if initial is not None:
ax.scatter(
actual,
initial,
s=14,
alpha=0.45,
color="#d62728",
label="before training",
)
ax.scatter(
actual,
predicted,
s=14,
alpha=0.55,
color="#1f77b4",
label="after training",
)
ax.plot([lower, upper], [lower, upper], color="0.25", linewidth=1)
ax.set_xlabel(f"actual {actual_col}")
ax.set_ylabel(f"predicted {actual_col}")
ax.set_title(title)
ax.set_xlim(lower, upper)
ax.set_ylim(lower, upper)
ax.legend()
fig.tight_layout()
fig.savefig(plot_file, dpi=200)
plt.close(fig)
def main():
args = parse_args()
device = get_device(args.device)
(
base_encoder,
initial_head,
trained_encoder,
trained_head,
coord_mean,
coord_std,
) = load_model(args.model_path, device)
data = pd.read_csv(args.data_file).dropna(
subset=["text", "latitude", "longitude"]
)
data = data.reset_index(drop=True)
texts = data["text"].astype(str).tolist()
predicted = predict(
trained_encoder,
trained_head,
texts,
coord_mean,
coord_std,
device,
args.batch_size,
)
true_coords = data[["latitude", "longitude"]].to_numpy(dtype=np.float32)
errors = haversine_km(predicted, true_coords)
results = data.copy()
results["predicted_latitude"] = predicted[:, 0]
results["predicted_longitude"] = predicted[:, 1]
if initial_head is not None:
initial_predicted = predict(
base_encoder,
initial_head,
texts,
coord_mean,
coord_std,
device,
args.batch_size,
)
results["initial_predicted_latitude"] = initial_predicted[:, 0]
results["initial_predicted_longitude"] = initial_predicted[:, 1]
results["error_km"] = errors
# Reproduce the same train/val split used during training.
train_indices, val_indices = split_indices(
data, test_size=args.test_size, seed=args.seed
)
val_mask = np.zeros(len(results), dtype=bool)
val_mask[val_indices] = True
results["split"] = np.where(val_mask, "val", "train")
results.to_csv(args.output_file, index=False)
make_prediction_plot(results, args.plot_file)
make_predicted_vs_actual_plot(results, args.scatter_plot_file)
print(f"Wrote predictions to {args.output_file}")
print(f"Wrote plot to {args.plot_file}")
print(f"Wrote scatter plot to {args.scatter_plot_file}")
def _fmt(errs):
return (
f"mean={np.mean(errs):.3f} "
f"median={np.median(errs):.3f} "
f"p90={np.percentile(errs, 90):.3f}"
)
def _intersection_agg_errors(df):
"""Average per-bag predictions to one coordinate per intersection."""
agg = (
df.groupby("intersection")
.agg(
pred_lat=("predicted_latitude", "mean"),
pred_lon=("predicted_longitude", "mean"),
true_lat=("latitude", "first"),
true_lon=("longitude", "first"),
)
.reset_index()
)
return haversine_km(
agg[["pred_lat", "pred_lon"]].to_numpy(dtype=np.float32),
agg[["true_lat", "true_lon"]].to_numpy(dtype=np.float32),
)
val_df = results[results["split"] == "val"]
train_df = results[results["split"] == "train"]
val_bag_errors = val_df["error_km"].to_numpy()
train_bag_errors = train_df["error_km"].to_numpy()
val_agg_errors = _intersection_agg_errors(val_df)
train_agg_errors = _intersection_agg_errors(train_df)
n_val = val_df["intersection"].nunique() if "intersection" in val_df.columns else "?"
n_train = train_df["intersection"].nunique() if "intersection" in train_df.columns else "?"
print()
print(f"[all bags ({len(errors):>6} rows)] {_fmt(errors)}")
print(f"[train bags({len(train_bag_errors):>6} rows)] {_fmt(train_bag_errors)}")
print(f"[val bags ({len(val_bag_errors):>6} rows)] {_fmt(val_bag_errors)}")
print(f"[train agg ({n_train:>3} intersections)] {_fmt(train_agg_errors)}")
print(f"[val agg ({n_val:>3} intersections)] {_fmt(val_agg_errors)} ← generalization")
if __name__ == "__main__":
model_name = "sentence-transformers/all-MiniLM-L6-v2"
base_model = SentenceTransformer(model_name, device="cuda")
data = pd.read_csv("city_distances.csv")
# data_sample = data.sample(1_000)
checkpoint_dir = "checkpoints" # no slash
for checkpoint in sorted(glob.glob(f"{checkpoint_dir}/*")):
print(f"Evaluating {checkpoint}")
data_sample = data.sample(1_000)
trained_model = SentenceTransformer(checkpoint, device="cuda")
data_sample = calculate_similarity(data_sample, base_model, trained_model)
fig = make_plot(data_sample)
fig.savefig(f"./plots/progress_{checkpoint.split('/')[1]}.png", dpi=600)
main()

143
prepare_training_data.py Normal file
View File

@ -0,0 +1,143 @@
import argparse
import random
import pandas as pd
BASE_COLUMNS = ["intersection", "text_on_sign_exact", "latitude", "longitude"]
OPTIONAL_COLUMNS = ["code_type"]
EXCLUDED_INTERSECTIONS = {"56th-pena"}
def parse_args():
parser = argparse.ArgumentParser(
description="Create bootstrapped sign-text bags for coordinate training."
)
parser.add_argument(
"-i",
"--input-file",
default="training_data_raw.csv",
help="Raw pandas-exported CSV.",
)
parser.add_argument(
"-o",
"--output-file",
default="training.csv",
help="Prepared training CSV.",
)
parser.add_argument(
"--seed",
type=int,
default=1992,
help="Random seed for deterministic bootstrapping.",
)
parser.add_argument(
"--bag-size",
type=int,
default=8,
help=(
"Total number of sign texts per training row. Street signs fill "
"slots first; remaining slots are randomly sampled from all signs."
),
)
parser.add_argument(
"--samples-per-intersection",
type=int,
default=100,
help="Bootstrap samples to create for each intersection.",
)
parser.add_argument(
"--separator",
default=" | ",
help="Separator used when joining sampled sign texts.",
)
return parser.parse_args()
def load_raw_data(path):
data = pd.read_csv(path)
missing = sorted(set(BASE_COLUMNS) - set(data.columns))
if missing:
raise ValueError(f"Missing required columns: {missing}")
keep = BASE_COLUMNS + [c for c in OPTIONAL_COLUMNS if c in data.columns]
data = data[keep].copy()
data = data.dropna(subset=BASE_COLUMNS)
data["text_on_sign_exact"] = data["text_on_sign_exact"].astype(str).str.strip()
data = data[data["text_on_sign_exact"] != ""]
data = data[~data["intersection"].isin(EXCLUDED_INTERSECTIONS)]
return data
def make_bootstrap_data(data, bag_size, samples_per_intersection, seed, separator):
if bag_size < 1:
raise ValueError("--bag-size must be at least 1")
if samples_per_intersection < 1:
raise ValueError("--samples-per-intersection must be at least 1")
has_code_type = "code_type" in data.columns
rows = []
grouped = data.groupby("intersection", sort=True)
for intersection, group in grouped:
latitude = group["latitude"].mean()
longitude = group["longitude"].mean()
all_texts = group["text_on_sign_exact"].tolist()
# Street signs are the most geographically discriminative; always
# include them first up to bag_size, then fill remaining slots randomly.
if has_code_type:
street_texts = group[group["code_type"] == "street_sign"][
"text_on_sign_exact"
].tolist()
else:
street_texts = []
for sample_id in range(samples_per_intersection):
rng_state = seed + len(rows)
guaranteed = street_texts[:bag_size]
remaining = bag_size - len(guaranteed)
if remaining > 0:
filler = group.sample(
n=remaining, replace=True, random_state=rng_state
)["text_on_sign_exact"].tolist()
else:
filler = []
bag = guaranteed + filler
# Shuffle so position in the sequence doesn't encode sign type.
random.Random(rng_state).shuffle(bag)
rows.append(
{
"intersection": intersection,
"sample_id": sample_id,
"text": separator.join(bag),
"latitude": latitude,
"longitude": longitude,
"unique_sign_count": len(set(all_texts)),
"raw_sign_count": len(all_texts),
}
)
return pd.DataFrame(rows)
def main():
args = parse_args()
raw_data = load_raw_data(args.input_file)
training_data = make_bootstrap_data(
raw_data,
bag_size=args.bag_size,
samples_per_intersection=args.samples_per_intersection,
seed=args.seed,
separator=args.separator,
)
training_data.to_csv(args.output_file, index=False)
print(
f"Wrote {len(training_data):,} rows from "
f"{raw_data['intersection'].nunique():,} intersections to {args.output_file}"
)
if __name__ == "__main__":
main()

View File

@ -10,3 +10,5 @@ sentence-transformers==2.2.2
torch==2.0.0
torchaudio==2.0.1
torchvision==0.15.1
transformers==4.30.2
huggingface-hub==0.14.1

637
train.py
View File

@ -1,18 +1,17 @@
import argparse
import json
import logging
import os
import random
import numpy as np
import pandas as pd
from sentence_transformers import (
InputExample,
LoggingHandler,
SentenceTransformer,
losses,
)
from sentence_transformers.evaluation import EmbeddingSimilarityEvaluator
from sklearn.model_selection import train_test_split
from torch.utils.data import DataLoader
import torch
from sentence_transformers import LoggingHandler, SentenceTransformer
from data_utils import split_indices
from torch import nn
from torch.utils.data import DataLoader, Dataset
# Configure logging
logging.basicConfig(
format="%(asctime)s - %(message)s",
datefmt="%Y-%m-%d %H:%M:%S",
@ -20,77 +19,567 @@ logging.basicConfig(
handlers=[LoggingHandler()],
)
model_name = "sentence-transformers/all-MiniLM-L6-v2"
model = SentenceTransformer(model_name, device="cuda")
# num_examples = 10_000
# Perform train-test split
# Example fake data with right types (for testing)
# import faker
# fake = Faker()
# train_data = [
# (fake.city(), fake.city(), np.random.rand())
# for _ in range(num_examples)
# ]
data = pd.read_csv("city_distances.csv")
MAX_DISTANCE = 20_037.5 # global max distance
# MAX_DISTANCE = data["distance"].max() # about 5k
class SignCoordinateDataset(Dataset):
def __init__(self, texts, coordinates):
self.texts = list(texts)
self.coordinates = torch.tensor(coordinates, dtype=torch.float32)
print(f"{MAX_DISTANCE=}")
train_data = [
(row["city_from"], row["city_to"], 1 - row["distance"] / MAX_DISTANCE)
for _, row in data.iterrows()
]
def __len__(self):
return len(self.texts)
np.random.seed(1992)
np.random.shuffle(train_data)
train_examples = examples = [
InputExample(texts=[city_from, city_to], label=dist)
for city_from, city_to, dist in train_data
]
def __getitem__(self, index):
return self.texts[index], self.coordinates[index]
train_examples, val_examples = train_test_split(
examples, test_size=0.2, random_state=21
)
# validation examples can be something like templated sentences
# that maintain the same distance as the cities (same context)
# should probably add training examples like that too if needed
BATCH_SIZE = 16 * 16
num_examples = len(train_examples)
steps_per_epoch = num_examples // BATCH_SIZE
print(f"\nHead of training data (size: {num_examples}):")
print(train_data[:10], "\n")
class EmbeddingCoordinateDataset(Dataset):
def __init__(self, embeddings, coordinates):
self.embeddings = torch.tensor(embeddings, dtype=torch.float32)
self.coordinates = torch.tensor(coordinates, dtype=torch.float32)
# Create DataLoaders for train and validation datasets
train_dataloader = DataLoader(train_examples, shuffle=True, batch_size=BATCH_SIZE)
def __len__(self):
return len(self.embeddings)
print("TRAINING")
# Configure the training arguments
training_args = {
"output_path": "./output",
# "evaluation_steps": steps_per_epoch, # already evaluates at the end of each epoch
"epochs": 10,
"warmup_steps": 500,
"optimizer_params": {"lr": 2e-5},
# "weight_decay": 0, # not sure if this helps but works fine without setting it.
"scheduler": "WarmupLinear",
"save_best_model": True,
"checkpoint_path": "./checkpoints",
"checkpoint_save_steps": steps_per_epoch,
"checkpoint_save_total_limit": 100,
}
print(f"TRAINING ARGUMENTS:\n {training_args}")
def __getitem__(self, index):
return self.embeddings[index], self.coordinates[index]
train_loss = losses.CosineSimilarityLoss(model)
# Create an evaluator for validation dataset
evaluator = EmbeddingSimilarityEvaluator.from_input_examples(
val_examples, write_csv=True
)
class CoordinateRegressor(nn.Module):
def __init__(self, embedding_dim, hidden_dim=256, dropout=0.1):
super().__init__()
self.layers = nn.Sequential(
nn.Linear(embedding_dim, hidden_dim),
nn.ReLU(),
nn.Dropout(dropout),
nn.Linear(hidden_dim, hidden_dim // 2),
nn.ReLU(),
nn.Dropout(dropout),
nn.Linear(hidden_dim // 2, 2),
)
model.fit(
train_objectives=[(train_dataloader, train_loss)],
evaluator=evaluator,
**training_args,
)
def forward(self, embeddings):
return self.layers(embeddings)
def parse_args():
parser = argparse.ArgumentParser(description="Train sign text to lat/lon model.")
parser.add_argument("--data-file", default="training.csv")
parser.add_argument("--output-path", default="output")
parser.add_argument(
"--model-name", default="sentence-transformers/all-MiniLM-L6-v2"
)
parser.add_argument(
"--device",
default="cuda",
help="Device to use for training. Defaults to `cuda`.",
)
parser.add_argument("--seed", type=int, default=1992)
parser.add_argument("--epochs", type=int, default=10)
parser.add_argument("--batch-size", type=int, default=64)
parser.add_argument(
"--num-workers",
type=int,
default=2,
help="DataLoader workers to prefetch batches while the GPU trains.",
)
parser.add_argument(
"--save-every-epochs",
type=int,
default=5,
help=(
"Write the best checkpoint to disk every N epochs (and at the final "
"epoch). Validation still runs every epoch."
),
)
parser.add_argument("--learning-rate", type=float, default=1e-4)
parser.add_argument("--head-learning-rate", type=float, default=5e-2)
parser.add_argument("--weight-decay", type=float, default=0.001)
parser.add_argument("--test-size", type=float, default=0.2)
parser.add_argument("--hidden-dim", type=int, default=256)
parser.add_argument("--dropout", type=float, default=0.1)
parser.add_argument(
"--freeze-encoder",
action="store_true",
help="Train only the coordinate head; keep the sentence encoder fixed.",
)
parser.add_argument(
"--freeze-transformer-layers",
type=int,
default=0,
help="Freeze the first N transformer layers in the sentence encoder.",
)
parser.add_argument(
"--freeze-attention",
action="store_true",
help=(
"Freeze self-attention parameters while leaving other encoder "
"params trainable."
),
)
return parser.parse_args()
def get_device(requested_device):
if requested_device:
return torch.device(requested_device)
if torch.cuda.is_available():
return torch.device("cuda")
return torch.device("cpu")
def set_seed(seed):
random.seed(seed)
np.random.seed(seed)
torch.manual_seed(seed)
if torch.cuda.is_available():
torch.cuda.manual_seed_all(seed)
def normalize_coordinates(coordinates):
mean = coordinates.mean(axis=0)
std = coordinates.std(axis=0)
std[std == 0] = 1.0
return (coordinates - mean) / std, mean, std
def move_features_to_device(features, device, non_blocking=False):
return {
key: value.to(device, non_blocking=non_blocking)
for key, value in features.items()
}
def move_batch_to_device(features, labels, device, pin_memory=False):
non_blocking = pin_memory and device.type == "cuda"
if pin_memory and device.type == "cuda":
features = {key: value.pin_memory() for key, value in features.items()}
labels = labels.pin_memory()
return (
move_features_to_device(features, device, non_blocking=non_blocking),
labels.to(device, non_blocking=non_blocking),
)
def move_tensors_to_device(tensors, device, pin_memory=False):
non_blocking = pin_memory and device.type == "cuda"
if pin_memory and device.type == "cuda":
tensors = [tensor.pin_memory() for tensor in tensors]
return [tensor.to(device, non_blocking=non_blocking) for tensor in tensors]
def make_text_collate(tokenize):
def collate(batch):
texts, labels = zip(*batch)
features = tokenize(list(texts))
return features, torch.stack(labels)
return collate
def embedding_collate(batch):
embeddings, labels = zip(*batch)
return torch.stack(embeddings), torch.stack(labels)
def make_dataloader(dataset, batch_size, shuffle, collate_fn, num_workers, pin_memory):
loader_kwargs = {
"dataset": dataset,
"batch_size": batch_size,
"shuffle": shuffle,
"collate_fn": collate_fn,
"num_workers": num_workers,
"pin_memory": pin_memory,
}
if num_workers > 0:
loader_kwargs["persistent_workers"] = True
return DataLoader(**loader_kwargs)
def copy_module_state(module):
return {key: value.detach().cpu() for key, value in module.state_dict().items()}
def save_best_checkpoint(
output_path, encoder, head, best_states, coord_mean, coord_std, args
):
encoder_state = encoder.state_dict()
head_state = head.state_dict()
encoder.load_state_dict(best_states["encoder"])
head.load_state_dict(best_states["head"])
save_model(output_path, encoder, head, coord_mean, coord_std, args)
encoder.load_state_dict(encoder_state)
head.load_state_dict(head_state)
def should_save_checkpoint(epoch, total_epochs, save_every_epochs, pending_save):
if not pending_save:
return False
if epoch == total_epochs:
return True
return epoch % save_every_epochs == 0
@torch.no_grad()
def encode_texts(encoder, texts, batch_size, device):
encoder.eval()
embeddings = []
for start in range(0, len(texts), batch_size):
batch = texts[start : start + batch_size]
features = encoder.tokenize(batch)
features = {key: value.to(device) for key, value in features.items()}
batch_embeddings = encoder(features)["sentence_embedding"]
embeddings.append(batch_embeddings.cpu().numpy())
return np.vstack(embeddings)
def train_head_epoch(head, dataloader, optimizer, loss_fn, device):
head.train()
total_loss = 0.0
pin_memory = dataloader.pin_memory
for embeddings, labels in dataloader:
embeddings, labels = move_tensors_to_device(
[embeddings, labels], device, pin_memory=pin_memory
)
optimizer.zero_grad(set_to_none=True)
predictions = head(embeddings)
loss = loss_fn(predictions, labels)
loss.backward()
optimizer.step()
total_loss += loss.item() * labels.size(0)
return total_loss / len(dataloader.dataset)
@torch.no_grad()
def evaluate_head(head, dataloader, loss_fn, coord_mean, coord_std, device):
head.eval()
total_loss = 0.0
predictions_all = []
labels_all = []
pin_memory = dataloader.pin_memory
for embeddings, labels in dataloader:
embeddings, labels = move_tensors_to_device(
[embeddings, labels], device, pin_memory=pin_memory
)
predictions = head(embeddings)
loss = loss_fn(predictions, labels)
total_loss += loss.item() * labels.size(0)
predictions_all.append(predictions)
labels_all.append(labels)
pred_coords = torch.cat(predictions_all).float().cpu().numpy() * coord_std + coord_mean
true_coords = torch.cat(labels_all).float().cpu().numpy() * coord_std + coord_mean
errors_km = haversine_km(pred_coords, true_coords)
return total_loss / len(dataloader.dataset), float(np.mean(errors_km))
def train_epoch(
encoder,
head,
dataloader,
optimizer,
loss_fn,
device,
encoder_trainable,
):
if encoder_trainable:
encoder.train()
else:
encoder.eval()
head.train()
total_loss = 0.0
pin_memory = dataloader.pin_memory
for features, labels in dataloader:
features, labels = move_batch_to_device(
features, labels, device, pin_memory=pin_memory
)
optimizer.zero_grad(set_to_none=True)
if encoder_trainable:
embeddings = encoder(features)["sentence_embedding"]
else:
with torch.no_grad():
embeddings = encoder(features)["sentence_embedding"]
predictions = head(embeddings)
loss = loss_fn(predictions, labels)
loss.backward()
optimizer.step()
total_loss += loss.item() * labels.size(0)
return total_loss / len(dataloader.dataset)
@torch.no_grad()
def evaluate(encoder, head, dataloader, loss_fn, coord_mean, coord_std, device):
encoder.eval()
head.eval()
total_loss = 0.0
predictions_all = []
labels_all = []
pin_memory = dataloader.pin_memory
for features, labels in dataloader:
features, labels = move_batch_to_device(
features, labels, device, pin_memory=pin_memory
)
embeddings = encoder(features)["sentence_embedding"]
predictions = head(embeddings)
loss = loss_fn(predictions, labels)
total_loss += loss.item() * labels.size(0)
predictions_all.append(predictions)
labels_all.append(labels)
pred_coords = torch.cat(predictions_all).float().cpu().numpy() * coord_std + coord_mean
true_coords = torch.cat(labels_all).float().cpu().numpy() * coord_std + coord_mean
errors_km = haversine_km(pred_coords, true_coords)
return total_loss / len(dataloader.dataset), float(np.mean(errors_km))
def haversine_km(pred_coords, true_coords):
lat1 = np.radians(pred_coords[:, 0])
lon1 = np.radians(pred_coords[:, 1])
lat2 = np.radians(true_coords[:, 0])
lon2 = np.radians(true_coords[:, 1])
dlat = lat2 - lat1
dlon = lon2 - lon1
a = np.sin(dlat / 2) ** 2 + np.cos(lat1) * np.cos(lat2) * np.sin(dlon / 2) ** 2
return 2 * 6371.0088 * np.arcsin(np.sqrt(a))
def save_model(output_path, encoder, head, coord_mean, coord_std, args):
os.makedirs(output_path, exist_ok=True)
encoder.save(output_path)
torch.save(head.state_dict(), os.path.join(output_path, "coordinate_head.pt"))
metadata = {
"coord_mean": coord_mean.tolist(),
"coord_std": coord_std.tolist(),
"hidden_dim": args.hidden_dim,
"dropout": args.dropout,
"model_name": args.model_name,
}
with open(os.path.join(output_path, "coordinate_config.json"), "w") as f:
json.dump(metadata, f, indent=2)
def save_initial_state(output_path, encoder, head, coord_mean, coord_std, args):
os.makedirs(output_path, exist_ok=True)
encoder.save(os.path.join(output_path, "initial_encoder"))
torch.save(
head.state_dict(),
os.path.join(output_path, "initial_coordinate_head.pt"),
)
metadata = {
"coord_mean": coord_mean.tolist(),
"coord_std": coord_std.tolist(),
"hidden_dim": args.hidden_dim,
"dropout": args.dropout,
"model_name": args.model_name,
}
with open(os.path.join(output_path, "coordinate_config.json"), "w") as f:
json.dump(metadata, f, indent=2)
def freeze_encoder_parts(encoder, args):
if args.freeze_encoder:
for parameter in encoder.parameters():
parameter.requires_grad = False
return
transformer = encoder[0].auto_model
if args.freeze_transformer_layers > 0:
layers = transformer.encoder.layer[: args.freeze_transformer_layers]
for layer in layers:
for parameter in layer.parameters():
parameter.requires_grad = False
if args.freeze_attention:
for name, parameter in transformer.named_parameters():
if ".attention." in name or name.startswith("attention."):
parameter.requires_grad = False
def count_trainable_parameters(module):
trainable = sum(p.numel() for p in module.parameters() if p.requires_grad)
total = sum(p.numel() for p in module.parameters())
return trainable, total
def make_optimizer(encoder, head, args):
parameter_groups = []
encoder_parameters = [p for p in encoder.parameters() if p.requires_grad]
if encoder_parameters:
group = {"params": encoder_parameters, "lr": args.learning_rate}
parameter_groups.append(group)
parameter_groups.append(
{"params": head.parameters(), "lr": args.head_learning_rate}
)
return torch.optim.AdamW(parameter_groups, weight_decay=args.weight_decay)
def main():
args = parse_args()
set_seed(args.seed)
device = get_device(args.device)
pin_memory = device.type == "cuda"
print(f"Using device: {device}")
data = pd.read_csv(args.data_file)
data = data.dropna(subset=["text", "latitude", "longitude"])
texts = data["text"].astype(str).tolist()
coordinates = data[["latitude", "longitude"]].to_numpy(dtype=np.float32)
normalized_coordinates, coord_mean, coord_std = normalize_coordinates(coordinates)
train_indices, val_indices = split_indices(
data, test_size=args.test_size, seed=args.seed
)
train_dataset = SignCoordinateDataset(
[texts[i] for i in train_indices], normalized_coordinates[train_indices]
)
val_dataset = SignCoordinateDataset(
[texts[i] for i in val_indices], normalized_coordinates[val_indices]
)
encoder = SentenceTransformer(args.model_name, device=str(device))
encoder.to(device)
embedding_dim = encoder.get_sentence_embedding_dimension()
head = CoordinateRegressor(
embedding_dim=embedding_dim,
hidden_dim=args.hidden_dim,
dropout=args.dropout,
).to(device)
save_initial_state(args.output_path, encoder, head, coord_mean, coord_std, args)
freeze_encoder_parts(encoder, args)
encoder_trainable, encoder_total = count_trainable_parameters(encoder)
head_trainable, head_total = count_trainable_parameters(head)
print(
f"Trainable encoder params: {encoder_trainable:,}/{encoder_total:,}; "
f"head params: {head_trainable:,}/{head_total:,}"
)
if encoder_trainable == 0:
print("Caching frozen encoder embeddings...")
all_embeddings = encode_texts(encoder, texts, args.batch_size, device)
train_dataset = EmbeddingCoordinateDataset(
all_embeddings[train_indices], normalized_coordinates[train_indices]
)
val_dataset = EmbeddingCoordinateDataset(
all_embeddings[val_indices], normalized_coordinates[val_indices]
)
train_loader = make_dataloader(
train_dataset,
args.batch_size,
shuffle=True,
collate_fn=embedding_collate,
num_workers=args.num_workers,
pin_memory=pin_memory,
)
val_loader = make_dataloader(
val_dataset,
args.batch_size,
shuffle=False,
collate_fn=embedding_collate,
num_workers=args.num_workers,
pin_memory=pin_memory,
)
else:
text_collate = make_text_collate(encoder.tokenize)
train_loader = make_dataloader(
train_dataset,
args.batch_size,
shuffle=True,
collate_fn=text_collate,
num_workers=args.num_workers,
pin_memory=pin_memory,
)
val_loader = make_dataloader(
val_dataset,
args.batch_size,
shuffle=False,
collate_fn=text_collate,
num_workers=args.num_workers,
pin_memory=pin_memory,
)
optimizer = make_optimizer(encoder, head, args)
loss_fn = nn.HuberLoss()
scheduler = torch.optim.lr_scheduler.ReduceLROnPlateau(
optimizer, mode="min", patience=5, factor=0.5, min_lr=1e-7
)
best_val_loss = float("inf")
best_states = None
pending_save = False
print(
f"Training on {len(train_dataset):,} rows; "
f"validating on {len(val_dataset):,} rows; "
f"batch_size={args.batch_size}; num_workers={args.num_workers}"
)
for epoch in range(1, args.epochs + 1):
if encoder_trainable == 0:
train_loss = train_head_epoch(
head, train_loader, optimizer, loss_fn, device
)
val_loss, val_error_km = evaluate_head(
head, val_loader, loss_fn, coord_mean, coord_std, device
)
else:
train_loss = train_epoch(
encoder,
head,
train_loader,
optimizer,
loss_fn,
device,
encoder_trainable > 0,
)
val_loss, val_error_km = evaluate(
encoder,
head,
val_loader,
loss_fn,
coord_mean,
coord_std,
device,
)
scheduler.step(val_loss)
current_lr = optimizer.param_groups[-1]["lr"]
print(
f"epoch={epoch} train_loss={train_loss:.6f} "
f"val_loss={val_loss:.6f} val_error_km={val_error_km:.3f} "
f"lr={current_lr:.2e}"
)
if val_loss < best_val_loss:
best_val_loss = val_loss
best_states = {
"encoder": copy_module_state(encoder),
"head": copy_module_state(head),
}
pending_save = True
if should_save_checkpoint(
epoch, args.epochs, args.save_every_epochs, pending_save
):
save_best_checkpoint(
args.output_path,
encoder,
head,
best_states,
coord_mean,
coord_std,
args,
)
pending_save = False
print(
f"Saved best model to {args.output_path} "
f"(val_loss={best_val_loss:.6f})"
)
if __name__ == "__main__":
main()