From 05a66f114c4ece6f5a5e9ee862b5964e7b153742 Mon Sep 17 00:00:00 2001 From: mm Date: Sun, 31 Dec 2023 07:04:54 +0000 Subject: [PATCH] prep experiment, fix bugs --- dataloader.py | 1 + search.py | 16 ++++++++++++---- 2 files changed, 13 insertions(+), 4 deletions(-) diff --git a/dataloader.py b/dataloader.py index 6ca33af..0f12bc2 100644 --- a/dataloader.py +++ b/dataloader.py @@ -28,6 +28,7 @@ def create_dataloader(N: int = 50, **kwargs): def create_gray_supplement(N: int = 50): linear_space = torch.linspace(0, 1, N) gray_tensor = linear_space.unsqueeze(1).repeat(1, 3) + gray_tensor = preprocess_data(gray_tensor) return [(gray_tensor[i], f"gray{i/N:2.4f}") for i in range(len(gray_tensor))] diff --git a/search.py b/search.py index bd956f4..f73f8fc 100644 --- a/search.py +++ b/search.py @@ -1,8 +1,11 @@ +import subprocess +import sys from random import sample +import numpy as np from lightning_sdk import Machine, Studio -NUM_JOBS = 21 +NUM_JOBS = 64 # reference to the current studio # if you run outside of Lightning, you can pass the Studio name @@ -15,7 +18,7 @@ job_plugin = studio.installed_plugins["jobs"] # do a sweep over learning rates # Define the ranges or sets of values for each hyperparameter -alpha_values = [0.1, 0.25, 0.5, 0.7, 0.9] +alpha_values = list(np.round(np.linspace(0, 2, 21), 4)) learning_rate_values = [1e-3, 1e-4, 1e-5] batch_size_values = [128] max_epochs_values = [5000] @@ -31,7 +34,7 @@ all_params = [ # perform random search with a limit -search_params = sample(all_params, NUM_JOBS) +search_params = sample(all_params, min(NUM_JOBS, len(all_params))) for idx, params in enumerate(search_params): a, lr, bs, me = params @@ -39,4 +42,9 @@ for idx, params in enumerate(search_params): job_name = f"color2_{bs}_{a}_{lr:2.2e}" # job_plugin.run(cmd, machine=Machine.T4, name=job_name) print(f"Running {params}: {cmd}") - os.system(cmd) + try: + # Run the command and wait for it to complete + subprocess.run(cmd, shell=True, check=True) + except KeyboardInterrupt: + print("Interrupted by user") + sys.exit(1)