Newer
Older
Demo-Maker / main.py
import argparse
import csv
import os
import pickle
import re

import cv2
import numpy as np
import pandas as pd
from dotenv import load_dotenv
from mmdet.apis import DetInferencer, inference_detector, init_detector

# New imports for RTMPose
from mmpose.apis import inference_topdown
from mmpose.apis import init_model as init_pose_estimator
from mmpose.evaluation.functional import nms
from mmpose.registry import VISUALIZERS
from mmpose.structures import merge_data_samples
from mmpose.utils import adapt_mmdet_pipeline

import config
from util.calc_ste_position import CalcStethoscopePosition
from util.ears_ai import EarsAI

# Load environment variables
load_dotenv()

# Get colors from environment variables
CONV_COLOR = tuple(
    map(int, os.getenv("CONV_COLOR", "0,255,0").split(","))
)  # Default: Green
XGBOOST_COLOR = tuple(
    map(int, os.getenv("XGBOOST_COLOR", "255,0,0").split(","))
)  # Default: Red
LIGHTGBM_COLOR = tuple(
    map(int, os.getenv("LIGHTGBM_COLOR", "0,0,255").split(","))
)  # Default: Blue

# Get model execution settings
CONV_ENABLED = os.getenv("CONV_ENABLED", "True").lower() == "true"
XGBOOST_ENABLED = os.getenv("XGBOOST_ENABLED", "True").lower() == "true"
LIGHTGBM_ENABLED = os.getenv("LIGHTGBM_ENABLED", "True").lower() == "true"
POSENET_ENABLED = os.getenv("PoseNet_ENABLED", "True").lower() == "true"
RTMPOSE_ENABLED = os.getenv("RTMPose_ENABLED", "False").lower() == "true"
MobileNetV1SSD_ENABLED = os.getenv("MobileNetV1SSD_ENABLED", "False").lower() == "true"
YOLOX_ENABLED = os.getenv("YOLOX_ENABLED", "False").lower() == "true"

# Get normalization setting
NORMALIZE_ENABLED = os.getenv("NORMALIZE_ENABLED", "False").lower() == "true"


def init_yolox():
    try:
        # MMDetectionのデフォルトスコープを設定
        from mmengine.registry import DefaultScope

        DefaultScope.get_instance("mmdet", scope_name="mmdet")

        init_args = {
            "model": config.YOLOX_CONFIG_FILE,
            "weights": config.YOLOX_CHECKPOINT_FILE,
            "device": config.DEVICE,
        }

        yolox_inferencer = DetInferencer(**init_args)
        return yolox_inferencer

    except Exception as e:
        print(f"Error initializing YOLOX: {str(e)}")
        return None


def yolox_detector_inference(frame, yolox_inferencer, score_thr=0.3):
    # yolox_inferencerがNoneの場合、デフォルト値を返す
    if yolox_inferencer is None:
        return frame, 0, 0

    # 以下は既存の処理
    frame_rgb = cv2.cvtColor(frame, cv2.COLOR_BGR2RGB)
    result = yolox_inferencer(inputs=frame_rgb, return_vis=True)
    predictions = result["predictions"][0]
    stethoscope_x = None
    stethoscope_y = None
    max_score = -1

    for i, (label, score) in enumerate(
        zip(predictions["labels"], predictions["scores"])
    ):
        if score >= score_thr and label == 0 and score > max_score:
            bbox = predictions["bboxes"][i]
            stethoscope_x = (bbox[0] + bbox[2]) / 2
            stethoscope_y = (bbox[1] + bbox[3]) / 2
            max_score = score

    stethoscope_overlay_img = result["visualization"][0]
    if (
        len(stethoscope_overlay_img.shape) == 3
        and stethoscope_overlay_img.shape[2] == 3
    ):
        stethoscope_overlay_img = cv2.cvtColor(
            stethoscope_overlay_img, cv2.COLOR_RGB2BGR
        )

    # 検出結果がない場合は0,0を返す
    if stethoscope_x is None or stethoscope_y is None:
        stethoscope_x = 0
        stethoscope_y = 0

    return stethoscope_overlay_img, stethoscope_x, stethoscope_y


