first train/eval
This commit is contained in:
parent
6c5d71e2d9
commit
e8419354f4
1
.python-version
Normal file
1
.python-version
Normal file
@ -0,0 +1 @@
|
|||||||
|
3.10
|
||||||
52
Makefile
52
Makefile
@ -1,46 +1,28 @@
|
|||||||
all: install data train eval
|
all: data train eval
|
||||||
|
|
||||||
city_distances.csv: generate_data.py
|
training.csv: prepare_training_data.py training_data_raw.csv
|
||||||
@echo "Generating distance data..."
|
@echo "Preparing bootstrapped sign text bags..."
|
||||||
@bash -c 'time python generate_data.py --country US --workers 8 --chunk-size 4200'
|
@bash -c 'source .venv/bin/activate && python prepare_training_data.py'
|
||||||
@echo "Calculating range of generated data..."
|
|
||||||
@cat city_distances.csv | tail -n +2 | sort -t',' -k3n | head -n1
|
|
||||||
@cat city_distances.csv | tail -n +2 | sort -t',' -k3nr | head -n1
|
|
||||||
|
|
||||||
data: city_distances.csv
|
data: training.csv
|
||||||
|
|
||||||
train: check train.py
|
train: train.py training.csv
|
||||||
@echo "Training embeddings..."
|
@echo "Training coordinate regressor..."
|
||||||
@bash -c 'time python train.py'
|
@bash -c 'source .venv/bin/activate && python train.py'
|
||||||
|
|
||||||
eval: check eval.py
|
eval: eval.py training.csv
|
||||||
@echo "Evaluating trained model..."
|
@echo "Evaluating coordinate regressor..."
|
||||||
@bash -c 'time python eval.py'
|
@bash -c 'source .venv/bin/activate && python eval.py'
|
||||||
|
|
||||||
lint:
|
lint:
|
||||||
@echo "Auto-linting files and performing final style checks..."
|
@echo "Auto-linting files and performing final style checks..."
|
||||||
@isort --profile=black .
|
@bash -c 'source .venv/bin/activate && isort --profile=black *.py'
|
||||||
@black .
|
@bash -c 'source .venv/bin/activate && black *.py'
|
||||||
@flake8 --max-line-length=88 --ignore E203 .
|
@bash -c 'source .venv/bin/activate && flake8 --max-line-length=88 --ignore E203 *.py'
|
||||||
|
|
||||||
check: lint
|
|
||||||
@echo "Checking for unstaged or untracked changes..."
|
|
||||||
@git diff-index --quiet HEAD -- || { echo "Unstaged or untracked changes detected!"; exit 1; }
|
|
||||||
|
|
||||||
clean:
|
clean:
|
||||||
@echo "Removing outputs/ and checkpoints/ directories"
|
@echo "Removing generated outputs"
|
||||||
@rm -rf output/
|
@rm -rf output/
|
||||||
@rm -rf checkpoints/
|
@rm -f training.csv predictions.csv
|
||||||
|
|
||||||
compress: plots/progress_12474_sm.png
|
.PHONY: data train eval lint clean all
|
||||||
|
|
||||||
plots/progress_12474_sm.png: plots/progress_12474.png
|
|
||||||
@convert -resize 33% plots/progress_12474.png progress_sample.png
|
|
||||||
|
|
||||||
install: .requirements_installed
|
|
||||||
|
|
||||||
.requirements_installed: requirements.txt
|
|
||||||
pip install -r requirements.txt
|
|
||||||
@echo "installed requirements" > .requirements_installed
|
|
||||||
|
|
||||||
.PHONY: data train eval lint check clean all
|
|
||||||
|
|||||||
120
README.md
120
README.md
@ -1,89 +1,83 @@
|
|||||||
# CityBert
|
# Sign Coordinate Regressor
|
||||||
|
|
||||||
CityBert is a machine learning project that fine-tunes a neural network model to understand the similarity between cities based on their geodesic distances.
|
This project fine-tunes a sentence-transformer model to predict a latitude and
|
||||||
|
longitude pair from text observed on signs at an intersection.
|
||||||
|
|
||||||
The project generates a dataset of US cities and their pair-wise geodesic distances, which are then used to train the model.
|
The raw dataset is expected at `training_data_raw.csv`. It may include the
|
||||||
|
pandas-exported index column; the preprocessing script only uses:
|
||||||
|
|
||||||
The project can be extended to include other distance metrics or additional data, such as airport codes, city aliases, or time zones.
|
- `intersection`
|
||||||
|
- `text_on_sign_exact`
|
||||||
|
- `latitude`
|
||||||
|
- `longitude`
|
||||||
|
|
||||||
> Note that this model only considers geographic distances and does not take into account other factors such as political borders or transportation infrastructure.
|
## Workflow
|
||||||
These factors contribute to a sense of "distance as it pertains to travel difficulty," which is not directly reflected by this model.
|
|
||||||
|
|
||||||
## But Why?
|
Prepare bootstrapped training rows:
|
||||||
|
|
||||||
### Demonstrate Flexibility
|
```bash
|
||||||
This project showcases how pre-trained language models can be fine-tuned to understand geographic relationships between cities.
|
source .venv/bin/activate
|
||||||
|
python prepare_training_data.py --seed 1992 --bag-size 5 --samples-per-intersection 50
|
||||||
|
```
|
||||||
|
|
||||||
### Contribute to the Community
|
This writes `training.csv` with rows shaped like:
|
||||||
Out-of-the-box neural network models struggle to grasp the spatial relationships, cultural connections, and underlying patterns between cities that humans intuitively understand. Training a specialized model bridges this gap, allowing the network to capture complex relationships and better comprehend geographic data.
|
|
||||||
|
|
||||||
By training a model on pairs of city names and geodesic distances, we enhance its ability to infer city similarity based on names alone. This is beneficial in applications like search engines, recommendation systems, or other natural language processing tasks involving geographic context.
|
```text
|
||||||
|
intersection,sample_id,text,latitude,longitude,unique_sign_count,raw_sign_count
|
||||||
|
```
|
||||||
|
|
||||||
This model will be published for public use, and the code can be adapted for other specialized use-cases.
|
Each row is a deterministic bootstrap sample of sign texts from one
|
||||||
|
intersection, joined into a single text field. This trains on "some signs seen
|
||||||
|
at this coordinate" instead of one sign or every sign at that coordinate.
|
||||||
|
|
||||||
### Explore Tradeoffs
|
Train the model:
|
||||||
Using a neural network to understand geographic relationships provides a more robust and flexible representation compared to traditional latitude/longitude lookups. It can capture complex patterns and relationships in natural language tasks that may be difficult to model using traditional methods.
|
|
||||||
|
|
||||||
Although there's an initial computational overhead in training the model, benefits include handling various location-based queries and better handling of aliases and alternate city names.
|
```bash
|
||||||
|
python train.py
|
||||||
|
```
|
||||||
|
|
||||||
This tradeoff allows more efficient and context-aware processing of location-based information, making it valuable in specific applications. In scenarios requiring high precision or quick solutions, traditional methods may still be more suitable.
|
Evaluate and write predictions:
|
||||||
|
|
||||||
Ultimately, the efficiency of a neural network compared to traditional methods depends on the specific problem and desired trade-offs between accuracy, efficiency, and flexibility.
|
```bash
|
||||||
|
python eval.py
|
||||||
|
```
|
||||||
|
|
||||||
### Applicability to Other Tasks
|
This writes `predictions.csv` and a map-style diagnostic plot:
|
||||||
The approach demonstrated can be extended to other metrics or features beyond geographic distance. By adapting dataset generation and fine-tuning processes, models can be trained to learn various relationships and similarities between different entities.
|
|
||||||
|
|
||||||
|

