Newer
Older
Demo-Maker / main.py
import argparse
import csv
import os
import pickle
import re
import time
from threading import Thread

import cv2
import numpy as np
import pandas as pd
from mmdet.apis import DetInferencer, inference_detector, init_detector

# RTMpose
from mmpose.apis import inference_topdown
from mmpose.apis import init_model as init_pose_estimator
from mmpose.evaluation.functional import nms
from mmpose.registry import VISUALIZERS
from mmpose.structures import merge_data_samples
from mmpose.utils import adapt_mmdet_pipeline

import config

# EARSNet
from modules.EARSNet.predictor import EARSNetPredictor

# Utilities
from util.calc_ste_position import CalcStethoscopePosition
from util.ears_ai import EarsAI

###############################################################################
# Config 値を参照
###############################################################################
CONV_COLOR = config.CONV_COLOR
XGBOOST_COLOR = config.XGBOOST_COLOR
LIGHTGBM_COLOR = config.LIGHTGBM_COLOR
EARSNET_COLOR = config.EARSNET_COLOR
CATBOOST_COLOR = config.CATBOOST_COLOR
NGBOOST_COLOR = config.NGBOOST_COLOR

CONV_ENABLED = config.CONV_ENABLED
XGBOOST_ENABLED = config.XGBOOST_ENABLED
LIGHTGBM_ENABLED = config.LIGHTGBM_ENABLED
CATBOOST_ENABLED = config.CATBOOST_ENABLED
NGBOOST_ENABLED = config.NGBOOST_ENABLED
POSENET_ENABLED = config.POSENET_ENABLED
RTMPOSE_ENABLED = config.RTMPOSE_ENABLED
MobileNetV1SSD_ENABLED = config.MOBILENETV1SSD_ENABLED
YOLOX_ENABLED = config.YOLOX_ENABLED
EARSNET_ENABLED = config.EARSNET_ENABLED

# ★ クロップ画像を使う EARSNet (別モデル) を使うかどうか
EARSNET_CROP_ENABLED = config.EARSNET_CROP_ENABLED

NORMALIZE_ENABLED = config.NORMALIZE_ENABLED

DEVICE = config.DEVICE  # "cuda" or "cpu" など

###############################################################################
# リアルタイムFPS計測用のグローバル変数&スレッド定義
###############################################################################
processed_frames = 0  # 処理済みフレーム数(メインスレッドでインクリメント)
stop_fps_thread = False  # スレッド終了フラグ
fps_history = []


def fps_monitor(interval=1.0):
    """
    別スレッドとして起動し、一定時間おきに processed_frames を確認してリアルタイムFPSを計算する。
    interval=1.0 なら1秒ごとにFPSを出力。
    """
    global processed_frames, stop_fps_thread, fps_history

    last_count = 0
    last_time = time.time()

    while not stop_fps_thread:
        time.sleep(interval)
        now = time.time()

        current_count = processed_frames
        frames_delta = current_count - last_count
        time_delta = now - last_time

        if time_delta > 0:
            current_fps = frames_delta / time_delta
        else:
            current_fps = 0.0

        print(
            f"[FPS Monitor] Real-time FPS: {current_fps:.2f}  (frames: +{frames_delta})"
        )

        fps_history.append((now, current_fps))

        last_count = current_count
        last_time = now


###############################################################################
# モデルロード系
###############################################################################
def load_model(model_path, model_type="lgb"):
    with open(model_path, "rb") as model_file:
        return pickle.load(model_file)


def load_scaler(scaler_path):
    with open(scaler_path, "rb") as f:
        return pickle.load(f)


###############################################################################
# YOLOX
###############################################################################
def init_yolox():
    try:
        from mmengine.registry import DefaultScope

        DefaultScope.get_instance("mmdet", scope_name="mmdet")

        init_args = {
            "model": config.YOLOX_CONFIG_FILE,
            "weights": config.YOLOX_CHECKPOINT_FILE,
            "device": DEVICE,
        }
        yolox_inferencer = DetInferencer(**init_args)
        return yolox_inferencer

    except Exception as e:
        print(f"Error initializing YOLOX: {str(e)}")
        return None


def draw_polygon_and_detection(image, polygon_vertices, stethoscope_x, stethoscope_y):
    overlay = image.copy()
    vertices = polygon_vertices.astype(np.int32)
    cv2.polylines(overlay, [vertices], True, (0, 255, 0), 2)

    if stethoscope_x is not None and stethoscope_y is not None:
        center = (int(stethoscope_x), int(stethoscope_y))
        cv2.circle(overlay, center, 10, (255, 0, 0), -1)
        cv2.circle(overlay, center, 12, (255, 255, 255), 2)
    return overlay


def expand_points(p1, p2):
    """
    2点を中央から外側に拡張(肩や腰の領域を拡大する用途)するヘルパー関数
    """
    mid_x = (p1[0] + p2[0]) / 2
    mid_y = (p1[1] + p2[1]) / 2

    vec_x = p1[0] - mid_x
    vec_y = p1[1] - mid_y

    new_p1 = [mid_x + vec_x * 2, mid_y + vec_y * 2]
    new_p2 = [mid_x - vec_x * 2, mid_y - vec_y * 2]

    return np.array(new_p1), np.array(new_p2)


