|
@ -2,7 +2,7 @@ import matplotlib.colors as mcolors |
|
|
import torch |
|
|
import torch |
|
|
|
|
|
|
|
|
|
|
|
|
|
|
def preprocess_data(data, skip: bool = False): |
|
|
def preprocess_data(data, skip: bool = True): |
|
|
# Assuming 'data' is a tensor of shape [n_samples, 3] |
|
|
# Assuming 'data' is a tensor of shape [n_samples, 3] |
|
|
if not skip: |
|
|
if not skip: |
|
|
# Compute argmin and argmax for each row |
|
|
# Compute argmin and argmax for each row |
|
|