Newer
Older
Demo-Maker / main.py
@mikado-4410 mikado-4410 on 22 Jan 2025 39 KB [fix]FPSの計算方法を修正
import argparse
import csv
import os
import pickle
import re
import time
from threading import Thread

import cv2
import numpy as np
import pandas as pd
from mmdet.apis import DetInferencer, inference_detector, init_detector

# New imports for RTMPose
from mmpose.apis import inference_topdown
from mmpose.apis import init_model as init_pose_estimator
from mmpose.evaluation.functional import nms
from mmpose.registry import VISUALIZERS
from mmpose.structures import merge_data_samples
from mmpose.utils import adapt_mmdet_pipeline

import config

# EARSNetPredictor のみをインポート
from modules.EARSNet.predictor import EARSNetPredictor
from util.calc_ste_position import CalcStethoscopePosition
from util.ears_ai import EarsAI

# Get colors from config
CONV_COLOR = config.CONV_COLOR
XGBOOST_COLOR = config.XGBOOST_COLOR
LIGHTGBM_COLOR = config.LIGHTGBM_COLOR
EARSNET_COLOR = config.EARSNET_COLOR
CATBOOST_COLOR = config.CATBOOST_COLOR
NGBOOST_COLOR = config.NGBOOST_COLOR

# Get model execution settings
CONV_ENABLED = config.CONV_ENABLED
XGBOOST_ENABLED = config.XGBOOST_ENABLED
LIGHTGBM_ENABLED = config.LIGHTGBM_ENABLED
CATBOOST_ENABLED = config.CATBOOST_ENABLED
NGBOOST_ENABLED = config.NGBOOST_ENABLED
POSENET_ENABLED = config.POSENET_ENABLED
RTMPOSE_ENABLED = config.RTMPOSE_ENABLED
MobileNetV1SSD_ENABLED = config.MOBILENETV1SSD_ENABLED
YOLOX_ENABLED = config.YOLOX_ENABLED
EARSNET_ENABLED = config.EARSNET_ENABLED

# Get normalization setting
NORMALIZE_ENABLED = config.NORMALIZE_ENABLED

###############################################################################
# リアルタイムFPS計測用のグローバル変数&スレッド定義
###############################################################################
processed_frames = 0  # 処理済みフレーム数(メインスレッドでインクリメント)
stop_fps_thread = False  # スレッド終了フラグ

# 必要に応じてリアルタイムFPSの履歴を保存するリスト (後でCSV化したい場合)
fps_history = []


def fps_monitor(interval=1.0):
    """
    別スレッドとして起動し、一定時間おきに processed_frames を確認してリアルタイムFPSを計算する。
    interval=1.0 なら1秒ごとにFPSを出力。
    """
    global processed_frames, stop_fps_thread, fps_history

    last_count = 0
    last_time = time.time()

    while not stop_fps_thread:
        time.sleep(interval)
        now = time.time()

        current_count = processed_frames
        frames_delta = current_count - last_count
        time_delta = now - last_time

        if time_delta > 0:
            current_fps = frames_delta / time_delta
        else:
            current_fps = 0.0

        print(
            f"[FPS Monitor] Real-time FPS: {current_fps:.2f}  (frames: +{frames_delta})"
        )

        # 履歴を残したい場合は下記を使用
        fps_history.append((now, current_fps))

        # カウント更新
        last_count = current_count
        last_time = now


###############################################################################
# 以下は従来の処理 (姿勢推定、聴診器検出、FPS計測など)
###############################################################################
def load_model(model_path, model_type="lgb"):
    with open(model_path, "rb") as model_file:
        return pickle.load(model_file)


def load_scaler(scaler_path):
    with open(scaler_path, "rb") as f:
        return pickle.load(f)


def init_yolox():
    try:
        from mmengine.registry import DefaultScope

        DefaultScope.get_instance("mmdet", scope_name="mmdet")

        init_args = {
            "model": config.YOLOX_CONFIG_FILE,
            "weights": config.YOLOX_CHECKPOINT_FILE,
            "device": config.DEVICE,
        }

        yolox_inferencer = DetInferencer(**init_args)
        return yolox_inferencer

    except Exception as e:
        print(f"Error initializing YOLOX: {str(e)}")
        return None


def draw_polygon_and_detection(image, polygon_vertices, stethoscope_x, stethoscope_y):
    overlay = image.copy()
    vertices = polygon_vertices.astype(np.int32)
    cv2.polylines(overlay, [vertices], True, (0, 255, 0), 2)

    if stethoscope_x is not None and stethoscope_y is not None:
        center = (int(stethoscope_x), int(stethoscope_y))
        cv2.circle(overlay, center, 10, (255, 0, 0), -1)
        cv2.circle(overlay, center, 12, (255, 255, 255), 2)
    return overlay