def point_in_polygon(point, vertices):
    x, y = point
    n = len(vertices)
    inside = False

    j = n - 1
    for i in range(n):
        if (vertices[i][1] > y) != (vertices[j][1] > y):
            slope = (vertices[j][0] - vertices[i][0]) / (
                vertices[j][1] - vertices[i][1]
            )
            intersect_x = slope * (y - vertices[i][1]) + vertices[i][0]
            if x < intersect_x:
                inside = not inside
        j = i

    return inside


def yolox_detector_inference(frame, yolox_inferencer, pose_keypoints, score_thr=0.3):
    """YOLOXで聴診器を検出し、ポリゴン内部にある聴診器の中心座標を返す。"""
    frame_rgb = cv2.cvtColor(frame, cv2.COLOR_BGR2RGB)
    result = yolox_inferencer(inputs=frame_rgb, return_vis=True)
    predictions = result["predictions"][0]
    stethoscope_x = None
    stethoscope_y = None
    max_score = -1

    # keypoints 配列から部位を取得 (COCOフォーマット想定)
    nose = pose_keypoints[0]
    left_shoulder = pose_keypoints[5]
    right_shoulder = pose_keypoints[6]
    left_hip = pose_keypoints[11]
    right_hip = pose_keypoints[12]

    # 肩と腰を大きめに外に広げる
    expanded_left_shoulder, expanded_right_shoulder = expand_points(
        left_shoulder, right_shoulder
    )
    expanded_left_hip, expanded_right_hip = expand_points(left_hip, right_hip)

    polygon_vertices = np.array(
        [
            nose,
            expanded_left_shoulder,
            expanded_left_hip,
            expanded_right_hip,
            expanded_right_shoulder,
        ]
    )

    for i, (label, score) in enumerate(
        zip(predictions["labels"], predictions["scores"])
    ):
        # label=0 → 聴診器と仮定 (学習済みクラスのラベルに合わせる)
        if score >= score_thr and label == 0:
            bbox = predictions["bboxes"][i]
            center_x = (bbox[0] + bbox[2]) / 2
            center_y = (bbox[1] + bbox[3]) / 2

            if point_in_polygon([center_x, center_y], polygon_vertices):
                if score > max_score:
                    stethoscope_x = center_x
                    stethoscope_y = center_y
                    max_score = score

    if stethoscope_x is None or stethoscope_y is None:
        stethoscope_x = 0
        stethoscope_y = 0

    stethoscope_overlay_img = result["visualization"][0]
    if (
        len(stethoscope_overlay_img.shape) == 3
        and stethoscope_overlay_img.shape[2] == 3
    ):
        stethoscope_overlay_img = cv2.cvtColor(
            stethoscope_overlay_img, cv2.COLOR_RGB2BGR
        )

    stethoscope_overlay_img = draw_polygon_and_detection(
        stethoscope_overlay_img, polygon_vertices, stethoscope_x, stethoscope_y
    )

    return stethoscope_overlay_img, stethoscope_x, stethoscope_y


###############################################################################
# 各種座標変換
###############################################################################
def normalize_quadrilateral_with_point(points, extra_point):
    """4点(肩・肩・腰・腰)と任意の1点(聴診器)を正規化して返す。"""
    all_points = np.vstack([points.reshape(-1, 2), extra_point])
    center = np.mean(points.reshape(-1, 2), axis=0)
    centered_points = all_points - center

    shoulder_angle = calculate_rotation_angle(centered_points[0], centered_points[1])
    hip_angle = calculate_rotation_angle(centered_points[2], centered_points[3])
    average_angle = (shoulder_angle + hip_angle) / 2

    rotation_matrix = np.array(
        [
            [np.cos(-average_angle), -np.sin(-average_angle)],
            [np.sin(-average_angle), np.cos(-average_angle)],
        ]
    )

    rotated_points = np.dot(centered_points, rotation_matrix.T)
    max_edge_length = np.max(
        np.linalg.norm(
            np.roll(rotated_points[:4], -1, axis=0) - rotated_points[:4], axis=1
        )
    )
    if max_edge_length == 0:
        return rotated_points  # 0割り防止

    return rotated_points / max_edge_length


def calculate_rotation_angle(point1, point2):
    vector = point2 - point1
    return np.arctan2(vector[1], vector[0])


def video_to_frames(video_path, output_dir):
    os.makedirs(output_dir, exist_ok=True)
    video = cv2.VideoCapture(video_path)
    if not video.isOpened():
        raise IOError(f"Could not open video file: {video_path}")

    frame_num = 0
    while True:
        success, frame = video.read()
        if not success:
            break
        frame_num += 1
        cv2.imwrite(os.path.join(output_dir, f"{frame_num}-frame.png"), frame)

    video.release()
    print(f"All frames saved to {output_dir}")


###############################################################################
# RTMpose キーポイント抽出
###############################################################################
def extract_keypoints_rtmpose(pose_results):
    if not pose_results:
        print("No pose results found.")
        return None

    max_avg_visible = 0
    best_instance = None
    for result in pose_results:
        pred_instances = result.pred_instances
        for instance in pred_instances:
            avg_visible = np.mean(instance.keypoints_visible)
            if avg_visible > max_avg_visible:
                max_avg_visible = avg_visible
                best_instance = instance

    if best_instance is None:
        print("No valid instances found.")
        return None

    keypoints = best_instance.keypoints[0]
    return keypoints


