|
@ -61,7 +61,13 @@ def create_circle( |
|
|
|
|
|
|
|
|
|
|
|
|
|
|
def plot_preds( |
|
|
def plot_preds( |
|
|
preds, rgb_values, fname: str, roll: bool = False, dpi: int = 300, figsize=(6, 6) |
|
|
preds, |
|
|
|
|
|
rgb_values, |
|
|
|
|
|
fname: str, |
|
|
|
|
|
roll: bool = False, |
|
|
|
|
|
inner_radius=1 / 3, |
|
|
|
|
|
dpi: int = 300, |
|
|
|
|
|
figsize=(6, 6), |
|
|
): |
|
|
): |
|
|
if isinstance(preds, torch.Tensor): |
|
|
if isinstance(preds, torch.Tensor): |
|
|
preds = preds.detach().cpu().numpy() |
|
|
preds = preds.detach().cpu().numpy() |
|
@ -111,7 +117,6 @@ def plot_preds( |
|
|
ax.set_ylim(-radius, radius) |
|
|
ax.set_ylim(-radius, radius) |
|
|
|
|
|
|
|
|
# Overlay white circle |
|
|
# Overlay white circle |
|
|
inner_radius = 1 / 3 |
|
|
|
|
|
circle = patches.Circle( |
|
|
circle = patches.Circle( |
|
|
(0, 0), inner_radius, transform=ax.transData._b, color="white", zorder=2 |
|
|
(0, 0), inner_radius, transform=ax.transData._b, color="white", zorder=2 |
|
|
) |
|
|
) |
|
|