ResNet18 + U-Net デコーダで、画像中の目標位置をヒートマップ回帰するプロジェクト

.gitignore 針先端検出の学習と推論 基本実装 12 days ago
.python-version 針先端検出の学習と推論 基本実装 12 days ago
README.md README 一般的な表現に修正 12 days ago
dataset.py 針先端検出の学習と推論 基本実装 12 days ago
predict.py 針先端検出の学習と推論 基本実装 12 days ago
pyproject.toml 針先端検出の学習と推論 基本実装 12 days ago
torch_check.py 針先端検出の学習と推論 基本実装 12 days ago
train.py 針先端検出の学習と推論 基本実装 12 days ago
train_val.py 針先端検出の学習と推論 基本実装 12 days ago
unet.py 針先端検出の学習と推論 基本実装 12 days ago
uv.lock 針先端検出の学習と推論 基本実装 12 days ago
README.md

Unet Point Detection

ResNet18 + U-Net 風デコーダで、画像中の目標位置をヒートマップ回帰するプロジェクトです。
学習時は CVAT の annotations.xml から座標を読み取り、推論時は画像ごとの (x, y, conf)pred.csv に出力します。

特徴

  • モデル: ResNet18UNetunet.py
  • 教師信号: ガウシアンヒートマップ(dataset.py
  • 損失: MSELosstrain_val.py
  • 評価: GT と予測の argmax 間ユークリッド距離(pixel)
  • 推論出力: 可視化画像 + pred.csv

想定ディレクトリ

このコードは、以下のように学習データが 1 つ上の階層にある前提です(train.py / predict.py のデフォルト値)。

...
├─ dataset/
│  ├─ images/
│  │  ├─ frame_000001.jpg
│  │  └─ ...
│  └─ annotations.xml
└─ code/
	 ├─ train.py
	 ├─ predict.py
	 └─ ...

アノテーション形式(CVAT)

  • annotations.xml 内の各 <image> 要素を使用
  • <points label="tip" points="x,y"> を読み取り
  • label="tip" が付いた画像のみ学習対象

ラベル名を変える場合は NeedleTipDataset(..., label_name="...") を変更してください。

環境構築(uv)

1) uv のインストール

winget install --id=astral-sh.uv -e

2) 仮想環境作成 + 依存解決

pyproject.toml は CUDA 13.0 向けの PyTorch インデックスを参照します。

uv sync

uv sync.venv が自動作成され、依存がインストールされます。

GPU/ドライバ条件に合わない場合は、環境に合わせて torch / torchvision のバージョンを調整してください。

3) 動作確認

uv run python torch_check.py

学習

uv run python train.py

デフォルト設定:

  • 入力サイズ: 384x384
  • sigma=8.0
  • epoch: 100
  • batch size: 16
  • optimizer: AdamW(lr=3e-4, weight_decay=1e-4)
  • train/val split: 80/20(seed=42)

生成物:

  • checkpoints/best.pt(最小 val err(px) 更新時に保存)
  • val_log/val_map_epXXX.png(予測分布の可視化)

推論

uv run python predict.py

デフォルト入力/出力:

  • 入力画像: ../dataset/images/*.jpg
  • 重み: checkpoints/best.pt
  • 出力先: predict_out/

出力ファイル:

  • predict_out/pred_*.jpg : 元画像上に予測点を描画
  • predict_out/hmap_*.jpg : 予測ヒートマップ可視化
  • predict_out/pred.csv : name,x,y,conf

pred.csv(x, y)元画像座標系へスケールバックされた値です。

主要ファイル

  • dataset.py : CVAT XML 読み込み・ヒートマップ生成・Dataset
  • unet.py : ResNet18 エンコーダ + U-Net デコーダ
  • train_val.py : 学習/検証ループ・可視化・座標復元
  • train.py : 学習エントリポイント
  • predict.py : 画像フォルダ推論と CSV 出力

よくある問題

  • FileNotFoundError が出る
    • ../dataset/images../dataset/annotations.xml の配置を確認してください。
  • CUDA が使えない
    • uv run python torch_check.py で確認し、PyTorch の CUDA バージョンを環境に合わせてください。
  • 学習が遅い/落ちる
    • train.pybatch_sizenum_workers を下げてください。

メモ

  • モデル出力と GT サイズが一致しない場合、コード内で補間してから損失計算しています。
  • predict.py はデフォルトで先頭 100 枚(paths[0:100])を処理します。