From 336531421d2ac4860277555d2b966a1eca9c5b04 Mon Sep 17 00:00:00 2001 From: mm Date: Thu, 4 May 2023 22:41:31 +0000 Subject: [PATCH] prettier readme + full eval + images --- Makefile | 21 +++++++++- README.md | 76 ++++++++++++++++++++++++++++------- eval.py | 4 +- plots/progress_35845_sm.png | Bin 0 -> 261105 bytes plots/progress_680065_sm.png | Bin 0 -> 229625 bytes 5 files changed, 82 insertions(+), 19 deletions(-) create mode 100644 plots/progress_35845_sm.png create mode 100644 plots/progress_680065_sm.png diff --git a/Makefile b/Makefile index d7a8027..a748bc1 100644 --- a/Makefile +++ b/Makefile @@ -1,7 +1,11 @@ -city_distances.csv: check generate_data.py +all: install data train eval + +city_distances_full.csv: check generate_data.py @echo "Generating distance data..." @bash -c 'time python generate_data.py' +data: city_distances_full.csv + train: check train.py @echo "Training embeddings..." @bash -c 'time python train.py' @@ -23,4 +27,17 @@ check: lint clean: @echo "Removing outputs/ and checkpoints/ directories" @rm -rf output/ - @rm -rf checkpoints/ \ No newline at end of file + @rm -rf checkpoints/ + +compress: plots/progress_35845_sm.png plots/progress_680065_sm.png + +plots/progress_35845_sm.png: plots/progress_35845.png + @convert -resize 33% plots/progress_35845.png plots/progress_35845_sm.png + +plots/progress_680065_sm.png: plots/progress_680065.png + @convert -resize 33% plots/progress_680065.png plots/progress_680065_sm.png + +install: + pip install -r requirements.txt + +.PHONY: data train eval lint check clean all \ No newline at end of file diff --git a/README.md b/README.md index e53b4a7..aaffb52 100644 --- a/README.md +++ b/README.md @@ -1,21 +1,67 @@ -# citybert +# CityBert -1. Generates dataset of cities (US only for now) and their pair-wise geodesic distances. -2. Uses that dataset to fine-tune a neural-net to understand that cities closer to one another are more similar. -3. Distances become `labels` through the formula `1 - distance/MAX_DISTANCE`, where `MAX_DISTANCE=20_037.5 # km` represents half of the Earth's circumfrence. +CityBert is a machine learning project that fine-tunes a neural network model to understand the similarity between cities based on their geodesic distances. -There are other factors that can make cities that are "close together" on the globe "far apart" in reality, due to political borders. -Factors like this are not considered in this model, it is only considering geography. +The project generates a dataset of US cities and their pair-wise geodesic distances, which are then used to train the model. -However, for use-cases that involve different measures of distances (perhaps just time-zones, or something that considers the reality of travel), the general principals proven here should be applicable (pick a metric, generate data, train). +The project can be extended to include other distance metrics or additional data, such as airport codes, city aliases, or time zones. -A particularly useful addition to the dataset here: -- airports: they (more/less) have unique codes, and this semantic understanding would be helpful for search engines. -- aliases for cities: the dataset used for city data (lat/lon) contains a pretty exhaustive list of aliases for the cities. It would be good to generate examples of these with a distance of 0 and train the model on this knowledge. -- time-zones: encode difference in hours (relative to worst-possible-case) as labels associated with the time-zone formatted-strings. +> Note that this model only considers geographic distances and does not take into account other factors such as political borders or transportation infrastructure. +These factors contribute to a sense of "distance as it pertains to travel difficulty," which is not directly reflected by this model. -# notes -- see `Makefile` for instructions. -- Generating the data took about 13 minutes (for 3269 US cities) on 8-cores (Intel 9700K), yielding 2,720,278 records (combinations of cities). + +## Overview of Project Files + +- `generate_data.py`: Generates a dataset of US cities and their pairwise geodesic distances. +- `train.py`: Trains the neural network model using the generated dataset. +- `eval.py`: Evaluates the trained model by comparing the similarity between city vectors before and after training. +- `Makefile`: Automates the execution of various tasks, such as generating data, training, and evaluation. +- `README.md`: Provides a description of the project, instructions on how to use it, and expected results. +- `requirements.txt`: Defines requirements used for creating the results. + + +## How to Use + +1. Install the required dependencies by running `pip install -r requirements.txt`. +2. Run `make city_distances.csv` to generate the dataset of city distances. +3. Run `make train` to train the neural network model. +4. Run `make eval` to evaluate the trained model and generate evaluation plots. + +**You can also just run `make` (i.e., `make all`) which will run through all of those steps.** + + +## What to Expect + +After training, the model should be able to understand the similarity between cities based on their geodesic distances. +You can inspect the evaluation plots generated by the `eval.py` script to see the improvement in similarity scores before and after training. + +After five epochs, the model no longer treats the terms as unrelated: +![Evaluation plot](./plots/progress_35845_sm.png) + +After ten epochs, we can see the model has learned to correlate our desired quantities: +![Evaluation plot](./plots/progress_680065_sm.png) + + +*The above plots are examples showing the relationship between geodesic distance and the similarity between the embedded vectors (1 = more similar), for 10,000 randomly selected pairs of US cities (re-sampled for each image).* + +*Note the (vertical) "gap" we see in the image, corresponding to the size of the continental United States (~5,000 km)* + +--- + +## Future Improvements + +There are several potential improvements and extensions to the current model: + +1. **Incorporate airport codes**: Train the model to understand the unique codes of airports, which could be useful for search engines and other applications. +2. **Add city aliases**: Enhance the dataset with city aliases, so the model can recognize different names for the same city. The `geonamescache` package already includes these. +3. **Include time zones**: Train the model to understand time zone differences between cities, which could be helpful for various time-sensitive use cases. The `geonamescache` package already includes this data, but how to calculate the hours between them is an open question. +4. **Expand to other distance metrics**: Adapt the model to consider other measures of distance, such as transportation infrastructure or travel time. +5. **Train on sentences**: Improve the model's performance on sentences by adding training and validation examples that involve city names in the context of sentences. Can use generative AI to create template sentences (mad-libs style) to create random and diverse training examples. +6. **Global city support**: Extend the model to support cities outside the US and cover a broader range of geographic locations. + + +# Notes +- Generating the data took about 13 minutes (for 3269 US cities) on 8-cores (Intel 9700K), yielding 2,720,278 records (combinations of cities). - Training on an Nvidia 3090 FE takes about an hour per epoch with an 80/20 test/train split. Batch size is 16, so there were 136,014 steps per epoch -- **TODO**`**: Need to add training / validation examples that involve city names in the context of sentences. _It is unclear how the model performs on sentences, as it was trained only on word-pairs. \ No newline at end of file +- Evaluation on the above hardware took about 15 minutes for 20 epochs at 10k samples each. +- **WARNING**: _It is unclear how the model performs on sentences, as it was trained and evaluated only on word-pairs._ See improvement (5) above. diff --git a/eval.py b/eval.py index d2dc2dc..2403ddf 100644 --- a/eval.py +++ b/eval.py @@ -58,7 +58,7 @@ def make_plot(data): ) ax.set_xlabel("distance between cities (km)") ax.set_ylabel("similarity between vectors\n(cosine)") - fig.legend(loc="upper right") + ax.legend(loc="center right") return fig @@ -69,7 +69,7 @@ if __name__ == "__main__": data = pd.read_csv("city_distances_full.csv") # data_sample = data.sample(1_000) checkpoint_dir = "checkpoints_absmax_split" # no slash - for checkpoint in sorted(glob.glob(f"{checkpoint_dir}/*"))[14::]: + for checkpoint in sorted(glob.glob(f"{checkpoint_dir}/*")): print(f"Evaluating {checkpoint}") data_sample = data.sample(1_000) trained_model = SentenceTransformer(checkpoint, device="cuda") diff --git a/plots/progress_35845_sm.png b/plots/progress_35845_sm.png new file mode 100644 index 0000000000000000000000000000000000000000..ac9144be81c341b9b2070b932c6ab125150f9812 GIT binary patch literal 261105 zcmeFZg;!fa*FH)q6et#oyOvO#;+jH(6nB>b#T|-k(Lm7B7K%&p;-v&HUMMaN?hxDp z1m~vTz3=b6_iwo0$~u#k$z;x~Gkec|_TJA_>A zO;)fGyj*WjyJGiY^}r&jWUI=1aZf)G)bY;7Lf>qQ&TB}Z->(nZ?&iWzMbJfM9WT)L ze>YxxWy@!e{$Cx|BhkH5|NY*&4-4x5|BU_D8;?hU{69aoj;AK&-lG2RXhC+vA;JHS zwpFc3xBKttq4c)`qyN2Cxb<*N^TOI|3(T{xi5Qo}BUW7tUJSQScG)FMR8&Yub4@ zLAFc%{|v|?tS(CUUn}_kwSozmG8lpxAsviu=6p{gWTzY(>_sxQDyDhQC4GE-*I5~_ zlv=O0G6kpz)(#h2PA(cdq_?)VT(8d({2`hi<$;B9^A#i*5~<*Ii3-u5;}$UP!dCgd_ER4n8G9>OCcDr(J7 z)w8x{6?9=2>T`S0ao4#1r{nfd;jHu0`}zejMMaV}^PW}9p0@z43V9)#D?x#jhv*QH z#rzSX1IT__;`&g5UX&&OZxEk<(&1NI{@(~L7lUfMnpqG+A5ijv04sSo^nl^pm;VRs_+2 zZt6j_9oH2dG|MOS%Opqsqk#qEe)rA4-vT_`0kv)PwGHI8^{Jpm*EsJ8*`6ugp4*_F zednh>ZRYp)$NZGE=K<}uuDn(r0-?s0<6TQG##7R6(4r6!_IzQX=OQml7txb$P=m2U zOMNS4POlFGLG-*B%ls4Bg>=ND!-$B8AagqU%?Axtuu$fe&LUI57rzXIX*FeSM=nSz zg0-%VCIK=q;0y{jm##RlcZFvSph02Nn;cNl)|b;&;!rX3z=H;RA)1pee7+tPzBV$( zyKTkwllA^_dm&z)o(r`8xCPBO2AfC)UGmEy1wn%CNXZVt`)$U%LbP10J!ZUo71Zw= zq}cj!ca(H=U3C=a+qP$RQgT0)+q_fUq3+^QimBsPyk zs6IV732DPJGoglt5JR&BSfFUAzJboh-uYmGIyMo7PlAF`I-RHo{qt520tir1adSo? zR`)K1z5)Whv8CfD(yQ{rn-cOiY6Kv{v0>G^T{)l)qF%Cq z6F+?Bi&du2)t_Z2yCL(Ybt6b)kLxZZkuh7Hx@5;|8Zp1|E-`a*a`UUOxP#qoJ!bIE zan0F>0O@MsQA3a@8jeZ184k5zqYvH70>0}4znl4GhTN;KxztB~{gPHLEh7B@~O zNYDc^?Lr!TYM*q2ZoAQX=GHSeXB_l^3Jtv-09^q=CpVzGl85)-zklzsxP^ua$XwJW z^%CEW5%(fO?-BjCx3_Ad;|=cb$xgPldNSRL!%({wB-I@`a!VIJB%#fVbxeCGDF3)U zS>w{#dlJtr1nc2Ktl#`R%i$u_(0K3xBr!PpE=81TW9d%z^NTZqXhF~QV_7;7ss?sr#J zdw@(ai%6lYw3ry1{85%b&zGZiE|pn26?znNF)<_u7srUPL{9y2;^Goy3!~I;nV$`x z55%Z8ymP+3IPgWN_dxlIYkV0}p!dHT6yLF7nNoJpk-|%xsRA-Cj?U{`<_@Gr4-Pz! z_K%1e`-fC}PCcJ4a`%-K$Avk&631hkm|Ovap&m8?5=SZeTo2NYkJqHSUMAXMtrixt zER@cS+$g#P(b0cKdF56FQmmFtON4xH=7<-S0{QsH^joil^aTH`Nk;f(mbwC%qw7v| z+guBK{yCmI+9xL)jgVC(y3Bp z(gjDKmPAcQ7bsL=$UuENaq#x*6N~#lp|KDM-%TAKl<&5SFKA6T^pf~t^>VQKcJMNv z@jCzVv&Gft%Y^02gv)Tg%kWIs94O@MJ60n%B8LWn2#GhN79 zXQB85V6}Pi_8~U*bgch{fM6=3yUWQAIFDu5v$S*^j`iqcvXwVmosxBJC8m1C=K>(q z4y>9*Mn@dJWn*d)2^k)KZzdCB8d=oc?O8d0;9L?LE0tDM1W7Yj)*>#kssEab+{&Dl zF)@*Eotx)^o^~?LTUs)M+yW)hr*%+pN$T=zrHdiekmWh!xOTo)z{%;KD_>J?V8g(w zC5eng_S*IQ;xJ>CxST6EWNF`-#KOH^RqgW_sOrshY_GTWc$!ROed+;nMHWxX&a%Cp zw6-=S^a!Q$`?@V~cEtDa9qrKYCUNlTw>@KW&LmRmq2Af~KYIfW*4zcdnZvF6P^kZ? z`yG{C`C8pmCYUAlq+>&C8Uu(|j$bZ|cclqrXM0*-UX=|w4UJztONFckI=WTITc$N! z&MuGMU@)^3ELck@933H?B(hr>Q5V`PIXUfDy;R}l=8|UH?mSd!7Z-!lNq!*Z)~t*G z?bZ#co4dNzF0nr3hxe4tur)8gzJtT9^p|K&Yph8=>iF(MM-crwbauv?RKDL>#QI0& zd!22gLQW1aftyecj+$?mVIGoyi_tmK#d*W)o&T(LhkwY}ut}R}df{Tv*0xwtQquq- z$eQ;t)W0!Ii}}Z6dOW^Z6xv9ftaL{FGb7I}PM@>ln@;VWU#{`puU*zM-qg;_&BdX` z&BC9zM~b4{t=%Y63XCHsX-IZ>@HvLPK%+1vb*QW?>;2!o=UMnbg$@w+#Oj(;!=PN@ zB()C{zn0AjKr7MO^CjiTAN1yd`&>mWVvzaEd_-l#!)rB?5}a~OuD|rTC}nhXcC^3I z|AiyHJeW>($kPntC)ss${O@mWY_uYjrtCMa_;@ybpIJ4xO6~7ke+XV05Qmp9j##@j zLLrR}pQ~3_TN@2129f789aZa5*%3s(~F*yP90 z1Z;dnK$J1~O{D;v_wcmPNyh8Dee29je>z<9e=Yjbs2x-M_BAX1a&dWB^DNDL@F-;S z^R4Kbxi?Fk-#{A##uZ5`#Cy_fNlhk_jUn_Zt@2=n9+DQ1ybw)OiiPb^LQlNgM9TkfFQDBZRn2NY+6&a`kX=B08<1R;%r2zqt5Je3wVOdyy4G0UgSy z{;Nwy;p$YBwFGod-L_~vvcquiDBTjI~%^Ngqn(8Kr<1&-KFDol? z6BLP>iMI^-RUGT;cTBK#>rm0ww#LwUfWL_Rfj@e^8%_meXCAW?M|5`wnM(zbo%}Vi zc^$X+8)+X3>d&gI%#w7MK3ocx!Io|-IhP4$m|rteWL2Zm5dIunEHp-ls-W>oUq1Z~ zAJ!Xo3NXpLGi2S~8uE6zr1qjV912uNyEG@A#3B4>h|A5*-BS{JUBY-X_Yb(#HaF8$ zciacuZn@pto&~hoH3RzICZ~2PCcXS|FR- z2#bon7uF9*NY7gTKC{R}gher`%8t5-PEA2dfR#OqE31l04@RIFul`T+CSo8EN>z-1-Q6wqb8}M)hr;>5rzb|{Q?5aET8(kf6N=I^ zeDW|7LnB;PW2ToU#v>);9IGu!j0m}{D9XA@L%Q{~?p32|1ft{0wmqg-D0dhO8;9&r z%7-Pe?c<(@&(BD%6SMmI-bzIaMYebLrPHTgzr|;F2=LP&K`_m|b19q!%D}*@r5NIR zQ0uj0TuF87M^NtRi%pr;TL5+W6;OOs-zkW4B*V*HH$?544pl3$;jUYr?mwP=st$~5 zy+PIUT?EV?+3S!816%^zZkMK~Z@og=<72sBPM1+?ff*wQ)PWD+V|YkQOXl51W#i!0 zq;C@?)oYvxvWUY)E^KIeeEplcmdyU@{JhG)v-3)e`KGK{J_kNI@UE`o%cz!2H43be z?ChY$HLFA;u_h16oDE&+(UY|&8~4l$xd_s$UEOqRQibfqn3(6w!0OcpTo1je&6@gD zmppupxkLdkOJ)nG;a^t-%HU|HsmXJ6g$xX-($DYUP*jx9v*FE?V_4U!)XzEi(fd}D zA^~mDjiAuN?(u#?ZFPSAGDtZR`aj9pYBB0jGCw?O$y^!EWIb>lkcj8(h(_*o zw5XrbT!jmxoV&+Tu)$<6uV8{FSNncv9Q|w)z5E7>uk1{x)f%+Cc1k`p$sldE{5ZD+ z4h7R%cVvFIMkQRTmkgzp6c0u-A|onhjs>^s{Cvo11baF=WrNNN851Co&CN}i7h7)_=Me-Y{-!??wLm+dzEC&?n=zXmDQA@Ing{X9X9gwu=)MM9EgB2 z#2*S5d5Odb2Udd@U5(x{2atS(Am2K1C22&Zg>>Lb;5yUsLG=jnxcBiyR+}-%HpQvWo3wBw1Ft4f0{$|=RSu}|hhws@iD>J<08Iy%D zrpn3Dl$8A0u(bnrKu6!z-uyM`j?&%^#U0`P+4SY+X7*=;ar?t~q4QBv5F7TAf%zN6 zQp?h3{I_xohN_cpJ4t6pWo4#g^bEzX0%_S(>{ho+A+l#0ZHbYLo>jQ_XFPH(B)_Jn zi(1a7rDEqawsFwV8opZDcq)vxVAS!cFzH#0*%^++irCn-UQ&M$p|ufnid587ddA_|6JVVd5ies)q?2WQyI$`q32Ve79x!|0T5hz#l( zb2Ft`Oem|xBG_?ci5M$NTbx8{?X~>-6;Z+uW1j%Og-QC>@Lj4zAMbTLI zuA{c2BlNqy@5B9FrfVWI5s|VXlPi}gGvtjfw;r6%|EnR>4zCct4ZEro0Dy8&1xU!j zs3l9T1A1sqU!dU?=3B?|{ol2ImIZy7{?>&-G3rdU-4%dnjxq`_3T=p*>X5hPAa7P& zCU=H7P%i$>m(c~~LdECR z5-zT*J6c+s)hEX%^y>^zC;Dm<+x2dk#*{P&Utev8bQ-0h6B*RXpJK3p$Wu!D2)kiyffkFpn zJfH$16V(6;;rHYz_I5z|(z`oL#nU?+va>oo;&9*xJXLQ7EeyvtZ%P&uSs;( zLM3Tpu!)Ayv%|G|Qq?;qe~tS+gpeTJS;yGp=Yt-4z|l8+++>Y(Ww<|Zr&cY0^6(nY z`FPa+ST42VfCxFm?5EGr0hnNbFbbvO*U)7*S>50%k~`|1@m2hRMPIarLF%_idKPyV zho!4j2eoel)3dIN32kax4@<#|MB4ZpDO&eY2SZ`nAZJd;N9jdHK>6o%`SI8=x+d4= z*;06kIUlX4iHZMF9f{wh%O6-Nwx!>_qcUz!dVO>0C`<8fALTlW^_72aR3Y8(j7EPt zk-@%fv>WMU5h{&fIwZ*>ybrWj4w{+I z0Z3fcXN7&S>ms*8X2}Ag2`#H$UE||g`-g`L;5ct!)UZhAaQZdOs;;AhF{Vd0`vV+3 z|NWrF71zm}Q#C!;*1^HS;(IUML?F(43Q_L+vfl$NdZ{LR4`oGJF_ayJ{MeY3;Y3hE z?_XuCkMaRP#-pN{zUu(>jIV|c0*~Z!01+}Fv#eHPX^d7>a?kn`VFl2)gm0|n^3