def expand_points(p1, p2):
    mid_x = (p1[0] + p2[0]) / 2
    mid_y = (p1[1] + p2[1]) / 2

    vec_x = p1[0] - mid_x
    vec_y = p1[1] - mid_y

    new_p1 = [mid_x + vec_x * 2, mid_y + vec_y * 2]
    new_p2 = [mid_x - vec_x * 2, mid_y - vec_y * 2]

    return np.array(new_p1), np.array(new_p2)


def point_in_polygon(point, vertices):
    x, y = point
    n = len(vertices)
    inside = False

    j = n - 1
    for i in range(n):
        if (vertices[i][1] > y) != (vertices[j][1] > y):
            slope = (vertices[j][0] - vertices[i][0]) / (
                vertices[j][1] - vertices[i][1]
            )
            intersect_x = slope * (y - vertices[i][1]) + vertices[i][0]
            if x < intersect_x:
                inside = not inside
        j = i

    return inside


def yolox_detector_inference(frame, yolox_inferencer, pose_keypoints, score_thr=0.3):
    """YOLOXで聴診器を検出し、ポリゴン内部にある聴診器の中心座標を返す。"""
    frame_rgb = cv2.cvtColor(frame, cv2.COLOR_BGR2RGB)
    result = yolox_inferencer(inputs=frame_rgb, return_vis=True)
    predictions = result["predictions"][0]
    stethoscope_x = None
    stethoscope_y = None
    max_score = -1

    # keypoints 配列から部位を取得
    nose = pose_keypoints[0]
    left_shoulder = pose_keypoints[5]
    right_shoulder = pose_keypoints[6]
    left_hip = pose_keypoints[11]
    right_hip = pose_keypoints[12]

    # 肩と腰を大きめに外に広げる
    expanded_left_shoulder, expanded_right_shoulder = expand_points(
        left_shoulder, right_shoulder
    )
    expanded_left_hip, expanded_right_hip = expand_points(left_hip, right_hip)

    polygon_vertices = np.array(
        [
            nose,
            expanded_left_shoulder,
            expanded_left_hip,
            expanded_right_hip,
            expanded_right_shoulder,
        ]
    )

    for i, (label, score) in enumerate(
        zip(predictions["labels"], predictions["scores"])
    ):
        # label=0 → 聴診器と仮定 (実際は学習クラスによって要変更)
        if score >= score_thr and label == 0:
            bbox = predictions["bboxes"][i]
            center_x = (bbox[0] + bbox[2]) / 2
            center_y = (bbox[1] + bbox[3]) / 2

            if point_in_polygon([center_x, center_y], polygon_vertices):
                if score > max_score:
                    stethoscope_x = center_x
                    stethoscope_y = center_y
                    max_score = score

    if stethoscope_x is None or stethoscope_y is None:
        stethoscope_x = 0
        stethoscope_y = 0

    stethoscope_overlay_img = result["visualization"][0]
    if (
        len(stethoscope_overlay_img.shape) == 3
        and stethoscope_overlay_img.shape[2] == 3
    ):
        stethoscope_overlay_img = cv2.cvtColor(
            stethoscope_overlay_img, cv2.COLOR_RGB2BGR
        )

    stethoscope_overlay_img = draw_polygon_and_detection(
        stethoscope_overlay_img, polygon_vertices, stethoscope_x, stethoscope_y
    )

    return stethoscope_overlay_img, stethoscope_x, stethoscope_y


def normalize_quadrilateral_with_point(points, extra_point):
    """4点(肩・肩・腰・腰)と任意の1点(聴診器)を正規化して返す。"""
    all_points = np.vstack([points.reshape(-1, 2), extra_point])
    center = np.mean(points.reshape(-1, 2), axis=0)
    centered_points = all_points - center

    shoulder_angle = calculate_rotation_angle(centered_points[0], centered_points[1])
    hip_angle = calculate_rotation_angle(centered_points[2], centered_points[3])
    average_angle = (shoulder_angle + hip_angle) / 2

    rotation_matrix = np.array(
        [
            [np.cos(-average_angle), -np.sin(-average_angle)],
            [np.sin(-average_angle), np.cos(-average_angle)],
        ]
    )

    rotated_points = np.dot(centered_points, rotation_matrix.T)
    max_edge_length = np.max(
        np.linalg.norm(
            np.roll(rotated_points[:4], -1, axis=0) - rotated_points[:4], axis=1
        )
    )
    if max_edge_length == 0:
        return rotated_points  # 0割り防止
    return rotated_points / max_edge_length


def calculate_rotation_angle(point1, point2):
    vector = point2 - point1
    return np.arctan2(vector[1], vector[0])


def video_to_frames(video_path, output_dir):
    os.makedirs(output_dir, exist_ok=True)
    video = cv2.VideoCapture(video_path)
    if not video.isOpened():
        raise IOError(f"Could not open video file: {video_path}")

    frame_num = 0
    while True:
        success, frame = video.read()
        if not success:
            break
        frame_num += 1
        cv2.imwrite(os.path.join(output_dir, f"{frame_num}-frame.png"), frame)

    video.release()
    print(f"All frames saved to {output_dir}")