|
||||||
|
|
||||||
## Overview of Project Files
|
It also writes a coordinate calibration plot:
|
||||||
|
|
||||||
- `generate_data.py`: Generates a dataset of US cities and their pairwise geodesic distances.
|

|
||||||
- `train.py`: Trains the neural network model using the generated dataset.
|
|
||||||
- `eval.py`: Evaluates the trained model by comparing the similarity between city vectors before and after training.
|
|
||||||
- `Makefile`: Automates the execution of various tasks, such as generating data, training, and evaluation.
|
|
||||||
- `README.md`: Provides a description of the project, instructions on how to use it, and expected results.
|
|
||||||
- `requirements.txt`: Defines requirements used for creating the results.
|
|
||||||
|
|
||||||
|
Or run the full pipeline:
|
||||||
|
|
||||||
## How to Use
|
```bash
|
||||||
|
make
|
||||||
|
```
|
||||||
|
|
||||||
1. Install the required dependencies by running `pip install -r requirements.txt`.
|
## Useful Options
|
||||||
2. Run `make data` to generate the dataset of city distances. This will create `city_distances.csv` by default only with US cities (for now).
|
|
||||||
3. Run `make train` to train the neural network model.
|
|
||||||
4. Run `make eval` to evaluate the trained model and generate evaluation plots.
|
|
||||||
|
|
||||||
**You can also just run `make` (i.e., `make all`) which will run through all of those steps.**
|
`prepare_training_data.py`:
|
||||||
|
|
||||||
|
- `--seed`: deterministic bootstrap seed.
|
||||||
|
- `--bag-size`: number of sign texts per sampled bag.
|
||||||
|
- `--samples-per-intersection`: number of bags generated per intersection.
|
||||||
|
|
||||||
## What to Expect
|
`train.py`:
|
||||||
|
|
||||||
After training, the model should be able to understand the similarity between cities based on their geodesic distances.
|
- `--model-name`: sentence-transformers base model.
|
||||||
You can inspect the evaluation plots generated by the `eval.py` script to see the improvement in similarity scores before and after training.
|
- `--epochs`: training epochs.
|
||||||
|
- `--batch-size`: batch size.
|
||||||
|
- `--device`: explicit device such as `cpu`, `cuda`, or `mps`.
|
||||||
|
|
||||||
After even just one epoch, we can see the model has learned to correlate our desired quantities:
|
## Outputs
|
||||||
|
|
||||||

