switch to unsupervised
This commit is contained in:
parent
b9d334e49a
commit
721993d9e5
86
hsv1.txt
Normal file
86
hsv1.txt
Normal file
@ -0,0 +1,86 @@
|
|||||||
|
# lightning.pytorch==2.1.3
|
||||||
|
seed_everything: 1387
|
||||||
|
trainer:
|
||||||
|
accelerator: auto
|
||||||
|
strategy: auto
|
||||||
|
devices: auto
|
||||||
|
num_nodes: 1
|
||||||
|
precision: null
|
||||||
|
logger: null
|
||||||
|
callbacks:
|
||||||
|
- class_path: callbacks.SaveImageCallback
|
||||||
|
init_args:
|
||||||
|
save_interval: 0
|
||||||
|
final_dir: out
|
||||||
|
fast_dev_run: false
|
||||||
|
max_epochs: 10
|
||||||
|
min_epochs: 10
|
||||||
|
max_steps: -1
|
||||||
|
min_steps: null
|
||||||
|
max_time: null
|
||||||
|
limit_train_batches: null
|
||||||
|
limit_val_batches: 50
|
||||||
|
limit_test_batches: null
|
||||||
|
limit_predict_batches: null
|
||||||
|
overfit_batches: 0.0
|
||||||
|
val_check_interval: null
|
||||||
|
check_val_every_n_epoch: 1
|
||||||
|
num_sanity_val_steps: null
|
||||||
|
log_every_n_steps: 3
|
||||||
|
enable_checkpointing: null
|
||||||
|
enable_progress_bar: null
|
||||||
|
enable_model_summary: null
|
||||||
|
accumulate_grad_batches: 1
|
||||||
|
gradient_clip_val: null
|
||||||
|
gradient_clip_algorithm: null
|
||||||
|
deterministic: null
|
||||||
|
benchmark: null
|
||||||
|
inference_mode: true
|
||||||
|
use_distributed_sampler: true
|
||||||
|
profiler: null
|
||||||
|
detect_anomaly: false
|
||||||
|
barebones: false
|
||||||
|
plugins: null
|
||||||
|
sync_batchnorm: false
|
||||||
|
reload_dataloaders_every_n_epochs: 0
|
||||||
|
default_root_dir: null
|
||||||
|
model:
|
||||||
|
transform: tanh
|
||||||
|
width: 128
|
||||||
|
depth: 4
|
||||||
|
bias: true
|
||||||
|
alpha: 0.0
|
||||||
|
data:
|
||||||
|
val_size: 10000
|
||||||
|
train_size: 10000
|
||||||
|
batch_size: 256
|
||||||
|
num_workers: 3
|
||||||
|
ckpt_path: null
|
||||||
|
optimizer:
|
||||||
|
class_path: torch.optim.Adam
|
||||||
|
init_args:
|
||||||
|
lr: 0.001
|
||||||
|
betas:
|
||||||
|
- 0.9
|
||||||
|
- 0.999
|
||||||
|
eps: 1.0e-08
|
||||||
|
weight_decay: 0.0
|
||||||
|
amsgrad: false
|
||||||
|
foreach: null
|
||||||
|
maximize: false
|
||||||
|
capturable: false
|
||||||
|
differentiable: false
|
||||||
|
fused: null
|
||||||
|
lr_scheduler:
|
||||||
|
class_path: lightning.pytorch.cli.ReduceLROnPlateau
|
||||||
|
init_args:
|
||||||
|
monitor: hp_metric
|
||||||
|
mode: min
|
||||||
|
factor: 0.05
|
||||||
|
patience: 5
|
||||||
|
threshold: 0.0001
|
||||||
|
threshold_mode: rel
|
||||||
|
cooldown: 10
|
||||||
|
min_lr: 0.0
|
||||||
|
eps: 1.0e-08
|
||||||
|
verbose: true
|
86
hsv2.txt
Normal file
86
hsv2.txt
Normal file
@ -0,0 +1,86 @@
|
|||||||
|
# lightning.pytorch==2.1.3
|
||||||
|
seed_everything: 31
|
||||||
|
trainer:
|
||||||
|
accelerator: auto
|
||||||
|
strategy: auto
|
||||||
|
devices: auto
|
||||||
|
num_nodes: 1
|
||||||
|
precision: null
|
||||||
|
logger: null
|
||||||
|
callbacks:
|
||||||
|
- class_path: callbacks.SaveImageCallback
|
||||||
|
init_args:
|
||||||
|
save_interval: 0
|
||||||
|
final_dir: out
|
||||||
|
fast_dev_run: false
|
||||||
|
max_epochs: 10
|
||||||
|
min_epochs: 10
|
||||||
|
max_steps: -1
|
||||||
|
min_steps: null
|
||||||
|
max_time: null
|
||||||
|
limit_train_batches: null
|
||||||
|
limit_val_batches: 50
|
||||||
|
limit_test_batches: null
|
||||||
|
limit_predict_batches: null
|
||||||
|
overfit_batches: 0.0
|
||||||
|
val_check_interval: null
|
||||||
|
check_val_every_n_epoch: 1
|
||||||
|
num_sanity_val_steps: null
|
||||||
|
log_every_n_steps: 3
|
||||||
|
enable_checkpointing: null
|
||||||
|
enable_progress_bar: null
|
||||||
|
enable_model_summary: null
|
||||||
|
accumulate_grad_batches: 1
|
||||||
|
gradient_clip_val: null
|
||||||
|
gradient_clip_algorithm: null
|
||||||
|
deterministic: null
|
||||||
|
benchmark: null
|
||||||
|
inference_mode: true
|
||||||
|
use_distributed_sampler: true
|
||||||
|
profiler: null
|
||||||
|
detect_anomaly: false
|
||||||
|
barebones: false
|
||||||
|
plugins: null
|
||||||
|
sync_batchnorm: false
|
||||||
|
reload_dataloaders_every_n_epochs: 0
|
||||||
|
default_root_dir: null
|
||||||
|
model:
|
||||||
|
transform: tanh
|
||||||
|
width: 256
|
||||||
|
depth: 8
|
||||||
|
bias: true
|
||||||
|
alpha: 0.0
|
||||||
|
data:
|
||||||
|
val_size: 10000
|
||||||
|
train_size: 10000
|
||||||
|
batch_size: 256
|
||||||
|
num_workers: 3
|
||||||
|
ckpt_path: null
|
||||||
|
optimizer:
|
||||||
|
class_path: torch.optim.AdamW
|
||||||
|
init_args:
|
||||||
|
lr: 0.001
|
||||||
|
betas:
|
||||||
|
- 0.9
|
||||||
|
- 0.999
|
||||||
|
eps: 1.0e-08
|
||||||
|
weight_decay: 0.01
|
||||||
|
amsgrad: false
|
||||||
|
maximize: false
|
||||||
|
foreach: null
|
||||||
|
capturable: false
|
||||||
|
differentiable: false
|
||||||
|
fused: null
|
||||||
|
lr_scheduler:
|
||||||
|
class_path: lightning.pytorch.cli.ReduceLROnPlateau
|
||||||
|
init_args:
|
||||||
|
monitor: hp_metric
|
||||||
|
mode: min
|
||||||
|
factor: 0.05
|
||||||
|
patience: 5
|
||||||
|
threshold: 0.0001
|
||||||
|
threshold_mode: rel
|
||||||
|
cooldown: 10
|
||||||
|
min_lr: 0.0
|
||||||
|
eps: 1.0e-08
|
||||||
|
verbose: true
|
24
model.py
24
model.py
@ -56,23 +56,29 @@ class ColorTransformerModel(L.LightningModule):
|
|||||||
# torch.norm(outputs - labels), torch.norm(1 + outputs - labels)
|
# torch.norm(outputs - labels), torch.norm(1 + outputs - labels)
|
||||||
# ).mean()
|
# ).mean()
|
||||||
distance = torch.norm(outputs - labels).mean()
|
distance = torch.norm(outputs - labels).mean()
|
||||||
loss = p_loss
|
|
||||||
|
|
||||||
self.log("train_loss", distance)
|
# Backprop with this:
|
||||||
|
loss = p_loss
|
||||||
|
# p_loss is unsupervised
|
||||||
|
# distance is supervised.
|
||||||
self.log("hp_metric", loss)
|
self.log("hp_metric", loss)
|
||||||
self.log("p_loss", p_loss)
|
|
||||||
return distance
|
# Log all losses individually
|
||||||
|
self.log("train_mse", distance)
|
||||||
|
self.log("train_pres", p_loss)
|
||||||
|
return loss
|
||||||
|
|
||||||
def validation_step(self, batch):
|
def validation_step(self, batch):
|
||||||
inputs, labels = batch # these are true HSV labels - no learning allowed.
|
inputs, labels = batch # these are true HSV labels - no learning allowed.
|
||||||
outputs = self.forward(inputs)
|
outputs = self.forward(inputs)
|
||||||
distance = torch.minimum(
|
# distance = torch.minimum(
|
||||||
torch.norm(outputs - labels), torch.norm(1 + outputs - labels)
|
# torch.norm(outputs - labels), torch.norm(1 + outputs - labels)
|
||||||
)
|
# )
|
||||||
|
distance = torch.norm(outputs - labels)
|
||||||
mean_loss = torch.mean(distance)
|
mean_loss = torch.mean(distance)
|
||||||
max_loss = torch.max(distance)
|
max_loss = torch.max(distance)
|
||||||
self.log("val_mean_loss", mean_loss)
|
self.log("val_mse", mean_loss)
|
||||||
self.log("val_max_loss", max_loss)
|
self.log("val_max", max_loss)
|
||||||
return mean_loss
|
return mean_loss
|
||||||
|
|
||||||
def configure_optimizers(self):
|
def configure_optimizers(self):
|
||||||
|
Loading…
Reference in New Issue
Block a user