def load_model(model_path, model_type="lgb"):
    with open(model_path, "rb") as model_file:
        return pickle.load(model_file)


def normalize_quadrilateral_with_point(points, extra_point):
    all_points = np.vstack([points.reshape(-1, 2), extra_point])
    center = np.mean(points.reshape(-1, 2), axis=0)
    centered_points = all_points - center

    shoulder_angle = calculate_rotation_angle(centered_points[0], centered_points[1])
    hip_angle = calculate_rotation_angle(centered_points[2], centered_points[3])
    average_angle = (shoulder_angle + hip_angle) / 2

    rotation_matrix = np.array(
        [
            [np.cos(-average_angle), -np.sin(-average_angle)],
            [np.sin(-average_angle), np.cos(-average_angle)],
        ]
    )

    rotated_points = np.dot(centered_points, rotation_matrix.T)
    max_edge_length = np.max(
        np.linalg.norm(
            np.roll(rotated_points[:4], -1, axis=0) - rotated_points[:4], axis=1
        )
    )
    return rotated_points / max_edge_length


def calculate_rotation_angle(point1, point2):
    vector = point2 - point1
    return np.arctan2(vector[1], vector[0])


def video_to_frames(video_path, output_dir):
    os.makedirs(output_dir, exist_ok=True)
    video = cv2.VideoCapture(video_path)
    if not video.isOpened():
        raise IOError(f"Could not open video file: {video_path}")

    frame_num = 0
    while True:
        success, frame = video.read()
        if not success:
            break
        frame_num += 1
        cv2.imwrite(os.path.join(output_dir, f"{frame_num}-frame.png"), frame)

    video.release()
    print(f"All frames saved to {output_dir}")


def extract_keypoints_rtmpose(pose_results):
    if not pose_results:
        print("No pose results found.")
        return None

    max_avg_visible = 0
    best_instance = None
    for result in pose_results:
        pred_instances = result.pred_instances
        for instance in pred_instances:
            avg_visible = np.mean(instance.keypoints_visible)
            if avg_visible > max_avg_visible:
                max_avg_visible = avg_visible
                best_instance = instance

    if best_instance is None:
        print("No valid instances found.")
        return None

    keypoints = best_instance.keypoints[0]
    return keypoints