###############################################################################
# 胴体クロップ生成
###############################################################################
def crop_body_from_keypoints(frame, left_shoulder, right_shoulder, left_hip, right_hip):
    """
    RTMPOSE 等で推定された肩・腰をもとに胴体をざっくり囲むバウンディングボックスを計算し、
    そこをクロップして返す。
    戻り値: (cropped_frame, (xmin, ymin))
       cropped_frame: クロップ後の画像 (np.ndarray)
       (xmin, ymin): クロップ領域の左上座標 (元画像座標系へのマッピング用)
    """
    h, w, _ = frame.shape

    # 左右肩・左右腰 4点の x, y
    xs = [left_shoulder[0], right_shoulder[0], left_hip[0], right_hip[0]]
    ys = [left_shoulder[1], right_shoulder[1], left_hip[1], right_hip[1]]

    xmin = int(min(xs))
    xmax = int(max(xs))
    ymin = int(min(ys))
    ymax = int(max(ys))

    # 多少のマージンを足す (上下左右に 20 ピクセルなど)
    margin = 20
    xmin = max(0, xmin - margin)
    xmax = min(w, xmax + margin)
    ymin = max(0, ymin - margin)
    ymax = min(h, ymax + margin)

    cropped_frame = frame[ymin:ymax, xmin:xmax].copy()

    return cropped_frame, (xmin, ymin)


