Compare commits

..

1 Commits

Author SHA1 Message Date
mm
35f05d40df details in readme 2023-05-04 10:41:52 +00:00
3 changed files with 7 additions and 12 deletions

View File

@ -1,20 +1,16 @@
city_distances.csv: check generate_data.py city_distances.csv: lint generate_data.py
bash -c 'time python generate_data.py' bash -c 'time python generate_data.py'
train: check train.py
bash -c 'time python train.py'
eval: check eval.py
bash -c 'time python eval.py'
lint: lint:
isort --profile=black . isort --profile=black .
black . black .
flake8 --max-line-length=88 . flake8 --max-line-length=88 .
check: lint train: lint train.py
@echo "Checking for unstaged or untracked changes..." bash -c 'time python train.py'
@git diff-index --quiet HEAD -- || { echo "Unstaged or untracked changes detected!"; exit 1; }
eval: lint eval.py
bash -c 'time python eval.py'
clean: clean:
rm -rf output/ rm -rf output/

View File

@ -12,7 +12,6 @@ However, for use-cases that involve different measures of distances (perhaps jus
A particularly useful addition to the dataset here: A particularly useful addition to the dataset here:
- airports: they (more/less) have unique codes, and this semantic understanding would be helpful for search engines. - airports: they (more/less) have unique codes, and this semantic understanding would be helpful for search engines.
- aliases for cities: the dataset used for city data (lat/lon) contains a pretty exhaustive list of aliases for the cities. It would be good to generate examples of these with a distance of 0 and train the model on this knowledge. - aliases for cities: the dataset used for city data (lat/lon) contains a pretty exhaustive list of aliases for the cities. It would be good to generate examples of these with a distance of 0 and train the model on this knowledge.
- time-zones: encode difference in hours (relative to worst-possible-case) as labels associated with the time-zone formatted-strings.
# notes # notes
- see `Makefile` for instructions. - see `Makefile` for instructions.

View File

@ -66,7 +66,7 @@ if __name__ == "__main__":
model_name = "sentence-transformers/all-MiniLM-L6-v2" model_name = "sentence-transformers/all-MiniLM-L6-v2"
base_model = SentenceTransformer(model_name, device="cuda") base_model = SentenceTransformer(model_name, device="cuda")
data = pd.read_csv("city_distances_full.csv") data = pd.read_csv("city_distances_sample.csv")
# data_sample = data.sample(1_000) # data_sample = data.sample(1_000)
checkpoint_dir = "checkpoints_absmax_split" # no slash checkpoint_dir = "checkpoints_absmax_split" # no slash
for checkpoint in sorted(glob.glob(f"{checkpoint_dir}/*")): for checkpoint in sorted(glob.glob(f"{checkpoint_dir}/*")):