ResNet18 + U-Net デコーダで、画像中の目標位置をヒートマップ回帰するプロジェクト
| .gitignore | 12 days ago | ||
| .python-version | 12 days ago | ||
| README.md | 12 days ago | ||
| dataset.py | 12 days ago | ||
| predict.py | 12 days ago | ||
| pyproject.toml | 12 days ago | ||
| torch_check.py | 12 days ago | ||
| train.py | 12 days ago | ||
| train_val.py | 12 days ago | ||
| unet.py | 12 days ago | ||
| uv.lock | 12 days ago | ||
ResNet18 + U-Net 風デコーダで、画像中の目標位置をヒートマップ回帰するプロジェクトです。
学習時は CVAT の annotations.xml から座標を読み取り、推論時は画像ごとの (x, y, conf) を pred.csv に出力します。
ResNet18UNet(unet.py)dataset.py)MSELoss(train_val.py)pred.csvこのコードは、以下のように学習データが 1 つ上の階層にある前提です(train.py / predict.py のデフォルト値)。
... ├─ dataset/ │ ├─ images/ │ │ ├─ frame_000001.jpg │ │ └─ ... │ └─ annotations.xml └─ code/ ├─ train.py ├─ predict.py └─ ...
annotations.xml 内の各 <image> 要素を使用<points label="tip" points="x,y"> を読み取りlabel="tip" が付いた画像のみ学習対象ラベル名を変える場合は
NeedleTipDataset(..., label_name="...")を変更してください。
winget install --id=astral-sh.uv -e
pyproject.toml は CUDA 13.0 向けの PyTorch インデックスを参照します。
uv sync
uv syncで.venvが自動作成され、依存がインストールされます。
GPU/ドライバ条件に合わない場合は、環境に合わせて torch / torchvision のバージョンを調整してください。
uv run python torch_check.py
uv run python train.py
デフォルト設定:
384x384sigma=8.010016AdamW(lr=3e-4, weight_decay=1e-4)80/20(seed=42)生成物:
checkpoints/best.pt(最小 val err(px) 更新時に保存)val_log/val_map_epXXX.png(予測分布の可視化)uv run python predict.py
デフォルト入力/出力:
../dataset/images/*.jpgcheckpoints/best.ptpredict_out/出力ファイル:
predict_out/pred_*.jpg : 元画像上に予測点を描画predict_out/hmap_*.jpg : 予測ヒートマップ可視化predict_out/pred.csv : name,x,y,confpred.csv の (x, y) は元画像座標系へスケールバックされた値です。
dataset.py : CVAT XML 読み込み・ヒートマップ生成・Datasetunet.py : ResNet18 エンコーダ + U-Net デコーダtrain_val.py : 学習/検証ループ・可視化・座標復元train.py : 学習エントリポイントpredict.py : 画像フォルダ推論と CSV 出力FileNotFoundError が出る
../dataset/images と ../dataset/annotations.xml の配置を確認してください。uv run python torch_check.py で確認し、PyTorch の CUDA バージョンを環境に合わせてください。train.py の batch_size と num_workers を下げてください。predict.py はデフォルトで先頭 100 枚(paths[0:100])を処理します。