Compare commits
3 Commits
35f05d40df
...
f193018ac2
Author | SHA1 | Date | |
---|---|---|---|
f193018ac2 | |||
282c0466d8 | |||
e9adbed41a |
16
Makefile
16
Makefile
@ -1,16 +1,20 @@
|
||||
city_distances.csv: lint generate_data.py
|
||||
city_distances.csv: check generate_data.py
|
||||
bash -c 'time python generate_data.py'
|
||||
|
||||
train: check train.py
|
||||
bash -c 'time python train.py'
|
||||
|
||||
eval: check eval.py
|
||||
bash -c 'time python eval.py'
|
||||
|
||||
lint:
|
||||
isort --profile=black .
|
||||
black .
|
||||
flake8 --max-line-length=88 .
|
||||
|
||||
train: lint train.py
|
||||
bash -c 'time python train.py'
|
||||
|
||||
eval: lint eval.py
|
||||
bash -c 'time python eval.py'
|
||||
check: lint
|
||||
@echo "Checking for unstaged or untracked changes..."
|
||||
@git diff-index --quiet HEAD -- || { echo "Unstaged or untracked changes detected!"; exit 1; }
|
||||
|
||||
clean:
|
||||
rm -rf output/
|
||||
|
15
README.md
15
README.md
@ -1,8 +1,8 @@
|
||||
# city-transformers
|
||||
# citybert
|
||||
|
||||
Generates dataset of cities (US only for now) and their geodesic distances.
|
||||
Uses that dataset to fine-tune a neural-net to understand that cities closer to one another are more similar.
|
||||
Distances become `labels` through the formula `1 - distance/MAX_DISTANCE`, where `MAX_DISTANCE=20_037.5 # km` represents half of the Earth's circumfrence.
|
||||
1. Generates dataset of cities (US only for now) and their pair-wise geodesic distances.
|
||||
2. Uses that dataset to fine-tune a neural-net to understand that cities closer to one another are more similar.
|
||||
3. Distances become `labels` through the formula `1 - distance/MAX_DISTANCE`, where `MAX_DISTANCE=20_037.5 # km` represents half of the Earth's circumfrence.
|
||||
|
||||
There are other factors that can make cities that are "close together" on the globe "far apart" in reality, due to political borders.
|
||||
Factors like this are not considered in this model, it is only considering geography.
|
||||
@ -12,5 +12,10 @@ However, for use-cases that involve different measures of distances (perhaps jus
|
||||
A particularly useful addition to the dataset here:
|
||||
- airports: they (more/less) have unique codes, and this semantic understanding would be helpful for search engines.
|
||||
- aliases for cities: the dataset used for city data (lat/lon) contains a pretty exhaustive list of aliases for the cities. It would be good to generate examples of these with a distance of 0 and train the model on this knowledge.
|
||||
- time-zones: encode difference in hours (relative to worst-possible-case) as labels associated with the time-zone formatted-strings.
|
||||
|
||||
see `Makefile` for instructions.
|
||||
# notes
|
||||
- see `Makefile` for instructions.
|
||||
- Generating the data took about 13 minutes (for 3269 US cities) on 8-cores (Intel 9700K), yielding 2,720,278 records (combinations of cities).
|
||||
- Training on an Nvidia 3090 FE takes about an hour per epoch with an 80/20 test/train split. Batch size is 16, so there were 136,014 steps per epoch
|
||||
- **TODO**`**: Need to add training / validation examples that involve city names in the context of sentences. _It is unclear how the model performs on sentences, as it was trained only on word-pairs.
|
2
eval.py
2
eval.py
@ -66,7 +66,7 @@ if __name__ == "__main__":
|
||||
model_name = "sentence-transformers/all-MiniLM-L6-v2"
|
||||
base_model = SentenceTransformer(model_name, device="cuda")
|
||||
|
||||
data = pd.read_csv("city_distances_sample.csv")
|
||||
data = pd.read_csv("city_distances_full.csv")
|
||||
# data_sample = data.sample(1_000)
|
||||
checkpoint_dir = "checkpoints_absmax_split" # no slash
|
||||
for checkpoint in sorted(glob.glob(f"{checkpoint_dir}/*")):
|
||||
|
Loading…
Reference in New Issue
Block a user