You can not select more than 25 topics Topics must start with a letter or number, can include dashes ('-') and can be up to 35 characters long.

51 lines
1.4 KiB

import subprocess
import sys
11 months ago
from random import sample
import numpy as np
11 months ago
from lightning_sdk import Machine, Studio
NUM_JOBS = 64
11 months ago
# reference to the current studio
# if you run outside of Lightning, you can pass the Studio name
studio = Studio()
# use the jobs plugin
11 months ago
studio.install_plugin("jobs")
job_plugin = studio.installed_plugins["jobs"]
11 months ago
# do a sweep over learning rates
# Define the ranges or sets of values for each hyperparameter
alpha_values = list(np.round(np.linspace(0, 2, 21), 4))
11 months ago
learning_rate_values = [1e-3, 1e-4, 1e-5]
batch_size_values = [128]
max_epochs_values = [5000]
11 months ago
# Generate all possible combinations of hyperparameters
11 months ago
all_params = [
(alpha, lr, bs, me)
for alpha in alpha_values
for lr in learning_rate_values
for bs in batch_size_values
for me in max_epochs_values
]
11 months ago
# perform random search with a limit
search_params = sample(all_params, min(NUM_JOBS, len(all_params)))
11 months ago
for idx, params in enumerate(search_params):
a, lr, bs, me = params
cmd = f"cd ~/colors && python main.py --alpha {a} --lr {lr} --bs {bs} --max_epochs {me}"
11 months ago
job_name = f"color2_{bs}_{a}_{lr:2.2e}"
# job_plugin.run(cmd, machine=Machine.T4, name=job_name)
print(f"Running {params}: {cmd}")
try:
# Run the command and wait for it to complete
subprocess.run(cmd, shell=True, check=True)
except KeyboardInterrupt:
print("Interrupted by user")
sys.exit(1)