Newer
Older
Demo-Maker / main.py
import argparse
import csv
import os
import pickle
import re

import cv2
import numpy as np
import pandas as pd
from mmdet.apis import DetInferencer, inference_detector, init_detector

# New imports for RTMPose
from mmpose.apis import inference_topdown
from mmpose.apis import init_model as init_pose_estimator
from mmpose.evaluation.functional import nms
from mmpose.registry import VISUALIZERS
from mmpose.structures import merge_data_samples
from mmpose.utils import adapt_mmdet_pipeline

import config
from util.calc_ste_position import CalcStethoscopePosition
from util.ears_ai import EarsAI

# Get colors from config
CONV_COLOR = config.CONV_COLOR
XGBOOST_COLOR = config.XGBOOST_COLOR
LIGHTGBM_COLOR = config.LIGHTGBM_COLOR

# Get model execution settings
CONV_ENABLED = config.CONV_ENABLED
XGBOOST_ENABLED = config.XGBOOST_ENABLED
LIGHTGBM_ENABLED = config.LIGHTGBM_ENABLED
POSENET_ENABLED = config.POSENET_ENABLED
RTMPOSE_ENABLED = config.RTMPOSE_ENABLED
MobileNetV1SSD_ENABLED = config.MOBILENETV1SSD_ENABLED
YOLOX_ENABLED = config.YOLOX_ENABLED

# Get normalization setting
NORMALIZE_ENABLED = config.NORMALIZE_ENABLED


def init_yolox():
    try:
        # MMDetectionのデフォルトスコープを設定
        from mmengine.registry import DefaultScope

        DefaultScope.get_instance("mmdet", scope_name="mmdet")

        init_args = {
            "model": config.YOLOX_CONFIG_FILE,
            "weights": config.YOLOX_CHECKPOINT_FILE,
            "device": config.DEVICE,
        }

        yolox_inferencer = DetInferencer(**init_args)
        return yolox_inferencer

    except Exception as e:
        print(f"Error initializing YOLOX: {str(e)}")
        return None


def draw_polygon_and_detection(image, polygon_vertices, stethoscope_x, stethoscope_y):
    """
    5角形の領域と検出された聴診器位置を描画する
    """
    # 画像のコピーを作成
    overlay = image.copy()

    # 5角形を描画
    vertices = polygon_vertices.astype(np.int32)
    cv2.polylines(overlay, [vertices], True, (0, 255, 0), 2)  # 緑色の線で5角形を描画

    # 検出された聴診器位置を描画(存在する場合)
    if stethoscope_x is not None and stethoscope_y is not None:
        center = (int(stethoscope_x), int(stethoscope_y))
        cv2.circle(overlay, center, 10, (255, 0, 0), -1)  # 青色の円で検出位置を描画
        cv2.circle(overlay, center, 12, (255, 255, 255), 2)  # 白い縁取り

    return overlay


def expand_points(p1, p2):
    """
    2点間の中点を変えずに距離を2倍に拡張する
    """
    mid_x = (p1[0] + p2[0]) / 2
    mid_y = (p1[1] + p2[1]) / 2

    # 各点から中点へのベクトルを計算
    vec_x = p1[0] - mid_x
    vec_y = p1[1] - mid_y

    # ベクトルを2倍に拡張
    new_p1 = [mid_x + vec_x * 2, mid_y + vec_y * 2]
    new_p2 = [mid_x - vec_x * 2, mid_y - vec_y * 2]

    return np.array(new_p1), np.array(new_p2)


def point_in_polygon(point, vertices):
    """
    点が多角形の内部にあるかどうかを判定する
    """
    x, y = point
    n = len(vertices)
    inside = False

    j = n - 1
    for i in range(n):
        if (vertices[i][1] > y) != (vertices[j][1] > y) and (
            x
            < (vertices[j][0] - vertices[i][0])
            * (y - vertices[i][1])
            / (vertices[j][1] - vertices[i][1])
            + vertices[i][0]
        ):
            inside = not inside
        j = i

    return inside


