Newer
Older
Demo-Maker / main.py
import cv2
import os
import csv
import re
import numpy as np
import pandas as pd
import pickle
import matplotlib.pyplot as plt
import torch
from PIL import Image
import argparse
from torchvision import transforms
from util.ears_ai import EarsAI
from util.calc_ste_position import CalcStethoscopePosition
from modules.EARSForDL.model import RegressionResNet
from dotenv import load_dotenv

# New imports for RTMPose
from mmpose.apis import inference_topdown
from mmpose.apis import init_model as init_pose_estimator
from mmpose.evaluation.functional import nms
from mmpose.registry import VISUALIZERS
from mmpose.structures import merge_data_samples, split_instances
from mmpose.utils import adapt_mmdet_pipeline
from mmdet.apis import inference_detector, init_detector

# Load environment variables
load_dotenv()

# Get colors from environment variables
CONV_COLOR = tuple(map(int, os.getenv("CONV_COLOR", "0,255,0").split(",")))  # Default: Green
XGBOOST_COLOR = tuple(map(int, os.getenv("XGBOOST_COLOR", "255,0,0").split(",")))  # Default: Red
LIGHTGBM_COLOR = tuple(map(int, os.getenv("LIGHTGBM_COLOR", "0,0,255").split(",")))  # Default: Blue

# Get model execution settings
CONV_ENABLED = os.getenv("CONV_ENABLED", "True").lower() == "true"
XGBOOST_ENABLED = os.getenv("XGBOOST_ENABLED", "True").lower() == "true"
LIGHTGBM_ENABLED = os.getenv("LIGHTGBM_ENABLED", "True").lower() == "true"
POSENET_ENABLED = os.getenv("PoseNet_ENABLED", "True").lower() == "true"
RTMPOSE_ENABLED = os.getenv("RTMPose_ENABLED", "False").lower() == "true"

# Get normalization setting
NORMALIZE_ENABLED = os.getenv("NORMALIZE_ENABLED", "False").lower() == "true"


def load_model(model_path, model_type="lgb"):
    with open(model_path, "rb") as model_file:
        return pickle.load(model_file)


def normalize_quadrilateral_with_point(points, extra_point):
    all_points = np.vstack([points.reshape(-1, 2), extra_point])
    center = np.mean(points.reshape(-1, 2), axis=0)
    centered_points = all_points - center

    shoulder_angle = calculate_rotation_angle(centered_points[0], centered_points[1])
    hip_angle = calculate_rotation_angle(centered_points[2], centered_points[3])
    average_angle = (shoulder_angle + hip_angle) / 2

    rotation_matrix = np.array(
        [[np.cos(-average_angle), -np.sin(-average_angle)], [np.sin(-average_angle), np.cos(-average_angle)]]
    )

    rotated_points = np.dot(centered_points, rotation_matrix.T)
    max_edge_length = np.max(np.linalg.norm(np.roll(rotated_points[:4], -1, axis=0) - rotated_points[:4], axis=1))
    return rotated_points / max_edge_length


def calculate_rotation_angle(point1, point2):
    vector = point2 - point1
    return np.arctan2(vector[1], vector[0])


def video_to_frames(video_path, output_dir):
    os.makedirs(output_dir, exist_ok=True)
    video = cv2.VideoCapture(video_path)
    if not video.isOpened():
        raise IOError(f"Could not open video file: {video_path}")

    frame_num = 0
    while True:
        success, frame = video.read()
        if not success:
            break
        frame_num += 1
        cv2.imwrite(os.path.join(output_dir, f"{frame_num}-frame.png"), frame)

    video.release()
    print(f"All frames saved to {output_dir}")


def preprocess_image(image_path):
    transform = transforms.Compose(
        [
            transforms.Resize((224, 224)),
            transforms.ToTensor(),
            transforms.Normalize(mean=[0.485, 0.456, 0.406], std=[0.229, 0.224, 0.225]),
        ]
    )
    return transform(Image.open(image_path).convert("RGB")).unsqueeze(0)


