From 6321157df184f42e2e45970bb1bcbfc9d598ea24 Mon Sep 17 00:00:00 2001 From: "Michael Pilosov, PhD" Date: Sat, 2 Mar 2024 23:49:41 +0000 Subject: [PATCH] align plotting code with defaults --- check.py | 9 +++++++-- newsearch.py | 5 +++-- scripts/sortcolor.py | 4 ++-- 3 files changed, 12 insertions(+), 6 deletions(-) diff --git a/check.py b/check.py index ff8276b..bd253f3 100644 --- a/check.py +++ b/check.py @@ -61,7 +61,13 @@ def create_circle( def plot_preds( - preds, rgb_values, fname: str, roll: bool = False, dpi: int = 300, figsize=(6, 6) + preds, + rgb_values, + fname: str, + roll: bool = False, + inner_radius=1 / 3, + dpi: int = 300, + figsize=(6, 6), ): if isinstance(preds, torch.Tensor): preds = preds.detach().cpu().numpy() @@ -111,7 +117,6 @@ def plot_preds( ax.set_ylim(-radius, radius) # Overlay white circle - inner_radius = 1 / 3 circle = patches.Circle( (0, 0), inner_radius, transform=ax.transData._b, color="white", zorder=2 ) diff --git a/newsearch.py b/newsearch.py index 5c90546..757ec5a 100644 --- a/newsearch.py +++ b/newsearch.py @@ -8,7 +8,7 @@ from lightning_sdk import Machine, Studio # noqa: F401 # consistency of randomly sampled experiments. seed(19920921) -NUM_JOBS = 50 +NUM_JOBS = 33 # reference to the current studio # if you run outside of Lightning, you can pass the Studio name @@ -38,7 +38,7 @@ seeds = list(range(21, 1992)) optimizers = [ # "Adagrad", "Adam", - # "SGD", + "SGD", # "AdamW", # "LBFGS", # "RAdam", @@ -81,6 +81,7 @@ python newmain.py fit \ --model.bias true \ --model.loop true \ --model.transform tanh \ +--model.dropout 0 \ --trainer.min_epochs 10 \ --trainer.max_epochs {me} \ --trainer.log_every_n_steps 3 \ diff --git a/scripts/sortcolor.py b/scripts/sortcolor.py index 5aa4085..5d2ebfa 100644 --- a/scripts/sortcolor.py +++ b/scripts/sortcolor.py @@ -215,9 +215,9 @@ def plot_preds( rgb_values, fname: str, roll: bool = False, - dpi: int = 150, inner_radius: float = 1 / 3, - figsize=(3, 3), + dpi: int = 300, + figsize=(6, 6), ): sorted_inds = np.argsort(preds.ravel()) colors = rgb_values[sorted_inds, :3]