def extract_keypoints_rtmpose(pose_results):
    if not pose_results:
        print("No pose results found.")
        return None

    max_avg_visible = 0
    best_instance = None
    for result in pose_results:
        pred_instances = result.pred_instances
        for instance in pred_instances:
            avg_visible = np.mean(instance.keypoints_visible)
            if avg_visible > max_avg_visible:
                max_avg_visible = avg_visible
                best_instance = instance

    if best_instance is None:
        print("No valid instances found.")
        return None

    keypoints = best_instance.keypoints[0]
    return keypoints


def process_images(args, detector, pose_estimator, visualizer):
    """
    メインスレッドでフレームごとの推論を行う。
    別スレッドでリアルタイムFPSを計測しているため、
    フレーム処理終了後に processed_frames をインクリメントする。
    """
    print("Starting process_images function...")
    global processed_frames  # 別スレッドと共有
    ears_ai = EarsAI()
    calc_position = CalcStethoscopePosition()
    base_dir = os.path.join(args.output_dir, "frames")
    results_dir = os.path.join(args.output_dir, "results")
    csv_path = os.path.join(results_dir, "results.csv")
    normalized_csv_path = os.path.join(results_dir, "results-convert.csv")
    pose_overlay_dir = os.path.join(results_dir, "pose_overlay_image")
    stethoscope_overlay_dir = os.path.join(results_dir, "stethoscope_overlay_image")

    os.makedirs(results_dir, exist_ok=True)
    os.makedirs(pose_overlay_dir, exist_ok=True)
    os.makedirs(stethoscope_overlay_dir, exist_ok=True)

    png_files = sorted(
        [f for f in os.listdir(base_dir) if f.lower().endswith(".png")],
        key=lambda x: int(re.search(r"(\d+)", x).group(1)),
    )
    print(f"Found {len(png_files)} PNG files.")

    rows = []
    normalized_rows = []

    # ----------------
    # YOLOX 初期化
    # ----------------
    yolox_inferencer = None
    if YOLOX_ENABLED:
        yolox_inferencer = init_yolox()

    # ----------------
    # 時間計測用 dict
    # ----------------
    timings = {
        # 単体推論
        "rtmpose_single": [],
        "yolox_single": [],
        "conv_single": [],
        "lightgbm_single": [],
        "xgboost_single": [],
        "earsnet_single": [],
        # パイプライン推論 (RTMPose+YOLOX → 各モデル)
        "pipeline_rtmpose_yolox_conv": [],
        "pipeline_rtmpose_yolox_lightgbm": [],
        "pipeline_rtmpose_yolox_xgboost": [],
        "pipeline_earsnet": [],
    }

    # ----------------
    # 各モデルの事前ロード
    # ----------------
    if LIGHTGBM_ENABLED:
        lgb_model_x = load_model("./models/LightGBM/stethoscope_calc_x_best_model.pkl")
        lgb_model_y = load_model("./models/LightGBM/stethoscope_calc_y_best_model.pkl")
        lgb_scaler_x = load_scaler("./models/LightGBM/scaler-x.pkl")
        lgb_scaler_y = load_scaler("./models/LightGBM/scaler-y.pkl")
    if XGBOOST_ENABLED:
        xg_model_x = load_model("./models/XGBoost/stethoscope_calc_x_best_model.pkl")
        xg_model_y = load_model("./models/XGBoost/stethoscope_calc_y_best_model.pkl")
        xg_scaler_x = load_scaler("./models/XGBoost/scaler-x.pkl")
        xg_scaler_y = load_scaler("./models/XGBoost/scaler-y.pkl")
    if CATBOOST_ENABLED:
        catboost_model_x = load_model(
            "./models/CatBoost/stethoscope_calc_x_best_model.pkl"
        )
        catboost_model_y = load_model(
            "./models/CatBoost/stethoscope_calc_y_best_model.pkl"
        )
    if NGBOOST_ENABLED:
        ngboost_model_x = load_model(
            "./models/NGBoost/stethoscope_calc_x_best_model.pkl"
        )
        ngboost_model_y = load_model(
            "./models/NGBoost/stethoscope_calc_y_best_model.pkl"
        )

    # EARSNET を使用する場合のみ初期化
    if EARSNET_ENABLED:
        earsnet_predictor = EARSNetPredictor(
            weight_path="models/EARSNet/best_model.pth",
            resnet_depth="18",  # 学習時と同じResNet深度
            pretrained=True,  # 学習時の設定に合わせる
            device="cuda",  # or "cpu"
        )

    input_columns = [
        "left_shoulder_x",
        "left_shoulder_y",
        "right_shoulder_x",
        "right_shoulder_y",
        "left_hip_x",
        "left_hip_y",
        "right_hip_x",
        "right_hip_y",
        "stethoscope_x",
        "stethoscope_y",
    ]

    # ------------------------------------------------------------
    # メインループ(フレームごとに処理)
    # ------------------------------------------------------------
    for image_file_name in png_files:
        image_path = os.path.join(base_dir, image_file_name)
        frame = cv2.imread(image_path)
        if frame is None:
            print(f"Failed to load image: {image_path}")
            continue

        # (A) -- 姿勢推定(RTMPose or PoseNet) & YOLOX 推論までの時間測定の準備
        pipeline_detection_start = time.time()

        # ============================================================
        # (1) PoseNet or RTMPOSE による姿勢推定(肩・腰座標取得)
        # ============================================================
        left_shoulder = (0, 0)
        right_shoulder = (0, 0)
        left_hip = (0, 0)
        right_hip = (0, 0)
        pose_overlay_img = frame.copy()

        if POSENET_ENABLED:
            # ▼ PoseNet
            start_time_pose = time.time()
            pose_overlay_img, *landmarks = ears_ai.pose_detect(frame, None)
            end_time_pose = time.time()
            timings["rtmpose_single"].append(end_time_pose - start_time_pose)

            left_shoulder = landmarks[0]
            right_shoulder = landmarks[1]
            left_hip = landmarks[2]
            right_hip = landmarks[3]

        elif RTMPOSE_ENABLED:
            # ▼ RTMPOSE
            frame_rgb = cv2.cvtColor(frame, cv2.COLOR_BGR2RGB)

            start_time_pose = time.time()
            det_result = inference_detector(detector, frame_rgb)
            pred_instance = det_result.pred_instances.cpu().numpy()
            bboxes = np.concatenate(
                (pred_instance.bboxes, pred_instance.scores[:, None]), axis=1
            )
            # 人物のみ (label=0想定)
            bboxes = bboxes[
                np.logical_and(pred_instance.labels == 0, pred_instance.scores > 0.3)
            ]
            bboxes = bboxes[nms(bboxes, 0.3), :4]
            pose_results = inference_topdown(pose_estimator, frame_rgb, bboxes)
            data_samples = merge_data_samples(pose_results)
            pose_keypoints = extract_keypoints_rtmpose(pose_results)
            end_time_pose = time.time()

            timings["rtmpose_single"].append(end_time_pose - start_time_pose)

            if pose_keypoints is None:
                print(f"Failed to extract keypoints for image: {image_path}")
                processed_frames += 1
                continue

            left_shoulder = pose_keypoints[5]
            right_shoulder = pose_keypoints[6]
            left_hip = pose_keypoints[11]
            right_hip = pose_keypoints[12]

            # 可視化
            if visualizer is not None:
                visualizer.add_datasample(
                    "result",
                    frame_rgb,
                    data_sample=data_samples,
                    draw_gt=False,
                    draw_heatmap=False,
                    draw_bbox=False,
                    show_kpt_idx=False,
                    skeleton_style="mmpose",
                    show=False,
                    wait_time=0,
                    kpt_thr=0.3,
                )
            pose_overlay_img = visualizer.get_image()  # (RGB)

        # ============================================================
        # (2) YOLOX or SSD で聴診器の推定(必要に応じて)
        # ============================================================
        stethoscope_overlay_img = frame.copy()
        stethoscope_x = 0
        stethoscope_y = 0

        if MobileNetV1SSD_ENABLED:
            start_time_ssd = time.time()
            stethoscope_overlay_img, stethoscope_x, stethoscope_y = ears_ai.ssd_detect(
                frame, None
            )
            end_time_ssd = time.time()

        if YOLOX_ENABLED:
            if (
                RTMPOSE_ENABLED
                and "pose_keypoints" in locals()
                and pose_keypoints is not None
            ):
                start_time_yolox = time.time()
                (
                    stethoscope_overlay_img,
                    stethoscope_x,
                    stethoscope_y,
                ) = yolox_detector_inference(frame, yolox_inferencer, pose_keypoints)
                end_time_yolox = time.time()
                timings["yolox_single"].append(end_time_yolox - start_time_yolox)

            elif POSENET_ENABLED:
                pose_keypoints_pose_net = [[0, 0]] * 13
                pose_keypoints_pose_net[5] = left_shoulder
                pose_keypoints_pose_net[6] = right_shoulder
                pose_keypoints_pose_net[11] = left_hip
                pose_keypoints_pose_net[12] = right_hip

                start_time_yolox = time.time()
                (
                    stethoscope_overlay_img,
                    stethoscope_x,
                    stethoscope_y,
                ) = yolox_detector_inference(
                    frame, yolox_inferencer, pose_keypoints_pose_net
                )
                end_time_yolox = time.time()
                timings["yolox_single"].append(end_time_yolox - start_time_yolox)

        # (A') -- RTMPose + YOLOX の検出処理終了時刻 (パイプライン計測用)
        pipeline_detection_end = time.time()
        detection_time = pipeline_detection_end - pipeline_detection_start

        # 可視化結果を保存
        if (RTMPOSE_ENABLED or POSENET_ENABLED) and (
            YOLOX_ENABLED or MobileNetV1SSD_ENABLED
        ):
            if RTMPOSE_ENABLED:
                cv2.imwrite(
                    os.path.join(pose_overlay_dir, image_file_name),
                    cv2.cvtColor(pose_overlay_img, cv2.COLOR_RGB2BGR),
                )
            else:
                cv2.imwrite(
                    os.path.join(pose_overlay_dir, image_file_name),
                    pose_overlay_img,
                )

            cv2.imwrite(
                os.path.join(stethoscope_overlay_dir, image_file_name),
                stethoscope_overlay_img,
            )

        # ============================================================
        # (3) CSV用に肩・腰・聴診器座標をまとめる
        # ============================================================
        if POSENET_ENABLED:
            row = {
                "image_file_name": image_file_name,
                "left_shoulder_x": left_shoulder[1],
                "left_shoulder_y": left_shoulder[0],
                "right_shoulder_x": right_shoulder[1],
                "right_shoulder_y": right_shoulder[0],
                "left_hip_x": left_hip[1],
                "left_hip_y": left_hip[0],
                "right_hip_x": right_hip[1],
                "right_hip_y": right_hip[0],
                "stethoscope_x": stethoscope_x,
                "stethoscope_y": stethoscope_y,
            }
        else:
            row = {
                "image_file_name": image_file_name,
                "left_shoulder_x": left_shoulder[0],
                "left_shoulder_y": left_shoulder[1],
                "right_shoulder_x": right_shoulder[0],
                "right_shoulder_y": right_shoulder[1],
                "left_hip_x": left_hip[0],
                "left_hip_y": left_hip[1],
                "right_hip_x": right_hip[0],
                "right_hip_y": right_hip[1],
                "stethoscope_x": stethoscope_x,
                "stethoscope_y": stethoscope_y,
            }

        # (C) EARSNET
        if EARSNET_ENABLED:
            start_earsnet = time.time()
            earsnet_x, earsnet_y = earsnet_predictor.predict(image_path)
            end_earsnet = time.time()
            timings["earsnet_single"].append(end_earsnet - start_earsnet)

            row["earsnet_stethoscope_x"] = earsnet_x
            row["earsnet_stethoscope_y"] = earsnet_y

            # EARSNETパイプライン時間 (単体処理として計測しておく)
            pipeline_earsnet_time = end_earsnet - start_earsnet
            timings["pipeline_earsnet"].append(pipeline_earsnet_time)

        rows.append(row)

        # ============================================================
        # (4) 正規化処理 (4点+聴診器)
        # ============================================================
        source_points = np.array(
            [
                [float(row[f"{pos}_x"]), float(row[f"{pos}_y"])]
                for pos in ["left_shoulder", "right_shoulder", "left_hip", "right_hip"]
            ],
            dtype=np.float32,
        )
        stethoscope_point = np.array(
            [float(row["stethoscope_x"]), float(row["stethoscope_y"])]
        )
        normalized_points = normalize_quadrilateral_with_point(
            source_points.flatten(), stethoscope_point
        )

        normalized_row = {
            "image_file_name": image_file_name,
            "left_shoulder_x": normalized_points[0, 0],
            "left_shoulder_y": normalized_points[0, 1],
            "right_shoulder_x": normalized_points[1, 0],
            "right_shoulder_y": normalized_points[1, 1],
            "left_hip_x": normalized_points[2, 0],
            "left_hip_y": normalized_points[2, 1],
            "right_hip_x": normalized_points[3, 0],
            "right_hip_y": normalized_points[3, 1],
            "stethoscope_x": normalized_points[4, 0],
            "stethoscope_y": normalized_points[4, 1],
        }
        if EARSNET_ENABLED:
            normalized_row["earsnet_stethoscope_x"] = row["earsnet_stethoscope_x"]
            normalized_row["earsnet_stethoscope_y"] = row["earsnet_stethoscope_y"]

        normalized_rows.append(normalized_row)

        # (5) パイプライン推論 (Conv, LightGBM, XGBoost, etc.)
        if RTMPOSE_ENABLED and YOLOX_ENABLED:
            # conv
            if CONV_ENABLED:
                conv_start = time.time()
                source_pts = np.array(
                    [
                        [float(row[f"{pos}_x"]), float(row[f"{pos}_y"])]
                        for pos in [
                            "left_shoulder",
                            "right_shoulder",
                            "left_hip",
                            "right_hip",
                        ]
                    ],
                    dtype=np.float32,
                )
                stetho_pt = np.array(
                    [float(row["stethoscope_x"]), float(row["stethoscope_y"])]
                )
                conv_stethoscope = calc_position.calc_affine(source_pts, *stetho_pt)
                conv_end = time.time()
                timings["conv_single"].append(conv_end - conv_start)

                pipeline_time_conv = detection_time + (conv_end - conv_start)
                timings["pipeline_rtmpose_yolox_conv"].append(pipeline_time_conv)

            # XGBoost
            if XGBOOST_ENABLED:
                xg_start = time.time()
                input_data_xg = (
                    pd.DataFrame([normalized_rows[-1]])
                    if NORMALIZE_ENABLED
                    else pd.DataFrame([rows[-1]])
                )
                X_scaled_x = xg_scaler_x.transform(input_data_xg[input_columns])
                xg_x_pred = int(xg_model_x.predict(X_scaled_x)[0])
                X_scaled_y = xg_scaler_y.transform(input_data_xg[input_columns])
                xg_y_pred = int(xg_model_y.predict(X_scaled_y)[0])
                xg_end = time.time()
                timings["xgboost_single"].append(xg_end - xg_start)

                pipeline_time_xgboost = detection_time + (xg_end - xg_start)
                timings["pipeline_rtmpose_yolox_xgboost"].append(pipeline_time_xgboost)

            # LightGBM
            if LIGHTGBM_ENABLED:
                lgb_start = time.time()
                input_data_lgb = (
                    pd.DataFrame([normalized_rows[-1]])
                    if NORMALIZE_ENABLED
                    else pd.DataFrame([rows[-1]])
                )
                X_scaled_x = lgb_scaler_x.transform(input_data_lgb[input_columns])
                lgb_x_pred = int(lgb_model_x.predict(X_scaled_x)[0])
                X_scaled_y = lgb_scaler_y.transform(input_data_lgb[input_columns])
                lgb_y_pred = int(lgb_model_y.predict(X_scaled_y)[0])
                lgb_end = time.time()
                timings["lightgbm_single"].append(lgb_end - lgb_start)

                pipeline_time_lightgbm = detection_time + (lgb_end - lgb_start)
                timings["pipeline_rtmpose_yolox_lightgbm"].append(
                    pipeline_time_lightgbm
                )

        processed_frames += 1

    # ========================================================================
    # CSV 書き込み
    # ========================================================================
    if rows:
        print(f"Writing {len(rows)} rows to CSV...")
        fieldnames = list(rows[0].keys())
        if CONV_ENABLED:
            fieldnames.extend(["conv_stethoscope_x", "conv_stethoscope_y"])
        if XGBOOST_ENABLED:
            fieldnames.extend(["Xgboost_stethoscope_x", "Xgboost_stethoscope_y"])
        if LIGHTGBM_ENABLED:
            fieldnames.extend(["lightGBM_stethoscope_x", "lightGBM_stethoscope_y"])
        if CATBOOST_ENABLED:
            fieldnames.extend(["catboost_stethoscope_x", "catboost_stethoscope_y"])
        if NGBOOST_ENABLED:
            fieldnames.extend(["ngboost_stethoscope_x", "ngboost_stethoscope_y"])

        csvfile_path = os.path.join(results_dir, "results.csv")
        normfile_path = os.path.join(results_dir, "results-convert.csv")
        with (
            open(csvfile_path, "w", newline="") as csvfile,
            open(normfile_path, "w", newline="") as norm_csvfile,
        ):
            writer = csv.DictWriter(csvfile, fieldnames=fieldnames)
            norm_writer = csv.DictWriter(norm_csvfile, fieldnames=fieldnames)
            writer.writeheader()
            norm_writer.writeheader()

            # 前回値を保持する辞書 (未検出時に使いたい場合)
            prev_values = {}
            if CONV_ENABLED:
                prev_values["conv"] = (180, 180)
            if LIGHTGBM_ENABLED:
                prev_values["lightGBM"] = (180, 180)
            if XGBOOST_ENABLED:
                prev_values["Xgboost"] = (180, 180)
            if CATBOOST_ENABLED:
                prev_values["catboost"] = (180, 180)
            if NGBOOST_ENABLED:
                prev_values["ngboost"] = (180, 180)

            for row_, norm_row_ in zip(rows, normalized_rows):
                writer.writerow(row_)
                norm_writer.writerow(norm_row_)

        print(f"Processed and saved results to: {csvfile_path}")
        print(f"Processed and saved normalized results to: {normfile_path}")

        generate_visualizations(csvfile_path, base_dir, results_dir)
    else:
        print("No data to write to CSV.")

    # ========================================================================
    # (6) FPS計算 & CSV保存 (サブコンポーネント&パイプラインごとの合計/平均)
    # ========================================================================
    fps_data = []
    for method_name, time_list in timings.items():
        if not time_list:
            continue
        total_time = sum(time_list)
        num_calls = len(time_list)
        avg_time = total_time / num_calls if num_calls > 0 else 0
        fps = 1.0 / avg_time if avg_time > 0 else 0
        fps_data.append(
            {
                "method_name": method_name,
                "num_calls": num_calls,
                "total_time_sec": f"{total_time:.6f}",
                "avg_time_sec": f"{avg_time:.6f}",
                "fps": f"{fps:.2f}",
            }
        )

    fps_csv_path = os.path.join(results_dir, "fps_results.csv")
    with open(fps_csv_path, "w", newline="") as f:
        writer = csv.DictWriter(
            f,
            fieldnames=[
                "method_name",
                "num_calls",
                "total_time_sec",
                "avg_time_sec",
                "fps",
            ],
        )
        writer.writeheader()
        for rowf in fps_data:
            writer.writerow(rowf)

    print("\n===== FPS Results (subcomponent & pipeline) =====")
    for rowf in fps_data:
        print(
            f"{rowf['method_name']}: calls={rowf['num_calls']}, "
            f"total={rowf['total_time_sec']}s, avg={rowf['avg_time_sec']}s, FPS={rowf['fps']}"
        )


