import torch
from PIL import Image
from torchvision import transforms
from .model import RegressionModel
class StethoscopePredictor:
def __init__(
self, model_path, model_type="resnet", model_version="18", device=None
):
"""
Initialize the predictor with a trained model
Args:
model_path (str): Path to the saved model weights
model_type (str): Type of model ('resnet' or 'efficientnet')
model_version (str): Version of the model (e.g., '18' for ResNet18, 'b0' for EfficientNet-B0)
device (str): Device to run the model on ('cuda' or 'cpu')
"""
self.device = (
device if device else ("cuda" if torch.cuda.is_available() else "cpu")
)
self.transform = transforms.Compose(
[
transforms.Resize((224, 224)),
transforms.ToTensor(),
transforms.Normalize(
mean=[0.485, 0.456, 0.406], std=[0.229, 0.224, 0.225]
),
]
)
# Initialize model
self.model = RegressionModel(model_name=model_version, model_type=model_type)
self.model.load_state_dict(torch.load(model_path, map_location=self.device))
self.model = self.model.to(self.device)
self.model.eval()
def predict(self, image_path):
"""
Predict stethoscope coordinates from an image
Args:
image_path (str): Path to the input image
Returns:
tuple: Predicted (x, y) coordinates
"""
# Load and preprocess image
image = Image.open(image_path).convert("RGB")
image_tensor = self.transform(image).unsqueeze(0).to(self.device)
# Make prediction
with torch.no_grad():
prediction = self.model(image_tensor)
return prediction[0].cpu().numpy()
def load_model(model_path, model_type="resnet", model_version="18", device=None):
"""
Load a trained stethoscope detection model
Args:
model_path (str): Path to the saved model weights
model_type (str): Type of model ('resnet' or 'efficientnet')
model_version (str): Version of the model (e.g., '18' for ResNet18, 'b0' for EfficientNet-B0)
device (str): Device to run the model on ('cuda' or 'cpu')
Returns:
StethoscopePredictor: Initialized predictor object
"""
return StethoscopePredictor(model_path, model_type, model_version, device)
def predict(predictor, image_path):
"""
Predict stethoscope coordinates using a loaded model
Args:
predictor (StethoscopePredictor): Initialized predictor object
image_path (str): Path to the input image
Returns:
tuple: Predicted (x, y) coordinates
"""
return predictor.predict(image_path)