import csv
import glob
import os
import cv2
import numpy as np
import torch
import torch.nn.functional as F
from train_val import heatmap_to_xy, write_heatmaps_as_image
from unet import ResNet18UNet
@torch.no_grad()
def infer_folder(images_dir, ckpt_path, out_dir="output", img_size=(384, 384)):
device = "cuda" if torch.cuda.is_available() else "cpu"
print("using device:", device)
model = ResNet18UNet(out_channels=1, pretrained=False).to(device)
state = torch.load(ckpt_path, map_location=device)
model.load_state_dict(state["model"])
model.eval()
paths = sorted(glob.glob(os.path.join(images_dir, "*.jpg")))
rows = [("name", "x", "y", "conf")]
for p in paths[0:100]:
img0 = cv2.imread(p, cv2.IMREAD_COLOR)
h0, w0 = img0.shape[:2]
img = cv2.resize(img0, img_size, interpolation=cv2.INTER_AREA)
img = img[:, :, ::-1].astype(np.float32) / 255.0
x = torch.from_numpy(np.transpose(img, (2, 0, 1))).unsqueeze(0).to(device)
logits = model(x)
hm = torch.sigmoid(logits)
if hm.shape[-2:] != (img_size[1], img_size[0]):
hm = F.interpolate(
hm,
size=(img_size[1], img_size[0]),
mode="bilinear",
align_corners=False,
)
xy = heatmap_to_xy(hm)[0].cpu().numpy() # in resized coords
conf = float(hm[0, 0, int(xy[1]), int(xy[0])].item())
out_path = os.path.join(out_dir, os.path.basename(p).replace("frame", "hmap"))
write_heatmaps_as_image([hm], [f"pred conf={conf:.2f}"], out_path)
# back to original coords
sx = w0 / img_size[0]
sy = h0 / img_size[1]
x0 = float(xy[0] * sx)
y0 = float(xy[1] * sy)
cv2.circle(img0, (int(x0), int(y0)), 5, (0, 0, 255), -1)
out_path = os.path.join(out_dir, os.path.basename(p).replace("frame", "pred"))
cv2.imwrite(out_path, img0)
rows.append((os.path.basename(p), x0, y0, conf))
print(f"processed {p}, conf={conf:.2f}")
out_csv = os.path.join(out_dir, "pred.csv")
with open(out_csv, "w", newline="", encoding="utf-8") as f:
w = csv.writer(f)
w.writerows(rows)
print("saved", out_csv)
def main():
images_dir = "../dataset/images"
ckpt_path = "checkpoints/best.pt"
out_dir = "predict_out"
os.makedirs(out_dir, exist_ok=True)
infer_folder(images_dir, ckpt_path, out_dir=out_dir, img_size=(384, 384))
if __name__ == "__main__":
main()