from pathlib import Path
from typing import List
import matplotlib.pyplot as plt
import torch
import torch.nn as nn
from torch.cuda.amp import GradScaler, autocast
def heatmap_to_xy(hm: torch.Tensor):
"""
hm: (B,1,H,W) after sigmoid
returns: (B,2) xy in heatmap coords
"""
B, _, H, W = hm.shape
flat = hm.view(B, -1)
idx = torch.argmax(flat, dim=1) # (B,)
y = (idx // W).float()
x = (idx % W).float()
return torch.stack([x, y], dim=1)
# for debug: save heatmaps as images
def write_heatmaps_as_image(heatmaps: List[torch.Tensor], titles: List[str], path: str):
n = len(heatmaps)
fig, axes = plt.subplots(1, n, figsize=(4 * n, 4), tight_layout=True)
if n == 1:
axes = [axes]
for i, hm in enumerate(heatmaps):
img = hm[0, 0].cpu().numpy()
im = axes[i].imshow(img, cmap="gray", vmin=0.0, vmax=1.0)
axes[i].set_title(titles[i])
axes[i].axis("off")
fig.colorbar(im, ax=axes[i], fraction=0.046, pad=0.04)
fig.savefig(path)
plt.close(fig)
def train_one_epoch(model, loader, optimizer, device):
model.train()
scaler = GradScaler()
loss_fn = nn.MSELoss()
total = 0.0
for batch in loader:
img = batch["image"].to(device)
gt = batch["heatmap"].to(device)
optimizer.zero_grad(set_to_none=True)
with autocast():
pred = model(img)
# サイズがズレたらGTへ合わせる
if pred.shape[-2:] != gt.shape[-2:]:
pred = nn.functional.interpolate(
pred, size=gt.shape[-2:], mode="bilinear", align_corners=False
)
loss = loss_fn(pred, gt)
scaler.scale(loss).backward()
scaler.step(optimizer)
scaler.update()
total += float(loss.item())
return total / max(1, len(loader))
@torch.no_grad()
def validate(model, loader, device, epoch, log_dir):
model.eval()
loss_fn = nn.MSELoss()
saved_distribution = False
Path(log_dir).mkdir(parents=True, exist_ok=True)
total_loss = 0.0
total_err = 0.0
n = 0
for batch in loader:
img = batch["image"].to(device)
gt = batch["heatmap"].to(device)
pred = model(img)
if pred.shape[-2:] != gt.shape[-2:]:
pred = nn.functional.interpolate(
pred, size=gt.shape[-2:], mode="bilinear", align_corners=False
)
loss = loss_fn(pred, gt)
total_loss += float(loss.item())
hm = torch.sigmoid(pred)
xy_pred = heatmap_to_xy(hm) # heatmap coords
if not saved_distribution:
out_path = Path(log_dir) / f"val_map_ep{epoch:03d}.png"
write_heatmaps_as_image(
[pred, hm, gt], ["pred", "heatmap", "gt"], path=out_path
)
saved_distribution = True
# GT点(ヒートマップの最大点)でも良いが、ここではGTヒートマップのargmaxを使う
xy_gt = heatmap_to_xy(gt)
err = torch.norm(xy_pred - xy_gt, dim=1).sum().item() # pixel in resized space
total_err += err
n += img.size(0)
return total_loss / max(1, len(loader)), total_err / max(1, n)