import os
import xml.etree.ElementTree as ET
import cv2
import numpy as np
import torch
from torch.utils.data import Dataset
def load_cvat_points(xml_path: str, label_name: str = "tip"):
"""
returns: dict[str, tuple[float,float]] e.g. {"frame_000004.jpg": (143.46, 451.91)}
"""
tree = ET.parse(xml_path)
root = tree.getroot()
ann = {}
for img in root.findall("image"):
name = img.attrib["name"]
# image内に複数pointsがある可能性もあるのでfindall
pts = img.findall("points")
for p in pts:
if p.attrib.get("label") == label_name:
xy = p.attrib["points"].split(",")
x, y = float(xy[0]), float(xy[1])
ann[name] = (x, y)
break
return ann
def make_gaussian_heatmap(h, w, x, y, sigma=3.0):
"""center (x,y) in pixel coords; returns (h,w) float32"""
xx, yy = np.meshgrid(np.arange(w), np.arange(h))
hm = np.exp(-((xx - x) ** 2 + (yy - y) ** 2) / (2 * sigma**2)).astype(np.float32)
return hm
class NeedleTipDataset(Dataset):
def __init__(
self, images_dir, xml_path, img_size=(384, 384), sigma=3.0, label_name="tip"
):
self.images_dir = images_dir
self.img_w = img_size[0]
self.img_h = img_size[1]
self.sigma = sigma
self.ann = load_cvat_points(xml_path, label_name=label_name)
# tipが付いている画像だけを学習に使う(まずはここから)
self.names = sorted(list(self.ann.keys()))
def __len__(self):
return len(self.names)
def __getitem__(self, idx):
name = self.names[idx]
path = os.path.join(self.images_dir, name)
img = cv2.imread(path, cv2.IMREAD_COLOR)
if img is None:
raise FileNotFoundError(path)
h0, w0 = img.shape[:2]
x0, y0 = self.ann[name] # original coords
# resize
img = cv2.resize(img, (self.img_w, self.img_h), interpolation=cv2.INTER_AREA)
sx = self.img_w / w0
sy = self.img_h / h0
x = x0 * sx
y = y0 * sy
# heatmap
hm = make_gaussian_heatmap(
self.img_h, self.img_w, x, y, sigma=self.sigma
) # (H,W)
# to tensor
img = img[:, :, ::-1].copy() # BGR->RGB
img = img.astype(np.float32) / 255.0
img = np.transpose(img, (2, 0, 1)) # (C,H,W)
return {
"image": torch.from_numpy(img), # float32, (3,H,W)
"heatmap": torch.from_numpy(hm)[None, ...], # float32, (1,H,W)
"name": name,
"scale": torch.tensor([sx, sy], dtype=torch.float32),
}