def generate_visualizations(csv_path, original_images_dir, results_dir):
    """
    NaNが混入した場合に描画でエラーにならないように修正。
    聴診器の推論結果がNaNの場合は描画をスキップする。
    """
    df = pd.read_csv(csv_path)
    body_image = cv2.imread("./images/body/BodyF.png")

    dirs = {"marked": "marked_images"}
    if CONV_ENABLED:
        dirs["conv"] = "conv"
    if XGBOOST_ENABLED:
        dirs["Xgboost"] = "Xgboost"
    if LIGHTGBM_ENABLED:
        dirs["lightGBM"] = "lightGBM"
    if CATBOOST_ENABLED:
        dirs["catboost"] = "catboost"
    if NGBOOST_ENABLED:
        dirs["ngboost"] = "ngboost"
    if EARSNET_ENABLED:
        dirs["earsnet"] = "earsnet"
    dirs["combined"] = "combined"

    os.makedirs(os.path.join(results_dir, "marked_images"), exist_ok=True)
    for key in dirs:
        if key != "marked":
            os.makedirs(
                os.path.join(results_dir, f"{dirs[key]}_with_trajectory"), exist_ok=True
            )
            os.makedirs(
                os.path.join(results_dir, f"{dirs[key]}_without_trajectory"),
                exist_ok=True,
            )

    points = {key: [] for key in dirs.keys() if key != "marked"}
    colors = {
        "conv": CONV_COLOR,
        "Xgboost": XGBOOST_COLOR,
        "lightGBM": LIGHTGBM_COLOR,
        "catboost": CATBOOST_COLOR,
        "ngboost": NGBOOST_COLOR,
        "earsnet": EARSNET_COLOR,
    }

    for _, row in df.iterrows():
        original_image_path = os.path.join(original_images_dir, row["image_file_name"])
        if not os.path.exists(original_image_path):
            continue

        original_image = cv2.imread(original_image_path)
        if original_image is None:
            continue

        # 肩・腰・聴診器などをマーキング
        for point in [
            "left_shoulder",
            "right_shoulder",
            "left_hip",
            "right_hip",
            "stethoscope",
        ]:
            col_x = f"{point}_x"
            col_y = f"{point}_y"
            if col_x in row and col_y in row:
                val_x = row[col_x]
                val_y = row[col_y]
                # NaNチェック
                if pd.isna(val_x) or pd.isna(val_y):
                    continue
                cv2.circle(
                    original_image,
                    (int(val_x), int(val_y)),
                    10,
                    (255, 255, 0),
                    -1,
                )

        cv2.imwrite(
            os.path.join(results_dir, "marked_images", row["image_file_name"]),
            original_image,
        )

        # BodyF.png の上に軌跡を描画する
        combined_image_with_traj = body_image.copy()
        combined_image_without_traj = body_image.copy()

        for key in points:
            col_x = f"{key}_stethoscope_x"
            col_y = f"{key}_stethoscope_y"
            if col_x not in row or col_y not in row:
                continue

            val_x = row[col_x]
            val_y = row[col_y]
            # NaNであればスキップ
            if pd.isna(val_x) or pd.isna(val_y):
                continue

            x, y = int(val_x), int(val_y)
            points[key].append((x, y))

            # 1) 個別 with trajectory
            image_with_trajectory = body_image.copy()
            if len(points[key]) > 1:
                cv2.polylines(
                    image_with_trajectory,
                    [np.array(points[key])],
                    False,
                    colors.get(key, (0, 0, 255)),
                    2,
                )
            cv2.circle(
                image_with_trajectory,
                (x, y),
                10,
                colors.get(key, (0, 0, 255)),
                -1,
            )
            cv2.imwrite(
                os.path.join(
                    results_dir, f"{dirs[key]}_with_trajectory", row["image_file_name"]
                ),
                image_with_trajectory,
            )

            # 2) 個別 without trajectory
            image_without_trajectory = body_image.copy()
            cv2.circle(
                image_without_trajectory,
                (x, y),
                10,
                colors.get(key, (0, 0, 255)),
                -1,
            )
            cv2.imwrite(
                os.path.join(
                    results_dir,
                    f"{dirs[key]}_without_trajectory",
                    row["image_file_name"],
                ),
                image_without_trajectory,
            )

            # 3) combined with trajectory
            if len(points[key]) > 1:
                cv2.polylines(
                    combined_image_with_traj,
                    [np.array(points[key])],
                    False,
                    colors.get(key, (0, 0, 255)),
                    2,
                )
            cv2.circle(
                combined_image_with_traj,
                (x, y),
                10,
                colors.get(key, (0, 0, 255)),
                -1,
            )
            # 4) combined without trajectory
            cv2.circle(
                combined_image_without_traj,
                (x, y),
                10,
                colors.get(key, (0, 0, 255)),
                -1,
            )

        # まとめて保存
        os.makedirs(
            os.path.join(results_dir, "combined_with_trajectory"), exist_ok=True
        )
        os.makedirs(
            os.path.join(results_dir, "combined_without_trajectory"), exist_ok=True
        )

        cv2.imwrite(
            os.path.join(
                results_dir, "combined_with_trajectory", row["image_file_name"]
            ),
            combined_image_with_traj,
        )
        cv2.imwrite(
            os.path.join(
                results_dir, "combined_without_trajectory", row["image_file_name"]
            ),
            combined_image_without_traj,
        )

    # 動画化
    create_video_from_images(
        os.path.join(results_dir, "marked_images"),
        os.path.join(results_dir, "marked_video.mp4"),
    )

    for key in dirs:
        if key != "marked":
            create_video_from_images(
                os.path.join(results_dir, f"{dirs[key]}_with_trajectory"),
                os.path.join(results_dir, f"{key}_video_with_trajectory.mp4"),
            )
            create_video_from_images(
                os.path.join(results_dir, f"{dirs[key]}_without_trajectory"),
                os.path.join(results_dir, f"{key}_video_without_trajectory.mp4"),
            )


