from pathlib import Path
from typing import List

import matplotlib.pyplot as plt
import torch
import torch.nn as nn
from torch.cuda.amp import GradScaler, autocast


def heatmap_to_xy(hm: torch.Tensor):
    """
    hm: (B,1,H,W) after sigmoid
    returns: (B,2) xy in heatmap coords
    """
    B, _, H, W = hm.shape
    flat = hm.view(B, -1)
    idx = torch.argmax(flat, dim=1)  # (B,)
    y = (idx // W).float()
    x = (idx % W).float()
    return torch.stack([x, y], dim=1)


# for debug: save heatmaps as images
def write_heatmaps_as_image(heatmaps: List[torch.Tensor], titles: List[str], path: str):
    n = len(heatmaps)
    fig, axes = plt.subplots(1, n, figsize=(4 * n, 4), tight_layout=True)
    if n == 1:
        axes = [axes]
    for i, hm in enumerate(heatmaps):
        img = hm[0, 0].cpu().numpy()
        im = axes[i].imshow(img, cmap="gray", vmin=0.0, vmax=1.0)
        axes[i].set_title(titles[i])
        axes[i].axis("off")
        fig.colorbar(im, ax=axes[i], fraction=0.046, pad=0.04)
    fig.savefig(path)
    plt.close(fig)


def train_one_epoch(model, loader, optimizer, device):

    model.train()
    scaler = GradScaler()
    loss_fn = nn.MSELoss()

    total = 0.0
    for batch in loader:
        img = batch["image"].to(device)
        gt = batch["heatmap"].to(device)

        optimizer.zero_grad(set_to_none=True)

        with autocast():
            pred = model(img)
            # サイズがズレたらGTへ合わせる
            if pred.shape[-2:] != gt.shape[-2:]:
                pred = nn.functional.interpolate(
                    pred, size=gt.shape[-2:], mode="bilinear", align_corners=False
                )

            loss = loss_fn(pred, gt)

        scaler.scale(loss).backward()
        scaler.step(optimizer)
        scaler.update()

        total += float(loss.item())
    return total / max(1, len(loader))


@torch.no_grad()
def validate(model, loader, device, epoch, log_dir):
    model.eval()
    loss_fn = nn.MSELoss()
    saved_distribution = False
    Path(log_dir).mkdir(parents=True, exist_ok=True)

    total_loss = 0.0
    total_err = 0.0
    n = 0

    for batch in loader:
        img = batch["image"].to(device)
        gt = batch["heatmap"].to(device)

        pred = model(img)
        if pred.shape[-2:] != gt.shape[-2:]:
            pred = nn.functional.interpolate(
                pred, size=gt.shape[-2:], mode="bilinear", align_corners=False
            )

        loss = loss_fn(pred, gt)
        total_loss += float(loss.item())

        hm = torch.sigmoid(pred)
        xy_pred = heatmap_to_xy(hm)  # heatmap coords

        if not saved_distribution:
            out_path = Path(log_dir) / f"val_map_ep{epoch:03d}.png"
            write_heatmaps_as_image(
                [pred, hm, gt], ["pred", "heatmap", "gt"], path=out_path
            )
            saved_distribution = True

        # GT点（ヒートマップの最大点）でも良いが、ここではGTヒートマップのargmaxを使う
        xy_gt = heatmap_to_xy(gt)

        err = torch.norm(xy_pred - xy_gt, dim=1).sum().item()  # pixel in resized space
        total_err += err
        n += img.size(0)

    return total_loss / max(1, len(loader)), total_err / max(1, n)