def extract_keypoints_rtmpose(pose_results):
    if not pose_results:
        print("No pose results found.")
        return None

    max_avg_visible = 0
    best_instance = None
    for result in pose_results:
        pred_instances = result.pred_instances
        for instance in pred_instances:
            avg_visible = np.mean(instance.keypoints_visible)
            if avg_visible > max_avg_visible:
                max_avg_visible = avg_visible
                best_instance = instance

    if best_instance is None:
        print("No valid instances found.")
        return None

    keypoints = best_instance.keypoints[0]
    return keypoints


def process_images(args, detector, pose_estimator, visualizer):
    print("Starting process_images function...")
    ears_ai = EarsAI()
    calc_position = CalcStethoscopePosition()
    base_dir = os.path.join(args.output_dir, "frames")
    results_dir = os.path.join(args.output_dir, "results")
    csv_path = os.path.join(results_dir, "results.csv")
    normalized_csv_path = os.path.join(results_dir, "results-convert.csv")
    pose_overlay_dir = os.path.join(results_dir, "pose_overlay_image")
    stethoscope_overlay_dir = os.path.join(results_dir, "stethoscope_overlay_image")

    os.makedirs(results_dir, exist_ok=True)
    os.makedirs(pose_overlay_dir, exist_ok=True)
    os.makedirs(stethoscope_overlay_dir, exist_ok=True)

    png_files = sorted(
        [f for f in os.listdir(base_dir) if f.lower().endswith(".png")],
        key=lambda x: int(re.search(r"(\d+)", x).group(1)),
    )
    print(f"Found {len(png_files)} PNG files.")

    rows = []
    normalized_rows = []
    for image_file_name in png_files:
        print(f"Processing image: {image_file_name}")
        image_path = os.path.join(base_dir, image_file_name)
        frame = cv2.imread(image_path)
        if frame is None:
            print(f"Failed to load image: {image_path}")
            continue

        if POSENET_ENABLED:
            pose_overlay_img, *landmarks = ears_ai.pose_detect(frame, None)
            left_shoulder = landmarks[0]
            right_shoulder = landmarks[1]
            left_hip = landmarks[2]
            right_hip = landmarks[3]

        elif RTMPOSE_ENABLED:
            det_result = inference_detector(detector, frame)
            pred_instance = det_result.pred_instances.cpu().numpy()
            bboxes = np.concatenate((pred_instance.bboxes, pred_instance.scores[:, None]), axis=1)
            bboxes = bboxes[np.logical_and(pred_instance.labels == 0, pred_instance.scores > 0.3)]
            bboxes = bboxes[nms(bboxes, 0.3), :4]
            pose_results = inference_topdown(pose_estimator, frame, bboxes)
            data_samples = merge_data_samples(pose_results)
            keypoints = extract_keypoints_rtmpose(pose_results)
            if keypoints is None:
                print(f"Failed to extract keypoints for image: {image_path}")
                continue
            left_shoulder = keypoints[5]
            right_shoulder = keypoints[6]
            left_hip = keypoints[11]
            right_hip = keypoints[12]
            if visualizer is not None:
                visualizer.add_datasample(
                    "result",
                    frame,
                    data_sample=data_samples,
                    draw_gt=False,
                    draw_heatmap=False,
                    draw_bbox=False,
                    show_kpt_idx=False,
                    skeleton_style="mmpose",
                    show=False,
                    wait_time=0,
                    kpt_thr=0.3,
                )
            pose_overlay_img = visualizer.get_image()
        else:
            print("No pose estimation method enabled. Please enable either PoseNet or RTMPose.")
            continue
        stethoscope_overlay_img, stethoscope_x, stethoscope_y = ears_ai.ssd_detect(frame, None)

        cv2.imwrite(os.path.join(pose_overlay_dir, image_file_name), cv2.cvtColor(pose_overlay_img, cv2.COLOR_RGB2BGR))
        cv2.imwrite(
            os.path.join(stethoscope_overlay_dir, image_file_name),
            cv2.cvtColor(stethoscope_overlay_img, cv2.COLOR_RGB2BGR),
        )

        if POSENET_ENABLED:
            row = {
                "image_file_name": image_file_name,
                "left_shoulder_x": left_shoulder[1],
                "left_shoulder_y": left_shoulder[0],
                "right_shoulder_x": right_shoulder[1],
                "right_shoulder_y": right_shoulder[0],
                "left_hip_x": left_hip[1],
                "left_hip_y": left_hip[0],
                "right_hip_x": right_hip[1],
                "right_hip_y": right_hip[0],
                "stethoscope_x": stethoscope_x,
                "stethoscope_y": stethoscope_y,
            }
        elif RTMPOSE_ENABLED:
            row = {
                "image_file_name": image_file_name,
                "left_shoulder_x": left_shoulder[0],
                "left_shoulder_y": left_shoulder[1],
                "right_shoulder_x": right_shoulder[0],
                "right_shoulder_y": right_shoulder[1],
                "left_hip_x": left_hip[0],
                "left_hip_y": left_hip[1],
                "right_hip_x": right_hip[0],
                "right_hip_y": right_hip[1],
                "stethoscope_x": stethoscope_x,
                "stethoscope_y": stethoscope_y,
            }
        else:
            print("No pose estimation method enabled. Please enable either PoseNet or RTMPose.")
            continue

        rows.append(row)

        source_points = np.array(
            [
                [float(row[f"{pos}_x"]), float(row[f"{pos}_y"])]
                for pos in ["left_shoulder", "right_shoulder", "left_hip", "right_hip"]
            ],
            dtype=np.float32,
        )
        stethoscope_point = np.array([float(row["stethoscope_x"]), float(row["stethoscope_y"])])
        normalized_points = normalize_quadrilateral_with_point(source_points.flatten(), stethoscope_point)

        normalized_row = {
            "image_file_name": image_file_name,
            "left_shoulder_x": normalized_points[0, 0],
            "left_shoulder_y": normalized_points[0, 1],
            "right_shoulder_x": normalized_points[1, 0],
            "right_shoulder_y": normalized_points[1, 1],
            "left_hip_x": normalized_points[2, 0],
            "left_hip_y": normalized_points[2, 1],
            "right_hip_x": normalized_points[3, 0],
            "right_hip_y": normalized_points[3, 1],
            "stethoscope_x": normalized_points[4, 0],
            "stethoscope_y": normalized_points[4, 1],
        }
        normalized_rows.append(normalized_row)

    if rows:
        print(f"Writing {len(rows)} rows to CSV...")
        fieldnames = list(rows[0].keys())
        if CONV_ENABLED:
            fieldnames.extend(["conv_stethoscope_x", "conv_stethoscope_y"])
        if XGBOOST_ENABLED:
            fieldnames.extend(["Xgboost_stethoscope_x", "Xgboost_stethoscope_y"])
        if LIGHTGBM_ENABLED:
            fieldnames.extend(["lightGBM_stethoscope_x", "lightGBM_stethoscope_y"])

        if LIGHTGBM_ENABLED:
            lgb_model_x = load_model("./models/lgb_stethoscope_calc_x_best_model-Fold4.pkl")
            lgb_model_y = load_model("./models/lgb_stethoscope_calc_y_best_model-Fold4.pkl")
        if XGBOOST_ENABLED:
            xg_model_x = load_model("./models/xg_stethoscope_calc_x_best_model-Fold4.pkl")
            xg_model_y = load_model("./models/xg_stethoscope_calc_y_best_model-Fold4.pkl")

        with open(csv_path, "w", newline="") as csvfile, open(normalized_csv_path, "w", newline="") as norm_csvfile:
            writer = csv.DictWriter(csvfile, fieldnames=fieldnames)
            norm_writer = csv.DictWriter(norm_csvfile, fieldnames=fieldnames)
            writer.writeheader()
            norm_writer.writeheader()

            prev_values = {}
            if CONV_ENABLED:
                prev_values["conv"] = (180, 180)
            if LIGHTGBM_ENABLED:
                prev_values["lightGBM"] = (180, 180)
            if XGBOOST_ENABLED:
                prev_values["Xgboost"] = (180, 180)

            for i, (row, norm_row) in enumerate(zip(rows, normalized_rows)):
                source_points = np.array(
                    [
                        [float(row[f"{pos}_x"]), float(row[f"{pos}_y"])]
                        for pos in ["left_shoulder", "right_shoulder", "left_hip", "right_hip"]
                    ],
                    dtype=np.float32,
                )
                stethoscope_point = np.array([float(row["stethoscope_x"]), float(row["stethoscope_y"])])

                if stethoscope_point[0] == 0 and stethoscope_point[1] == 0:
                    for key in prev_values:
                        row[f"{key}_stethoscope_x"], row[f"{key}_stethoscope_y"] = prev_values[key]
                        norm_row[f"{key}_stethoscope_x"], norm_row[f"{key}_stethoscope_y"] = prev_values[key]
                else:
                    if CONV_ENABLED:
                        conv_stethoscope = calc_position.calc_affine(source_points, *stethoscope_point)
                        row["conv_stethoscope_x"], row["conv_stethoscope_y"] = conv_stethoscope
                        norm_row["conv_stethoscope_x"], norm_row["conv_stethoscope_y"] = conv_stethoscope

                    if NORMALIZE_ENABLED:
                        input_data = pd.DataFrame([norm_row])
                    else:
                        input_data = pd.DataFrame([row])

                    input_columns = [
                        f"{pos}_{coord}"
                        for pos in ["left_shoulder", "right_shoulder", "left_hip", "right_hip", "stethoscope"]
                        for coord in ["x", "y"]
                    ]

                    if LIGHTGBM_ENABLED:
                        lgb_x = int(lgb_model_x.predict(input_data[input_columns])[0])
                        lgb_y = int(lgb_model_y.predict(input_data[input_columns])[0])
                        row["lightGBM_stethoscope_x"], row["lightGBM_stethoscope_y"] = lgb_x, lgb_y
                        norm_row["lightGBM_stethoscope_x"], norm_row["lightGBM_stethoscope_y"] = lgb_x, lgb_y

                    if XGBOOST_ENABLED:
                        xg_x = int(xg_model_x.predict(input_data[input_columns])[0])
                        xg_y = int(xg_model_y.predict(input_data[input_columns])[0])
                        row["Xgboost_stethoscope_x"], row["Xgboost_stethoscope_y"] = xg_x, xg_y
                        norm_row["Xgboost_stethoscope_x"], norm_row["Xgboost_stethoscope_y"] = xg_x, xg_y

                    for key in prev_values:
                        prev_values[key] = (row[f"{key}_stethoscope_x"], row[f"{key}_stethoscope_y"])

                writer.writerow(row)
                norm_writer.writerow(norm_row)

        print(f"Processed and saved results to: {csv_path}")
        print(f"Processed and saved normalized results to: {normalized_csv_path}")

        generate_visualizations(csv_path, base_dir, results_dir)
    else:
        print("No data to write to CSV.")


