rename data file
This commit is contained in:
		
							parent
							
								
									294d4bb1cd
								
							
						
					
					
						commit
						5c33b5135f
					
				
							
								
								
									
										4
									
								
								Makefile
									
									
									
									
									
								
							
							
						
						
									
										4
									
								
								Makefile
									
									
									
									
									
								
							| @ -1,10 +1,10 @@ | ||||
| all: install data train eval | ||||
| 
 | ||||
| city_distances_full.csv: check generate_data.py | ||||
| city_distances.csv: check generate_data.py | ||||
| 	@echo "Generating distance data..." | ||||
| 	@bash -c 'time python generate_data.py --country US --workers 8 --chunk-size 4200' | ||||
| 
 | ||||
| data: city_distances_full.csv | ||||
| data: city_distances.csv | ||||
| 
 | ||||
| train: check train.py | ||||
| 	@echo "Training embeddings..." | ||||
|  | ||||
| @ -47,7 +47,7 @@ The approach demonstrated can be extended to other metrics or features beyond ge | ||||
| ## How to Use | ||||
| 
 | ||||
| 1. Install the required dependencies by running `pip install -r requirements.txt`. | ||||
| 2. Run `make city_distances.csv` to generate the dataset of city distances. | ||||
| 2. Run `make data` to generate the dataset of city distances. | ||||
| 3. Run `make train` to train the neural network model. | ||||
| 4. Run `make eval` to evaluate the trained model and generate evaluation plots. | ||||
| 
 | ||||
|  | ||||
							
								
								
									
										2
									
								
								eval.py
									
									
									
									
									
								
							
							
						
						
									
										2
									
								
								eval.py
									
									
									
									
									
								
							| @ -66,7 +66,7 @@ if __name__ == "__main__": | ||||
|     model_name = "sentence-transformers/all-MiniLM-L6-v2" | ||||
|     base_model = SentenceTransformer(model_name, device="cuda") | ||||
| 
 | ||||
|     data = pd.read_csv("city_distances_full.csv") | ||||
|     data = pd.read_csv("city_distances.csv") | ||||
|     # data_sample = data.sample(1_000) | ||||
|     checkpoint_dir = "checkpoints_absmax_split"  # no slash | ||||
|     for checkpoint in sorted(glob.glob(f"{checkpoint_dir}/*")): | ||||
|  | ||||
| @ -32,7 +32,7 @@ parser.add_argument( | ||||
|     "--output-file", | ||||
|     help="Specify the name of the output file (file.csv)", | ||||
|     type=str, | ||||
|     default="city_distances_full.csv", | ||||
|     default="city_distances.csv", | ||||
| ) | ||||
| parser.add_argument( | ||||
|     "--shuffle", | ||||
|  | ||||
							
								
								
									
										2
									
								
								train.py
									
									
									
									
									
								
							
							
						
						
									
										2
									
								
								train.py
									
									
									
									
									
								
							| @ -32,7 +32,7 @@ model = SentenceTransformer(model_name, device="cuda") | ||||
| #     (fake.city(), fake.city(), np.random.rand()) | ||||
| #     for _ in range(num_examples) | ||||
| # ] | ||||
| data = pd.read_csv("city_distances_full.csv") | ||||
| data = pd.read_csv("city_distances.csv") | ||||
| MAX_DISTANCE = 20_037.5  # global max distance | ||||
| # MAX_DISTANCE = data["distance"].max()  # about 5k | ||||
| 
 | ||||
|  | ||||
		Loading…
	
		Reference in New Issue
	
	Block a user