Newer
Older
Demo-Maker / util / ears_ai.py
import cv2
import numpy as np
import torch

import modules.posenet as posenet
from modules.PytorchSSD.ssd.mobilenetv1_ssd import (
    create_mobilenetv1_ssd,
    create_mobilenetv1_ssd_predictor,
)
from util import const


class EarsAI:
    def __init__(self):
        self.model_path = const.MODEL_PATH
        self.label_path = const.LABEL_PATH

        self.setup_ssd_model()
        self.setup_posenet()

    def setup_ssd_model(self):
        """SSDモデルのセットアップを行う"""
        class_names = [name.strip() for name in open(self.label_path).readlines()]
        net = create_mobilenetv1_ssd(len(class_names), is_test=True)
        net.load(self.model_path)
        self.predictor = create_mobilenetv1_ssd_predictor(net, candidate_size=200)
        self.class_names = class_names

    def setup_posenet(self):
        """PoseNetのセットアップを行う"""
        self.posenet_model = posenet.load_model(101).cuda()
        self.output_stride = self.posenet_model.output_stride

    def pose_detect(self, frame, vid):
        """姿勢検出を行う"""
        if frame is None:
            raise ValueError("Input frame is None")

        print(f"Pose detect - Input frame shape: {frame.shape}")

        frame = cv2.cvtColor(frame, cv2.COLOR_BGR2RGB)
        input_image, draw_image, output_scale = posenet.read_imgfile(
            frame, scale_factor=0.7125, output_stride=self.output_stride
        )
        with torch.no_grad():
            input_image = torch.Tensor(input_image).cuda()
            (
                heatmaps_result,
                offsets_result,
                displacement_fwd_result,
                displacement_bwd_result,
            ) = self.posenet_model(input_image)
            pose_scores, keypoint_scores, keypoint_coords = (
                posenet.decode_multiple_poses(
                    heatmaps_result.squeeze(0),
                    offsets_result.squeeze(0),
                    displacement_fwd_result.squeeze(0),
                    displacement_bwd_result.squeeze(0),
                    output_stride=self.output_stride,
                    max_pose_detections=1,
                    min_pose_score=0.0,
                )
            )

        keypoint_coords *= output_scale
        overlay_image = posenet.draw_skel_and_kp(
            frame,
            pose_scores,
            keypoint_scores,
            keypoint_coords,
            min_pose_score=0.0,
            min_part_score=0.0,
        )

        # Extract keypoint coordinates
        left_shoulder = keypoint_coords[
            0, posenet.PART_NAMES.index("leftShoulder"), :
        ].astype(np.int32)
        right_shoulder = keypoint_coords[
            0, posenet.PART_NAMES.index("rightShoulder"), :
        ].astype(np.int32)
        left_hip = keypoint_coords[0, posenet.PART_NAMES.index("leftHip"), :].astype(
            np.int32
        )
        right_hip = keypoint_coords[0, posenet.PART_NAMES.index("rightHip"), :].astype(
            np.int32
        )

        return overlay_image, left_shoulder, right_shoulder, left_hip, right_hip

    def ssd_detect(self, frame, vid):
        """SSDによる検出を行う"""
        overlay_image = cv2.cvtColor(frame, cv2.COLOR_RGB2BGR)
        boxes, labels, probs = self.predictor.predict(overlay_image, 1, 0.20)
        overlay_image = cv2.cvtColor(overlay_image, cv2.COLOR_BGR2RGB)

        stethoscope_x, stethoscope_y = 0, 0
        if len(probs) != 0:
            max_index = np.argmax(probs)
            box = boxes[max_index, :]
            cv2.rectangle(
                overlay_image,
                (int(box[0]), int(box[1])),
                (int(box[2]), int(box[3])),
                (0, 255, 255),
                2,
            )
            stethoscope_x = int((box[0] + box[2]) / 2)
            stethoscope_y = int((box[1] + box[3]) / 2)
            label = f"{self.class_names[labels[max_index]]}: {probs[max_index]:.2f}"
            cv2.putText(
                overlay_image,
                label,
                (int(box[0]) + 20, int(box[1]) + 40),
                cv2.FONT_HERSHEY_SIMPLEX,
                0.5,
                (255, 0, 255),
                1,
            )

        return overlay_image, stethoscope_x, stethoscope_y