|
- `training.csv`: prepared bootstrapped dataset.
|
||||||
|
- `output/`: saved sentence-transformer encoder, coordinate head, and coordinate
|
||||||
*The above plot is an example showing the relationship between geodesic distance and the similarity between the embedded vectors (1 = more similar), for 10,000 randomly selected pairs of US cities (re-sampled for each image).*
|
normalization metadata.
|
||||||
|
- `predictions.csv`: evaluation rows with predicted coordinates and `error_km`.
|
||||||
*Note the (vertical) "gap" we see in the image, corresponding to the size of the continental United States (~5,000 km)*
|
- `plots/prediction_map.png`: actual vs predicted coordinates with line segments
|
||||||
|
showing the prediction error.
|
||||||
---
|
- `plots/predicted_vs_actual.png`: predicted vs actual latitude and longitude
|
||||||
|
scatter plots.
|
||||||
## Future Improvements
|
|
||||||
|
|
||||||
There are several potential improvements and extensions to the current model:
|
|
||||||
|
|
||||||
1. **Incorporate airport codes**: Train the model to understand the unique codes of airports, which could be useful for search engines and other applications.
|
|
||||||
2. **Add city aliases**: Enhance the dataset with city aliases, so the model can recognize different names for the same city. The `geonamescache` package already includes these.
|
|
||||||
3. **Include time zones**: Train the model to understand time zone differences between cities, which could be helpful for various time-sensitive use cases. The `geonamescache` package already includes this data, but how to calculate the hours between them is an open question.
|
|
||||||
4. **Expand to other distance metrics**: Adapt the model to consider other measures of distance, such as transportation infrastructure or travel time.
|
|
||||||
5. **Train on sentences**: Improve the model's performance on sentences by adding training and validation examples that involve city names in the context of sentences. Can use generative AI to create template sentences (mad-libs style) to create random and diverse training examples.
|
|
||||||
6. **Global city support**: Extend the model to support cities outside the US and cover a broader range of geographic locations.
|
|
||||||
|
|
||||||
|
|
||||||
# Notes
|
|
||||||
- Generating the data took about 10 minutes (for 3269 US cities, of which there were 2826 unique names), in parallel on 8-cores (Intel 9700K), yielding 3,991,725 (combinations of cities) with size 150MB.
|
|
||||||
- For cities with the same name, the one with the larger population is selected (had to make some sort of choice...).
|
|
||||||
- Training on an Nvidia 3090 FE takes about an hour per epoch with an 80/20 test/train split and batch size 16. At batch size 16 times larger, each epoch took about 5-6 minutes.
|
|
||||||
- Evaluation (generating plots) on the above hardware took about 15 minutes for 20 epochs at 10k samples each.
|
|
||||||
- **WARNING**: _It is unclear how the model performs on sentences, as it was trained and evaluated only on word-pairs._ See improvement (5) above.
|
|
||||||
|
|||||||
302
eval.py
302
eval.py
@ -1,79 +1,257 @@
|
|||||||
import glob
|
import argparse
|
||||||
import logging
|
import json
|
||||||
import os
|
import os
|
||||||
|
|
||||||
import numpy as np
|
os.environ.setdefault("MPLCONFIGDIR", "/private/tmp/matplotlib")
|
||||||
import pandas as pd
|
os.environ.setdefault("XDG_CACHE_HOME", "/private/tmp")
|
||||||
from matplotlib import pyplot as plt
|
|
||||||
from sentence_transformers import LoggingHandler, SentenceTransformer
|
|
||||||
|
|
||||||
# from sentence_transformers.evaluation import EmbeddingSimilarityEvaluator
|
import matplotlib # noqa: E402
|
||||||
# from sklearn.model_selection import train_test_split
|
import numpy as np # noqa: E402
|
||||||
|
import pandas as pd # noqa: E402
|
||||||
|
import torch # noqa: E402
|
||||||
|
from sentence_transformers import SentenceTransformer # noqa: E402
|
||||||
|
|
||||||
if not os.path.exists("./plots"):
|
from train import CoordinateRegressor, haversine_km # noqa: E402
|
||||||
os.mkdir("./plots")
|
|
||||||
|
|
||||||
# Configure logging
|
matplotlib.use("Agg")
|
||||||
logging.basicConfig(
|
from matplotlib import pyplot as plt # noqa: E402
|
||||||
format="%(asctime)s - %(message)s",
|
|
||||||
datefmt="%Y-%m-%d %H:%M:%S",
|
|
||||||
level=logging.INFO,
|
def parse_args():
|
||||||
handlers=[LoggingHandler()],
|
parser = argparse.ArgumentParser(description="Evaluate sign coordinate model.")
|
||||||
|
parser.add_argument("--data-file", default="training.csv")
|
||||||
|
parser.add_argument("--model-path", default="output")
|
||||||
|
parser.add_argument("--output-file", default="predictions.csv")
|
||||||
|
parser.add_argument("--plot-file", default="plots/prediction_map.png")
|
||||||
|
parser.add_argument(
|
||||||
|
"--scatter-plot-file",
|
||||||
|
default="plots/predicted_vs_actual.png",
|
||||||
|
help="2x1 scatter plot of predicted vs actual latitude and longitude.",
|
||||||
|
)
|
||||||
|
parser.add_argument("--device", default=None)
|
||||||
|
parser.add_argument("--batch-size", type=int, default=64)
|
||||||
|
return parser.parse_args()
|
||||||
|
|
||||||
|
|
||||||
|
def get_device(requested_device):
|
||||||
|
if requested_device:
|
||||||
|
return torch.device(requested_device)
|
||||||
|
if torch.cuda.is_available():
|
||||||
|
return torch.device("cuda")
|
||||||
|
if torch.backends.mps.is_available():
|
||||||
|
return torch.device("mps")
|
||||||
|
return torch.device("cpu")
|
||||||
|
|
||||||
|
|
||||||
|
def load_head(model_path, filename, encoder, config, device):
|
||||||
|
head = CoordinateRegressor(
|
||||||
|
embedding_dim=encoder.get_sentence_embedding_dimension(),
|
||||||
|
hidden_dim=config["hidden_dim"],
|
||||||
|
dropout=config["dropout"],
|
||||||
|
).to(device)
|
||||||
|
head.load_state_dict(
|
||||||
|
torch.load(
|
||||||
|
os.path.join(model_path, filename),
|
||||||
|
map_location=device,
|
||||||
|
)
|
||||||
|
)
|
||||||
|
head.eval()
|
||||||
|
return head
|
||||||
|
|
||||||
|
|
||||||
|
def load_model(model_path, device):
|
||||||
|
with open(os.path.join(model_path, "coordinate_config.json")) as f:
|
||||||
|
config = json.load(f)
|
||||||
|
|
||||||
|
trained_encoder = SentenceTransformer(model_path, device=str(device))
|
||||||
|
trained_head = load_head(
|
||||||
|
model_path, "coordinate_head.pt", trained_encoder, config, device
|
||||||
|
)
|
||||||
|
trained_encoder.eval()
|
||||||
|
|
||||||
|
initial_encoder_path = os.path.join(model_path, "initial_encoder")
|
||||||
|
if os.path.exists(initial_encoder_path):
|
||||||
|
base_encoder = SentenceTransformer(initial_encoder_path, device=str(device))
|
||||||
|
else:
|
||||||
|
base_encoder = SentenceTransformer(config["model_name"], device=str(device))
|
||||||
|
initial_head_path = os.path.join(model_path, "initial_coordinate_head.pt")
|
||||||
|
initial_head = None
|
||||||
|
if os.path.exists(initial_head_path):
|
||||||
|
initial_head = load_head(
|
||||||
|
model_path, "initial_coordinate_head.pt", base_encoder, config, device
|
||||||
|
)
|
||||||
|
base_encoder.eval()
|
||||||
|
|
||||||
|
return (
|
||||||
|
base_encoder,
|
||||||
|
initial_head,
|
||||||
|
trained_encoder,
|
||||||
|
trained_head,
|
||||||
|
np.array(config["coord_mean"]),
|
||||||
|
np.array(config["coord_std"]),
|
||||||
)
|
)
|
||||||
|
|
||||||
|
|
||||||
def evaluate(model, city_from, city_to):
|
@torch.no_grad()
|
||||||
city_to = model.encode(city_to)
|
def predict(encoder, head, texts, coord_mean, coord_std, device, batch_size):
|
||||||
city_from = model.encode(city_from)
|
predictions = []
|
||||||
return np.dot(city_to, city_from) / (
|
for start in range(0, len(texts), batch_size):
|
||||||
np.linalg.norm(city_to) * np.linalg.norm(city_from)
|
batch = texts[start : start + batch_size]
|
||||||
|
features = encoder.tokenize(batch)
|
||||||
|
features = {key: value.to(device) for key, value in features.items()}
|
||||||
|
embeddings = encoder(features)["sentence_embedding"]
|
||||||
|
batch_predictions = head(embeddings).cpu().numpy()
|
||||||
|
predictions.append(batch_predictions)
|
||||||
|
|
||||||
|
normalized = np.vstack(predictions)
|
||||||
|
return normalized * coord_std + coord_mean
|
||||||
|
|
||||||
|
|
||||||
|
def make_prediction_plot(results, plot_file):
|
||||||
|
os.makedirs(os.path.dirname(plot_file), exist_ok=True)
|
||||||
|
fig, ax = plt.subplots(figsize=(8, 8))
|
||||||
|
|
||||||
|
for _, row in results.iterrows():
|
||||||
|
ax.plot(
|
||||||
|
[row["longitude"], row["predicted_longitude"]],
|
||||||
|
[row["latitude"], row["predicted_latitude"]],
|
||||||
|
color="0.75",
|
||||||
|
linewidth=0.4,
|
||||||
|
alpha=0.35,
|
||||||
)
|
)
|
||||||
|
|
||||||
|
|
||||||
def calculate_similarity(data, base_model, trained_model):
|
|
||||||
# MAX_DISTANCE = 20_037.5
|
|
||||||
# data["distance"] /= MAX_DISTANCE
|
|
||||||
data["similarity_before"] = data.apply(
|
|
||||||
lambda x: evaluate(base_model, x["city_from"], x["city_to"]), axis=1
|
|
||||||
)
|
|
||||||
|
|
||||||
data["similarity_after"] = data.apply(
|
|
||||||
lambda x: evaluate(trained_model, x["city_from"], x["city_to"]), axis=1
|
|
||||||
)
|
|
||||||
return data
|
|
||||||
|
|
||||||
|
|
||||||
def make_plot(data):
|
|
||||||
fig, ax = plt.subplots()
|
|
||||||
|
|
||||||
ax.scatter(
|
ax.scatter(
|
||||||
data["distance"],
|
results["longitude"],
|
||||||
data["similarity_before"],
|
results["latitude"],
|
||||||
color="r",
|
s=18,
|
||||||
alpha=0.1,
|
color="#1f77b4",
|
||||||
label="before",
|
alpha=0.75,
|
||||||
|
label="actual",
|
||||||
)
|
)
|
||||||
ax.scatter(
|
ax.scatter(
|
||||||
data["distance"], data["similarity_after"], color="b", alpha=0.1, label="after"
|
results["predicted_longitude"],
|
||||||
|
results["predicted_latitude"],
|
||||||
|
s=12,
|
||||||
|
color="#d62728",
|
||||||
|
alpha=0.45,
|
||||||
|
label="predicted",
|
||||||
)
|
)
|
||||||
ax.set_xlabel("distance between cities (km)")
|
ax.set_xlabel("longitude")
|
||||||
ax.set_ylabel("similarity between vectors\n(cosine)")
|
ax.set_ylabel("latitude")
|
||||||
ax.legend(loc="center right")
|
ax.set_title("Sign Text Coordinate Predictions")
|
||||||
return fig
|
ax.legend()
|
||||||
|
ax.set_aspect("equal", adjustable="box")
|
||||||
|
fig.tight_layout()
|
||||||
|
fig.savefig(plot_file, dpi=200)
|
||||||
|
plt.close(fig)
|
||||||
|
|
||||||
|
|
||||||
|
def make_predicted_vs_actual_plot(results, plot_file):
|
||||||
|
os.makedirs(os.path.dirname(plot_file), exist_ok=True)
|
||||||
|
fig, axes = plt.subplots(nrows=2, ncols=1, figsize=(7, 9))
|
||||||
|
|
||||||
|
plot_specs = [
|
||||||
|
("latitude", "initial_predicted_latitude", "predicted_latitude", "Latitude"),
|
||||||
|
(
|
||||||
|
"longitude",
|
||||||
|
"initial_predicted_longitude",
|
||||||
|
"predicted_longitude",
|
||||||
|
"Longitude",
|
||||||
|
),
|
||||||
|
]
|
||||||
|
for ax, (actual_col, initial_col, predicted_col, title) in zip(axes, plot_specs):
|
||||||
|
actual = results[actual_col]
|
||||||
|
initial = results[initial_col] if initial_col in results else None
|
||||||
|
predicted = results[predicted_col]
|
||||||
|
values = [actual, predicted]
|
||||||
|
if initial is not None:
|
||||||
|
values.append(initial)
|
||||||
|
lower = min(series.min() for series in values)
|
||||||
|
upper = max(series.max() for series in values)
|
||||||
|
|
||||||
|
if initial is not None:
|
||||||
|
ax.scatter(
|
||||||
|
actual,
|
||||||
|
initial,
|
||||||
|
s=14,
|
||||||
|
alpha=0.45,
|
||||||
|
color="#d62728",
|
||||||
|
label="before training",
|
||||||
|
)
|
||||||
|
ax.scatter(
|
||||||
|
actual,
|
||||||
|
predicted,
|
||||||
|
s=14,
|
||||||
|
alpha=0.55,
|
||||||
|
color="#1f77b4",
|
||||||
|
label="after training",
|
||||||
|
)
|
||||||
|
ax.plot([lower, upper], [lower, upper], color="0.25", linewidth=1)
|
||||||
|
ax.set_xlabel(f"actual {actual_col}")
|
||||||
|
ax.set_ylabel(f"predicted {actual_col}")
|
||||||
|
ax.set_title(title)
|
||||||
|
ax.set_xlim(lower, upper)
|
||||||
|
ax.set_ylim(lower, upper)
|
||||||
|
ax.legend()
|
||||||
|
|
||||||
|
fig.tight_layout()
|
||||||
|
fig.savefig(plot_file, dpi=200)
|
||||||
|
plt.close(fig)
|
||||||
|
|
||||||
|
|
||||||
|
def main():
|
||||||
|
args = parse_args()
|
||||||
|
device = get_device(args.device)
|
||||||
|
(
|
||||||
|
base_encoder,
|
||||||
|
initial_head,
|
||||||
|
trained_encoder,
|
||||||
|
trained_head,
|
||||||
|
coord_mean,
|
||||||
|
coord_std,
|
||||||
|
) = load_model(args.model_path, device)
|
||||||
|
|
||||||
|
data = pd.read_csv(args.data_file).dropna(subset=["text", "latitude", "longitude"])
|
||||||
|
texts = data["text"].astype(str).tolist()
|
||||||
|
predicted = predict(
|
||||||
|
trained_encoder,
|
||||||
|
trained_head,
|
||||||
|
texts,
|
||||||
|
coord_mean,
|
||||||
|
coord_std,
|
||||||
|
device,
|
||||||
|
args.batch_size,
|
||||||
|
)
|
||||||
|
true_coords = data[["latitude", "longitude"]].to_numpy(dtype=np.float32)
|
||||||
|
errors = haversine_km(predicted, true_coords)
|
||||||
|
|
||||||
|
results = data.copy()
|
||||||
|
results["predicted_latitude"] = predicted[:, 0]
|
||||||
|
results["predicted_longitude"] = predicted[:, 1]
|
||||||
|
if initial_head is not None:
|
||||||
|
initial_predicted = predict(
|
||||||
|
base_encoder,
|
||||||
|
initial_head,
|
||||||
|
texts,
|
||||||
|
coord_mean,
|
||||||
|
coord_std,
|
||||||
|
device,
|
||||||
|
args.batch_size,
|
||||||
|
)
|
||||||
|
results["initial_predicted_latitude"] = initial_predicted[:, 0]
|
||||||
|
results["initial_predicted_longitude"] = initial_predicted[:, 1]
|
||||||
|
results["error_km"] = errors
|
||||||
|
results.to_csv(args.output_file, index=False)
|
||||||
|
make_prediction_plot(results, args.plot_file)
|
||||||
|
make_predicted_vs_actual_plot(results, args.scatter_plot_file)
|
||||||
|
|
||||||
|
print(f"Wrote predictions to {args.output_file}")
|
||||||
|
print(f"Wrote plot to {args.plot_file}")
|
||||||
|
print(f"Wrote scatter plot to {args.scatter_plot_file}")
|
||||||
|
print(f"mean_error_km={errors.mean():.3f}")
|
||||||
|
print(f"median_error_km={np.median(errors):.3f}")
|
||||||
|
print(f"p90_error_km={np.percentile(errors, 90):.3f}")
|
||||||
|
|
||||||
|
|
||||||
if __name__ == "__main__":
|
if __name__ == "__main__":
|
||||||
model_name = "sentence-transformers/all-MiniLM-L6-v2"
|
main()
|
||||||
base_model = SentenceTransformer(model_name, device="cuda")
|
|
||||||
|
|
||||||
data = pd.read_csv("city_distances.csv")
|
|
||||||
# data_sample = data.sample(1_000)
|
|
||||||
checkpoint_dir = "checkpoints" # no slash
|
|
||||||
for checkpoint in sorted(glob.glob(f"{checkpoint_dir}/*")):
|
|
||||||
print(f"Evaluating {checkpoint}")
|
|
||||||
data_sample = data.sample(1_000)
|
|
||||||
trained_model = SentenceTransformer(checkpoint, device="cuda")
|
|
||||||
|
|
||||||
data_sample = calculate_similarity(data_sample, base_model, trained_model)
|
|
||||||
fig = make_plot(data_sample)
|
|
||||||
fig.savefig(f"./plots/progress_{checkpoint.split('/')[1]}.png", dpi=600)
|
|
||||||
|
|||||||
115
prepare_training_data.py
Normal file
115
prepare_training_data.py
Normal file
@ -0,0 +1,115 @@
|
|||||||
|
import argparse
|
||||||
|
|
||||||
|
import pandas as pd
|
||||||
|
|
||||||
|
INPUT_COLUMNS = ["intersection", "text_on_sign_exact", "latitude", "longitude"]
|
||||||
|
|
||||||
|
|
||||||
|
def parse_args():
|
||||||
|
parser = argparse.ArgumentParser(
|
||||||
|
description="Create bootstrapped sign-text bags for coordinate training."
|
||||||
|
)
|
||||||
|
parser.add_argument(
|
||||||
|
"-i",
|
||||||
|
"--input-file",
|
||||||
|
default="training_data_raw.csv",
|
||||||
|
help="Raw pandas-exported CSV.",
|
||||||
|
)
|
||||||
|
parser.add_argument(
|
||||||
|
"-o",
|
||||||
|
"--output-file",
|
||||||
|
default="training.csv",
|
||||||
|
help="Prepared training CSV.",
|
||||||
|
)
|
||||||
|
parser.add_argument(
|
||||||
|
"--seed",
|
||||||
|
type=int,
|
||||||
|
default=1992,
|
||||||
|
help="Random seed for deterministic bootstrapping.",
|
||||||
|
)
|
||||||
|
parser.add_argument(
|
||||||
|
"--bag-size",
|
||||||
|
type=int,
|
||||||
|
default=5,
|
||||||
|
help="Number of sign texts to sample for each training row.",
|
||||||
|
)
|
||||||
|
parser.add_argument(
|
||||||
|
"--samples-per-intersection",
|
||||||
|
type=int,
|
||||||
|
default=50,
|
||||||
|
help="Bootstrap samples to create for each intersection.",
|
||||||
|
)
|
||||||
|
parser.add_argument(
|
||||||
|
"--separator",
|
||||||
|
default=" | ",
|
||||||
|
help="Separator used when joining sampled sign texts.",
|
||||||
|
)
|
||||||
|
return parser.parse_args()
|
||||||
|
|
||||||
|
|
||||||
|
def load_raw_data(path):
|
||||||
|
data = pd.read_csv(path)
|
||||||
|
missing = sorted(set(INPUT_COLUMNS) - set(data.columns))
|
||||||
|
if missing:
|
||||||
|
raise ValueError(f"Missing required columns: {missing}")
|
||||||
|
|
||||||
|
data = data[INPUT_COLUMNS].copy()
|
||||||
|
data = data.dropna(subset=INPUT_COLUMNS)
|
||||||
|
data["text_on_sign_exact"] = data["text_on_sign_exact"].astype(str).str.strip()
|
||||||
|
data = data[data["text_on_sign_exact"] != ""]
|
||||||
|
return data
|
||||||
|
|
||||||
|
|
||||||
|
def make_bootstrap_data(data, bag_size, samples_per_intersection, seed, separator):
|
||||||
|
if bag_size < 1:
|
||||||
|
raise ValueError("--bag-size must be at least 1")
|
||||||
|
if samples_per_intersection < 1:
|
||||||
|
raise ValueError("--samples-per-intersection must be at least 1")
|
||||||
|
|
||||||
|
rows = []
|
||||||
|
grouped = data.groupby("intersection", sort=True)
|
||||||
|
for intersection, group in grouped:
|
||||||
|
texts = group["text_on_sign_exact"].tolist()
|
||||||
|
latitude = group["latitude"].mean()
|
||||||
|
longitude = group["longitude"].mean()
|
||||||
|
|
||||||
|
for sample_id in range(samples_per_intersection):
|
||||||
|
sampled = group.sample(
|
||||||
|
n=bag_size,
|
||||||
|
replace=True,
|
||||||
|
random_state=seed + len(rows),
|
||||||
|
)
|
||||||
|
rows.append(
|
||||||
|
{
|
||||||
|
"intersection": intersection,
|
||||||
|
"sample_id": sample_id,
|
||||||
|
"text": separator.join(sampled["text_on_sign_exact"].tolist()),
|
||||||
|
"latitude": latitude,
|
||||||
|
"longitude": longitude,
|
||||||
|
"unique_sign_count": len(set(texts)),
|
||||||
|
"raw_sign_count": len(texts),
|
||||||
|
}
|
||||||
|
)
|
||||||
|
|
||||||
|
return pd.DataFrame(rows)
|
||||||
|
|
||||||
|
|
||||||
|
def main():
|
||||||
|
args = parse_args()
|
||||||
|
raw_data = load_raw_data(args.input_file)
|
||||||
|
training_data = make_bootstrap_data(
|
||||||
|
raw_data,
|
||||||
|
bag_size=args.bag_size,
|
||||||
|
samples_per_intersection=args.samples_per_intersection,
|
||||||
|
seed=args.seed,
|
||||||
|
separator=args.separator,
|
||||||
|
)
|
||||||
|
training_data.to_csv(args.output_file, index=False)
|
||||||
|
print(
|
||||||
|
f"Wrote {len(training_data):,} rows from "
|
||||||
|
f"{raw_data['intersection'].nunique():,} intersections to {args.output_file}"
|
||||||
|
)
|
||||||
|
|
||||||
|
|
||||||
|
if __name__ == "__main__":
|
||||||
|
main()
|
||||||
@ -10,3 +10,5 @@ sentence-transformers==2.2.2
|
|||||||
torch==2.0.0
|
torch==2.0.0
|
||||||
torchaudio==2.0.1
|
torchaudio==2.0.1
|
||||||
torchvision==0.15.1
|
torchvision==0.15.1
|
||||||
|
transformers==4.30.2
|
||||||
|
huggingface-hub==0.14.1
|
||||||
|
|||||||
317
train.py
317
train.py
@ -1,18 +1,17 @@
|
|||||||
|
import argparse
|
||||||
|
import json
|
||||||
import logging
|
import logging
|
||||||
|
import os
|
||||||
|
import random
|
||||||
|
|
||||||
import numpy as np
|
import numpy as np
|
||||||
import pandas as pd
|
import pandas as pd
|
||||||
from sentence_transformers import (
|
import torch
|
||||||
InputExample,
|
from sentence_transformers import LoggingHandler, SentenceTransformer
|
||||||
LoggingHandler,
|
from sklearn.model_selection import GroupShuffleSplit, train_test_split
|
||||||
SentenceTransformer,
|
from torch import nn
|
||||||
losses,
|
from torch.utils.data import DataLoader, Dataset
|
||||||
)
|
|
||||||
from sentence_transformers.evaluation import EmbeddingSimilarityEvaluator
|
|
||||||
from sklearn.model_selection import train_test_split
|
|
||||||
from torch.utils.data import DataLoader
|
|
||||||
|
|
||||||
# Configure logging
|
|
||||||
logging.basicConfig(
|
logging.basicConfig(
|
||||||
format="%(asctime)s - %(message)s",
|
format="%(asctime)s - %(message)s",
|
||||||
datefmt="%Y-%m-%d %H:%M:%S",
|
datefmt="%Y-%m-%d %H:%M:%S",
|
||||||
@ -20,77 +19,255 @@ logging.basicConfig(
|
|||||||
handlers=[LoggingHandler()],
|
handlers=[LoggingHandler()],
|
||||||
)
|
)
|
||||||
|
|
||||||
model_name = "sentence-transformers/all-MiniLM-L6-v2"
|
|
||||||
model = SentenceTransformer(model_name, device="cuda")
|
|
||||||
# num_examples = 10_000
|
|
||||||
|
|
||||||
# Perform train-test split
|
class SignCoordinateDataset(Dataset):
|
||||||
# Example fake data with right types (for testing)
|
def __init__(self, texts, coordinates):
|
||||||
# import faker
|
self.texts = list(texts)
|
||||||
# fake = Faker()
|
self.coordinates = torch.tensor(coordinates, dtype=torch.float32)
|
||||||
# train_data = [
|
|
||||||
# (fake.city(), fake.city(), np.random.rand())
|
|
||||||
# for _ in range(num_examples)
|
|
||||||
# ]
|
|
||||||
data = pd.read_csv("city_distances.csv")
|
|
||||||
MAX_DISTANCE = 20_037.5 # global max distance
|
|
||||||
# MAX_DISTANCE = data["distance"].max() # about 5k
|
|
||||||
|
|
||||||
print(f"{MAX_DISTANCE=}")
|
def __len__(self):
|
||||||
train_data = [
|
return len(self.texts)
|
||||||
(row["city_from"], row["city_to"], 1 - row["distance"] / MAX_DISTANCE)
|
|
||||||
for _, row in data.iterrows()
|
|
||||||
]
|
|
||||||
|
|
||||||
np.random.seed(1992)
|
def __getitem__(self, index):
|
||||||
np.random.shuffle(train_data)
|
return self.texts[index], self.coordinates[index]
|
||||||
train_examples = examples = [
|
|
||||||
InputExample(texts=[city_from, city_to], label=dist)
|
|
||||||
for city_from, city_to, dist in train_data
|
|
||||||
]
|
|
||||||
|
|
||||||
train_examples, val_examples = train_test_split(
|
|
||||||
examples, test_size=0.2, random_state=21
|
class CoordinateRegressor(nn.Module):
|
||||||
|
def __init__(self, embedding_dim, hidden_dim=256, dropout=0.1):
|
||||||
|
super().__init__()
|
||||||
|
self.layers = nn.Sequential(
|
||||||
|
nn.Linear(embedding_dim, hidden_dim),
|
||||||
|
nn.ReLU(),
|
||||||
|
nn.Dropout(dropout),
|
||||||
|
nn.Linear(hidden_dim, hidden_dim // 2),
|
||||||
|
nn.ReLU(),
|
||||||
|
nn.Dropout(dropout),
|
||||||
|
nn.Linear(hidden_dim // 2, 2),
|
||||||
)
|
)
|
||||||
# validation examples can be something like templated sentences
|
|
||||||
# that maintain the same distance as the cities (same context)
|
|
||||||
# should probably add training examples like that too if needed
|
|
||||||
BATCH_SIZE = 16 * 16
|
|
||||||
num_examples = len(train_examples)
|
|
||||||
steps_per_epoch = num_examples // BATCH_SIZE
|
|
||||||
|
|
||||||
print(f"\nHead of training data (size: {num_examples}):")
|
def forward(self, embeddings):
|
||||||
print(train_data[:10], "\n")
|
return self.layers(embeddings)
|
||||||
|
|
||||||
# Create DataLoaders for train and validation datasets
|
|
||||||
train_dataloader = DataLoader(train_examples, shuffle=True, batch_size=BATCH_SIZE)
|
|
||||||
|
|
||||||
print("TRAINING")
|
def parse_args():
|
||||||
# Configure the training arguments
|
parser = argparse.ArgumentParser(description="Train sign text to lat/lon model.")
|
||||||
training_args = {
|
parser.add_argument("--data-file", default="training.csv")
|
||||||
"output_path": "./output",
|
parser.add_argument("--output-path", default="output")
|
||||||
# "evaluation_steps": steps_per_epoch, # already evaluates at the end of each epoch
|
parser.add_argument(
|
||||||
"epochs": 10,
|
"--model-name", default="sentence-transformers/all-MiniLM-L6-v2"
|
||||||
"warmup_steps": 500,
|
)
|
||||||
"optimizer_params": {"lr": 2e-5},
|
parser.add_argument("--device", default=None)
|
||||||
# "weight_decay": 0, # not sure if this helps but works fine without setting it.
|
parser.add_argument("--seed", type=int, default=1992)
|
||||||
"scheduler": "WarmupLinear",
|
parser.add_argument("--epochs", type=int, default=10)
|
||||||
"save_best_model": True,
|
parser.add_argument("--batch-size", type=int, default=32)
|
||||||
"checkpoint_path": "./checkpoints",
|
parser.add_argument("--learning-rate", type=float, default=2e-5)
|
||||||
"checkpoint_save_steps": steps_per_epoch,
|
parser.add_argument("--head-learning-rate", type=float, default=1e-3)
|
||||||
"checkpoint_save_total_limit": 100,
|
parser.add_argument("--weight-decay", type=float, default=0.01)
|
||||||
|
parser.add_argument("--test-size", type=float, default=0.2)
|
||||||
|
parser.add_argument("--hidden-dim", type=int, default=256)
|
||||||
|
parser.add_argument("--dropout", type=float, default=0.1)
|
||||||
|
return parser.parse_args()
|
||||||
|
|
||||||
|
|
||||||
|
def get_device(requested_device):
|
||||||
|
if requested_device:
|
||||||
|
return torch.device(requested_device)
|
||||||
|
if torch.cuda.is_available():
|
||||||
|
return torch.device("cuda")
|
||||||
|
if torch.backends.mps.is_available():
|
||||||
|
return torch.device("mps")
|
||||||
|
return torch.device("cpu")
|
||||||
|
|
||||||
|
|
||||||
|
def set_seed(seed):
|
||||||
|
random.seed(seed)
|
||||||
|
np.random.seed(seed)
|
||||||
|
torch.manual_seed(seed)
|
||||||
|
if torch.cuda.is_available():
|
||||||
|
torch.cuda.manual_seed_all(seed)
|
||||||
|
|
||||||
|
|
||||||
|
def normalize_coordinates(coordinates):
|
||||||
|
mean = coordinates.mean(axis=0)
|
||||||
|
std = coordinates.std(axis=0)
|
||||||
|
std[std == 0] = 1.0
|
||||||
|
return (coordinates - mean) / std, mean, std
|
||||||
|
|
||||||
|
|
||||||
|
def collate_fn(model, device):
|
||||||
|
def collate(batch):
|
||||||
|
texts, labels = zip(*batch)
|
||||||
|
features = model.tokenize(list(texts))
|
||||||
|
features = {key: value.to(device) for key, value in features.items()}
|
||||||
|
return features, torch.stack(labels).to(device)
|
||||||
|
|
||||||
|
return collate
|
||||||
|
|
||||||
|
|
||||||
|
def train_epoch(encoder, head, dataloader, optimizer, loss_fn):
|
||||||
|
encoder.train()
|
||||||
|
head.train()
|
||||||
|
total_loss = 0.0
|
||||||
|
|
||||||
|
for features, labels in dataloader:
|
||||||
|
optimizer.zero_grad()
|
||||||
|
embeddings = encoder(features)["sentence_embedding"]
|
||||||
|
predictions = head(embeddings)
|
||||||
|
loss = loss_fn(predictions, labels)
|
||||||
|
loss.backward()
|
||||||
|
optimizer.step()
|
||||||
|
total_loss += loss.item() * labels.size(0)
|
||||||
|
|
||||||
|
return total_loss / len(dataloader.dataset)
|
||||||
|
|
||||||
|
|
||||||
|
@torch.no_grad()
|
||||||
|
def evaluate(encoder, head, dataloader, loss_fn, coord_mean, coord_std):
|
||||||
|
encoder.eval()
|
||||||
|
head.eval()
|
||||||
|
total_loss = 0.0
|
||||||
|
errors_km = []
|
||||||
|
|
||||||
|
for features, labels in dataloader:
|
||||||
|
embeddings = encoder(features)["sentence_embedding"]
|
||||||
|
predictions = head(embeddings)
|
||||||
|
loss = loss_fn(predictions, labels)
|
||||||
|
total_loss += loss.item() * labels.size(0)
|
||||||
|
|
||||||
|
pred_coords = predictions.cpu().numpy() * coord_std + coord_mean
|
||||||
|
true_coords = labels.cpu().numpy() * coord_std + coord_mean
|
||||||
|
errors_km.extend(haversine_km(pred_coords, true_coords))
|
||||||
|
|
||||||
|
return total_loss / len(dataloader.dataset), float(np.mean(errors_km))
|
||||||
|
|
||||||
|
|
||||||
|
def haversine_km(pred_coords, true_coords):
|
||||||
|
lat1 = np.radians(pred_coords[:, 0])
|
||||||
|
lon1 = np.radians(pred_coords[:, 1])
|
||||||
|
lat2 = np.radians(true_coords[:, 0])
|
||||||
|
lon2 = np.radians(true_coords[:, 1])
|
||||||
|
dlat = lat2 - lat1
|
||||||
|
dlon = lon2 - lon1
|
||||||
|
a = np.sin(dlat / 2) ** 2 + np.cos(lat1) * np.cos(lat2) * np.sin(dlon / 2) ** 2
|
||||||
|
return 2 * 6371.0088 * np.arcsin(np.sqrt(a))
|
||||||
|
|
||||||
|
|
||||||
|
def save_model(output_path, encoder, head, coord_mean, coord_std, args):
|
||||||
|
os.makedirs(output_path, exist_ok=True)
|
||||||
|
encoder.save(output_path)
|
||||||
|
torch.save(head.state_dict(), os.path.join(output_path, "coordinate_head.pt"))
|
||||||
|
metadata = {
|
||||||
|
"coord_mean": coord_mean.tolist(),
|
||||||
|
"coord_std": coord_std.tolist(),
|
||||||
|
"hidden_dim": args.hidden_dim,
|
||||||
|
"dropout": args.dropout,
|
||||||
|
"model_name": args.model_name,
|
||||||
}
|
}
|
||||||
print(f"TRAINING ARGUMENTS:\n {training_args}")
|
with open(os.path.join(output_path, "coordinate_config.json"), "w") as f:
|
||||||
|
json.dump(metadata, f, indent=2)
|
||||||
|
|
||||||
train_loss = losses.CosineSimilarityLoss(model)
|
|
||||||
|
|
||||||
# Create an evaluator for validation dataset
|
def save_initial_state(output_path, encoder, head, coord_mean, coord_std, args):
|
||||||
evaluator = EmbeddingSimilarityEvaluator.from_input_examples(
|
os.makedirs(output_path, exist_ok=True)
|
||||||
val_examples, write_csv=True
|
encoder.save(os.path.join(output_path, "initial_encoder"))
|
||||||
|
torch.save(
|
||||||
|
head.state_dict(),
|
||||||
|
os.path.join(output_path, "initial_coordinate_head.pt"),
|
||||||
|
)
|
||||||
|
metadata = {
|
||||||
|
"coord_mean": coord_mean.tolist(),
|
||||||
|
"coord_std": coord_std.tolist(),
|
||||||
|
"hidden_dim": args.hidden_dim,
|
||||||
|
"dropout": args.dropout,
|
||||||
|
"model_name": args.model_name,
|
||||||
|
}
|
||||||
|
with open(os.path.join(output_path, "coordinate_config.json"), "w") as f:
|
||||||
|
json.dump(metadata, f, indent=2)
|
||||||
|
|
||||||
|
|
||||||
|
def main():
|
||||||
|
args = parse_args()
|
||||||
|
set_seed(args.seed)
|
||||||
|
device = get_device(args.device)
|
||||||
|
print(f"Using device: {device}")
|
||||||
|
|
||||||
|
data = pd.read_csv(args.data_file)
|
||||||
|
data = data.dropna(subset=["text", "latitude", "longitude"])
|
||||||
|
texts = data["text"].astype(str).tolist()
|
||||||
|
coordinates = data[["latitude", "longitude"]].to_numpy(dtype=np.float32)
|
||||||
|
normalized_coordinates, coord_mean, coord_std = normalize_coordinates(coordinates)
|
||||||
|
|
||||||
|
indices = np.arange(len(data))
|
||||||
|
if "intersection" in data.columns:
|
||||||
|
splitter = GroupShuffleSplit(
|
||||||
|
n_splits=1, test_size=args.test_size, random_state=args.seed
|
||||||
|
)
|
||||||
|
train_indices, val_indices = next(
|
||||||
|
splitter.split(indices, groups=data["intersection"])
|
||||||
|
)
|
||||||
|
else:
|
||||||
|
train_indices, val_indices = train_test_split(
|
||||||
|
indices, test_size=args.test_size, random_state=args.seed
|
||||||
)
|
)
|
||||||
|
|
||||||
model.fit(
|
train_dataset = SignCoordinateDataset(
|
||||||
train_objectives=[(train_dataloader, train_loss)],
|
[texts[i] for i in train_indices], normalized_coordinates[train_indices]
|
||||||
evaluator=evaluator,
|
|
||||||
**training_args,
|
|
||||||
)
|
)
|
||||||
|
val_dataset = SignCoordinateDataset(
|
||||||
|
[texts[i] for i in val_indices], normalized_coordinates[val_indices]
|
||||||
|
)
|
||||||
|
|
||||||
|
encoder = SentenceTransformer(args.model_name, device=str(device))
|
||||||
|
embedding_dim = encoder.get_sentence_embedding_dimension()
|
||||||
|
head = CoordinateRegressor(
|
||||||
|
embedding_dim=embedding_dim,
|
||||||
|
hidden_dim=args.hidden_dim,
|
||||||
|
dropout=args.dropout,
|
||||||
|
).to(device)
|
||||||
|
save_initial_state(args.output_path, encoder, head, coord_mean, coord_std, args)
|
||||||
|
|
||||||
|
train_loader = DataLoader(
|
||||||
|
train_dataset,
|
||||||
|
batch_size=args.batch_size,
|
||||||
|
shuffle=True,
|
||||||
|
collate_fn=collate_fn(encoder, device),
|
||||||
|
)
|
||||||
|
val_loader = DataLoader(
|
||||||
|
val_dataset,
|
||||||
|
batch_size=args.batch_size,
|
||||||
|
shuffle=False,
|
||||||
|
collate_fn=collate_fn(encoder, device),
|
||||||
|
)
|
||||||
|
|
||||||
|
optimizer = torch.optim.AdamW(
|
||||||
|
[
|
||||||
|
{"params": encoder.parameters(), "lr": args.learning_rate},
|
||||||
|
{"params": head.parameters(), "lr": args.head_learning_rate},
|
||||||
|
],
|
||||||
|
weight_decay=args.weight_decay,
|
||||||
|
)
|
||||||
|
loss_fn = nn.MSELoss()
|
||||||
|
best_val_loss = float("inf")
|
||||||
|
|
||||||
|
print(
|
||||||
|
f"Training on {len(train_dataset):,} rows; "
|
||||||
|
f"validating on {len(val_dataset):,} rows"
|
||||||
|
)
|
||||||
|
for epoch in range(1, args.epochs + 1):
|
||||||
|
train_loss = train_epoch(encoder, head, train_loader, optimizer, loss_fn)
|
||||||
|
val_loss, val_error_km = evaluate(
|
||||||
|
encoder, head, val_loader, loss_fn, coord_mean, coord_std
|
||||||
|
)
|
||||||
|
print(
|
||||||
|
f"epoch={epoch} train_loss={train_loss:.6f} "
|
||||||
|
f"val_loss={val_loss:.6f} val_error_km={val_error_km:.3f}"
|
||||||
|
)
|
||||||
|
|
||||||
|
if val_loss < best_val_loss:
|
||||||
|
best_val_loss = val_loss
|
||||||
|
save_model(args.output_path, encoder, head, coord_mean, coord_std, args)
|
||||||
|
print(f"Saved best model to {args.output_path}")
|
||||||
|
|
||||||
|
|
||||||
|
if __name__ == "__main__":
|
||||||
|
main()
|
||||||
|
|||||||
Loading…
Reference in New Issue
Block a user