diff --git a/main.py b/main.py index 65c1e1e..42aec28 100644 --- a/main.py +++ b/main.py @@ -20,8 +20,9 @@ from mmpose.utils import adapt_mmdet_pipeline import config -from modules.EARSNet.predictor import load_model as load_earsnet_model -from modules.EARSNet.predictor import predict as predict_earsnet + +# --- EARSNetPredictor のみをインポート --- +from modules.EARSNet.predictor import EARSNetPredictor from util.calc_ste_position import CalcStethoscopePosition from util.ears_ai import EarsAI @@ -390,11 +391,14 @@ ngboost_model_y = load_model( "./models/NGBoost/stethoscope_calc_y_best_model.pkl" ) + + # EARSNET を使用する場合のみ初期化 if EARSNET_ENABLED: - earsnet_predictor = load_earsnet_model( - model_path="models/EARSNet/best_model.pth", - model_type="resnet", - model_version="18", + earsnet_predictor = EARSNetPredictor( + weight_path="models/EARSNet/best_model.pth", + resnet_depth="18", # 学習時と同じResNet深度 + pretrained=True, # 学習時の設定に合わせる + device="cuda", # or "cpu" ) input_columns = [ @@ -432,7 +436,7 @@ if POSENET_ENABLED: # ▼ PoseNet start_time_pose = time.time() - pose_overlay_img, *landmarks = EarsAI().pose_detect(frame, None) + pose_overlay_img, *landmarks = ears_ai.pose_detect(frame, None) end_time_pose = time.time() timings["rtmpose_single"].append(end_time_pose - start_time_pose) @@ -464,7 +468,6 @@ if pose_keypoints is None: print(f"Failed to extract keypoints for image: {image_path}") - # 処理ができなかった場合もフレーム処理は一応完了するのでincrement processed_frames += 1 continue @@ -500,15 +503,18 @@ # SSD (MobileNetV1SSD) if MobileNetV1SSD_ENABLED: start_time_ssd = time.time() - stethoscope_overlay_img, stethoscope_x, stethoscope_y = EarsAI().ssd_detect( + stethoscope_overlay_img, stethoscope_x, stethoscope_y = ears_ai.ssd_detect( frame, None ) end_time_ssd = time.time() - # もし単体測定したい場合は "ssd_single" などに追加 # YOLOX if YOLOX_ENABLED: - if RTMPOSE_ENABLED and pose_keypoints is not None: + if ( + RTMPOSE_ENABLED + and "pose_keypoints" in locals() + and pose_keypoints is not None + ): start_time_yolox = time.time() (stethoscope_overlay_img, stethoscope_x, stethoscope_y) = ( yolox_detector_inference(frame, yolox_inferencer, pose_keypoints) @@ -517,7 +523,6 @@ timings["yolox_single"].append(end_time_yolox - start_time_yolox) elif POSENET_ENABLED: - # PoseNetの場合 keypoints形式を整える pose_keypoints_pose_net = [[0, 0]] * 13 pose_keypoints_pose_net[5] = left_shoulder pose_keypoints_pose_net[6] = right_shoulder @@ -539,14 +544,12 @@ if (RTMPOSE_ENABLED or POSENET_ENABLED) and ( YOLOX_ENABLED or MobileNetV1SSD_ENABLED ): - # PoseはRGB->BGR に変換(RTMPOSE時) if RTMPOSE_ENABLED: cv2.imwrite( os.path.join(pose_overlay_dir, image_file_name), cv2.cvtColor(pose_overlay_img, cv2.COLOR_RGB2BGR), ) else: - # PoseNetなら BGR のまま cv2.imwrite( os.path.join(pose_overlay_dir, image_file_name), pose_overlay_img, @@ -561,7 +564,6 @@ # (3) CSV用に肩・腰・聴診器座標をまとめる # ============================================================ if POSENET_ENABLED: - # PoseNet が (y, x) row = { "image_file_name": image_file_name, "left_shoulder_x": left_shoulder[1], @@ -576,7 +578,6 @@ "stethoscope_y": stethoscope_y, } else: - # RTMPOSE の場合 (x, y) row = { "image_file_name": image_file_name, "left_shoulder_x": left_shoulder[0], @@ -591,6 +592,20 @@ "stethoscope_y": stethoscope_y, } + # --------------------------------------------------------- + # (C) EARSNET (ここでフレーム毎に実行し、rowに書き込む) + # --------------------------------------------------------- + if EARSNET_ENABLED: + start_earsnet = time.time() + earsnet_x, earsnet_y = earsnet_predictor.predict(image_path) + timings["earsnet_single"].append(time.time() - start_earsnet) + + # row にEARSNET座標を格納 + row["earsnet_stethoscope_x"] = earsnet_x + row["earsnet_stethoscope_y"] = earsnet_y + + # row を保存 + # --------------------------------------------------------- rows.append(row) # 正規化 @@ -607,6 +622,19 @@ normalized_points = normalize_quadrilateral_with_point( source_points.flatten(), stethoscope_point ) + + # EARSNet も同様に stethoscope_x,y を正規化 → ここでは省略例 + # しかし "耳" は別枠の場合もあるので、行うなら同様に対応 + # 例: earsnet_points = np.array([earsnet_x, earsnet_y]) + # ... で正規化etc. + + # 一旦 stethoscope用だけ + earsnet_x_n, earsnet_y_n = 0, 0 + if EARSNET_ENABLED: + # 224×224にリサイズしている場合、単純計算だけでなく + # ここでは "ただの合成例" として省略 + pass + normalized_row = { "image_file_name": image_file_name, "left_shoulder_x": normalized_points[0, 0], @@ -620,188 +648,17 @@ "stethoscope_x": normalized_points[4, 0], "stethoscope_y": normalized_points[4, 1], } + if EARSNET_ENABLED: + normalized_row["earsnet_stethoscope_x"] = earsnet_x # 必要に応じて計算 + normalized_row["earsnet_stethoscope_y"] = earsnet_y + normalized_rows.append(normalized_row) - # ============================================================ - # (4) 各パイプラインの FPS計測(例: RTMPOSE+YOLOX+conv, etc.) - # ============================================================ - # 以下は例として「改めて同じフレームをRTMPOSE+YOLOX+各種手法」で測定。 - # パイプラインごとにRTMPOSEとYOLOXを呼び直すため、処理時間は増加します。 - # もし重複呼び出しを避けたければ実装を見直してください。 - # -- (A) RTMPOSE + YOLOX + conv - if RTMPOSE_ENABLED and YOLOX_ENABLED and CONV_ENABLED: - start_pipeline = time.time() - - # 1) rtmpose (再度推定) - start_rtmpose = time.time() - frame_rgb = cv2.cvtColor(frame, cv2.COLOR_BGR2RGB) - det_result = inference_detector(detector, frame_rgb) - pred_instance = det_result.pred_instances.cpu().numpy() - bboxes = np.concatenate( - (pred_instance.bboxes, pred_instance.scores[:, None]), axis=1 - ) - bboxes = bboxes[ - np.logical_and(pred_instance.labels == 0, pred_instance.scores > 0.3) - ] - bboxes = bboxes[nms(bboxes, 0.3), :4] - pose_results = inference_topdown(pose_estimator, frame_rgb, bboxes) - keypoints_conv_pipeline = extract_keypoints_rtmpose(pose_results) - end_rtmpose = time.time() - timings["rtmpose_single"].append(end_rtmpose - start_rtmpose) - - # 2) YOLOX (再度推定) - s_x_tmp, s_y_tmp = 0, 0 - if keypoints_conv_pipeline is not None: - start_yolox = time.time() - _, s_x_tmp, s_y_tmp = yolox_detector_inference( - frame, yolox_inferencer, keypoints_conv_pipeline - ) - end_yolox = time.time() - timings["yolox_single"].append(end_yolox - start_yolox) - - # 3) conv - if keypoints_conv_pipeline is not None: - start_conv = time.time() - source_pts = np.array( - [ - keypoints_conv_pipeline[5], - keypoints_conv_pipeline[6], - keypoints_conv_pipeline[11], - keypoints_conv_pipeline[12], - ], - dtype=np.float32, - ) - if s_x_tmp != 0 or s_y_tmp != 0: - _ = calc_position.calc_affine(source_pts, s_x_tmp, s_y_tmp) - end_conv = time.time() - timings["conv_single"].append(end_conv - start_conv) - - end_pipeline = time.time() - timings["pipeline_rtmpose_yolox_conv"].append(end_pipeline - start_pipeline) - - # -- (B) RTMPOSE + YOLOX + LightGBM - if RTMPOSE_ENABLED and YOLOX_ENABLED and LIGHTGBM_ENABLED: - start_pipeline = time.time() - # 1) RTMPOSE - start_rtmpose = time.time() - frame_rgb = cv2.cvtColor(frame, cv2.COLOR_BGR2RGB) - det_result = inference_detector(detector, frame_rgb) - pred_instance = det_result.pred_instances.cpu().numpy() - bboxes = np.concatenate( - (pred_instance.bboxes, pred_instance.scores[:, None]), axis=1 - ) - bboxes = bboxes[ - np.logical_and(pred_instance.labels == 0, pred_instance.scores > 0.3) - ] - bboxes = bboxes[nms(bboxes, 0.3), :4] - pose_results = inference_topdown(pose_estimator, frame_rgb, bboxes) - keypoints_lgb_pipeline = extract_keypoints_rtmpose(pose_results) - end_rtmpose = time.time() - timings["rtmpose_single"].append(end_rtmpose - start_rtmpose) - - # 2) YOLOX - s_x_tmp, s_y_tmp = 0, 0 - if keypoints_lgb_pipeline is not None: - start_yolox = time.time() - _, s_x_tmp, s_y_tmp = yolox_detector_inference( - frame, yolox_inferencer, keypoints_lgb_pipeline - ) - end_yolox = time.time() - timings["yolox_single"].append(end_yolox - start_yolox) - - # 3) LightGBM - if s_x_tmp != 0 or s_y_tmp != 0: - input_data = ( - pd.DataFrame([row]) - if not NORMALIZE_ENABLED - else pd.DataFrame([normalized_row]) - ) - start_lgb = time.time() - # x 座標予測 - X_scaled_x = lgb_scaler_x.transform(input_data[input_columns]) - _ = lgb_model_x.predict(X_scaled_x) - # y 座標予測 - X_scaled_y = lgb_scaler_y.transform(input_data[input_columns]) - _ = lgb_model_y.predict(X_scaled_y) - end_lgb = time.time() - timings["lightgbm_single"].append(end_lgb - start_lgb) - - end_pipeline = time.time() - timings["pipeline_rtmpose_yolox_lightgbm"].append( - end_pipeline - start_pipeline - ) - - # -- (C) RTMPOSE + YOLOX + XGBoost - if RTMPOSE_ENABLED and YOLOX_ENABLED and XGBOOST_ENABLED: - start_pipeline = time.time() - # 1) RTMPOSE - start_rtmpose = time.time() - frame_rgb = cv2.cvtColor(frame, cv2.COLOR_BGR2RGB) - det_result = inference_detector(detector, frame_rgb) - pred_instance = det_result.pred_instances.cpu().numpy() - bboxes = np.concatenate( - (pred_instance.bboxes, pred_instance.scores[:, None]), axis=1 - ) - bboxes = bboxes[ - np.logical_and(pred_instance.labels == 0, pred_instance.scores > 0.3) - ] - bboxes = bboxes[nms(bboxes, 0.3), :4] - pose_results = inference_topdown(pose_estimator, frame_rgb, bboxes) - keypoints_xgb_pipeline = extract_keypoints_rtmpose(pose_results) - end_rtmpose = time.time() - timings["rtmpose_single"].append(end_rtmpose - start_rtmpose) - - # 2) YOLOX - s_x_tmp, s_y_tmp = 0, 0 - if keypoints_xgb_pipeline is not None: - start_yolox = time.time() - _, s_x_tmp, s_y_tmp = yolox_detector_inference( - frame, yolox_inferencer, keypoints_xgb_pipeline - ) - end_yolox = time.time() - timings["yolox_single"].append(end_yolox - start_yolox) - - # 3) XGBoost - if s_x_tmp != 0 or s_y_tmp != 0: - input_data = ( - pd.DataFrame([row]) - if not NORMALIZE_ENABLED - else pd.DataFrame([normalized_row]) - ) - start_xgb = time.time() - # x座標 - X_scaled_x = xg_scaler_x.transform(input_data[input_columns]) - _ = xg_model_x.predict(X_scaled_x) - # y座標 - X_scaled_y = xg_scaler_y.transform(input_data[input_columns]) - _ = xg_model_y.predict(X_scaled_y) - end_xgb = time.time() - timings["xgboost_single"].append(end_xgb - start_xgb) - - end_pipeline = time.time() - timings["pipeline_rtmpose_yolox_xgboost"].append( - end_pipeline - start_pipeline - ) - - # -- (D) EARSNET パイプライン(EARSNET 単体) - if EARSNET_ENABLED: - start_pipeline_earsnet = time.time() - # EARSNET の単体推論 - start_earsnet = time.time() - _ = predict_earsnet(earsnet_predictor, image_path) - end_earsnet = time.time() - timings["earsnet_single"].append(end_earsnet - start_earsnet) - - end_pipeline_earsnet = time.time() - timings["pipeline_earsnet"].append( - end_pipeline_earsnet - start_pipeline_earsnet - ) - - # フレームごとの処理が完了したらカウンタをインクリメント + # フレームごとの処理完了 processed_frames += 1 # ======================================================================== - # (5) 各フレームの位置推定(Conv, LightGBM, XGBoost, CatBoost, NGBoost, EARSNET) + # (5) 各フレームの位置推定(Conv, LightGBM, XGBoost, CatBoost, NGBoost) # → CSV 書き込み # ======================================================================== if rows: @@ -817,8 +674,7 @@ fieldnames.extend(["catboost_stethoscope_x", "catboost_stethoscope_y"]) if NGBOOST_ENABLED: fieldnames.extend(["ngboost_stethoscope_x", "ngboost_stethoscope_y"]) - if EARSNET_ENABLED: - fieldnames.extend(["earsnet_stethoscope_x", "earsnet_stethoscope_y"]) + # EARSNETはすでに row に earsnet_stethoscope_x,y があるのでOK os.makedirs(results_dir, exist_ok=True) @@ -845,8 +701,7 @@ prev_values["catboost"] = (180, 180) if NGBOOST_ENABLED: prev_values["ngboost"] = (180, 180) - if EARSNET_ENABLED: - prev_values["earsnet"] = (180, 180) + # EARSNETは前回値利用しないなら不要 for row, norm_row in zip(rows, normalized_rows): input_data = ( @@ -857,6 +712,7 @@ # 聴診器未検出の場合 if row["stethoscope_x"] == 0 and row["stethoscope_y"] == 0: + # 省略: conv/lightgbm/xgboostなどで前回値代入 for key in prev_values: row[f"{key}_stethoscope_x"], row[f"{key}_stethoscope_y"] = ( prev_values[key] @@ -912,7 +768,10 @@ ( norm_row["lightGBM_stethoscope_x"], norm_row["lightGBM_stethoscope_y"], - ) = (lgb_x_pred, lgb_y_pred) + ) = ( + lgb_x_pred, + lgb_y_pred, + ) end_time_lgb = time.time() timings["lightgbm_single"].append(end_time_lgb - start_time_lgb) prev_values["lightGBM"] = (lgb_x_pred, lgb_y_pred) @@ -931,7 +790,10 @@ ( norm_row["Xgboost_stethoscope_x"], norm_row["Xgboost_stethoscope_y"], - ) = (xg_x_pred, xg_y_pred) + ) = ( + xg_x_pred, + xg_y_pred, + ) end_time_xgb = time.time() timings["xgboost_single"].append(end_time_xgb - start_time_xgb) prev_values["Xgboost"] = (xg_x_pred, xg_y_pred) @@ -952,7 +814,10 @@ ( norm_row["catboost_stethoscope_x"], norm_row["catboost_stethoscope_y"], - ) = (catboost_x, catboost_y) + ) = ( + catboost_x, + catboost_y, + ) end_time_cat = time.time() # timings["catboost_single"].append( ... ) # 必要なら追加 prev_values["catboost"] = (catboost_x, catboost_y) @@ -973,28 +838,14 @@ ( norm_row["ngboost_stethoscope_x"], norm_row["ngboost_stethoscope_y"], - ) = (ngboost_x, ngboost_y) + ) = ( + ngboost_x, + ngboost_y, + ) end_time_ngb = time.time() # timings["ngboost_single"].append( ... ) # 必要なら追加 prev_values["ngboost"] = (ngboost_x, ngboost_y) - # EARSNET (再度実行する場合) - if EARSNET_ENABLED: - start_time_enet = time.time() - earsnet_coords = predict_earsnet(earsnet_predictor, image_path) - row["earsnet_stethoscope_x"], row["earsnet_stethoscope_y"] = ( - earsnet_coords - ) - ( - norm_row["earsnet_stethoscope_x"], - norm_row["earsnet_stethoscope_y"], - ) = earsnet_coords - end_time_enet = time.time() - timings["earsnet_single"].append( - end_time_enet - start_time_enet - ) - prev_values["earsnet"] = earsnet_coords - writer.writerow(row) norm_writer.writerow(norm_row) @@ -1295,13 +1146,6 @@ stop_fps_thread = True fps_thread.join() - # もし fps_history をCSV保存したい場合はここで行う - # with open("fps_history.csv", "w", newline="") as f: - # writer = csv.writer(f) - # writer.writerow(["timestamp", "fps"]) - # for timestamp, fps_value in fps_history: - # writer.writerow([timestamp, fps_value]) - print("All done.") diff --git a/modules/EARSNet/model.py b/modules/EARSNet/model.py index aefe8d3..10f689a 100644 --- a/modules/EARSNet/model.py +++ b/modules/EARSNet/model.py @@ -1,14 +1,6 @@ import torch.nn as nn import torchvision.models as models from torchvision.models import ( - EfficientNet_B0_Weights, - EfficientNet_B1_Weights, - EfficientNet_B2_Weights, - EfficientNet_B3_Weights, - EfficientNet_B4_Weights, - EfficientNet_B5_Weights, - EfficientNet_B6_Weights, - EfficientNet_B7_Weights, ResNet18_Weights, ResNet34_Weights, ResNet50_Weights, @@ -18,66 +10,49 @@ class RegressionModel(nn.Module): - def __init__(self, model_name, model_type="resnet"): + def __init__(self, resnet_depth: str, pretrained: bool = True): + """ + Args: + resnet_depth (str): "18", "34", "50", "101", or "152" + pretrained (bool): True if using ImageNet pretrained weights, False for scratch + """ super(RegressionModel, self).__init__() - - self.model_type = model_type.lower() - - if self.model_type == "resnet": - self.model = self._init_resnet(model_name) - elif self.model_type == "efficientnet": - self.model = self._init_efficientnet(model_name) - else: - raise ValueError( - "Invalid model type. Choose from 'resnet' or 'efficientnet'." - ) + self.model = self._init_resnet(resnet_depth, pretrained) # Modify the final fully connected layer - num_features = self._get_num_features() - if self.model_type == "resnet": - self.model.fc = nn.Linear(num_features, 2) - else: # efficientnet - self.model.classifier = nn.Linear(num_features, 2) + num_features = self.model.fc.in_features + self.model.fc = nn.Linear(num_features, 2) - def _init_resnet(self, depth): + def _init_resnet(self, depth: str, pretrained: bool): + # pretrained=True => Use ImageNet weights + # pretrained=False => weights=None (scratch) if depth == "18": - return models.resnet18(weights=ResNet18_Weights.IMAGENET1K_V1) + if pretrained: + return models.resnet18(weights=ResNet18_Weights.IMAGENET1K_V1) + else: + return models.resnet18(weights=None) elif depth == "34": - return models.resnet34(weights=ResNet34_Weights.IMAGENET1K_V1) + if pretrained: + return models.resnet34(weights=ResNet34_Weights.IMAGENET1K_V1) + else: + return models.resnet34(weights=None) elif depth == "50": - return models.resnet50(weights=ResNet50_Weights.IMAGENET1K_V1) + if pretrained: + return models.resnet50(weights=ResNet50_Weights.IMAGENET1K_V1) + else: + return models.resnet50(weights=None) elif depth == "101": - return models.resnet101(weights=ResNet101_Weights.IMAGENET1K_V1) + if pretrained: + return models.resnet101(weights=ResNet101_Weights.IMAGENET1K_V1) + else: + return models.resnet101(weights=None) elif depth == "152": - return models.resnet152(weights=ResNet152_Weights.IMAGENET1K_V1) + if pretrained: + return models.resnet152(weights=ResNet152_Weights.IMAGENET1K_V1) + else: + return models.resnet152(weights=None) else: raise ValueError("Invalid ResNet depth. Choose from 18, 34, 50, 101, 152.") - def _init_efficientnet(self, version): - if version == "b0": - return models.efficientnet_b0(weights=EfficientNet_B0_Weights.IMAGENET1K_V1) - elif version == "b1": - return models.efficientnet_b1(weights=EfficientNet_B1_Weights.IMAGENET1K_V1) - elif version == "b2": - return models.efficientnet_b2(weights=EfficientNet_B2_Weights.IMAGENET1K_V1) - elif version == "b3": - return models.efficientnet_b3(weights=EfficientNet_B3_Weights.IMAGENET1K_V1) - elif version == "b4": - return models.efficientnet_b4(weights=EfficientNet_B4_Weights.IMAGENET1K_V1) - elif version == "b5": - return models.efficientnet_b5(weights=EfficientNet_B5_Weights.IMAGENET1K_V1) - elif version == "b6": - return models.efficientnet_b6(weights=EfficientNet_B6_Weights.IMAGENET1K_V1) - elif version == "b7": - return models.efficientnet_b7(weights=EfficientNet_B7_Weights.IMAGENET1K_V1) - else: - raise ValueError("Invalid EfficientNet version. Choose from 'b0' to 'b7'.") - - def _get_num_features(self): - if self.model_type == "resnet": - return self.model.fc.in_features - else: # efficientnet - return self.model.classifier[1].in_features - def forward(self, x): return self.model(x) diff --git a/modules/EARSNet/predictor.py b/modules/EARSNet/predictor.py index 5841bc6..e47d72f 100644 --- a/modules/EARSNet/predictor.py +++ b/modules/EARSNet/predictor.py @@ -1,88 +1,94 @@ +import os + import torch from PIL import Image from torchvision import transforms +# 学習時に使った RegressionModel と同じものを import +# (train.py 内では "model.py" を読んでいる想定) from .model import RegressionModel -class StethoscopePredictor: +class EARSNetPredictor: def __init__( - self, model_path, model_type="resnet", model_version="18", device=None + self, + weight_path: str, + resnet_depth: str = "18", + pretrained: bool = True, + device: str = None, ): """ - Initialize the predictor with a trained model + 学習時と同じ構造・重みをもつモデルをロードし、 + 224×224スケールでの推論を行えるようにするクラス。 Args: - model_path (str): Path to the saved model weights - model_type (str): Type of model ('resnet' or 'efficientnet') - model_version (str): Version of the model (e.g., '18' for ResNet18, 'b0' for EfficientNet-B0) - device (str): Device to run the model on ('cuda' or 'cpu') + weight_path (str): 学習済みモデルの pth ファイルパス (例: "best_model.pth") + resnet_depth (str): "18","34","50","101","152" など + pretrained (bool): True なら ImageNet 事前学習ウェイトベース + device (str): 'cuda' or 'cpu' (指定しない場合、自動判定) """ self.device = ( - device if device else ("cuda" if torch.cuda.is_available() else "cpu") - ) - self.transform = transforms.Compose( - [ - transforms.Resize((224, 224)), - transforms.ToTensor(), - transforms.Normalize( - mean=[0.485, 0.456, 0.406], std=[0.229, 0.224, 0.225] - ), - ] + torch.device(device) + if device + else torch.device("cuda" if torch.cuda.is_available() else "cpu") ) - # Initialize model - self.model = RegressionModel(model_name=model_version, model_type=model_type) - self.model.load_state_dict(torch.load(model_path, map_location=self.device)) - self.model = self.model.to(self.device) + # 学習時と同じ ResNet depth & 前処理設定でモデルを用意 + self.model = RegressionModel( + resnet_depth=resnet_depth, + pretrained=pretrained, + ).to(self.device) + + # 学習済みウェイトをロード + if not os.path.isfile(weight_path): + raise FileNotFoundError(f"Weight file not found: {weight_path}") + self.model.load_state_dict(torch.load(weight_path, map_location=self.device)) self.model.eval() - def predict(self, image_path): - """ - Predict stethoscope coordinates from an image + # 学習時に使った画像サイズ・正規化パラメータ + self.input_size = (224, 224) + self.mean = [0.485, 0.456, 0.406] + self.std = [0.229, 0.224, 0.225] - Args: - image_path (str): Path to the input image + # 今回は "224×224 の座標" をそのまま返す簡易実装のため、 + # 元画像→224×224 へリサイズしたスケーリング係数は保持しない + + def _preprocess(self, image: Image.Image): + """ + 画像を224x224にリサイズ → テンソル化 → 正規化 + """ + # 画像を224x224にリサイズ + resized_image = image.resize(self.input_size, Image.BILINEAR) + + # Tensor化 & 正規化 + x = transforms.ToTensor()(resized_image) + x = transforms.Normalize(mean=self.mean, std=self.std)(x) + + return x # スケール係数は返さない + + def predict(self, image_path: str): + """ + 1枚の画像ファイルパスに対し、(224×224座標系) での (x, y) を推論して返す。 Returns: - tuple: Predicted (x, y) coordinates + (pred_x_224, pred_y_224): 224x224座標系での推定結果 """ - # Load and preprocess image + if not os.path.isfile(image_path): + raise FileNotFoundError(f"Image file not found: {image_path}") + + # 画像を開く image = Image.open(image_path).convert("RGB") - image_tensor = self.transform(image).unsqueeze(0).to(self.device) - # Make prediction + # 前処理 (224x224にリサイズ & 正規化) + input_tensor = self._preprocess(image) + + # 推論 + input_tensor = input_tensor.unsqueeze(0).to(self.device) # (B=1, C, H, W) with torch.no_grad(): - prediction = self.model(image_tensor) + pred = self.model(input_tensor) # shape: (1,2) - return prediction[0].cpu().numpy() + pred_x_224 = pred[0][0].item() + pred_y_224 = pred[0][1].item() - -def load_model(model_path, model_type="resnet", model_version="18", device=None): - """ - Load a trained stethoscope detection model - - Args: - model_path (str): Path to the saved model weights - model_type (str): Type of model ('resnet' or 'efficientnet') - model_version (str): Version of the model (e.g., '18' for ResNet18, 'b0' for EfficientNet-B0) - device (str): Device to run the model on ('cuda' or 'cpu') - - Returns: - StethoscopePredictor: Initialized predictor object - """ - return StethoscopePredictor(model_path, model_type, model_version, device) - - -def predict(predictor, image_path): - """ - Predict stethoscope coordinates using a loaded model - - Args: - predictor (StethoscopePredictor): Initialized predictor object - image_path (str): Path to the input image - - Returns: - tuple: Predicted (x, y) coordinates - """ - return predictor.predict(image_path) + # 「224×224の座標」をそのまま返す + return pred_x_224, pred_y_224