prettier readme + full eval + images
This commit is contained in:
parent
313c850631
commit
336531421d
19
Makefile
19
Makefile
@ -1,7 +1,11 @@
|
||||
city_distances.csv: check generate_data.py
|
||||
all: install data train eval
|
||||
|
||||
city_distances_full.csv: check generate_data.py
|
||||
@echo "Generating distance data..."
|
||||
@bash -c 'time python generate_data.py'
|
||||
|
||||
data: city_distances_full.csv
|
||||
|
||||
train: check train.py
|
||||
@echo "Training embeddings..."
|
||||
@bash -c 'time python train.py'
|
||||
@ -24,3 +28,16 @@ clean:
|
||||
@echo "Removing outputs/ and checkpoints/ directories"
|
||||
@rm -rf output/
|
||||
@rm -rf checkpoints/
|
||||
|
||||
compress: plots/progress_35845_sm.png plots/progress_680065_sm.png
|
||||
|
||||
plots/progress_35845_sm.png: plots/progress_35845.png
|
||||
@convert -resize 33% plots/progress_35845.png plots/progress_35845_sm.png
|
||||
|
||||
plots/progress_680065_sm.png: plots/progress_680065.png
|
||||
@convert -resize 33% plots/progress_680065.png plots/progress_680065_sm.png
|
||||
|
||||
install:
|
||||
pip install -r requirements.txt
|
||||
|
||||
.PHONY: data train eval lint check clean all
|
74
README.md
74
README.md
@ -1,21 +1,67 @@
|
||||
# citybert
|
||||
# CityBert
|
||||
|
||||
1. Generates dataset of cities (US only for now) and their pair-wise geodesic distances.
|
||||
2. Uses that dataset to fine-tune a neural-net to understand that cities closer to one another are more similar.
|
||||
3. Distances become `labels` through the formula `1 - distance/MAX_DISTANCE`, where `MAX_DISTANCE=20_037.5 # km` represents half of the Earth's circumfrence.
|
||||
CityBert is a machine learning project that fine-tunes a neural network model to understand the similarity between cities based on their geodesic distances.
|
||||
|
||||
There are other factors that can make cities that are "close together" on the globe "far apart" in reality, due to political borders.
|
||||
Factors like this are not considered in this model, it is only considering geography.
|
||||
The project generates a dataset of US cities and their pair-wise geodesic distances, which are then used to train the model.
|
||||
|
||||
However, for use-cases that involve different measures of distances (perhaps just time-zones, or something that considers the reality of travel), the general principals proven here should be applicable (pick a metric, generate data, train).
|
||||
The project can be extended to include other distance metrics or additional data, such as airport codes, city aliases, or time zones.
|
||||
|
||||
A particularly useful addition to the dataset here:
|
||||
- airports: they (more/less) have unique codes, and this semantic understanding would be helpful for search engines.
|
||||
- aliases for cities: the dataset used for city data (lat/lon) contains a pretty exhaustive list of aliases for the cities. It would be good to generate examples of these with a distance of 0 and train the model on this knowledge.
|
||||
- time-zones: encode difference in hours (relative to worst-possible-case) as labels associated with the time-zone formatted-strings.
|
||||
> Note that this model only considers geographic distances and does not take into account other factors such as political borders or transportation infrastructure.
|
||||
These factors contribute to a sense of "distance as it pertains to travel difficulty," which is not directly reflected by this model.
|
||||
|
||||
# notes
|
||||
- see `Makefile` for instructions.
|
||||
|
||||
## Overview of Project Files
|
||||
|
||||
- `generate_data.py`: Generates a dataset of US cities and their pairwise geodesic distances.
|
||||
- `train.py`: Trains the neural network model using the generated dataset.
|
||||
- `eval.py`: Evaluates the trained model by comparing the similarity between city vectors before and after training.
|
||||
- `Makefile`: Automates the execution of various tasks, such as generating data, training, and evaluation.
|
||||
- `README.md`: Provides a description of the project, instructions on how to use it, and expected results.
|
||||
- `requirements.txt`: Defines requirements used for creating the results.
|
||||
|
||||
|
||||
## How to Use
|
||||
|
||||
1. Install the required dependencies by running `pip install -r requirements.txt`.
|
||||
2. Run `make city_distances.csv` to generate the dataset of city distances.
|
||||
3. Run `make train` to train the neural network model.
|
||||
4. Run `make eval` to evaluate the trained model and generate evaluation plots.
|
||||
|
||||
**You can also just run `make` (i.e., `make all`) which will run through all of those steps.**
|
||||
|
||||
|
||||
## What to Expect
|
||||
|
||||
After training, the model should be able to understand the similarity between cities based on their geodesic distances.
|
||||
You can inspect the evaluation plots generated by the `eval.py` script to see the improvement in similarity scores before and after training.
|
||||
|
||||
After five epochs, the model no longer treats the terms as unrelated:
|
||||
![Evaluation plot](./plots/progress_35845_sm.png)
|
||||
|
||||
After ten epochs, we can see the model has learned to correlate our desired quantities:
|
||||
![Evaluation plot](./plots/progress_680065_sm.png)
|
||||
|
||||
|
||||
*The above plots are examples showing the relationship between geodesic distance and the similarity between the embedded vectors (1 = more similar), for 10,000 randomly selected pairs of US cities (re-sampled for each image).*
|
||||
|
||||
*Note the (vertical) "gap" we see in the image, corresponding to the size of the continental United States (~5,000 km)*
|
||||
|
||||
---
|
||||
|
||||
## Future Improvements
|
||||
|
||||
There are several potential improvements and extensions to the current model:
|
||||
|
||||
1. **Incorporate airport codes**: Train the model to understand the unique codes of airports, which could be useful for search engines and other applications.
|
||||
2. **Add city aliases**: Enhance the dataset with city aliases, so the model can recognize different names for the same city. The `geonamescache` package already includes these.
|
||||
3. **Include time zones**: Train the model to understand time zone differences between cities, which could be helpful for various time-sensitive use cases. The `geonamescache` package already includes this data, but how to calculate the hours between them is an open question.
|
||||
4. **Expand to other distance metrics**: Adapt the model to consider other measures of distance, such as transportation infrastructure or travel time.
|
||||
5. **Train on sentences**: Improve the model's performance on sentences by adding training and validation examples that involve city names in the context of sentences. Can use generative AI to create template sentences (mad-libs style) to create random and diverse training examples.
|
||||
6. **Global city support**: Extend the model to support cities outside the US and cover a broader range of geographic locations.
|
||||
|
||||
|
||||
# Notes
|
||||
- Generating the data took about 13 minutes (for 3269 US cities) on 8-cores (Intel 9700K), yielding 2,720,278 records (combinations of cities).
|
||||
- Training on an Nvidia 3090 FE takes about an hour per epoch with an 80/20 test/train split. Batch size is 16, so there were 136,014 steps per epoch
|
||||
- **TODO**`**: Need to add training / validation examples that involve city names in the context of sentences. _It is unclear how the model performs on sentences, as it was trained only on word-pairs.
|
||||
- Evaluation on the above hardware took about 15 minutes for 20 epochs at 10k samples each.
|
||||
- **WARNING**: _It is unclear how the model performs on sentences, as it was trained and evaluated only on word-pairs._ See improvement (5) above.
|
||||
|
4
eval.py
4
eval.py
@ -58,7 +58,7 @@ def make_plot(data):
|
||||
)
|
||||
ax.set_xlabel("distance between cities (km)")
|
||||
ax.set_ylabel("similarity between vectors\n(cosine)")
|
||||
fig.legend(loc="upper right")
|
||||
ax.legend(loc="center right")
|
||||
return fig
|
||||
|
||||
|
||||
@ -69,7 +69,7 @@ if __name__ == "__main__":
|
||||
data = pd.read_csv("city_distances_full.csv")
|
||||
# data_sample = data.sample(1_000)
|
||||
checkpoint_dir = "checkpoints_absmax_split" # no slash
|
||||
for checkpoint in sorted(glob.glob(f"{checkpoint_dir}/*"))[14::]:
|
||||
for checkpoint in sorted(glob.glob(f"{checkpoint_dir}/*")):
|
||||
print(f"Evaluating {checkpoint}")
|
||||
data_sample = data.sample(1_000)
|
||||
trained_model = SentenceTransformer(checkpoint, device="cuda")
|
||||
|
BIN
plots/progress_35845_sm.png
Normal file
BIN
plots/progress_35845_sm.png
Normal file
Binary file not shown.
After Width: | Height: | Size: 255 KiB |
BIN
plots/progress_680065_sm.png
Normal file
BIN
plots/progress_680065_sm.png
Normal file
Binary file not shown.
After Width: | Height: | Size: 224 KiB |
Loading…
Reference in New Issue
Block a user