import os

import torch
from PIL import Image
from torchvision import transforms

# 学習時に使った RegressionModel と同じものを import
# （train.py 内では "model.py" を読んでいる想定）
from .model import RegressionModel


class EARSNetPredictor:
    def __init__(
        self,
        weight_path: str,
        resnet_depth: str = "18",
        pretrained: bool = True,
        device: str = None,
    ):
        """
        学習時と同じ構造・重みをもつモデルをロードし、
        224×224スケールでの推論を行えるようにするクラス。

        Args:
            weight_path (str): 学習済みモデルの pth ファイルパス (例: "best_model.pth")
            resnet_depth (str): "18","34","50","101","152" など
            pretrained (bool): True なら ImageNet 事前学習ウェイトベース
            device (str): 'cuda' or 'cpu' (指定しない場合、自動判定)
        """
        self.device = (
            torch.device(device)
            if device
            else torch.device("cuda" if torch.cuda.is_available() else "cpu")
        )

        # 学習時と同じ ResNet depth & 前処理設定でモデルを用意
        self.model = RegressionModel(
            resnet_depth=resnet_depth,
            pretrained=pretrained,
        ).to(self.device)

        # 学習済みウェイトをロード
        if not os.path.isfile(weight_path):
            raise FileNotFoundError(f"Weight file not found: {weight_path}")
        self.model.load_state_dict(torch.load(weight_path, map_location=self.device))
        self.model.eval()

        # 学習時に使った画像サイズ・正規化パラメータ
        self.input_size = (224, 224)
        self.mean = [0.485, 0.456, 0.406]
        self.std = [0.229, 0.224, 0.225]

        # 今回は "224×224 の座標" をそのまま返す簡易実装のため、
        # 元画像→224×224 へリサイズしたスケーリング係数は保持しない

    def _preprocess(self, image: Image.Image):
        """
        画像を224x224にリサイズ → テンソル化 → 正規化
        """
        # 画像を224x224にリサイズ
        resized_image = image.resize(self.input_size, Image.BILINEAR)

        # Tensor化 & 正規化
        x = transforms.ToTensor()(resized_image)
        x = transforms.Normalize(mean=self.mean, std=self.std)(x)

        return x  # スケール係数は返さない

    def predict(self, image_path: str):
        """
        1枚の画像ファイルパスに対し、(224×224座標系) での (x, y) を推論して返す。

        Returns:
            (pred_x_224, pred_y_224): 224x224座標系での推定結果
        """
        if not os.path.isfile(image_path):
            raise FileNotFoundError(f"Image file not found: {image_path}")

        # 画像を開く
        image = Image.open(image_path).convert("RGB")

        # 前処理 (224x224にリサイズ & 正規化)
        input_tensor = self._preprocess(image)

        # 推論
        input_tensor = input_tensor.unsqueeze(0).to(self.device)  # (B=1, C, H, W)
        with torch.no_grad():
            pred = self.model(input_tensor)  # shape: (1,2)

        pred_x_224 = pred[0][0].item()
        pred_y_224 = pred[0][1].item()

        # 「224×224の座標」をそのまま返す
        return pred_x_224, pred_y_224
