Newer
Older
UnetPointDetection / predict.py
import csv
import glob
import os

import cv2
import numpy as np
import torch
import torch.nn.functional as F

from train_val import heatmap_to_xy, write_heatmaps_as_image
from unet import ResNet18UNet


@torch.no_grad()
def infer_folder(images_dir, ckpt_path, out_dir="output", img_size=(384, 384)):
    device = "cuda" if torch.cuda.is_available() else "cpu"
    print("using device:", device)

    model = ResNet18UNet(out_channels=1, pretrained=False).to(device)
    state = torch.load(ckpt_path, map_location=device)
    model.load_state_dict(state["model"])
    model.eval()

    paths = sorted(glob.glob(os.path.join(images_dir, "*.jpg")))

    rows = [("name", "x", "y", "conf")]
    for p in paths[0:100]:
        img0 = cv2.imread(p, cv2.IMREAD_COLOR)
        h0, w0 = img0.shape[:2]
        img = cv2.resize(img0, img_size, interpolation=cv2.INTER_AREA)
        img = img[:, :, ::-1].astype(np.float32) / 255.0
        x = torch.from_numpy(np.transpose(img, (2, 0, 1))).unsqueeze(0).to(device)

        logits = model(x)
        hm = torch.sigmoid(logits)
        if hm.shape[-2:] != (img_size[1], img_size[0]):
            hm = F.interpolate(
                hm,
                size=(img_size[1], img_size[0]),
                mode="bilinear",
                align_corners=False,
            )
        xy = heatmap_to_xy(hm)[0].cpu().numpy()  # in resized coords
        conf = float(hm[0, 0, int(xy[1]), int(xy[0])].item())

        out_path = os.path.join(out_dir, os.path.basename(p).replace("frame", "hmap"))
        write_heatmaps_as_image([hm], [f"pred conf={conf:.2f}"], out_path)

        # back to original coords
        sx = w0 / img_size[0]
        sy = h0 / img_size[1]
        x0 = float(xy[0] * sx)
        y0 = float(xy[1] * sy)

        cv2.circle(img0, (int(x0), int(y0)), 5, (0, 0, 255), -1)
        out_path = os.path.join(out_dir, os.path.basename(p).replace("frame", "pred"))
        cv2.imwrite(out_path, img0)
        rows.append((os.path.basename(p), x0, y0, conf))
        print(f"processed {p}, conf={conf:.2f}")

    out_csv = os.path.join(out_dir, "pred.csv")
    with open(out_csv, "w", newline="", encoding="utf-8") as f:
        w = csv.writer(f)
        w.writerows(rows)

    print("saved", out_csv)


def main():
    images_dir = "../dataset/images"
    ckpt_path = "checkpoints/best.pt"
    out_dir = "predict_out"
    os.makedirs(out_dir, exist_ok=True)
    infer_folder(images_dir, ckpt_path, out_dir=out_dir, img_size=(384, 384))


if __name__ == "__main__":
    main()