def yolox_detector_inference(frame, yolox_inferencer, pose_keypoints, score_thr=0.3):
    """
    YOLOXで聴診器を検出し、指定された体のポリゴン内の検出結果のみを返す

    Args:
        frame: 入力画像
        yolox_inferencer: YOLOXの推論モデル
        pose_keypoints: RTMPoseで検出したキーポイント
        score_thr: 検出スコアの閾値

    Returns:
        tuple: (オーバーレイ画像, 検出されたx座標, 検出されたy座標)
    """
    # BGRからRGBに変換
    frame_rgb = cv2.cvtColor(frame, cv2.COLOR_BGR2RGB)

    # 検出を実行
    result = yolox_inferencer(inputs=frame_rgb, return_vis=True)

    # 検出結果を取得
    predictions = result["predictions"][0]
    stethoscope_x = None
    stethoscope_y = None
    max_score = -1

    # キーポイントから5角形の頂点を生成
    nose = pose_keypoints[0]
    left_shoulder = pose_keypoints[5]
    right_shoulder = pose_keypoints[6]
    left_hip = pose_keypoints[11]
    right_hip = pose_keypoints[12]

    # 肩と腰の座標を拡張
    expanded_left_shoulder, expanded_right_shoulder = expand_points(
        left_shoulder, right_shoulder
    )
    expanded_left_hip, expanded_right_hip = expand_points(left_hip, right_hip)

    # 5角形の頂点を定義
    polygon_vertices = np.array(
        [
            nose,
            expanded_left_shoulder,
            expanded_left_hip,
            expanded_right_hip,
            expanded_right_shoulder,
        ]
    )

    # 各検出結果に対して処理
    for i, (label, score) in enumerate(
        zip(predictions["labels"], predictions["scores"])
    ):
        if score >= score_thr and label == 0:  # label 0 は聴診器を示す
            bbox = predictions["bboxes"][i]
            # 中心座標を計算
            center_x = (bbox[0] + bbox[2]) / 2
            center_y = (bbox[1] + bbox[3]) / 2

            # 中心点が5角形の内部にあるか確認
            if point_in_polygon([center_x, center_y], polygon_vertices):
                if score > max_score:
                    stethoscope_x = center_x
                    stethoscope_y = center_y
                    max_score = score

    # 検出結果がない場合は0,0を返す
    if stethoscope_x is None or stethoscope_y is None:
        stethoscope_x = 0
        stethoscope_y = 0

    # 可視化結果の画像を取得してBGRに変換
    stethoscope_overlay_img = result["visualization"][0]
    if (
        len(stethoscope_overlay_img.shape) == 3
        and stethoscope_overlay_img.shape[2] == 3
    ):
        stethoscope_overlay_img = cv2.cvtColor(
            stethoscope_overlay_img, cv2.COLOR_RGB2BGR
        )

    # ポリゴンと検出位置を描画
    stethoscope_overlay_img = draw_polygon_and_detection(
        stethoscope_overlay_img, polygon_vertices, stethoscope_x, stethoscope_y
    )

    return stethoscope_overlay_img, stethoscope_x, stethoscope_y


def load_model(model_path, model_type="lgb"):
    with open(model_path, "rb") as model_file:
        return pickle.load(model_file)


def normalize_quadrilateral_with_point(points, extra_point):
    all_points = np.vstack([points.reshape(-1, 2), extra_point])
    center = np.mean(points.reshape(-1, 2), axis=0)
    centered_points = all_points - center

    shoulder_angle = calculate_rotation_angle(centered_points[0], centered_points[1])
    hip_angle = calculate_rotation_angle(centered_points[2], centered_points[3])
    average_angle = (shoulder_angle + hip_angle) / 2

    rotation_matrix = np.array(
        [
            [np.cos(-average_angle), -np.sin(-average_angle)],
            [np.sin(-average_angle), np.cos(-average_angle)],
        ]
    )

    rotated_points = np.dot(centered_points, rotation_matrix.T)
    max_edge_length = np.max(
        np.linalg.norm(
            np.roll(rotated_points[:4], -1, axis=0) - rotated_points[:4], axis=1
        )
    )
    return rotated_points / max_edge_length


def calculate_rotation_angle(point1, point2):
    vector = point2 - point1
    return np.arctan2(vector[1], vector[0])


def video_to_frames(video_path, output_dir):
    os.makedirs(output_dir, exist_ok=True)
    video = cv2.VideoCapture(video_path)
    if not video.isOpened():
        raise IOError(f"Could not open video file: {video_path}")

    frame_num = 0
    while True:
        success, frame = video.read()
        if not success:
            break
        frame_num += 1
        cv2.imwrite(os.path.join(output_dir, f"{frame_num}-frame.png"), frame)

    video.release()
    print(f"All frames saved to {output_dir}")


