diff --git a/config.py b/config.py index 44850ea..f531564 100644 --- a/config.py +++ b/config.py @@ -7,10 +7,7 @@ CONV_COLOR = (0, 255, 0) # Green XGBOOST_COLOR = (255, 0, 0) # Red LIGHTGBM_COLOR = (0, 0, 255) # Blue -RESNET_COLOR = (255, 165, 0) # Orange -EFFICIENTNET_COLOR = (0, 0, 255) # Blue -MOBILENET_COLOR = (255, 0, 0) # Red -SQUEEZENET_COLOR = (128, 0, 128) # Purple +EARSNET_COLOR = (0, 165, 255) # Orange # Model execution settings CONV_ENABLED = True @@ -23,10 +20,7 @@ YOLOX_ENABLED = True # Neural network model settings -RESNET_ENABLED = False -EFFICIENTNET_ENABLED = False -MOBILENET_ENABLED = False -SQUEEZENET_ENABLED = False +EARSNET_ENABLED = True # Processing settings BATCH_SIZE = 16 diff --git a/main.py b/main.py index 12deb59..3a7a025 100644 --- a/main.py +++ b/main.py @@ -18,6 +18,8 @@ from mmpose.utils import adapt_mmdet_pipeline import config +from modules.EARSNet.predictor import load_model as load_earsnet_model +from modules.EARSNet.predictor import predict as predict_earsnet from util.calc_ste_position import CalcStethoscopePosition from util.ears_ai import EarsAI @@ -25,6 +27,7 @@ CONV_COLOR = config.CONV_COLOR XGBOOST_COLOR = config.XGBOOST_COLOR LIGHTGBM_COLOR = config.LIGHTGBM_COLOR +EARSNET_COLOR = config.EARSNET_COLOR # Get model execution settings CONV_ENABLED = config.CONV_ENABLED @@ -34,6 +37,7 @@ RTMPOSE_ENABLED = config.RTMPOSE_ENABLED MobileNetV1SSD_ENABLED = config.MOBILENETV1SSD_ENABLED YOLOX_ENABLED = config.YOLOX_ENABLED +EARSNET_ENABLED = config.EARSNET_ENABLED # Get normalization setting NORMALIZE_ENABLED = config.NORMALIZE_ENABLED @@ -465,6 +469,8 @@ fieldnames.extend(["Xgboost_stethoscope_x", "Xgboost_stethoscope_y"]) if LIGHTGBM_ENABLED: fieldnames.extend(["lightGBM_stethoscope_x", "lightGBM_stethoscope_y"]) + if EARSNET_ENABLED: + fieldnames.extend(["earsnet_stethoscope_x", "earsnet_stethoscope_y"]) if LIGHTGBM_ENABLED: lgb_model_x = load_model( @@ -480,6 +486,13 @@ xg_model_y = load_model( "./models/xg_stethoscope_calc_y_best_model-Fold4.pkl" ) + # Load models + if EARSNET_ENABLED: + earsnet_predictor = load_earsnet_model( + model_path="models/EARSNet/best_model.pth", + model_type="resnet", + model_version="18", + ) with open(csv_path, "w", newline="") as csvfile, open( normalized_csv_path, "w", newline="" @@ -496,6 +509,8 @@ prev_values["lightGBM"] = (180, 180) if XGBOOST_ENABLED: prev_values["Xgboost"] = (180, 180) + if EARSNET_ENABLED: + prev_values["earsnet"] = (180, 180) for i, (row, norm_row) in enumerate(zip(rows, normalized_rows)): source_points = np.array( @@ -577,6 +592,17 @@ norm_row["Xgboost_stethoscope_y"], ) = xg_x, xg_y + if EARSNET_ENABLED: + image_path = os.path.join(base_dir, row["image_file_name"]) + earsnet_coords = predict_earsnet(earsnet_predictor, image_path) + row["earsnet_stethoscope_x"], row["earsnet_stethoscope_y"] = ( + earsnet_coords + ) + ( + norm_row["earsnet_stethoscope_x"], + norm_row["earsnet_stethoscope_y"], + ) = earsnet_coords + for key in prev_values: prev_values[key] = ( row[f"{key}_stethoscope_x"], @@ -605,6 +631,9 @@ dirs["Xgboost"] = "Xgboost" if LIGHTGBM_ENABLED: dirs["lightGBM"] = "lightGBM" + if EARSNET_ENABLED: + dirs["earsnet"] = "earsnet" + dirs["combined"] = "combined" os.makedirs(os.path.join(results_dir, "marked_images"), exist_ok=True) for key in dirs: @@ -618,16 +647,18 @@ ) points = {key: [] for key in dirs.keys() if key != "marked"} - colors = {"conv": CONV_COLOR, "Xgboost": XGBOOST_COLOR, "lightGBM": LIGHTGBM_COLOR} + colors = { + "conv": CONV_COLOR, + "Xgboost": XGBOOST_COLOR, + "lightGBM": LIGHTGBM_COLOR, + "earsnet": EARSNET_COLOR, + } for _, row in df.iterrows(): original_image = cv2.imread( os.path.join(original_images_dir, row["image_file_name"]) ) if original_image is None: - print( - f"Failed to load image: {os.path.join(original_images_dir, row['image_file_name'])}" - ) continue for point in [ @@ -650,10 +681,21 @@ original_image, ) + # Combined trajectoryのための画像 + combined_image_with_traj = body_image.copy() + combined_image_without_traj = body_image.copy() + for key in points: + if key == "combined": + continue + + if f"{key}_stethoscope_x" not in row: + continue + x, y = int(row[f"{key}_stethoscope_x"]), int(row[f"{key}_stethoscope_y"]) points[key].append((x, y)) + # 個別の手法の軌跡 image_with_trajectory = body_image.copy() if len(points[key]) > 1: cv2.polylines( @@ -671,6 +713,18 @@ image_with_trajectory, ) + # Combined imageに軌跡を追加 + if len(points[key]) > 1: + cv2.polylines( + combined_image_with_traj, + [np.array(points[key])], + False, + colors[key], + 2, + ) + cv2.circle(combined_image_with_traj, (x, y), 10, colors[key], -1) + cv2.circle(combined_image_without_traj, (x, y), 10, colors[key], -1) + image_without_trajectory = body_image.copy() cv2.circle(image_without_trajectory, (x, y), 10, colors[key], -1) cv2.imwrite( @@ -682,6 +736,20 @@ image_without_trajectory, ) + # Combined imageを保存 + cv2.imwrite( + os.path.join( + results_dir, "combined_with_trajectory", row["image_file_name"] + ), + combined_image_with_traj, + ) + cv2.imwrite( + os.path.join( + results_dir, "combined_without_trajectory", row["image_file_name"] + ), + combined_image_without_traj, + ) + create_video_from_images( os.path.join(results_dir, "marked_images"), os.path.join(results_dir, "marked_video.mp4"), diff --git a/modules/EARSForDL/EfficientNet.py b/modules/EARSForDL/EfficientNet.py deleted file mode 100644 index 7395171..0000000 --- a/modules/EARSForDL/EfficientNet.py +++ /dev/null @@ -1,52 +0,0 @@ -import torch.nn as nn -from torchvision.models import ( - EfficientNet_B0_Weights, - EfficientNet_B1_Weights, - EfficientNet_B2_Weights, - EfficientNet_B3_Weights, - EfficientNet_B4_Weights, - EfficientNet_B5_Weights, - EfficientNet_B6_Weights, - EfficientNet_B7_Weights, - efficientnet_b0, - efficientnet_b1, - efficientnet_b2, - efficientnet_b3, - efficientnet_b4, - efficientnet_b5, - efficientnet_b6, - efficientnet_b7, -) - - -class RegressionEfficientNet(nn.Module): - def __init__(self, efficientnet_version): - super(RegressionEfficientNet, self).__init__() - - if efficientnet_version == "b0": - self.model = efficientnet_b0(weights=EfficientNet_B0_Weights.IMAGENET1K_V1) - elif efficientnet_version == "b1": - self.model = efficientnet_b1(weights=EfficientNet_B1_Weights.IMAGENET1K_V1) - elif efficientnet_version == "b2": - self.model = efficientnet_b2(weights=EfficientNet_B2_Weights.IMAGENET1K_V1) - elif efficientnet_version == "b3": - self.model = efficientnet_b3(weights=EfficientNet_B3_Weights.IMAGENET1K_V1) - elif efficientnet_version == "b4": - self.model = efficientnet_b4(weights=EfficientNet_B4_Weights.IMAGENET1K_V1) - elif efficientnet_version == "b5": - self.model = efficientnet_b5(weights=EfficientNet_B5_Weights.IMAGENET1K_V1) - elif efficientnet_version == "b6": - self.model = efficientnet_b6(weights=EfficientNet_B6_Weights.IMAGENET1K_V1) - elif efficientnet_version == "b7": - self.model = efficientnet_b7(weights=EfficientNet_B7_Weights.IMAGENET1K_V1) - else: - raise ValueError("Invalid EfficientNet version. Choose from 'b0' to 'b7'.") - - # Modify the final fully connected layer - num_features = self.model.classifier[1].in_features - self.model.classifier = nn.Sequential( - nn.Dropout(p=0.2, inplace=True), nn.Linear(num_features, 2) - ) - - def forward(self, x): - return self.model(x) diff --git a/modules/EARSForDL/MobileNetV2.py b/modules/EARSForDL/MobileNetV2.py deleted file mode 100644 index f7a9673..0000000 --- a/modules/EARSForDL/MobileNetV2.py +++ /dev/null @@ -1,25 +0,0 @@ -import torch.nn as nn -import torchvision.models as models -from torchvision.models import MobileNet_V2_Weights - - -class RegressionMobileNetV2(nn.Module): - def __init__(self, pretrained=True): - super(RegressionMobileNetV2, self).__init__() - - # Load pretrained MobileNetV2 - if pretrained: - self.model = models.mobilenet_v2(weights=MobileNet_V2_Weights.IMAGENET1K_V1) - else: - self.model = models.mobilenet_v2(weights=None) - - # Get the number of features from the last layer - num_features = self.model.classifier[1].in_features - - # Replace the classifier with a new one for regression - self.model.classifier = nn.Sequential( - nn.Dropout(p=0.2), nn.Linear(num_features, 2) - ) - - def forward(self, x): - return self.model(x) diff --git a/modules/EARSForDL/ResNet.py b/modules/EARSForDL/ResNet.py deleted file mode 100644 index a2b1214..0000000 --- a/modules/EARSForDL/ResNet.py +++ /dev/null @@ -1,33 +0,0 @@ -import torch.nn as nn -import torchvision.models as models -from torchvision.models import ( - ResNet18_Weights, - ResNet34_Weights, - ResNet50_Weights, - ResNet101_Weights, - ResNet152_Weights, -) - - -class RegressionResNet(nn.Module): - def __init__(self, resnet_depth): - super(RegressionResNet, self).__init__() - if resnet_depth == 18: - self.model = models.resnet18(weights=ResNet18_Weights.IMAGENET1K_V1) - elif resnet_depth == 34: - self.model = models.resnet34(weights=ResNet34_Weights.IMAGENET1K_V1) - elif resnet_depth == 50: - self.model = models.resnet50(weights=ResNet50_Weights.IMAGENET1K_V1) - elif resnet_depth == 101: - self.model = models.resnet101(weights=ResNet101_Weights.IMAGENET1K_V1) - elif resnet_depth == 152: - self.model = models.resnet152(weights=ResNet152_Weights.IMAGENET1K_V1) - else: - raise ValueError("Invalid ResNet depth. Choose from 18, 34, 50, 101, 152.") - - # Modify the final fully connected layer - num_features = self.model.fc.in_features - self.model.fc = nn.Linear(num_features, 2) - - def forward(self, x): - return self.model(x) diff --git a/modules/EARSForDL/SqueezeNet.py b/modules/EARSForDL/SqueezeNet.py deleted file mode 100644 index 5c519e9..0000000 --- a/modules/EARSForDL/SqueezeNet.py +++ /dev/null @@ -1,31 +0,0 @@ -import torch -import torch.nn as nn -import torchvision.models as models -from torchvision.models import SqueezeNet1_0_Weights, SqueezeNet1_1_Weights - - -class RegressionSqueezeNet(nn.Module): - def __init__(self, version="1_0"): - super(RegressionSqueezeNet, self).__init__() - if version == "1_0": - self.model = models.squeezenet1_0(weights=SqueezeNet1_0_Weights.IMAGENET1K_V1) - elif version == "1_1": - self.model = models.squeezenet1_1(weights=SqueezeNet1_1_Weights.IMAGENET1K_V1) - else: - raise ValueError("Invalid SqueezeNet version. Choose from '1_0' or '1_1'.") - - # Remove the original classifier - self.model.classifier = nn.Sequential( - nn.Dropout(p=0.5), nn.Conv2d(512, 2, kernel_size=1), nn.ReLU(inplace=True), nn.AdaptiveAvgPool2d((1, 1)) - ) - - # Initialize the new classifier weights - for m in self.model.classifier: - if isinstance(m, nn.Conv2d): - nn.init.kaiming_uniform_(m.weight) - if m.bias is not None: - nn.init.constant_(m.bias, 0) - - def forward(self, x): - x = self.model(x) - return x.view(x.size(0), -1) diff --git a/modules/EARSNet/__init__.py b/modules/EARSNet/__init__.py new file mode 100644 index 0000000..e69de29 --- /dev/null +++ b/modules/EARSNet/__init__.py diff --git a/modules/EARSNet/model.py b/modules/EARSNet/model.py new file mode 100644 index 0000000..aefe8d3 --- /dev/null +++ b/modules/EARSNet/model.py @@ -0,0 +1,83 @@ +import torch.nn as nn +import torchvision.models as models +from torchvision.models import ( + EfficientNet_B0_Weights, + EfficientNet_B1_Weights, + EfficientNet_B2_Weights, + EfficientNet_B3_Weights, + EfficientNet_B4_Weights, + EfficientNet_B5_Weights, + EfficientNet_B6_Weights, + EfficientNet_B7_Weights, + ResNet18_Weights, + ResNet34_Weights, + ResNet50_Weights, + ResNet101_Weights, + ResNet152_Weights, +) + + +class RegressionModel(nn.Module): + def __init__(self, model_name, model_type="resnet"): + super(RegressionModel, self).__init__() + + self.model_type = model_type.lower() + + if self.model_type == "resnet": + self.model = self._init_resnet(model_name) + elif self.model_type == "efficientnet": + self.model = self._init_efficientnet(model_name) + else: + raise ValueError( + "Invalid model type. Choose from 'resnet' or 'efficientnet'." + ) + + # Modify the final fully connected layer + num_features = self._get_num_features() + if self.model_type == "resnet": + self.model.fc = nn.Linear(num_features, 2) + else: # efficientnet + self.model.classifier = nn.Linear(num_features, 2) + + def _init_resnet(self, depth): + if depth == "18": + return models.resnet18(weights=ResNet18_Weights.IMAGENET1K_V1) + elif depth == "34": + return models.resnet34(weights=ResNet34_Weights.IMAGENET1K_V1) + elif depth == "50": + return models.resnet50(weights=ResNet50_Weights.IMAGENET1K_V1) + elif depth == "101": + return models.resnet101(weights=ResNet101_Weights.IMAGENET1K_V1) + elif depth == "152": + return models.resnet152(weights=ResNet152_Weights.IMAGENET1K_V1) + else: + raise ValueError("Invalid ResNet depth. Choose from 18, 34, 50, 101, 152.") + + def _init_efficientnet(self, version): + if version == "b0": + return models.efficientnet_b0(weights=EfficientNet_B0_Weights.IMAGENET1K_V1) + elif version == "b1": + return models.efficientnet_b1(weights=EfficientNet_B1_Weights.IMAGENET1K_V1) + elif version == "b2": + return models.efficientnet_b2(weights=EfficientNet_B2_Weights.IMAGENET1K_V1) + elif version == "b3": + return models.efficientnet_b3(weights=EfficientNet_B3_Weights.IMAGENET1K_V1) + elif version == "b4": + return models.efficientnet_b4(weights=EfficientNet_B4_Weights.IMAGENET1K_V1) + elif version == "b5": + return models.efficientnet_b5(weights=EfficientNet_B5_Weights.IMAGENET1K_V1) + elif version == "b6": + return models.efficientnet_b6(weights=EfficientNet_B6_Weights.IMAGENET1K_V1) + elif version == "b7": + return models.efficientnet_b7(weights=EfficientNet_B7_Weights.IMAGENET1K_V1) + else: + raise ValueError("Invalid EfficientNet version. Choose from 'b0' to 'b7'.") + + def _get_num_features(self): + if self.model_type == "resnet": + return self.model.fc.in_features + else: # efficientnet + return self.model.classifier[1].in_features + + def forward(self, x): + return self.model(x) diff --git a/modules/EARSNet/predictor.py b/modules/EARSNet/predictor.py new file mode 100644 index 0000000..5841bc6 --- /dev/null +++ b/modules/EARSNet/predictor.py @@ -0,0 +1,88 @@ +import torch +from PIL import Image +from torchvision import transforms + +from .model import RegressionModel + + +class StethoscopePredictor: + def __init__( + self, model_path, model_type="resnet", model_version="18", device=None + ): + """ + Initialize the predictor with a trained model + + Args: + model_path (str): Path to the saved model weights + model_type (str): Type of model ('resnet' or 'efficientnet') + model_version (str): Version of the model (e.g., '18' for ResNet18, 'b0' for EfficientNet-B0) + device (str): Device to run the model on ('cuda' or 'cpu') + """ + self.device = ( + device if device else ("cuda" if torch.cuda.is_available() else "cpu") + ) + self.transform = transforms.Compose( + [ + transforms.Resize((224, 224)), + transforms.ToTensor(), + transforms.Normalize( + mean=[0.485, 0.456, 0.406], std=[0.229, 0.224, 0.225] + ), + ] + ) + + # Initialize model + self.model = RegressionModel(model_name=model_version, model_type=model_type) + self.model.load_state_dict(torch.load(model_path, map_location=self.device)) + self.model = self.model.to(self.device) + self.model.eval() + + def predict(self, image_path): + """ + Predict stethoscope coordinates from an image + + Args: + image_path (str): Path to the input image + + Returns: + tuple: Predicted (x, y) coordinates + """ + # Load and preprocess image + image = Image.open(image_path).convert("RGB") + image_tensor = self.transform(image).unsqueeze(0).to(self.device) + + # Make prediction + with torch.no_grad(): + prediction = self.model(image_tensor) + + return prediction[0].cpu().numpy() + + +def load_model(model_path, model_type="resnet", model_version="18", device=None): + """ + Load a trained stethoscope detection model + + Args: + model_path (str): Path to the saved model weights + model_type (str): Type of model ('resnet' or 'efficientnet') + model_version (str): Version of the model (e.g., '18' for ResNet18, 'b0' for EfficientNet-B0) + device (str): Device to run the model on ('cuda' or 'cpu') + + Returns: + StethoscopePredictor: Initialized predictor object + """ + return StethoscopePredictor(model_path, model_type, model_version, device) + + +def predict(predictor, image_path): + """ + Predict stethoscope coordinates using a loaded model + + Args: + predictor (StethoscopePredictor): Initialized predictor object + image_path (str): Path to the input image + + Returns: + tuple: Predicted (x, y) coordinates + """ + return predictor.predict(image_path)