Newer
Older
Demo-Maker / main-cnn.py
import argparse
import csv
import os
import re

import cv2
import numpy as np
import pandas as pd
import torch
from dotenv import load_dotenv
from PIL import Image
from torchvision import transforms

from modules.EARSForDL.EfficientNet import RegressionEfficientNet
from modules.EARSForDL.MobileNetV2 import RegressionMobileNetV2
from modules.EARSForDL.ResNet import RegressionResNet
from modules.EARSForDL.SqueezeNet import RegressionSqueezeNet

# RTMPose imports

# Load environment variables
load_dotenv()

# Get colors from environment variables
RESNET_COLOR = tuple(
    map(int, os.getenv("RESNET_COLOR", "255,165,0").split(","))
)  # Orange for ResNet
EFFICIENTNET_COLOR = tuple(
    map(int, os.getenv("EFFICIENTNET_COLOR", "0,0,255").split(","))
)  # Blue for EfficientNet
MOBILENET_COLOR = tuple(
    map(int, os.getenv("MOBILENET_COLOR", "255,0,0").split(","))
)  # Red for MobileNet
SQUEEZENET_COLOR = tuple(
    map(int, os.getenv("SQUEEZENET_COLOR", "128,0,128").split(","))
)  # Purple for SqueezeNet

# Get model execution settings from environment variables
RESNET_ENABLED = os.getenv("RESNET_ENABLED", "True").lower() == "true"
EFFICIENTNET_ENABLED = os.getenv("EFFICIENTNET_ENABLED", "True").lower() == "true"
MOBILENET_ENABLED = os.getenv("MOBILENET_ENABLED", "True").lower() == "true"
SQUEEZENET_ENABLED = os.getenv("SQUEEZENET_ENABLED", "True").lower() == "true"

# Get normalization setting
NORMALIZE_ENABLED = os.getenv("NORMALIZE_ENABLED", "False").lower() == "true"


def normalize_quadrilateral_with_point(points, extra_point):
    all_points = np.vstack([points.reshape(-1, 2), extra_point])
    center = np.mean(points.reshape(-1, 2), axis=0)
    centered_points = all_points - center

    shoulder_angle = calculate_rotation_angle(centered_points[0], centered_points[1])
    hip_angle = calculate_rotation_angle(centered_points[2], centered_points[3])
    average_angle = (shoulder_angle + hip_angle) / 2

    rotation_matrix = np.array(
        [
            [np.cos(-average_angle), -np.sin(-average_angle)],
            [np.sin(-average_angle), np.cos(-average_angle)],
        ]
    )

    rotated_points = np.dot(centered_points, rotation_matrix.T)
    max_edge_length = np.max(
        np.linalg.norm(
            np.roll(rotated_points[:4], -1, axis=0) - rotated_points[:4], axis=1
        )
    )
    return rotated_points / max_edge_length


def calculate_rotation_angle(point1, point2):
    vector = point2 - point1
    return np.arctan2(vector[1], vector[0])


def video_to_frames(video_path, output_dir):
    os.makedirs(output_dir, exist_ok=True)
    video = cv2.VideoCapture(video_path)
    if not video.isOpened():
        raise IOError(f"Could not open video file: {video_path}")

    frame_num = 0
    while True:
        success, frame = video.read()
        if not success:
            break
        frame_num += 1
        cv2.imwrite(os.path.join(output_dir, f"{frame_num}-frame.png"), frame)

    video.release()
    print(f"All frames saved to {output_dir}")


def preprocess_image(image_path):
    transform = transforms.Compose(
        [
            transforms.Resize((224, 224)),
            transforms.ToTensor(),
            transforms.Normalize(mean=[0.485, 0.456, 0.406], std=[0.229, 0.224, 0.225]),
        ]
    )
    return transform(Image.open(image_path).convert("RGB")).unsqueeze(0)


def extract_keypoints_rtmpose(pose_results):
    if not pose_results:
        print("No pose results found.")
        return None

    max_avg_visible = 0
    best_instance = None
    for result in pose_results:
        pred_instances = result.pred_instances
        for instance in pred_instances:
            avg_visible = np.mean(instance.keypoints_visible)
            if avg_visible > max_avg_visible:
                max_avg_visible = avg_visible
                best_instance = instance

    if best_instance is None:
        print("No valid instances found.")
        return None

    keypoints = best_instance.keypoints[0]
    return keypoints