def extract_keypoints_rtmpose(pose_results):
    if not pose_results:
        print("No pose results found.")
        return None

    max_avg_visible = 0
    best_instance = None
    for result in pose_results:
        pred_instances = result.pred_instances
        for instance in pred_instances:
            avg_visible = np.mean(instance.keypoints_visible)
            if avg_visible > max_avg_visible:
                max_avg_visible = avg_visible
                best_instance = instance

    if best_instance is None:
        print("No valid instances found.")
        return None

    keypoints = best_instance.keypoints[0]
    return keypoints


def process_images(args, detector, pose_estimator, visualizer):
    print("Starting process_images function...")
    ears_ai = EarsAI()
    calc_position = CalcStethoscopePosition()
    base_dir = os.path.join(args.output_dir, "frames")
    results_dir = os.path.join(args.output_dir, "results")
    csv_path = os.path.join(results_dir, "results.csv")
    normalized_csv_path = os.path.join(results_dir, "results-convert.csv")
    pose_overlay_dir = os.path.join(results_dir, "pose_overlay_image")
    stethoscope_overlay_dir = os.path.join(results_dir, "stethoscope_overlay_image")

    os.makedirs(results_dir, exist_ok=True)
    os.makedirs(pose_overlay_dir, exist_ok=True)
    os.makedirs(stethoscope_overlay_dir, exist_ok=True)

    png_files = sorted(
        [f for f in os.listdir(base_dir) if f.lower().endswith(".png")],
        key=lambda x: int(re.search(r"(\d+)", x).group(1)),
    )
    print(f"Found {len(png_files)} PNG files.")

    rows = []
    normalized_rows = []
    for image_file_name in png_files:
        print(f"Processing image: {image_file_name}")
        image_path = os.path.join(base_dir, image_file_name)
        frame = cv2.imread(image_path)
        if frame is None:
            print(f"Failed to load image: {image_path}")
            continue

        if POSENET_ENABLED:
            pose_overlay_img, *landmarks = ears_ai.pose_detect(frame, None)
            left_shoulder = landmarks[0]
            right_shoulder = landmarks[1]
            left_hip = landmarks[2]
            right_hip = landmarks[3]

        # process_images関数内のRTMPOSE_ENABLEDの部分を修正
        if RTMPOSE_ENABLED:
            det_result = inference_detector(detector, frame)
            pred_instance = det_result.pred_instances.cpu().numpy()
            bboxes = np.concatenate(
                (pred_instance.bboxes, pred_instance.scores[:, None]), axis=1
            )
            bboxes = bboxes[
                np.logical_and(pred_instance.labels == 0, pred_instance.scores > 0.3)
            ]
            bboxes = bboxes[nms(bboxes, 0.3), :4]
            pose_results = inference_topdown(pose_estimator, frame, bboxes)
            data_samples = merge_data_samples(pose_results)
            pose_keypoints = extract_keypoints_rtmpose(pose_results)
            if pose_keypoints is None:
                print(f"Failed to extract keypoints for image: {image_path}")
                continue

            """ left_shoulder = pose_keypoints[5]
            right_shoulder = pose_keypoints[6]
            left_hip = pose_keypoints[11]
            right_hip = pose_keypoints[12] テレコ確認"""
            left_shoulder = pose_keypoints[6]
            right_shoulder = pose_keypoints[5]
            left_hip = pose_keypoints[12]
            right_hip = pose_keypoints[11]

            if visualizer is not None:
                visualizer.add_datasample(
                    "result",
                    frame,
                    data_sample=data_samples,
                    draw_gt=False,
                    draw_heatmap=False,
                    draw_bbox=False,
                    show_kpt_idx=False,
                    skeleton_style="mmpose",
                    show=False,
                    wait_time=0,
                    kpt_thr=0.3,
                )
            pose_overlay_img = visualizer.get_image()

        # YOLOXの部分を修正
        yolox_inferencer = init_yolox()

        if MobileNetV1SSD_ENABLED:
            stethoscope_overlay_img, stethoscope_x, stethoscope_y = ears_ai.ssd_detect(
                frame, None
            )

        if YOLOX_ENABLED and pose_keypoints is not None:
            stethoscope_overlay_img, stethoscope_x, stethoscope_y = (
                yolox_detector_inference(
                    frame,
                    yolox_inferencer,
                    pose_keypoints,
                )
            )

        cv2.imwrite(
            os.path.join(pose_overlay_dir, image_file_name),
            cv2.cvtColor(pose_overlay_img, cv2.COLOR_RGB2BGR),
        )
        cv2.imwrite(
            os.path.join(stethoscope_overlay_dir, image_file_name),
            cv2.cvtColor(stethoscope_overlay_img, cv2.COLOR_RGB2BGR),
        )

        if POSENET_ENABLED:
            row = {
                "image_file_name": image_file_name,
                "left_shoulder_x": left_shoulder[1],
                "left_shoulder_y": left_shoulder[0],
                "right_shoulder_x": right_shoulder[1],
                "right_shoulder_y": right_shoulder[0],
                "left_hip_x": left_hip[1],
                "left_hip_y": left_hip[0],
                "right_hip_x": right_hip[1],
                "right_hip_y": right_hip[0],
                "stethoscope_x": stethoscope_x,
                "stethoscope_y": stethoscope_y,
            }
        elif RTMPOSE_ENABLED:
            row = {
                "image_file_name": image_file_name,
                "left_shoulder_x": left_shoulder[0],
                "left_shoulder_y": left_shoulder[1],
                "right_shoulder_x": right_shoulder[0],
                "right_shoulder_y": right_shoulder[1],
                "left_hip_x": left_hip[0],
                "left_hip_y": left_hip[1],
                "right_hip_x": right_hip[0],
                "right_hip_y": right_hip[1],
                "stethoscope_x": stethoscope_x,
                "stethoscope_y": stethoscope_y,
            }
        else:
            print(
                "No pose estimation method enabled. Please enable either PoseNet or RTMPose."
            )
            continue

        rows.append(row)

        source_points = np.array(
            [
                [float(row[f"{pos}_x"]), float(row[f"{pos}_y"])]
                for pos in ["left_shoulder", "right_shoulder", "left_hip", "right_hip"]
            ],
            dtype=np.float32,
        )
        stethoscope_point = np.array(
            [float(row["stethoscope_x"]), float(row["stethoscope_y"])]
        )
        normalized_points = normalize_quadrilateral_with_point(
            source_points.flatten(), stethoscope_point
        )

        normalized_row = {
            "image_file_name": image_file_name,
            "left_shoulder_x": normalized_points[0, 0],
            "left_shoulder_y": normalized_points[0, 1],
            "right_shoulder_x": normalized_points[1, 0],
            "right_shoulder_y": normalized_points[1, 1],
            "left_hip_x": normalized_points[2, 0],
            "left_hip_y": normalized_points[2, 1],
            "right_hip_x": normalized_points[3, 0],
            "right_hip_y": normalized_points[3, 1],
            "stethoscope_x": normalized_points[4, 0],
            "stethoscope_y": normalized_points[4, 1],
        }
        normalized_rows.append(normalized_row)

    if rows:
        print(f"Writing {len(rows)} rows to CSV...")
        fieldnames = list(rows[0].keys())
        if CONV_ENABLED:
            fieldnames.extend(["conv_stethoscope_x", "conv_stethoscope_y"])
        if XGBOOST_ENABLED:
            fieldnames.extend(["Xgboost_stethoscope_x", "Xgboost_stethoscope_y"])
        if LIGHTGBM_ENABLED:
            fieldnames.extend(["lightGBM_stethoscope_x", "lightGBM_stethoscope_y"])

        if LIGHTGBM_ENABLED:
            lgb_model_x = load_model(
                "./models/lgb_stethoscope_calc_x_best_model-Fold4.pkl"
            )
            lgb_model_y = load_model(
                "./models/lgb_stethoscope_calc_y_best_model-Fold4.pkl"
            )
        if XGBOOST_ENABLED:
            xg_model_x = load_model(
                "./models/xg_stethoscope_calc_x_best_model-Fold4.pkl"
            )
            xg_model_y = load_model(
                "./models/xg_stethoscope_calc_y_best_model-Fold4.pkl"
            )

        with open(csv_path, "w", newline="") as csvfile, open(
            normalized_csv_path, "w", newline=""
        ) as norm_csvfile:
            writer = csv.DictWriter(csvfile, fieldnames=fieldnames)
            norm_writer = csv.DictWriter(norm_csvfile, fieldnames=fieldnames)
            writer.writeheader()
            norm_writer.writeheader()

            prev_values = {}
            if CONV_ENABLED:
                prev_values["conv"] = (180, 180)
            if LIGHTGBM_ENABLED:
                prev_values["lightGBM"] = (180, 180)
            if XGBOOST_ENABLED:
                prev_values["Xgboost"] = (180, 180)

            for i, (row, norm_row) in enumerate(zip(rows, normalized_rows)):
                source_points = np.array(
                    [
                        [float(row[f"{pos}_x"]), float(row[f"{pos}_y"])]
                        for pos in [
                            "left_shoulder",
                            "right_shoulder",
                            "left_hip",
                            "right_hip",
                        ]
                    ],
                    dtype=np.float32,
                )
                stethoscope_point = np.array(
                    [float(row["stethoscope_x"]), float(row["stethoscope_y"])]
                )

                if stethoscope_point[0] == 0 and stethoscope_point[1] == 0:
                    for key in prev_values:
                        row[f"{key}_stethoscope_x"], row[f"{key}_stethoscope_y"] = (
                            prev_values[key]
                        )
                        (
                            norm_row[f"{key}_stethoscope_x"],
                            norm_row[f"{key}_stethoscope_y"],
                        ) = prev_values[key]
                else:
                    if CONV_ENABLED:
                        conv_stethoscope = calc_position.calc_affine(
                            source_points, *stethoscope_point
                        )
                        row["conv_stethoscope_x"], row["conv_stethoscope_y"] = (
                            conv_stethoscope
                        )
                        (
                            norm_row["conv_stethoscope_x"],
                            norm_row["conv_stethoscope_y"],
                        ) = conv_stethoscope

                    if NORMALIZE_ENABLED:
                        input_data = pd.DataFrame([norm_row])
                    else:
                        input_data = pd.DataFrame([row])

                    input_columns = [
                        f"{pos}_{coord}"
                        for pos in [
                            "left_shoulder",
                            "right_shoulder",
                            "left_hip",
                            "right_hip",
                            "stethoscope",
                        ]
                        for coord in ["x", "y"]
                    ]

                    if LIGHTGBM_ENABLED:
                        lgb_x = int(lgb_model_x.predict(input_data[input_columns])[0])
                        lgb_y = int(lgb_model_y.predict(input_data[input_columns])[0])
                        row["lightGBM_stethoscope_x"], row["lightGBM_stethoscope_y"] = (
                            lgb_x,
                            lgb_y,
                        )
                        (
                            norm_row["lightGBM_stethoscope_x"],
                            norm_row["lightGBM_stethoscope_y"],
                        ) = lgb_x, lgb_y

                    if XGBOOST_ENABLED:
                        xg_x = int(xg_model_x.predict(input_data[input_columns])[0])
                        xg_y = int(xg_model_y.predict(input_data[input_columns])[0])
                        row["Xgboost_stethoscope_x"], row["Xgboost_stethoscope_y"] = (
                            xg_x,
                            xg_y,
                        )
                        (
                            norm_row["Xgboost_stethoscope_x"],
                            norm_row["Xgboost_stethoscope_y"],
                        ) = xg_x, xg_y

                    for key in prev_values:
                        prev_values[key] = (
                            row[f"{key}_stethoscope_x"],
                            row[f"{key}_stethoscope_y"],
                        )

                writer.writerow(row)
                norm_writer.writerow(norm_row)

        print(f"Processed and saved results to: {csv_path}")
        print(f"Processed and saved normalized results to: {normalized_csv_path}")

        generate_visualizations(csv_path, base_dir, results_dir)
    else:
        print("No data to write to CSV.")