def process_images(args, detector, pose_estimator, visualizer):
    print("Starting process_images function...")
    ears_ai = EarsAI()
    calc_position = CalcStethoscopePosition()
    base_dir = os.path.join(args.output_dir, "frames")
    results_dir = os.path.join(args.output_dir, "results")
    csv_path = os.path.join(results_dir, "results.csv")
    normalized_csv_path = os.path.join(results_dir, "results-convert.csv")
    pose_overlay_dir = os.path.join(results_dir, "pose_overlay_image")
    stethoscope_overlay_dir = os.path.join(results_dir, "stethoscope_overlay_image")

    os.makedirs(results_dir, exist_ok=True)
    os.makedirs(pose_overlay_dir, exist_ok=True)
    os.makedirs(stethoscope_overlay_dir, exist_ok=True)

    png_files = sorted(
        [f for f in os.listdir(base_dir) if f.lower().endswith(".png")],
        key=lambda x: int(re.search(r"(\d+)", x).group(1)),
    )
    print(f"Found {len(png_files)} PNG files.")

    rows = []
    normalized_rows = []
    for image_file_name in png_files:
        print(f"Processing image: {image_file_name}")
        image_path = os.path.join(base_dir, image_file_name)
        frame = cv2.imread(image_path)
        if frame is None:
            print(f"Failed to load image: {image_path}")
            continue

        if POSENET_ENABLED:
            pose_overlay_img, *landmarks = ears_ai.pose_detect(frame, None)
            left_shoulder = landmarks[0]
            right_shoulder = landmarks[1]
            left_hip = landmarks[2]
            right_hip = landmarks[3]

        elif RTMPOSE_ENABLED:
            det_result = inference_detector(detector, frame)
            pred_instance = det_result.pred_instances.cpu().numpy()
            bboxes = np.concatenate(
                (pred_instance.bboxes, pred_instance.scores[:, None]), axis=1
            )
            bboxes = bboxes[
                np.logical_and(pred_instance.labels == 0, pred_instance.scores > 0.3)
            ]
            bboxes = bboxes[nms(bboxes, 0.3), :4]
            pose_results = inference_topdown(pose_estimator, frame, bboxes)
            data_samples = merge_data_samples(pose_results)
            keypoints = extract_keypoints_rtmpose(pose_results)
            if keypoints is None:
                print(f"Failed to extract keypoints for image: {image_path}")
                continue
            """ left_shoulder = keypoints[5]
            right_shoulder = keypoints[6]
            left_hip = keypoints[11]
            right_hip = keypoints[12] テレコ確認"""
            left_shoulder = keypoints[6]
            right_shoulder = keypoints[5]
            left_hip = keypoints[12]
            right_hip = keypoints[11]
            if visualizer is not None:
                visualizer.add_datasample(
                    "result",
                    frame,
                    data_sample=data_samples,
                    draw_gt=False,
                    draw_heatmap=False,
                    draw_bbox=False,
                    show_kpt_idx=False,
                    skeleton_style="mmpose",
                    show=False,
                    wait_time=0,
                    kpt_thr=0.3,
                )
            pose_overlay_img = visualizer.get_image()
        else:
            print(
                "No pose estimation method enabled. Please enable either PoseNet or RTMPose."
            )
            continue

        if MobileNetV1SSD_ENABLED:
            stethoscope_overlay_img, stethoscope_x, stethoscope_y = ears_ai.ssd_detect(
                frame, None
            )

        yolox_inferencer = init_yolox()

        if YOLOX_ENABLED:
            stethoscope_overlay_img, stethoscope_x, stethoscope_y = (
                yolox_detector_inference(frame, yolox_inferencer)
            )

        cv2.imwrite(
            os.path.join(pose_overlay_dir, image_file_name),
            cv2.cvtColor(pose_overlay_img, cv2.COLOR_RGB2BGR),
        )
        cv2.imwrite(
            os.path.join(stethoscope_overlay_dir, image_file_name),
            cv2.cvtColor(stethoscope_overlay_img, cv2.COLOR_RGB2BGR),
        )

        if POSENET_ENABLED:
            row = {
                "image_file_name": image_file_name,
                "left_shoulder_x": left_shoulder[1],
                "left_shoulder_y": left_shoulder[0],
                "right_shoulder_x": right_shoulder[1],
                "right_shoulder_y": right_shoulder[0],
                "left_hip_x": left_hip[1],
                "left_hip_y": left_hip[0],
                "right_hip_x": right_hip[1],
                "right_hip_y": right_hip[0],
                "stethoscope_x": stethoscope_x,
                "stethoscope_y": stethoscope_y,
            }
        elif RTMPOSE_ENABLED:
            row = {
                "image_file_name": image_file_name,
                "left_shoulder_x": left_shoulder[0],
                "left_shoulder_y": left_shoulder[1],
                "right_shoulder_x": right_shoulder[0],
                "right_shoulder_y": right_shoulder[1],
                "left_hip_x": left_hip[0],
                "left_hip_y": left_hip[1],
                "right_hip_x": right_hip[0],
                "right_hip_y": right_hip[1],
                "stethoscope_x": stethoscope_x,
                "stethoscope_y": stethoscope_y,
            }
        else:
            print(
                "No pose estimation method enabled. Please enable either PoseNet or RTMPose."
            )
            continue

        rows.append(row)

        source_points = np.array(
            [
                [float(row[f"{pos}_x"]), float(row[f"{pos}_y"])]
                for pos in ["left_shoulder", "right_shoulder", "left_hip", "right_hip"]
            ],
            dtype=np.float32,
        )
        stethoscope_point = np.array(
            [float(row["stethoscope_x"]), float(row["stethoscope_y"])]
        )
        normalized_points = normalize_quadrilateral_with_point(
            source_points.flatten(), stethoscope_point
        )

        normalized_row = {
            "image_file_name": image_file_name,
            "left_shoulder_x": normalized_points[0, 0],
            "left_shoulder_y": normalized_points[0, 1],
            "right_shoulder_x": normalized_points[1, 0],
            "right_shoulder_y": normalized_points[1, 1],
            "left_hip_x": normalized_points[2, 0],
            "left_hip_y": normalized_points[2, 1],
            "right_hip_x": normalized_points[3, 0],
            "right_hip_y": normalized_points[3, 1],
            "stethoscope_x": normalized_points[4, 0],
            "stethoscope_y": normalized_points[4, 1],
        }
        normalized_rows.append(normalized_row)

    if rows:
        print(f"Writing {len(rows)} rows to CSV...")
        fieldnames = list(rows[0].keys())
        if CONV_ENABLED:
            fieldnames.extend(["conv_stethoscope_x", "conv_stethoscope_y"])
        if XGBOOST_ENABLED:
            fieldnames.extend(["Xgboost_stethoscope_x", "Xgboost_stethoscope_y"])
        if LIGHTGBM_ENABLED:
            fieldnames.extend(["lightGBM_stethoscope_x", "lightGBM_stethoscope_y"])

        if LIGHTGBM_ENABLED:
            lgb_model_x = load_model(
                "./models/lgb_stethoscope_calc_x_best_model-Fold4.pkl"
            )
            lgb_model_y = load_model(
                "./models/lgb_stethoscope_calc_y_best_model-Fold4.pkl"
            )
        if XGBOOST_ENABLED:
            xg_model_x = load_model(
                "./models/xg_stethoscope_calc_x_best_model-Fold4.pkl"
            )
            xg_model_y = load_model(
                "./models/xg_stethoscope_calc_y_best_model-Fold4.pkl"
            )

        with open(csv_path, "w", newline="") as csvfile, open(
            normalized_csv_path, "w", newline=""
        ) as norm_csvfile:
            writer = csv.DictWriter(csvfile, fieldnames=fieldnames)
            norm_writer = csv.DictWriter(norm_csvfile, fieldnames=fieldnames)
            writer.writeheader()
            norm_writer.writeheader()

            prev_values = {}
            if CONV_ENABLED:
                prev_values["conv"] = (180, 180)
            if LIGHTGBM_ENABLED:
                prev_values["lightGBM"] = (180, 180)
            if XGBOOST_ENABLED:
                prev_values["Xgboost"] = (180, 180)

            for i, (row, norm_row) in enumerate(zip(rows, normalized_rows)):
                source_points = np.array(
                    [
                        [float(row[f"{pos}_x"]), float(row[f"{pos}_y"])]
                        for pos in [
                            "left_shoulder",
                            "right_shoulder",
                            "left_hip",
                            "right_hip",
                        ]
                    ],
                    dtype=np.float32,
                )
                stethoscope_point = np.array(
                    [float(row["stethoscope_x"]), float(row["stethoscope_y"])]
                )

                if stethoscope_point[0] == 0 and stethoscope_point[1] == 0:
                    for key in prev_values:
                        row[f"{key}_stethoscope_x"], row[f"{key}_stethoscope_y"] = (
                            prev_values[key]
                        )
                        (
                            norm_row[f"{key}_stethoscope_x"],
                            norm_row[f"{key}_stethoscope_y"],
                        ) = prev_values[key]
                else:
                    if CONV_ENABLED:
                        conv_stethoscope = calc_position.calc_affine(
                            source_points, *stethoscope_point
                        )
                        row["conv_stethoscope_x"], row["conv_stethoscope_y"] = (
                            conv_stethoscope
                        )
                        (
                            norm_row["conv_stethoscope_x"],
                            norm_row["conv_stethoscope_y"],
                        ) = conv_stethoscope

                    if NORMALIZE_ENABLED:
                        input_data = pd.DataFrame([norm_row])
                    else:
                        input_data = pd.DataFrame([row])

                    input_columns = [
                        f"{pos}_{coord}"
                        for pos in [
                            "left_shoulder",
                            "right_shoulder",
                            "left_hip",
                            "right_hip",
                            "stethoscope",
                        ]
                        for coord in ["x", "y"]
                    ]

                    if LIGHTGBM_ENABLED:
                        lgb_x = int(lgb_model_x.predict(input_data[input_columns])[0])
                        lgb_y = int(lgb_model_y.predict(input_data[input_columns])[0])
                        row["lightGBM_stethoscope_x"], row["lightGBM_stethoscope_y"] = (
                            lgb_x,
                            lgb_y,
                        )
                        (
                            norm_row["lightGBM_stethoscope_x"],
                            norm_row["lightGBM_stethoscope_y"],
                        ) = lgb_x, lgb_y

                    if XGBOOST_ENABLED:
                        xg_x = int(xg_model_x.predict(input_data[input_columns])[0])
                        xg_y = int(xg_model_y.predict(input_data[input_columns])[0])
                        row["Xgboost_stethoscope_x"], row["Xgboost_stethoscope_y"] = (
                            xg_x,
                            xg_y,
                        )
                        (
                            norm_row["Xgboost_stethoscope_x"],
                            norm_row["Xgboost_stethoscope_y"],
                        ) = xg_x, xg_y

                    for key in prev_values:
                        prev_values[key] = (
                            row[f"{key}_stethoscope_x"],
                            row[f"{key}_stethoscope_y"],
                        )

                writer.writerow(row)
                norm_writer.writerow(norm_row)

        print(f"Processed and saved results to: {csv_path}")
        print(f"Processed and saved normalized results to: {normalized_csv_path}")

        generate_visualizations(csv_path, base_dir, results_dir)
    else:
        print("No data to write to CSV.")