###############################################################################
# メイン処理
###############################################################################
def process_images(args, detector, pose_estimator, visualizer):
    global processed_frames
    ears_ai = EarsAI()
    calc_position = CalcStethoscopePosition()

    base_dir = os.path.join(args.output_dir, "frames")
    results_dir = os.path.join(args.output_dir, "results")
    csv_path = os.path.join(results_dir, "results.csv")
    normalized_csv_path = os.path.join(results_dir, "results-convert.csv")
    pose_overlay_dir = os.path.join(results_dir, "pose_overlay_image")
    stethoscope_overlay_dir = os.path.join(results_dir, "stethoscope_overlay_image")

    # クロップ画像を保存するディレクトリを作成
    cropped_dir = os.path.join(results_dir, "cropped_images")
    os.makedirs(cropped_dir, exist_ok=True)

    os.makedirs(results_dir, exist_ok=True)
    os.makedirs(pose_overlay_dir, exist_ok=True)
    os.makedirs(stethoscope_overlay_dir, exist_ok=True)

    png_files = sorted(
        [f for f in os.listdir(base_dir) if f.lower().endswith(".png")],
        key=lambda x: int(re.search(r"(\d+)", x).group(1)),
    )
    print(f"Found {len(png_files)} PNG files in {base_dir}.")

    rows = []
    normalized_rows = []

    # ------------------------------------------
    # YOLOX 初期化
    # ------------------------------------------
    yolox_inferencer = None
    if YOLOX_ENABLED:
        yolox_inferencer = init_yolox()

    # ------------------------------------------
    # 時間計測用 dict
    # ------------------------------------------
    timings = {
        # 単体推論
        "rtmpose_single": [],
        "yolox_single": [],
        "conv_single": [],
        "lightgbm_single": [],
        "xgboost_single": [],
        "earsnet_single": [],
        "earsnet_cropped_single": [],
        # パイプライン推論
        "pipeline_rtmpose_yolox_conv": [],
        "pipeline_rtmpose_yolox_lightgbm": [],
        "pipeline_rtmpose_yolox_xgboost": [],
        # 今回修正
        "pipeline_earsnet": [],  # EARSNet 単体
        "pipeline_earsnet_cropped": [],  # RTMPose + EARSNet(クロップ)
    }

    # ------------------------------------------
    # 各モデルの事前ロード
    # ------------------------------------------
    if LIGHTGBM_ENABLED:
        lgb_model_x = load_model("./models/LightGBM/stethoscope_calc_x_best_model.pkl")
        lgb_model_y = load_model("./models/LightGBM/stethoscope_calc_y_best_model.pkl")
        lgb_scaler_x = load_scaler("./models/LightGBM/scaler-x.pkl")
        lgb_scaler_y = load_scaler("./models/LightGBM/scaler-y.pkl")

    if XGBOOST_ENABLED:
        xg_model_x = load_model("./models/XGBoost/stethoscope_calc_x_best_model.pkl")
        xg_model_y = load_model("./models/XGBoost/stethoscope_calc_y_best_model.pkl")
        xg_scaler_x = load_scaler("./models/XGBoost/scaler-x.pkl")
        xg_scaler_y = load_scaler("./models/XGBoost/scaler-y.pkl")

    if CATBOOST_ENABLED:
        catboost_model_x = load_model(
            "./models/CatBoost/stethoscope_calc_x_best_model.pkl"
        )
        catboost_model_y = load_model(
            "./models/CatBoost/stethoscope_calc_y_best_model.pkl"
        )

    if NGBOOST_ENABLED:
        ngboost_model_x = load_model(
            "./models/NGBoost/stethoscope_calc_x_best_model.pkl"
        )
        ngboost_model_y = load_model(
            "./models/NGBoost/stethoscope_calc_y_best_model.pkl"
        )

    # 通常 EARSNet (クロップなし)
    if EARSNET_ENABLED:
        earsnet_predictor = EARSNetPredictor(
            weight_path="models/EARSNet/best_model.pth",
            resnet_depth="18",
            pretrained=True,
            device=DEVICE,
        )

    # クロップ画像用 EARSNet (別モデル)
    if EARSNET_CROP_ENABLED:
        earsnet_cropped_predictor = EARSNetPredictor(
            weight_path="models/EARSNet/crop/best_model.pth",  # 想定モデルファイル
            resnet_depth="18",
            pretrained=True,
            device=DEVICE,
        )

    # CSVで使用する列
    input_columns = [
        "left_shoulder_x",
        "left_shoulder_y",
        "right_shoulder_x",
        "right_shoulder_y",
        "left_hip_x",
        "left_hip_y",
        "right_hip_x",
        "right_hip_y",
        "stethoscope_x",
        "stethoscope_y",
    ]

    # ------------------------------------------------------------
    # メインループ(フレームごとに処理)
    # ------------------------------------------------------------
    for image_file_name in png_files:
        image_path = os.path.join(base_dir, image_file_name)
        frame = cv2.imread(image_path)
        if frame is None:
            print(f"Failed to load image: {image_path}")
            continue

        # (A) RTMPose
        rtmpose_time = 0.0
        if RTMPOSE_ENABLED:
            start_time_rtmpose = time.time()
            # ===== RTMpose推論 =====
            frame_rgb = cv2.cvtColor(frame, cv2.COLOR_BGR2RGB)
            det_result = inference_detector(detector, frame_rgb)
            pred_instance = det_result.pred_instances.cpu().numpy()

            bboxes = np.concatenate(
                (pred_instance.bboxes, pred_instance.scores[:, None]), axis=1
            )
            # 人物のみ (label=0想定)
            bboxes = bboxes[
                np.logical_and(pred_instance.labels == 0, pred_instance.scores > 0.3)
            ]
            bboxes = bboxes[nms(bboxes, 0.3), :4]

            pose_results = inference_topdown(pose_estimator, frame_rgb, bboxes)
            data_samples = merge_data_samples(pose_results)
            pose_keypoints = extract_keypoints_rtmpose(pose_results)

            end_time_rtmpose = time.time()
            rtmpose_time = end_time_rtmpose - start_time_rtmpose
            timings["rtmpose_single"].append(rtmpose_time)

            if pose_keypoints is None:
                print(f"Failed to extract keypoints for image: {image_path}")
                processed_frames += 1
                continue

            if visualizer is not None:
                visualizer.add_datasample(
                    "result",
                    frame_rgb,
                    data_sample=data_samples,
                    draw_gt=False,
                    draw_heatmap=False,
                    draw_bbox=False,
                    show_kpt_idx=False,
                    skeleton_style="mmpose",
                    show=False,
                    wait_time=0,
                    kpt_thr=0.3,
                )
            pose_overlay_img = visualizer.get_image()  # (RGB)
            pose_overlay_bgr = cv2.cvtColor(pose_overlay_img, cv2.COLOR_RGB2BGR)
            cv2.imwrite(
                os.path.join(pose_overlay_dir, image_file_name), pose_overlay_bgr
            )

            # COCOフォーマットのキーポイントを取り出す
            left_shoulder = (pose_keypoints[5][0], pose_keypoints[5][1])
            right_shoulder = (pose_keypoints[6][0], pose_keypoints[6][1])
            left_hip = (pose_keypoints[11][0], pose_keypoints[11][1])
            right_hip = (pose_keypoints[12][0], pose_keypoints[12][1])

        elif POSENET_ENABLED:
            # 既存 PoseNet
            start_time_rtmpose = time.time()
            pose_overlay_img, *landmarks = ears_ai.pose_detect(frame, None)
            end_time_rtmpose = time.time()
            rtmpose_time = end_time_rtmpose - start_time_rtmpose
            timings["rtmpose_single"].append(rtmpose_time)

            # landmarks = [left_shoulder, right_shoulder, left_hip, right_hip]
            left_shoulder = landmarks[0]
            right_shoulder = landmarks[1]
            left_hip = landmarks[2]
            right_hip = landmarks[3]

            # pose_overlay_img はすでに BGR 形式想定
            cv2.imwrite(
                os.path.join(pose_overlay_dir, image_file_name), pose_overlay_img
            )
        else:
            # RTMPose/PoseNet どちらも有効でない場合
            left_shoulder = (0, 0)
            right_shoulder = (0, 0)
            left_hip = (0, 0)
            right_hip = (0, 0)

        # (B) YOLOX (必要なら)
        yolox_time = 0.0
        stethoscope_x, stethoscope_y = 0, 0
        if YOLOX_ENABLED:
            if (
                RTMPOSE_ENABLED
                and "pose_keypoints" in locals()
                and pose_keypoints is not None
            ):
                start_time_yolox = time.time()
                stethoscope_overlay_img, stethoscope_x, stethoscope_y = (
                    yolox_detector_inference(frame, yolox_inferencer, pose_keypoints)
                )
                end_time_yolox = time.time()
                yolox_time = end_time_yolox - start_time_yolox
                timings["yolox_single"].append(yolox_time)

                # 可視化
                cv2.imwrite(
                    os.path.join(stethoscope_overlay_dir, image_file_name),
                    stethoscope_overlay_img,
                )

            elif POSENET_ENABLED:
                # PoseNet 用のキー配列に変換してYOLOX
                pose_keypoints_pose_net = [[0, 0]] * 13
                pose_keypoints_pose_net[5] = (left_shoulder[0], left_shoulder[1])
                pose_keypoints_pose_net[6] = (right_shoulder[0], right_shoulder[1])
                pose_keypoints_pose_net[11] = (left_hip[0], left_hip[1])
                pose_keypoints_pose_net[12] = (right_hip[0], right_hip[1])

                start_time_yolox = time.time()
                stethoscope_overlay_img, stethoscope_x, stethoscope_y = (
                    yolox_detector_inference(
                        frame, yolox_inferencer, pose_keypoints_pose_net
                    )
                )
                end_time_yolox = time.time()
                yolox_time = end_time_yolox - start_time_yolox
                timings["yolox_single"].append(yolox_time)

                # 可視化
                cv2.imwrite(
                    os.path.join(stethoscope_overlay_dir, image_file_name),
                    stethoscope_overlay_img,
                )

        # ここで、(RTMPose + YOLOX) の合計検出時間をパイプラインに使う場合あり
        detection_time_rtmpose_yolox = rtmpose_time + yolox_time

        # CSV用に座標をまとめる
        row = {
            "image_file_name": image_file_name,
            "left_shoulder_x": left_shoulder[0],
            "left_shoulder_y": left_shoulder[1],
            "right_shoulder_x": right_shoulder[0],
            "right_shoulder_y": right_shoulder[1],
            "left_hip_x": left_hip[0],
            "left_hip_y": left_hip[1],
            "right_hip_x": right_hip[0],
            "right_hip_y": right_hip[1],
            "stethoscope_x": stethoscope_x,
            "stethoscope_y": stethoscope_y,
        }

        # (C) EARSNet 単体
        #  -> pipeline_earsnet は RTMPose, YOLOX を含まない
        if EARSNET_ENABLED:
            start_time_earsnet = time.time()
            earsnet_x, earsnet_y = earsnet_predictor.predict(image_path)
            end_time_earsnet = time.time()

            earsnet_time = end_time_earsnet - start_time_earsnet
            timings["earsnet_single"].append(earsnet_time)

            # pipeline_earsnet = earsnet単体時間
            timings["pipeline_earsnet"].append(earsnet_time)

            row["earsnet_stethoscope_x"] = earsnet_x
            row["earsnet_stethoscope_y"] = earsnet_y

        # (D) クロップ画像 EARSNet ( RTMPose + EARSNet_Cropped )
        if EARSNET_CROP_ENABLED:
            # 1) クロップ生成
            cropped_img, (crop_xmin, crop_ymin) = crop_body_from_keypoints(
                frame, left_shoulder, right_shoulder, left_hip, right_hip
            )
            # クロップ画像を保存(確認用)
            cropped_filename = os.path.splitext(image_file_name)[0] + "_cropped.png"
            cv2.imwrite(os.path.join(cropped_dir, cropped_filename), cropped_img)

            # 2) EARSNet (クロップ版)
            start_time_earsnet_cropped = time.time()
            earsnet_cropped_x, earsnet_cropped_y = earsnet_cropped_predictor.predict(
                os.path.join(cropped_dir, cropped_filename)
            )
            end_time_earsnet_cropped = time.time()

            earsnet_cropped_time = end_time_earsnet_cropped - start_time_earsnet_cropped
            timings["earsnet_cropped_single"].append(earsnet_cropped_time)

            # pipeline_earsnet_cropped = RTMPose時間 + EARSNet(クロップ)
            pipeline_earsnet_cropped_time = rtmpose_time + earsnet_cropped_time
            timings["pipeline_earsnet_cropped"].append(pipeline_earsnet_cropped_time)

            # 3) 座標を元画像に変換
            global_x = earsnet_cropped_x
            global_y = earsnet_cropped_y

            row["earsnet_crop_stethoscope_x"] = global_x
            row["earsnet_crop_stethoscope_y"] = global_y

        # (E) 正規化
        source_points = np.array(
            [
                [float(row["left_shoulder_x"]), float(row["left_shoulder_y"])],
                [float(row["right_shoulder_x"]), float(row["right_shoulder_y"])],
                [float(row["left_hip_x"]), float(row["left_hip_y"])],
                [float(row["right_hip_x"]), float(row["right_hip_y"])],
            ],
            dtype=np.float32,
        )

        stethoscope_point = np.array(
            [float(row["stethoscope_x"]), float(row["stethoscope_y"])]
        )
        normalized_points = normalize_quadrilateral_with_point(
            source_points.flatten(), stethoscope_point
        )

        normalized_row = {
            "image_file_name": image_file_name,
            "left_shoulder_x": normalized_points[0, 0],
            "left_shoulder_y": normalized_points[0, 1],
            "right_shoulder_x": normalized_points[1, 0],
            "right_shoulder_y": normalized_points[1, 1],
            "left_hip_x": normalized_points[2, 0],
            "left_hip_y": normalized_points[2, 1],
            "right_hip_x": normalized_points[3, 0],
            "right_hip_y": normalized_points[3, 1],
            "stethoscope_x": normalized_points[4, 0],
            "stethoscope_y": normalized_points[4, 1],
        }

        if EARSNET_ENABLED:
            stetho_point_earsnet = np.array(
                [
                    float(row.get("earsnet_stethoscope_x", 0)),
                    float(row.get("earsnet_stethoscope_y", 0)),
                ]
            )
            norm_earsnet = normalize_quadrilateral_with_point(
                source_points.flatten(), stetho_point_earsnet
            )
            normalized_row["earsnet_stethoscope_x"] = norm_earsnet[4, 0]
            normalized_row["earsnet_stethoscope_y"] = norm_earsnet[4, 1]

        if EARSNET_CROP_ENABLED:
            stetho_point_crop = np.array(
                [
                    float(row.get("earsnet_crop_stethoscope_x", 0)),
                    float(row.get("earsnet_crop_stethoscope_y", 0)),
                ]
            )
            norm_earsnet_crop = normalize_quadrilateral_with_point(
                source_points.flatten(), stetho_point_crop
            )
            normalized_row["earsnet_crop_stethoscope_x"] = norm_earsnet_crop[4, 0]
            normalized_row["earsnet_crop_stethoscope_y"] = norm_earsnet_crop[4, 1]

        rows.append(row)
        normalized_rows.append(normalized_row)

        # (F) パイプライン (RTMPose+YOLOX → Conv/LightGBM/XGBoost)
        # ここは従来通り: detection_time_rtmpose_yolox + 各モデル時間

        if RTMPOSE_ENABLED and YOLOX_ENABLED:
            # conv
            if CONV_ENABLED:
                start_conv = time.time()
                source_pts = np.array(
                    [
                        [float(row[f"{pos}_x"]), float(row[f"{pos}_y"])]
                        for pos in [
                            "left_shoulder",
                            "right_shoulder",
                            "left_hip",
                            "right_hip",
                        ]
                    ],
                    dtype=np.float32,
                )
                stetho_pt = np.array(
                    [float(row["stethoscope_x"]), float(row["stethoscope_y"])]
                )
                _ = calc_position.calc_affine(source_pts, *stetho_pt)
                end_conv = time.time()
                conv_time = end_conv - start_conv
                timings["conv_single"].append(conv_time)

                # pipeline_rtmpose_yolox_conv
                timings["pipeline_rtmpose_yolox_conv"].append(
                    detection_time_rtmpose_yolox + conv_time
                )

            # XGBoost
            if XGBOOST_ENABLED:
                xg_start = time.time()
                if NORMALIZE_ENABLED:
                    input_data_xg = pd.DataFrame([normalized_rows[-1]])
                else:
                    input_data_xg = pd.DataFrame([rows[-1]])
                X_scaled_x = xg_scaler_x.transform(input_data_xg[input_columns])
                _ = xg_model_x.predict(X_scaled_x)[0]
                X_scaled_y = xg_scaler_y.transform(input_data_xg[input_columns])
                _ = xg_model_y.predict(X_scaled_y)[0]
                xg_end = time.time()
                xg_time = xg_end - xg_start
                timings["xgboost_single"].append(xg_time)

                timings["pipeline_rtmpose_yolox_xgboost"].append(
                    detection_time_rtmpose_yolox + xg_time
                )

            # LightGBM
            if LIGHTGBM_ENABLED:
                lgb_start = time.time()
                if NORMALIZE_ENABLED:
                    input_data_lgb = pd.DataFrame([normalized_rows[-1]])
                else:
                    input_data_lgb = pd.DataFrame([rows[-1]])
                X_scaled_x = lgb_scaler_x.transform(input_data_lgb[input_columns])
                _ = lgb_model_x.predict(X_scaled_x)[0]
                X_scaled_y = lgb_scaler_y.transform(input_data_lgb[input_columns])
                _ = lgb_model_y.predict(X_scaled_y)[0]
                lgb_end = time.time()
                lgb_time = lgb_end - lgb_start
                timings["lightgbm_single"].append(lgb_time)

                timings["pipeline_rtmpose_yolox_lightgbm"].append(
                    detection_time_rtmpose_yolox + lgb_time
                )

        processed_frames += 1

    # ========================================================================
    # CSV 書き込み
    # ========================================================================
    if rows:
        fieldnames = list(rows[0].keys())
        csvfile_path = os.path.join(results_dir, "results.csv")
        normfile_path = os.path.join(results_dir, "results-convert.csv")

        os.makedirs(results_dir, exist_ok=True)

        with (
            open(csvfile_path, "w", newline="") as csvfile,
            open(normfile_path, "w", newline="") as norm_csvfile,
        ):
            writer = csv.DictWriter(csvfile, fieldnames=fieldnames)
            writer.writeheader()

            norm_fieldnames = list(normalized_rows[0].keys())
            norm_writer = csv.DictWriter(norm_csvfile, fieldnames=norm_fieldnames)
            norm_writer.writeheader()

            for row_, norm_row_ in zip(rows, normalized_rows):
                writer.writerow(row_)
                norm_writer.writerow(norm_row_)

        print(f"Processed and saved results to: {csvfile_path}")
        print(f"Processed and saved normalized results to: {normfile_path}")

        # 可視化・動画化
        generate_visualizations(csvfile_path, base_dir, results_dir)
    else:
        print("No data to write to CSV.")

    # ========================================================================
    # FPS計算 & CSV保存 (サブコンポーネント&パイプラインごとの合計/平均)
    # ========================================================================
    fps_data = []
    for method_name, time_list in timings.items():
        if not time_list:
            continue
        total_time = sum(time_list)
        num_calls = len(time_list)
        avg_time = total_time / num_calls if num_calls > 0 else 0
        fps = 1.0 / avg_time if avg_time > 0 else 0
        fps_data.append(
            {
                "method_name": method_name,
                "num_calls": num_calls,
                "total_time_sec": f"{total_time:.6f}",
                "avg_time_sec": f"{avg_time:.6f}",
                "fps": f"{fps:.2f}",
            }
        )

    fps_csv_path = os.path.join(results_dir, "fps_results.csv")
    with open(fps_csv_path, "w", newline="") as f:
        writer = csv.DictWriter(
            f,
            fieldnames=[
                "method_name",
                "num_calls",
                "total_time_sec",
                "avg_time_sec",
                "fps",
            ],
        )
        writer.writeheader()
        for rowf in fps_data:
            writer.writerow(rowf)

    print("\n===== FPS Results (subcomponent & pipeline) =====")
    for rowf in fps_data:
        print(
            f"{rowf['method_name']}: calls={rowf['num_calls']}, "
            f"total={rowf['total_time_sec']}s, avg={rowf['avg_time_sec']}s, FPS={rowf['fps']}"
        )