def generate_visualizations(csv_path, original_images_dir, results_dir):
    df = pd.read_csv(csv_path)
    body_image = cv2.imread("./images/body/BodyF.png")

    dirs = {"marked": "marked_images"}
    if CONV_ENABLED:
        dirs["conv"] = "conv"
    if XGBOOST_ENABLED:
        dirs["Xgboost"] = "Xgboost"
    if LIGHTGBM_ENABLED:
        dirs["lightGBM"] = "lightGBM"

    os.makedirs(os.path.join(results_dir, "marked_images"), exist_ok=True)
    for key in dirs:
        if key != "marked":
            os.makedirs(os.path.join(results_dir, f"{dirs[key]}_with_trajectory"), exist_ok=True)
            os.makedirs(os.path.join(results_dir, f"{dirs[key]}_without_trajectory"), exist_ok=True)

    points = {key: [] for key in dirs.keys() if key != "marked"}
    colors = {"conv": CONV_COLOR, "Xgboost": XGBOOST_COLOR, "lightGBM": LIGHTGBM_COLOR}

    for _, row in df.iterrows():
        original_image = cv2.imread(os.path.join(original_images_dir, row["image_file_name"]))
        if original_image is None:
            print(f"Failed to load image: {os.path.join(original_images_dir, row['image_file_name'])}")
            continue

        for point in ["left_shoulder", "right_shoulder", "left_hip", "right_hip", "stethoscope"]:
            cv2.circle(original_image, (int(row[f"{point}_x"]), int(row[f"{point}_y"])), 10, (255, 255, 0), -1)

        cv2.imwrite(os.path.join(results_dir, "marked_images", row["image_file_name"]), original_image)

        for key in points:
            x, y = int(row[f"{key}_stethoscope_x"]), int(row[f"{key}_stethoscope_y"])
            points[key].append((x, y))

            image_with_trajectory = body_image.copy()
            if len(points[key]) > 1:
                cv2.polylines(image_with_trajectory, [np.array(points[key])], False, colors[key], 2)
            cv2.circle(image_with_trajectory, (x, y), 10, colors[key], -1)
            cv2.imwrite(
                os.path.join(results_dir, f"{dirs[key]}_with_trajectory", row["image_file_name"]), image_with_trajectory
            )

            image_without_trajectory = body_image.copy()
            cv2.circle(image_without_trajectory, (x, y), 10, colors[key], -1)
            cv2.imwrite(
                os.path.join(results_dir, f"{dirs[key]}_without_trajectory", row["image_file_name"]),
                image_without_trajectory,
            )

    create_video_from_images(os.path.join(results_dir, "marked_images"), os.path.join(results_dir, "marked_video.mp4"))

    for key in dirs:
        if key != "marked":
            create_video_from_images(
                os.path.join(results_dir, f"{dirs[key]}_with_trajectory"),
                os.path.join(results_dir, f"{key}_video_with_trajectory.mp4"),
            )
            create_video_from_images(
                os.path.join(results_dir, f"{dirs[key]}_without_trajectory"),
                os.path.join(results_dir, f"{key}_video_without_trajectory.mp4"),
            )