def generate_visualizations(csv_path, original_images_dir, results_dir):
    df = pd.read_csv(csv_path)
    body_image = cv2.imread("./images/body/BodyF.png")

    dirs = {"marked": "marked_images"}
    if CONV_ENABLED:
        dirs["conv"] = "conv"
    if XGBOOST_ENABLED:
        dirs["Xgboost"] = "Xgboost"
    if LIGHTGBM_ENABLED:
        dirs["lightGBM"] = "lightGBM"

    os.makedirs(os.path.join(results_dir, "marked_images"), exist_ok=True)
    for key in dirs:
        if key != "marked":
            os.makedirs(
                os.path.join(results_dir, f"{dirs[key]}_with_trajectory"), exist_ok=True
            )
            os.makedirs(
                os.path.join(results_dir, f"{dirs[key]}_without_trajectory"),
                exist_ok=True,
            )

    points = {key: [] for key in dirs.keys() if key != "marked"}
    colors = {"conv": CONV_COLOR, "Xgboost": XGBOOST_COLOR, "lightGBM": LIGHTGBM_COLOR}

    for _, row in df.iterrows():
        original_image = cv2.imread(
            os.path.join(original_images_dir, row["image_file_name"])
        )
        if original_image is None:
            print(
                f"Failed to load image: {os.path.join(original_images_dir, row['image_file_name'])}"
            )
            continue

        for point in [
            "left_shoulder",
            "right_shoulder",
            "left_hip",
            "right_hip",
            "stethoscope",
        ]:
            cv2.circle(
                original_image,
                (int(row[f"{point}_x"]), int(row[f"{point}_y"])),
                10,
                (255, 255, 0),
                -1,
            )

        cv2.imwrite(
            os.path.join(results_dir, "marked_images", row["image_file_name"]),
            original_image,
        )

        for key in points:
            x, y = int(row[f"{key}_stethoscope_x"]), int(row[f"{key}_stethoscope_y"])
            points[key].append((x, y))

            image_with_trajectory = body_image.copy()
            if len(points[key]) > 1:
                cv2.polylines(
                    image_with_trajectory,
                    [np.array(points[key])],
                    False,
                    colors[key],
                    2,
                )
            cv2.circle(image_with_trajectory, (x, y), 10, colors[key], -1)
            cv2.imwrite(
                os.path.join(
                    results_dir, f"{dirs[key]}_with_trajectory", row["image_file_name"]
                ),
                image_with_trajectory,
            )

            image_without_trajectory = body_image.copy()
            cv2.circle(image_without_trajectory, (x, y), 10, colors[key], -1)
            cv2.imwrite(
                os.path.join(
                    results_dir,
                    f"{dirs[key]}_without_trajectory",
                    row["image_file_name"],
                ),
                image_without_trajectory,
            )

    create_video_from_images(
        os.path.join(results_dir, "marked_images"),
        os.path.join(results_dir, "marked_video.mp4"),
    )

    for key in dirs:
        if key != "marked":
            create_video_from_images(
                os.path.join(results_dir, f"{dirs[key]}_with_trajectory"),
                os.path.join(results_dir, f"{key}_video_with_trajectory.mp4"),
            )
            create_video_from_images(
                os.path.join(results_dir, f"{dirs[key]}_without_trajectory"),
                os.path.join(results_dir, f"{key}_video_without_trajectory.mp4"),
            )