###############################################################################
# 可視化・動画化
###############################################################################
def generate_visualizations(csv_path, original_images_dir, results_dir):
    """
    CSVに書き込んだ推定結果を用い、BodyF.pngへの描画や動画化を行う。
    EARSNetクロップ版の結果も描画できるように調整。
    """
    df = pd.read_csv(csv_path)
    body_image = cv2.imread("./images/body/BodyF.png")

    # 生成ディレクトリ設定
    dirs = {"marked": "marked_images"}
    if CONV_ENABLED:
        dirs["conv"] = "conv"
    if XGBOOST_ENABLED:
        dirs["Xgboost"] = "Xgboost"
    if LIGHTGBM_ENABLED:
        dirs["lightGBM"] = "lightGBM"
    if CATBOOST_ENABLED:
        dirs["catboost"] = "catboost"
    if NGBOOST_ENABLED:
        dirs["ngboost"] = "ngboost"
    if EARSNET_ENABLED:
        dirs["earsnet"] = "earsnet"
    if EARSNET_CROP_ENABLED:
        dirs["earsnet_crop"] = "earsnet_crop"

    dirs["combined"] = "combined"

    os.makedirs(os.path.join(results_dir, "marked_images"), exist_ok=True)
    for key in dirs:
        if key != "marked":
            os.makedirs(
                os.path.join(results_dir, f"{dirs[key]}_with_trajectory"), exist_ok=True
            )
            os.makedirs(
                os.path.join(results_dir, f"{dirs[key]}_without_trajectory"),
                exist_ok=True,
            )

    # 描画に使う座標列
    points = {key: [] for key in dirs.keys() if key not in ["marked", "combined"]}

    # 色設定
    colors = {
        "conv": CONV_COLOR,
        "Xgboost": XGBOOST_COLOR,
        "lightGBM": LIGHTGBM_COLOR,
        "catboost": CATBOOST_COLOR,
        "ngboost": NGBOOST_COLOR,
        "earsnet": EARSNET_COLOR,
        "earsnet_crop": (255, 51, 255),  # ピンク系
    }

    for _, row in df.iterrows():
        original_image_path = os.path.join(original_images_dir, row["image_file_name"])
        if not os.path.exists(original_image_path):
            continue
        original_image = cv2.imread(original_image_path)
        if original_image is None:
            continue

        # 肩・腰・聴診器などをマーキング
        for point in [
            "left_shoulder",
            "right_shoulder",
            "left_hip",
            "right_hip",
            "stethoscope",
        ]:
            col_x = f"{point}_x"
            col_y = f"{point}_y"
            if col_x in row and col_y in row:
                val_x = row[col_x]
                val_y = row[col_y]
                if pd.isna(val_x) or pd.isna(val_y):
                    continue
                cv2.circle(
                    original_image,
                    (int(val_x), int(val_y)),
                    10,
                    (255, 255, 0),
                    -1,
                )
        # 保存
        marked_dir = os.path.join(results_dir, "marked_images")
        cv2.imwrite(
            os.path.join(marked_dir, row["image_file_name"]),
            original_image,
        )

        # BodyF.png に軌跡を描画
        combined_image_with_traj = body_image.copy()
        combined_image_without_traj = body_image.copy()

        # 各推定結果を描画
        for key in points.keys():
            col_x = f"{key}_stethoscope_x"
            col_y = f"{key}_stethoscope_y"
            if col_x not in row or col_y not in row:
                continue
            val_x = row[col_x]
            val_y = row[col_y]
            if pd.isna(val_x) or pd.isna(val_y):
                continue

            x, y = int(val_x), int(val_y)
            points[key].append((x, y))

            color = colors[key] if key in colors else (0, 0, 255)

            # 個別 with trajectory
            image_with_trajectory = body_image.copy()
            if len(points[key]) > 1:
                cv2.polylines(
                    image_with_trajectory,
                    [np.array(points[key])],
                    False,
                    color,
                    2,
                )
            cv2.circle(
                image_with_trajectory,
                (x, y),
                10,
                color,
                -1,
            )
            cv2.imwrite(
                os.path.join(
                    results_dir, f"{dirs[key]}_with_trajectory", row["image_file_name"]
                ),
                image_with_trajectory,
            )

            # 個別 without trajectory
            image_without_trajectory = body_image.copy()
            cv2.circle(
                image_without_trajectory,
                (x, y),
                10,
                color,
                -1,
            )
            cv2.imwrite(
                os.path.join(
                    results_dir,
                    f"{dirs[key]}_without_trajectory",
                    row["image_file_name"],
                ),
                image_without_trajectory,
            )

            # combined with trajectory
            if len(points[key]) > 1:
                cv2.polylines(
                    combined_image_with_traj,
                    [np.array(points[key])],
                    False,
                    color,
                    2,
                )
            cv2.circle(
                combined_image_with_traj,
                (x, y),
                10,
                color,
                -1,
            )
            # combined without trajectory
            cv2.circle(
                combined_image_without_traj,
                (x, y),
                10,
                color,
                -1,
            )

        # まとめて保存
        os.makedirs(
            os.path.join(results_dir, "combined_with_trajectory"), exist_ok=True
        )
        os.makedirs(
            os.path.join(results_dir, "combined_without_trajectory"), exist_ok=True
        )

        cv2.imwrite(
            os.path.join(
                results_dir, "combined_with_trajectory", row["image_file_name"]
            ),
            combined_image_with_traj,
        )
        cv2.imwrite(
            os.path.join(
                results_dir, "combined_without_trajectory", row["image_file_name"]
            ),
            combined_image_without_traj,
        )

    # 動画化
    create_video_from_images(
        os.path.join(results_dir, "marked_images"),
        os.path.join(results_dir, "marked_video.mp4"),
    )

    for key in dirs:
        if key not in ["marked", "combined"]:
            create_video_from_images(
                os.path.join(results_dir, f"{dirs[key]}_with_trajectory"),
                os.path.join(results_dir, f"{key}_video_with_trajectory.mp4"),
            )
            create_video_from_images(
                os.path.join(results_dir, f"{dirs[key]}_without_trajectory"),
                os.path.join(results_dir, f"{key}_video_without_trajectory.mp4"),
            )


