import os
import torch
from PIL import Image
from torchvision import transforms
# 学習時に使った RegressionModel と同じものを import
# (train.py 内では "model.py" を読んでいる想定)
from .model import RegressionModel
class EARSNetPredictor:
def __init__(
self,
weight_path: str,
resnet_depth: str = "18",
pretrained: bool = True,
device: str = None,
):
"""
学習時と同じ構造・重みをもつモデルをロードし、
224×224スケールでの推論を行えるようにするクラス。
Args:
weight_path (str): 学習済みモデルの pth ファイルパス (例: "best_model.pth")
resnet_depth (str): "18","34","50","101","152" など
pretrained (bool): True なら ImageNet 事前学習ウェイトベース
device (str): 'cuda' or 'cpu' (指定しない場合、自動判定)
"""
self.device = (
torch.device(device)
if device
else torch.device("cuda" if torch.cuda.is_available() else "cpu")
)
# 学習時と同じ ResNet depth & 前処理設定でモデルを用意
self.model = RegressionModel(
resnet_depth=resnet_depth,
pretrained=pretrained,
).to(self.device)
# 学習済みウェイトをロード
if not os.path.isfile(weight_path):
raise FileNotFoundError(f"Weight file not found: {weight_path}")
self.model.load_state_dict(torch.load(weight_path, map_location=self.device))
self.model.eval()
# 学習時に使った画像サイズ・正規化パラメータ
self.input_size = (224, 224)
self.mean = [0.485, 0.456, 0.406]
self.std = [0.229, 0.224, 0.225]
# 今回は "224×224 の座標" をそのまま返す簡易実装のため、
# 元画像→224×224 へリサイズしたスケーリング係数は保持しない
def _preprocess(self, image: Image.Image):
"""
画像を224x224にリサイズ → テンソル化 → 正規化
"""
# 画像を224x224にリサイズ
resized_image = image.resize(self.input_size, Image.BILINEAR)
# Tensor化 & 正規化
x = transforms.ToTensor()(resized_image)
x = transforms.Normalize(mean=self.mean, std=self.std)(x)
return x # スケール係数は返さない
def predict(self, image_path: str):
"""
1枚の画像ファイルパスに対し、(224×224座標系) での (x, y) を推論して返す。
Returns:
(pred_x_224, pred_y_224): 224x224座標系での推定結果
"""
if not os.path.isfile(image_path):
raise FileNotFoundError(f"Image file not found: {image_path}")
# 画像を開く
image = Image.open(image_path).convert("RGB")
# 前処理 (224x224にリサイズ & 正規化)
input_tensor = self._preprocess(image)
# 推論
input_tensor = input_tensor.unsqueeze(0).to(self.device) # (B=1, C, H, W)
with torch.no_grad():
pred = self.model(input_tensor) # shape: (1,2)
pred_x_224 = pred[0][0].item()
pred_y_224 = pred[0][1].item()
# 「224×224の座標」をそのまま返す
return pred_x_224, pred_y_224