import torch
from PIL import Image
from torchvision import transforms

from .model import RegressionModel


class StethoscopePredictor:
    def __init__(
        self, model_path, model_type="resnet", model_version="18", device=None
    ):
        """
        Initialize the predictor with a trained model

        Args:
            model_path (str): Path to the saved model weights
            model_type (str): Type of model ('resnet' or 'efficientnet')
            model_version (str): Version of the model (e.g., '18' for ResNet18, 'b0' for EfficientNet-B0)
            device (str): Device to run the model on ('cuda' or 'cpu')
        """
        self.device = (
            device if device else ("cuda" if torch.cuda.is_available() else "cpu")
        )
        self.transform = transforms.Compose(
            [
                transforms.Resize((224, 224)),
                transforms.ToTensor(),
                transforms.Normalize(
                    mean=[0.485, 0.456, 0.406], std=[0.229, 0.224, 0.225]
                ),
            ]
        )

        # Initialize model
        self.model = RegressionModel(model_name=model_version, model_type=model_type)
        self.model.load_state_dict(torch.load(model_path, map_location=self.device))
        self.model = self.model.to(self.device)
        self.model.eval()

    def predict(self, image_path):
        """
        Predict stethoscope coordinates from an image

        Args:
            image_path (str): Path to the input image

        Returns:
            tuple: Predicted (x, y) coordinates
        """
        # Load and preprocess image
        image = Image.open(image_path).convert("RGB")
        image_tensor = self.transform(image).unsqueeze(0).to(self.device)

        # Make prediction
        with torch.no_grad():
            prediction = self.model(image_tensor)

        return prediction[0].cpu().numpy()


def load_model(model_path, model_type="resnet", model_version="18", device=None):
    """
    Load a trained stethoscope detection model

    Args:
        model_path (str): Path to the saved model weights
        model_type (str): Type of model ('resnet' or 'efficientnet')
        model_version (str): Version of the model (e.g., '18' for ResNet18, 'b0' for EfficientNet-B0)
        device (str): Device to run the model on ('cuda' or 'cpu')

    Returns:
        StethoscopePredictor: Initialized predictor object
    """
    return StethoscopePredictor(model_path, model_type, model_version, device)


def predict(predictor, image_path):
    """
    Predict stethoscope coordinates using a loaded model

    Args:
        predictor (StethoscopePredictor): Initialized predictor object
        image_path (str): Path to the input image

    Returns:
        tuple: Predicted (x, y) coordinates
    """
    return predictor.predict(image_path)
