rename data file
This commit is contained in:
		
							parent
							
								
									294d4bb1cd
								
							
						
					
					
						commit
						5c33b5135f
					
				
							
								
								
									
										4
									
								
								Makefile
									
									
									
									
									
								
							
							
						
						
									
										4
									
								
								Makefile
									
									
									
									
									
								
							| @ -1,10 +1,10 @@ | |||||||
| all: install data train eval | all: install data train eval | ||||||
| 
 | 
 | ||||||
| city_distances_full.csv: check generate_data.py | city_distances.csv: check generate_data.py | ||||||
| 	@echo "Generating distance data..." | 	@echo "Generating distance data..." | ||||||
| 	@bash -c 'time python generate_data.py --country US --workers 8 --chunk-size 4200' | 	@bash -c 'time python generate_data.py --country US --workers 8 --chunk-size 4200' | ||||||
| 
 | 
 | ||||||
| data: city_distances_full.csv | data: city_distances.csv | ||||||
| 
 | 
 | ||||||
| train: check train.py | train: check train.py | ||||||
| 	@echo "Training embeddings..." | 	@echo "Training embeddings..." | ||||||
|  | |||||||
| @ -47,7 +47,7 @@ The approach demonstrated can be extended to other metrics or features beyond ge | |||||||
| ## How to Use | ## How to Use | ||||||
| 
 | 
 | ||||||
| 1. Install the required dependencies by running `pip install -r requirements.txt`. | 1. Install the required dependencies by running `pip install -r requirements.txt`. | ||||||
| 2. Run `make city_distances.csv` to generate the dataset of city distances. | 2. Run `make data` to generate the dataset of city distances. | ||||||
| 3. Run `make train` to train the neural network model. | 3. Run `make train` to train the neural network model. | ||||||
| 4. Run `make eval` to evaluate the trained model and generate evaluation plots. | 4. Run `make eval` to evaluate the trained model and generate evaluation plots. | ||||||
| 
 | 
 | ||||||
|  | |||||||
							
								
								
									
										2
									
								
								eval.py
									
									
									
									
									
								
							
							
						
						
									
										2
									
								
								eval.py
									
									
									
									
									
								
							| @ -66,7 +66,7 @@ if __name__ == "__main__": | |||||||
|     model_name = "sentence-transformers/all-MiniLM-L6-v2" |     model_name = "sentence-transformers/all-MiniLM-L6-v2" | ||||||
|     base_model = SentenceTransformer(model_name, device="cuda") |     base_model = SentenceTransformer(model_name, device="cuda") | ||||||
| 
 | 
 | ||||||
|     data = pd.read_csv("city_distances_full.csv") |     data = pd.read_csv("city_distances.csv") | ||||||
|     # data_sample = data.sample(1_000) |     # data_sample = data.sample(1_000) | ||||||
|     checkpoint_dir = "checkpoints_absmax_split"  # no slash |     checkpoint_dir = "checkpoints_absmax_split"  # no slash | ||||||
|     for checkpoint in sorted(glob.glob(f"{checkpoint_dir}/*")): |     for checkpoint in sorted(glob.glob(f"{checkpoint_dir}/*")): | ||||||
|  | |||||||
| @ -32,7 +32,7 @@ parser.add_argument( | |||||||
|     "--output-file", |     "--output-file", | ||||||
|     help="Specify the name of the output file (file.csv)", |     help="Specify the name of the output file (file.csv)", | ||||||
|     type=str, |     type=str, | ||||||
|     default="city_distances_full.csv", |     default="city_distances.csv", | ||||||
| ) | ) | ||||||
| parser.add_argument( | parser.add_argument( | ||||||
|     "--shuffle", |     "--shuffle", | ||||||
|  | |||||||
							
								
								
									
										2
									
								
								train.py
									
									
									
									
									
								
							
							
						
						
									
										2
									
								
								train.py
									
									
									
									
									
								
							| @ -32,7 +32,7 @@ model = SentenceTransformer(model_name, device="cuda") | |||||||
| #     (fake.city(), fake.city(), np.random.rand()) | #     (fake.city(), fake.city(), np.random.rand()) | ||||||
| #     for _ in range(num_examples) | #     for _ in range(num_examples) | ||||||
| # ] | # ] | ||||||
| data = pd.read_csv("city_distances_full.csv") | data = pd.read_csv("city_distances.csv") | ||||||
| MAX_DISTANCE = 20_037.5  # global max distance | MAX_DISTANCE = 20_037.5  # global max distance | ||||||
| # MAX_DISTANCE = data["distance"].max()  # about 5k | # MAX_DISTANCE = data["distance"].max()  # about 5k | ||||||
| 
 | 
 | ||||||
|  | |||||||
		Loading…
	
		Reference in New Issue
	
	Block a user