def create_video_from_images(image_dir, output_path):
    """
    指定ディレクトリ内の PNG 画像を1つの動画に変換する
    """
    if not os.path.exists(image_dir):
        return
    images = sorted(
        [img for img in os.listdir(image_dir) if img.endswith(".png")],
        key=lambda x: int(re.search(r"(\d+)", x).group()),
    )

    if not images:
        print(f"No images found in {image_dir}")
        return

    frame = cv2.imread(os.path.join(image_dir, images[0]))
    height, width, _ = frame.shape

    video = cv2.VideoWriter(
        output_path, cv2.VideoWriter_fourcc(*"mp4v"), 30, (width, height)
    )

    for image in images:
        img_path = os.path.join(image_dir, image)
        img = cv2.imread(img_path)
        video.write(img)

    video.release()
    print(f"Created video: {output_path}")


def main():
    parser = argparse.ArgumentParser(description="Process video and generate results.")
    parser.add_argument(
        "--video_path",
        default="./video/Media1.mp4",
        help="Path to the input video file",
    )
    parser.add_argument(
        "--output_dir",
        default="output",
        help="Directory to save output images and results",
    )
    det_config = "modules/rtmpose/mmdetection_cfg/rtmdet_m_640-8xb32_coco-person.py"
    det_checkpoint = (
        "models/rtmpose/rtmdet_m_8xb32-100e_coco-obj365-person-235e8209.pth"
    )
    pose_config = (
        "modules/rtmpose/configs/body_2d_keypoint/rtmpose/body8/"
        "rtmpose-l_8xb256-420e_body8-256x192.py"
    )
    pose_checkpoint = "models/rtmpose/rtmpose-l_simcc-aic-coco_pt-aic-coco_420e-256x192-f016ffe0_20230126.pth"

    args = parser.parse_args()

    os.makedirs(args.output_dir, exist_ok=True)

    # -------------------------
    # 1) FPSモニタ用スレッド開始
    # -------------------------
    fps_thread = Thread(target=fps_monitor, args=(1.0,), daemon=True)
    fps_thread.start()

    # -------------------------
    # 2) 動画をフレームに分割
    # -------------------------
    frames_dir = os.path.join(args.output_dir, "frames")
    video_to_frames(args.video_path, frames_dir)

    # -------------------------
    # 3) RTMPOSE初期化 (必要に応じて)
    # -------------------------
    if RTMPOSE_ENABLED:
        detector = init_detector(det_config, det_checkpoint, device="cuda:0")
        detector.cfg = adapt_mmdet_pipeline(detector.cfg)
        pose_estimator = init_pose_estimator(
            pose_config, pose_checkpoint, device="cuda:0"
        )
        visualizer = VISUALIZERS.build(pose_estimator.cfg.visualizer)
        visualizer.set_dataset_meta(
            pose_estimator.dataset_meta, skeleton_style="mmpose"
        )

        process_images(args, detector, pose_estimator, visualizer)
    else:
        process_images(args, None, None, None)

    # -------------------------
    # 4) スレッド終了指示・join
    # -------------------------
    global stop_fps_thread
    stop_fps_thread = True
    fps_thread.join()

    print("All done.")


if __name__ == "__main__":
    main()