rename data file

This commit is contained in:
mm 2023-05-05 07:11:38 +00:00
parent 294d4bb1cd
commit 5c33b5135f
5 changed files with 6 additions and 6 deletions

View File

@ -1,10 +1,10 @@
all: install data train eval all: install data train eval
city_distances_full.csv: check generate_data.py city_distances.csv: check generate_data.py
@echo "Generating distance data..." @echo "Generating distance data..."
@bash -c 'time python generate_data.py --country US --workers 8 --chunk-size 4200' @bash -c 'time python generate_data.py --country US --workers 8 --chunk-size 4200'
data: city_distances_full.csv data: city_distances.csv
train: check train.py train: check train.py
@echo "Training embeddings..." @echo "Training embeddings..."

View File

@ -47,7 +47,7 @@ The approach demonstrated can be extended to other metrics or features beyond ge
## How to Use ## How to Use
1. Install the required dependencies by running `pip install -r requirements.txt`. 1. Install the required dependencies by running `pip install -r requirements.txt`.
2. Run `make city_distances.csv` to generate the dataset of city distances. 2. Run `make data` to generate the dataset of city distances.
3. Run `make train` to train the neural network model. 3. Run `make train` to train the neural network model.
4. Run `make eval` to evaluate the trained model and generate evaluation plots. 4. Run `make eval` to evaluate the trained model and generate evaluation plots.

View File

@ -66,7 +66,7 @@ if __name__ == "__main__":
model_name = "sentence-transformers/all-MiniLM-L6-v2" model_name = "sentence-transformers/all-MiniLM-L6-v2"
base_model = SentenceTransformer(model_name, device="cuda") base_model = SentenceTransformer(model_name, device="cuda")
data = pd.read_csv("city_distances_full.csv") data = pd.read_csv("city_distances.csv")
# data_sample = data.sample(1_000) # data_sample = data.sample(1_000)
checkpoint_dir = "checkpoints_absmax_split" # no slash checkpoint_dir = "checkpoints_absmax_split" # no slash
for checkpoint in sorted(glob.glob(f"{checkpoint_dir}/*")): for checkpoint in sorted(glob.glob(f"{checkpoint_dir}/*")):

View File

@ -32,7 +32,7 @@ parser.add_argument(
"--output-file", "--output-file",
help="Specify the name of the output file (file.csv)", help="Specify the name of the output file (file.csv)",
type=str, type=str,
default="city_distances_full.csv", default="city_distances.csv",
) )
parser.add_argument( parser.add_argument(
"--shuffle", "--shuffle",

View File

@ -32,7 +32,7 @@ model = SentenceTransformer(model_name, device="cuda")
# (fake.city(), fake.city(), np.random.rand()) # (fake.city(), fake.city(), np.random.rand())
# for _ in range(num_examples) # for _ in range(num_examples)
# ] # ]
data = pd.read_csv("city_distances_full.csv") data = pd.read_csv("city_distances.csv")
MAX_DISTANCE = 20_037.5 # global max distance MAX_DISTANCE = 20_037.5 # global max distance
# MAX_DISTANCE = data["distance"].max() # about 5k # MAX_DISTANCE = data["distance"].max() # about 5k