Newer
Older
UnetPointDetection / train.py
import os

import torch
from torch.utils.data import DataLoader, random_split

from dataset import NeedleTipDataset
from train_val import train_one_epoch, validate
from unet import ResNet18UNet


def main():
    images_dir = "../dataset/images"
    xml_path = "../dataset/annotations.xml"
    val_log_dir = "val_log"
    os.makedirs(val_log_dir, exist_ok=True)
    checkpoint_dir = "checkpoints"
    os.makedirs(checkpoint_dir, exist_ok=True)

    device = "cuda" if torch.cuda.is_available() else "cpu"
    print("using device:", device)

    ds = NeedleTipDataset(images_dir, xml_path, img_size=(384, 384), sigma=8.0)
    n = len(ds)
    n_train = int(n * 0.8)
    n_val = n - n_train
    train_ds, val_ds = random_split(
        ds, [n_train, n_val], generator=torch.Generator().manual_seed(42)
    )

    train_loader = DataLoader(
        train_ds, batch_size=16, shuffle=True, num_workers=4, pin_memory=True
    )
    val_loader = DataLoader(
        val_ds, batch_size=16, shuffle=False, num_workers=4, pin_memory=True
    )

    model = ResNet18UNet(out_channels=1, pretrained=True).to(device)
    optimizer = torch.optim.AdamW(model.parameters(), lr=3e-4, weight_decay=1e-4)

    best = 1e9
    for epoch in range(1, 101):
        tr_loss = train_one_epoch(model, train_loader, optimizer, device)
        va_loss, va_err = validate(model, val_loader, device, epoch, val_log_dir)

        print(
            f"epoch {epoch:03d} | train loss {tr_loss:.6f} | val loss {va_loss:.6f} | val err(px) {va_err:.2f}"
        )

        if va_err < best:
            best = va_err
            torch.save(
                {"model": model.state_dict()}, os.path.join(checkpoint_dir, "best.pt")
            )
            print("  saved best.pt")


if __name__ == "__main__":
    main()