Browse Source

align plotting code with defaults

plotting-unify
Michael Pilosov, PhD 9 months ago
parent
commit
6321157df1
  1. 9
      check.py
  2. 5
      newsearch.py
  3. 4
      scripts/sortcolor.py

9
check.py

@ -61,7 +61,13 @@ def create_circle(
def plot_preds( def plot_preds(
preds, rgb_values, fname: str, roll: bool = False, dpi: int = 300, figsize=(6, 6) preds,
rgb_values,
fname: str,
roll: bool = False,
inner_radius=1 / 3,
dpi: int = 300,
figsize=(6, 6),
): ):
if isinstance(preds, torch.Tensor): if isinstance(preds, torch.Tensor):
preds = preds.detach().cpu().numpy() preds = preds.detach().cpu().numpy()
@ -111,7 +117,6 @@ def plot_preds(
ax.set_ylim(-radius, radius) ax.set_ylim(-radius, radius)
# Overlay white circle # Overlay white circle
inner_radius = 1 / 3
circle = patches.Circle( circle = patches.Circle(
(0, 0), inner_radius, transform=ax.transData._b, color="white", zorder=2 (0, 0), inner_radius, transform=ax.transData._b, color="white", zorder=2
) )

5
newsearch.py

@ -8,7 +8,7 @@ from lightning_sdk import Machine, Studio # noqa: F401
# consistency of randomly sampled experiments. # consistency of randomly sampled experiments.
seed(19920921) seed(19920921)
NUM_JOBS = 50 NUM_JOBS = 33
# reference to the current studio # reference to the current studio
# if you run outside of Lightning, you can pass the Studio name # if you run outside of Lightning, you can pass the Studio name
@ -38,7 +38,7 @@ seeds = list(range(21, 1992))
optimizers = [ optimizers = [
# "Adagrad", # "Adagrad",
"Adam", "Adam",
# "SGD", "SGD",
# "AdamW", # "AdamW",
# "LBFGS", # "LBFGS",
# "RAdam", # "RAdam",
@ -81,6 +81,7 @@ python newmain.py fit \
--model.bias true \ --model.bias true \
--model.loop true \ --model.loop true \
--model.transform tanh \ --model.transform tanh \
--model.dropout 0 \
--trainer.min_epochs 10 \ --trainer.min_epochs 10 \
--trainer.max_epochs {me} \ --trainer.max_epochs {me} \
--trainer.log_every_n_steps 3 \ --trainer.log_every_n_steps 3 \

4
scripts/sortcolor.py

@ -215,9 +215,9 @@ def plot_preds(
rgb_values, rgb_values,
fname: str, fname: str,
roll: bool = False, roll: bool = False,
dpi: int = 150,
inner_radius: float = 1 / 3, inner_radius: float = 1 / 3,
figsize=(3, 3), dpi: int = 300,
figsize=(6, 6),
): ):
sorted_inds = np.argsort(preds.ravel()) sorted_inds = np.argsort(preds.ravel())
colors = rgb_values[sorted_inds, :3] colors = rgb_values[sorted_inds, :3]

Loading…
Cancel
Save