def create_video_from_images(image_dir, output_path):
    if not os.path.exists(image_dir):
        return
    images = sorted(
        [img for img in os.listdir(image_dir) if img.endswith(".png")],
        key=lambda x: int(re.search(r"(\d+)", x).group()),
    )

    if not images:
        print(f"No images found in {image_dir}")
        return

    frame = cv2.imread(os.path.join(image_dir, images[0]))
    if frame is None:
        print(f"Failed to read the first image in {image_dir}")
        return
    height, width, _ = frame.shape

    video = cv2.VideoWriter(
        output_path, cv2.VideoWriter_fourcc(*"mp4v"), 30, (width, height)
    )

    for image in images:
        img_path = os.path.join(image_dir, image)
        img = cv2.imread(img_path)
        if img is not None:
            video.write(img)

    video.release()
    print(f"Created video: {output_path}")


def main():
    parser = argparse.ArgumentParser(description="Process video and generate results.")
    parser.add_argument(
        "--video_path",
        default="./video/Media1.mp4",
        help="Path to the input video file",
    )
    parser.add_argument(
        "--output_dir",
        default="output",
        help="Directory to save output images and results",
    )

    # RTMpose 用の config & checkpoint (必要に応じて変更)
    det_config = "modules/rtmpose/mmdetection_cfg/rtmdet_m_640-8xb32_coco-person.py"
    det_checkpoint = (
        "models/rtmpose/rtmdet_m_8xb32-100e_coco-obj365-person-235e8209.pth"
    )
    pose_config = (
        "modules/rtmpose/configs/body_2d_keypoint/rtmpose/body8/"
        "rtmpose-l_8xb256-420e_body8-256x192.py"
    )
    pose_checkpoint = "models/rtmpose/rtmpose-l_simcc-aic-coco_pt-aic-coco_420e-256x192-f016ffe0_20230126.pth"

    args = parser.parse_args()

    os.makedirs(args.output_dir, exist_ok=True)

    # 1) FPSモニタ用スレッド開始
    fps_thread = Thread(target=fps_monitor, args=(1.0,), daemon=True)
    fps_thread.start()

    # 2) 動画をフレームに分割
    frames_dir = os.path.join(args.output_dir, "frames")
    video_to_frames(args.video_path, frames_dir)

    # 3) RTMPOSE初期化 (必要なときのみ)
    if RTMPOSE_ENABLED:
        detector = init_detector(det_config, det_checkpoint, device=DEVICE)
        detector.cfg = adapt_mmdet_pipeline(detector.cfg)
        pose_estimator = init_pose_estimator(
            pose_config, pose_checkpoint, device=DEVICE
        )
        visualizer = VISUALIZERS.build(pose_estimator.cfg.visualizer)
        visualizer.set_dataset_meta(
            pose_estimator.dataset_meta, skeleton_style="mmpose"
        )

        process_images(args, detector, pose_estimator, visualizer)
    else:
        process_images(args, None, None, None)

    # 4) スレッド終了指示・join
    global stop_fps_thread
    stop_fps_thread = True
    fps_thread.join()

    print("All done.")


if __name__ == "__main__":
    main()