import os
import torch
from torch.utils.data import DataLoader, random_split
from dataset import NeedleTipDataset
from train_val import train_one_epoch, validate
from unet import ResNet18UNet
def main():
images_dir = "../dataset/images"
xml_path = "../dataset/annotations.xml"
val_log_dir = "val_log"
os.makedirs(val_log_dir, exist_ok=True)
checkpoint_dir = "checkpoints"
os.makedirs(checkpoint_dir, exist_ok=True)
device = "cuda" if torch.cuda.is_available() else "cpu"
print("using device:", device)
ds = NeedleTipDataset(images_dir, xml_path, img_size=(384, 384), sigma=8.0)
n = len(ds)
n_train = int(n * 0.8)
n_val = n - n_train
train_ds, val_ds = random_split(
ds, [n_train, n_val], generator=torch.Generator().manual_seed(42)
)
train_loader = DataLoader(
train_ds, batch_size=16, shuffle=True, num_workers=4, pin_memory=True
)
val_loader = DataLoader(
val_ds, batch_size=16, shuffle=False, num_workers=4, pin_memory=True
)
model = ResNet18UNet(out_channels=1, pretrained=True).to(device)
optimizer = torch.optim.AdamW(model.parameters(), lr=3e-4, weight_decay=1e-4)
best = 1e9
for epoch in range(1, 101):
tr_loss = train_one_epoch(model, train_loader, optimizer, device)
va_loss, va_err = validate(model, val_loader, device, epoch, val_log_dir)
print(
f"epoch {epoch:03d} | train loss {tr_loss:.6f} | val loss {va_loss:.6f} | val err(px) {va_err:.2f}"
)
if va_err < best:
best = va_err
torch.save(
{"model": model.state_dict()}, os.path.join(checkpoint_dir, "best.pt")
)
print(" saved best.pt")
if __name__ == "__main__":
main()