Newer
Older
UnetPointDetection / README.md

Unet Point Detection

ResNet18 + U-Net 風デコーダで、画像中の目標位置をヒートマップ回帰するプロジェクトです。
学習時は CVAT の annotations.xml から座標を読み取り、推論時は画像ごとの (x, y, conf)pred.csv に出力します。

特徴

  • モデル: ResNet18UNetunet.py
  • 教師信号: ガウシアンヒートマップ(dataset.py
  • 損失: MSELosstrain_val.py
  • 評価: GT と予測の argmax 間ユークリッド距離(pixel)
  • 推論出力: 可視化画像 + pred.csv

想定ディレクトリ

このコードは、以下のように学習データが 1 つ上の階層にある前提です(train.py / predict.py のデフォルト値)。

...
├─ dataset/
│  ├─ images/
│  │  ├─ frame_000001.jpg
│  │  └─ ...
│  └─ annotations.xml
└─ code/
	 ├─ train.py
	 ├─ predict.py
	 └─ ...

アノテーション形式(CVAT)

  • annotations.xml 内の各 <image> 要素を使用
  • <points label="tip" points="x,y"> を読み取り
  • label="tip" が付いた画像のみ学習対象

ラベル名を変える場合は NeedleTipDataset(..., label_name="...") を変更してください。

環境構築(uv)

1) uv のインストール

winget install --id=astral-sh.uv -e

2) 仮想環境作成 + 依存解決

pyproject.toml は CUDA 13.0 向けの PyTorch インデックスを参照します。

uv sync

uv sync.venv が自動作成され、依存がインストールされます。

GPU/ドライバ条件に合わない場合は、環境に合わせて torch / torchvision のバージョンを調整してください。

3) 動作確認

uv run python torch_check.py

学習

uv run python train.py

デフォルト設定:

  • 入力サイズ: 384x384
  • sigma=8.0
  • epoch: 100
  • batch size: 16
  • optimizer: AdamW(lr=3e-4, weight_decay=1e-4)
  • train/val split: 80/20(seed=42)

生成物:

  • checkpoints/best.pt(最小 val err(px) 更新時に保存)
  • val_log/val_map_epXXX.png(予測分布の可視化)

推論

uv run python predict.py

デフォルト入力/出力:

  • 入力画像: ../dataset/images/*.jpg
  • 重み: checkpoints/best.pt
  • 出力先: predict_out/

出力ファイル:

  • predict_out/pred_*.jpg : 元画像上に予測点を描画
  • predict_out/hmap_*.jpg : 予測ヒートマップ可視化
  • predict_out/pred.csv : name,x,y,conf

pred.csv(x, y)元画像座標系へスケールバックされた値です。

主要ファイル

  • dataset.py : CVAT XML 読み込み・ヒートマップ生成・Dataset
  • unet.py : ResNet18 エンコーダ + U-Net デコーダ
  • train_val.py : 学習/検証ループ・可視化・座標復元
  • train.py : 学習エントリポイント
  • predict.py : 画像フォルダ推論と CSV 出力

よくある問題

  • FileNotFoundError が出る
    • ../dataset/images../dataset/annotations.xml の配置を確認してください。
  • CUDA が使えない
    • uv run python torch_check.py で確認し、PyTorch の CUDA バージョンを環境に合わせてください。
  • 学習が遅い/落ちる
    • train.pybatch_sizenum_workers を下げてください。

メモ

  • モデル出力と GT サイズが一致しない場合、コード内で補間してから損失計算しています。
  • predict.py はデフォルトで先頭 100 枚(paths[0:100])を処理します。