def generate_visualizations(csv_path, original_images_dir, results_dir):
    df = pd.read_csv(csv_path)
    body_image = cv2.imread("./images/body/BodyF.png")

    dirs = {"marked": "marked_images"}
    if CONV_ENABLED:
        dirs["conv"] = "conv"
    if XGBOOST_ENABLED:
        dirs["Xgboost"] = "Xgboost"
    if LIGHTGBM_ENABLED:
        dirs["lightGBM"] = "lightGBM"

    os.makedirs(os.path.join(results_dir, "marked_images"), exist_ok=True)
    for key in dirs:
        if key != "marked":
            os.makedirs(
                os.path.join(results_dir, f"{dirs[key]}_with_trajectory"), exist_ok=True
            )
            os.makedirs(
                os.path.join(results_dir, f"{dirs[key]}_without_trajectory"),
                exist_ok=True,
            )

    points = {key: [] for key in dirs.keys() if key != "marked"}
    colors = {"conv": CONV_COLOR, "Xgboost": XGBOOST_COLOR, "lightGBM": LIGHTGBM_COLOR}

    for _, row in df.iterrows():
        original_image = cv2.imread(
            os.path.join(original_images_dir, row["image_file_name"])
        )
        if original_image is None:
            print(
                f"Failed to load image: {os.path.join(original_images_dir, row['image_file_name'])}"
            )
            continue

        for point in [
            "left_shoulder",
            "right_shoulder",
            "left_hip",
            "right_hip",
            "stethoscope",
        ]:
            cv2.circle(
                original_image,
                (int(row[f"{point}_x"]), int(row[f"{point}_y"])),
                10,
                (255, 255, 0),
                -1,
            )

        cv2.imwrite(
            os.path.join(results_dir, "marked_images", row["image_file_name"]),
            original_image,
        )

        for key in points:
            x, y = int(row[f"{key}_stethoscope_x"]), int(row[f"{key}_stethoscope_y"])
            points[key].append((x, y))

            image_with_trajectory = body_image.copy()
            if len(points[key]) > 1:
                cv2.polylines(
                    image_with_trajectory,
                    [np.array(points[key])],
                    False,
                    colors[key],
                    2,
                )
            cv2.circle(image_with_trajectory, (x, y), 10, colors[key], -1)
            cv2.imwrite(
                os.path.join(
                    results_dir, f"{dirs[key]}_with_trajectory", row["image_file_name"]
                ),
                image_with_trajectory,
            )

            image_without_trajectory = body_image.copy()
            cv2.circle(image_without_trajectory, (x, y), 10, colors[key], -1)
            cv2.imwrite(
                os.path.join(
                    results_dir,
                    f"{dirs[key]}_without_trajectory",
                    row["image_file_name"],
                ),
                image_without_trajectory,
            )

    create_video_from_images(
        os.path.join(results_dir, "marked_images"),
        os.path.join(results_dir, "marked_video.mp4"),
    )

    for key in dirs:
        if key != "marked":
            create_video_from_images(
                os.path.join(results_dir, f"{dirs[key]}_with_trajectory"),
                os.path.join(results_dir, f"{key}_video_with_trajectory.mp4"),
            )
            create_video_from_images(
                os.path.join(results_dir, f"{dirs[key]}_without_trajectory"),
                os.path.join(results_dir, f"{key}_video_without_trajectory.mp4"),
            )


