import os
import xml.etree.ElementTree as ET

import cv2
import numpy as np
import torch
from torch.utils.data import Dataset


def load_cvat_points(xml_path: str, label_name: str = "tip"):
    """
    returns: dict[str, tuple[float,float]]  e.g. {"frame_000004.jpg": (143.46, 451.91)}
    """
    tree = ET.parse(xml_path)
    root = tree.getroot()

    ann = {}
    for img in root.findall("image"):
        name = img.attrib["name"]
        # image内に複数pointsがある可能性もあるのでfindall
        pts = img.findall("points")
        for p in pts:
            if p.attrib.get("label") == label_name:
                xy = p.attrib["points"].split(",")
                x, y = float(xy[0]), float(xy[1])
                ann[name] = (x, y)
                break
    return ann


def make_gaussian_heatmap(h, w, x, y, sigma=3.0):
    """center (x,y) in pixel coords; returns (h,w) float32"""
    xx, yy = np.meshgrid(np.arange(w), np.arange(h))
    hm = np.exp(-((xx - x) ** 2 + (yy - y) ** 2) / (2 * sigma**2)).astype(np.float32)
    return hm


class NeedleTipDataset(Dataset):
    def __init__(
        self, images_dir, xml_path, img_size=(384, 384), sigma=3.0, label_name="tip"
    ):
        self.images_dir = images_dir
        self.img_w = img_size[0]
        self.img_h = img_size[1]
        self.sigma = sigma

        self.ann = load_cvat_points(xml_path, label_name=label_name)

        # tipが付いている画像だけを学習に使う（まずはここから）
        self.names = sorted(list(self.ann.keys()))

    def __len__(self):
        return len(self.names)

    def __getitem__(self, idx):
        name = self.names[idx]
        path = os.path.join(self.images_dir, name)
        img = cv2.imread(path, cv2.IMREAD_COLOR)
        if img is None:
            raise FileNotFoundError(path)

        h0, w0 = img.shape[:2]
        x0, y0 = self.ann[name]  # original coords

        # resize
        img = cv2.resize(img, (self.img_w, self.img_h), interpolation=cv2.INTER_AREA)
        sx = self.img_w / w0
        sy = self.img_h / h0
        x = x0 * sx
        y = y0 * sy

        # heatmap
        hm = make_gaussian_heatmap(
            self.img_h, self.img_w, x, y, sigma=self.sigma
        )  # (H,W)

        # to tensor
        img = img[:, :, ::-1].copy()  # BGR->RGB
        img = img.astype(np.float32) / 255.0
        img = np.transpose(img, (2, 0, 1))  # (C,H,W)

        return {
            "image": torch.from_numpy(img),  # float32, (3,H,W)
            "heatmap": torch.from_numpy(hm)[None, ...],  # float32, (1,H,W)
            "name": name,
            "scale": torch.tensor([sx, sy], dtype=torch.float32),
        }
