diff --git a/utils.py b/utils.py index 3522a76..e0f630c 100644 --- a/utils.py +++ b/utils.py @@ -2,7 +2,7 @@ import matplotlib.colors as mcolors import torch -def preprocess_data(data, skip: bool = False): +def preprocess_data(data, skip: bool = True): # Assuming 'data' is a tensor of shape [n_samples, 3] if not skip: # Compute argmin and argmax for each row