def process_images(args):
    print("Starting process_images function...")
    base_dir = os.path.join(args.output_dir, "frames")
    results_dir = os.path.join(args.output_dir, "results")
    csv_path = os.path.join(results_dir, "results.csv")

    # Load enabled models
    device = torch.device("cuda:0" if torch.cuda.is_available() else "cpu")
    models = {}

    if RESNET_ENABLED:
        resnet_model = RegressionResNet(resnet_depth=18)
        resnet_model.load_state_dict(
            torch.load("./models/best_model-resnet.pth", map_location=device)
        )
        resnet_model.to(device)
        resnet_model.eval()
        models["resnet"] = resnet_model

    if EFFICIENTNET_ENABLED:
        efficientnet_model = RegressionEfficientNet("b1")
        efficientnet_model.load_state_dict(
            torch.load("./models/best_model-efficient.pth", map_location=device)
        )
        efficientnet_model.to(device)
        efficientnet_model.eval()
        models["efficientnet"] = efficientnet_model

    if MOBILENET_ENABLED:
        mobilenet_model = RegressionMobileNetV2()
        mobilenet_model.load_state_dict(
            torch.load("./models/best_model-mobilenetV2.pth", map_location=device)
        )
        mobilenet_model.to(device)
        mobilenet_model.eval()
        models["mobilenet"] = mobilenet_model

    if SQUEEZENET_ENABLED:
        squeezenet_model = RegressionSqueezeNet("1_1")
        squeezenet_model.load_state_dict(
            torch.load("./models/best_model-squeeze.pth", map_location=device)
        )
        squeezenet_model.to(device)
        squeezenet_model.eval()
        models["squeezenet"] = squeezenet_model

    os.makedirs(results_dir, exist_ok=True)

    png_files = sorted(
        [f for f in os.listdir(base_dir) if f.lower().endswith(".png")],
        key=lambda x: int(re.search(r"(\d+)", x).group(1)),
    )
    print(f"Found {len(png_files)} PNG files.")

    rows = []
    for image_file_name in png_files:
        print(f"Processing image: {image_file_name}")
        image_path = os.path.join(base_dir, image_file_name)
        frame = cv2.imread(image_path)
        if frame is None:
            print(f"Failed to load image: {image_path}")
            continue

        # Get predictions from all enabled models
        processed_image = preprocess_image(image_path).to(device)
        row = {"image_file_name": image_file_name}

        with torch.no_grad():
            for model_name, model in models.items():
                output = model(processed_image)
                coords = output[0].cpu().numpy()
                row[f"{model_name}_stethoscope_x"] = int(coords[0])
                row[f"{model_name}_stethoscope_y"] = int(coords[1])

        rows.append(row)

    if rows:
        fieldnames = list(rows[0].keys())
        with open(csv_path, "w", newline="") as csvfile:
            writer = csv.DictWriter(csvfile, fieldnames=fieldnames)
            writer.writeheader()
            for row in rows:
                writer.writerow(row)

        print(f"Processed and saved results to: {csv_path}")
        generate_visualizations(csv_path, base_dir, results_dir)
    else:
        print("No data to write to CSV.")


def generate_visualizations(csv_path, original_images_dir, results_dir):
    df = pd.read_csv(csv_path)
    body_image = cv2.imread("./images/body/BodyF.png")

    # Define directories and colors for enabled models
    dirs = {}
    colors = {}

    if RESNET_ENABLED:
        dirs["resnet"] = "resnet"
        colors["resnet"] = RESNET_COLOR
    if EFFICIENTNET_ENABLED:
        dirs["efficientnet"] = "efficientnet"
        colors["efficientnet"] = EFFICIENTNET_COLOR
    if MOBILENET_ENABLED:
        dirs["mobilenet"] = "mobilenet"
        colors["mobilenet"] = MOBILENET_COLOR
    if SQUEEZENET_ENABLED:
        dirs["squeezenet"] = "squeezenet"
        colors["squeezenet"] = SQUEEZENET_COLOR

    # Create output directories
    for key in dirs:
        os.makedirs(
            os.path.join(results_dir, f"{dirs[key]}_with_trajectory"), exist_ok=True
        )
        os.makedirs(
            os.path.join(results_dir, f"{dirs[key]}_without_trajectory"), exist_ok=True
        )

    points = {key: [] for key in dirs.keys()}

    for _, row in df.iterrows():
        # Process each prediction method
        for key in points:
            x = int(row[f"{key}_stethoscope_x"])
            y = int(row[f"{key}_stethoscope_y"])
            points[key].append((x, y))

            # Draw with trajectory
            image_with_trajectory = body_image.copy()
            if len(points[key]) > 1:
                cv2.polylines(
                    image_with_trajectory,
                    [np.array(points[key])],
                    False,
                    colors[key],
                    2,
                )
            cv2.circle(image_with_trajectory, (x, y), 10, colors[key], -1)
            cv2.imwrite(
                os.path.join(
                    results_dir, f"{dirs[key]}_with_trajectory", row["image_file_name"]
                ),
                image_with_trajectory,
            )

            # Draw without trajectory
            image_without_trajectory = body_image.copy()
            cv2.circle(image_without_trajectory, (x, y), 10, colors[key], -1)
            cv2.imwrite(
                os.path.join(
                    results_dir,
                    f"{dirs[key]}_without_trajectory",
                    row["image_file_name"],
                ),
                image_without_trajectory,
            )

    # Create videos for all methods
    for key in dirs:
        create_video_from_images(
            os.path.join(results_dir, f"{dirs[key]}_with_trajectory"),
            os.path.join(results_dir, f"{key}_video_with_trajectory.mp4"),
        )
        create_video_from_images(
            os.path.join(results_dir, f"{dirs[key]}_without_trajectory"),
            os.path.join(results_dir, f"{key}_video_without_trajectory.mp4"),
        )


def create_video_from_images(image_dir, output_path):
    images = sorted(
        [img for img in os.listdir(image_dir) if img.endswith(".png")],
        key=lambda x: int(re.search(r"(\d+)", x).group()),
    )

    if not images:
        print(f"No images found in {image_dir}")
        return

    frame = cv2.imread(os.path.join(image_dir, images[0]))
    height, width, _ = frame.shape

    video = cv2.VideoWriter(
        output_path, cv2.VideoWriter_fourcc(*"mp4v"), 30, (width, height)
    )

    for image in images:
        img = cv2.imread(os.path.join(image_dir, image))
        video.write(img)

    video.release()
    print(f"Created video: {output_path}")


def main():
    parser = argparse.ArgumentParser(description="Process video and generate results.")
    parser.add_argument(
        "--video_path",
        default="./video/Test3-1.mp4",
        help="Path to the input video file",
    )
    parser.add_argument(
        "--output_dir",
        default="output-cnn",
        help="Directory to save output images and results",
    )

    args = parser.parse_args()
    os.makedirs(args.output_dir, exist_ok=True)

    frames_dir = os.path.join(args.output_dir, "frames")
    video_to_frames(args.video_path, frames_dir)
    process_images(args)


if __name__ == "__main__":
    main()