def create_video_from_images(image_dir, output_path):
    images = sorted(
        [img for img in os.listdir(image_dir) if img.endswith(".png")],
        key=lambda x: int(re.search(r"(\d+)", x).group()),
    )

    if not images:
        print(f"No images found in {image_dir}")
        return

    frame = cv2.imread(os.path.join(image_dir, images[0]))
    height, width, _ = frame.shape

    video = cv2.VideoWriter(
        output_path, cv2.VideoWriter_fourcc(*"mp4v"), 30, (width, height)
    )

    for image in images:
        img = cv2.imread(os.path.join(image_dir, image))
        video.write(img)

    video.release()
    print(f"Created video: {output_path}")


def main():
    parser = argparse.ArgumentParser(description="Process video and generate results.")
    parser = argparse.ArgumentParser(description="Process video and generate results.")
    parser.add_argument(
        "--video_path",
        default="./video/Test3-1.mp4",
        help="Path to the input video file",
    )
    parser.add_argument(
        "--output_dir",
        default="output",
        help="Directory to save output images and results",
    )
    det_config = "modules/rtmpose/mmdetection_cfg/rtmdet_m_640-8xb32_coco-person.py"
    det_checkpoint = (
        "models/rtmpose/rtmdet_m_8xb32-100e_coco-obj365-person-235e8209.pth"
    )
    pose_config = "modules/rtmpose/configs/body_2d_keypoint/rtmpose/body8/rtmpose-l_8xb256-420e_body8-256x192.py"
    pose_checkpoint = "models/rtmpose/rtmpose-l_simcc-aic-coco_pt-aic-coco_420e-256x192-f016ffe0_20230126.pth"

    args = parser.parse_args()

    os.makedirs(args.output_dir, exist_ok=True)

    frames_dir = os.path.join(args.output_dir, "frames")
    video_to_frames(args.video_path, frames_dir)

    if RTMPOSE_ENABLED:
        detector = init_detector(det_config, det_checkpoint, device="cuda:0")
        detector.cfg = adapt_mmdet_pipeline(detector.cfg)
        pose_estimator = init_pose_estimator(
            pose_config, pose_checkpoint, device="cuda:0"
        )
        visualizer = VISUALIZERS.build(pose_estimator.cfg.visualizer)
        visualizer.set_dataset_meta(
            pose_estimator.dataset_meta, skeleton_style="mmpose"
        )

        process_images(args, detector, pose_estimator, visualizer)
    else:
        process_images(args, None, None, None)


if __name__ == "__main__":
    main()