def create_video_from_images(image_dir, output_path):
    images = sorted(
        [img for img in os.listdir(image_dir) if img.endswith(".png")],
        key=lambda x: int(re.search(r"(\d+)", x).group()),
    )

    if not images:
        print(f"No images found in {image_dir}")
        return

    frame = cv2.imread(os.path.join(image_dir, images[0]))
    height, width, _ = frame.shape

    video = cv2.VideoWriter(output_path, cv2.VideoWriter_fourcc(*"mp4v"), 30, (width, height))

    for image in images:
        img = cv2.imread(os.path.join(image_dir, image))
        video.write(img)

    video.release()
    print(f"Created video: {output_path}")


def main():
    parser = argparse.ArgumentParser(description="Process video and generate results.")
    parser = argparse.ArgumentParser(description="Process video and generate results.")
    parser.add_argument("--video_path", default="./video/Test3-1.mp4", help="Path to the input video file")
    parser.add_argument("--output_dir", default="output", help="Directory to save output images and results")
    det_config = "modules/rtmpose/mmdetection_cfg/rtmdet_m_640-8xb32_coco-person.py"
    det_checkpoint = (
        "https://download.openmmlab.com/mmpose/v1/projects/rtmpose/rtmdet_m_8xb32-100e_coco-obj365-person-235e8209.pth"
    )
    pose_config = "modules/rtmpose/configs/body_2d_keypoint/rtmpose/body8/rtmpose-m_8xb256-420e_body8-256x192.py"
    pose_checkpoint = "https://download.openmmlab.com/mmpose/v1/projects/rtmposev1/rtmpose-m_simcc-body7_pt-body7_420e-256x192-e48f03d0_20230504.pth"

    args = parser.parse_args()

    os.makedirs(args.output_dir, exist_ok=True)

    frames_dir = os.path.join(args.output_dir, "frames")
    video_to_frames(args.video_path, frames_dir)

    if RTMPOSE_ENABLED:
        detector = init_detector(det_config, det_checkpoint, device="cuda:0")
        detector.cfg = adapt_mmdet_pipeline(detector.cfg)
        pose_estimator = init_pose_estimator(pose_config, pose_checkpoint, device="cuda:0")
        visualizer = VISUALIZERS.build(pose_estimator.cfg.visualizer)
        visualizer.set_dataset_meta(pose_estimator.dataset_meta, skeleton_style="mmpose")

        process_images(args, detector, pose_estimator, visualizer)
    else:
        process_images(args, None, None, None)


if __name__ == "__main__":
    main()