Newer
Older
Demo-Maker / util / ears_ai.py
@mikado-4410 mikado-4410 on 10 Oct 2024 3 KB 最初のコミット
import cv2
import numpy as np
import torch
import modules.posenet as posenet
from modules.PytorchSSD.ssd.mobilenetv1_ssd import create_mobilenetv1_ssd, create_mobilenetv1_ssd_predictor
from modules.util import const


class EarsAI:
    def __init__(self):
        self.model_path = const.MODEL_PATH
        self.label_path = const.LABEL_PATH

        self.setup_ssd_model()
        self.setup_posenet()

    def setup_ssd_model(self):
        """SSDモデルのセットアップを行う"""
        class_names = [name.strip() for name in open(self.label_path).readlines()]
        net = create_mobilenetv1_ssd(len(class_names), is_test=True)
        net.load(self.model_path)
        self.predictor = create_mobilenetv1_ssd_predictor(net, candidate_size=200)
        self.class_names = class_names

    def setup_posenet(self):
        """PoseNetのセットアップを行う"""
        self.posenet_model = posenet.load_model(101).cuda()
        self.output_stride = self.posenet_model.output_stride

    def pose_detect(self, frame, vid):
        """姿勢検出を行う"""
        if frame is None:
            raise ValueError("Input frame is None")

        print(f"Pose detect - Input frame shape: {frame.shape}")

        frame = cv2.cvtColor(frame, cv2.COLOR_BGR2RGB)
        input_image, draw_image, output_scale = posenet.read_imgfile(
            frame, scale_factor=0.7125, output_stride=self.output_stride
        )
        with torch.no_grad():
            input_image = torch.Tensor(input_image).cuda()
            heatmaps_result, offsets_result, displacement_fwd_result, displacement_bwd_result = self.posenet_model(
                input_image
            )
            pose_scores, keypoint_scores, keypoint_coords = posenet.decode_multiple_poses(
                heatmaps_result.squeeze(0),
                offsets_result.squeeze(0),
                displacement_fwd_result.squeeze(0),
                displacement_bwd_result.squeeze(0),
                output_stride=self.output_stride,
                max_pose_detections=10,
                min_pose_score=0.05,
            )

        keypoint_coords *= output_scale
        overlay_image = posenet.draw_skel_and_kp(
            frame, pose_scores, keypoint_scores, keypoint_coords, min_pose_score=0.15, min_part_score=0.1
        )

        # Extract keypoint coordinates
        left_shoulder = keypoint_coords[0, posenet.PART_NAMES.index("leftShoulder"), :].astype(np.int32)
        right_shoulder = keypoint_coords[0, posenet.PART_NAMES.index("rightShoulder"), :].astype(np.int32)
        left_hip = keypoint_coords[0, posenet.PART_NAMES.index("leftHip"), :].astype(np.int32)
        right_hip = keypoint_coords[0, posenet.PART_NAMES.index("rightHip"), :].astype(np.int32)

        return overlay_image, left_shoulder, right_shoulder, left_hip, right_hip

    def ssd_detect(self, frame, vid):
        """SSDによる検出を行う"""
        overlay_image = cv2.cvtColor(frame, cv2.COLOR_RGB2BGR)
        boxes, labels, probs = self.predictor.predict(overlay_image, 1, 0.20)
        overlay_image = cv2.cvtColor(overlay_image, cv2.COLOR_BGR2RGB)

        stethoscope_x, stethoscope_y = 0, 0
        if len(probs) != 0:
            max_index = np.argmax(probs)
            box = boxes[max_index, :]
            cv2.rectangle(overlay_image, (int(box[0]), int(box[1])), (int(box[2]), int(box[3])), (0, 255, 255), 2)
            stethoscope_x = int((box[0] + box[2]) / 2)
            stethoscope_y = int((box[1] + box[3]) / 2)
            label = f"{self.class_names[labels[max_index]]}: {probs[max_index]:.2f}"
            cv2.putText(
                overlay_image,
                label,
                (int(box[0]) + 20, int(box[1]) + 40),
                cv2.FONT_HERSHEY_SIMPLEX,
                0.5,
                (255, 0, 255),
                1,
            )

        return overlay_image, stethoscope_x, stethoscope_y