def create_video_from_images(image_dir, output_path):
    images = sorted(
        [img for img in os.listdir(image_dir) if img.endswith(".png")],
        key=lambda x: int(re.search(r"(\d+)", x).group()),
    )

    if not images:
        print(f"No images found in {image_dir}")
        return

    frame = cv2.imread(os.path.join(image_dir, images[0]))
    height, width, _ = frame.shape

    video = cv2.VideoWriter(
        output_path, cv2.VideoWriter_fourcc(*"mp4v"), 30, (width, height)
    )

    for image in images:
        img = cv2.imread(os.path.join(image_dir, image))
        video.write(img)

    video.release()
    print(f"Created video: {output_path}")


def main():
    parser = argparse.ArgumentParser(description="Process video and generate results.")
    parser = argparse.ArgumentParser(description="Process video and generate results.")
    parser.add_argument(
        "--video_path",
        default="./video/Test3-1.mp4",
        help="Path to the input video file",
    )
    parser.add_argument(
        "--output_dir",
        default="output",
        help="Directory to save output images and results",
    )
    det_config = "modules/rtmpose/mmdetection_cfg/rtmdet_m_640-8xb32_coco-person.py"
    det_checkpoint = (
        "models/rtmpose/rtmdet_m_8xb32-100e_coco-obj365-person-235e8209.pth"
    )
    pose_config = "modules/rtmpose/configs/body_2d_keypoint/rtmpose/body8/rtmpose-l_8xb256-420e_body8-256x192.py"
    pose_checkpoint = "models/rtmpose/rtmpose-l_simcc-aic-coco_pt-aic-coco_420e-256x192-f016ffe0_20230126.pth"

    args = parser.parse_args()

    os.makedirs(args.output_dir, exist_ok=True)

    frames_dir = os.path.join(args.output_dir, "frames")
    video_to_frames(args.video_path, frames_dir)

    if RTMPOSE_ENABLED:
        detector = init_detector(det_config, det_checkpoint, device="cuda:0")
        detector.cfg = adapt_mmdet_pipeline(detector.cfg)
        pose_estimator = init_pose_estimator(
            pose_config, pose_checkpoint, device="cuda:0"
        )
        visualizer = VISUALIZERS.build(pose_estimator.cfg.visualizer)
        visualizer.set_dataset_meta(
            pose_estimator.dataset_meta, skeleton_style="mmpose"
        )

        process_images(args, detector, pose_estimator, visualizer)
    else:
        process_images(args, None, None, None)


if __name__ == "__main__":
    main()