Newer
Older
Demo-Maker / output_posenet.py
import os
import re
import time

import cv2
import numpy as np

from util.ears_ai import EarsAI


def video_to_frames(video_path, output_dir):
    """動画からフレームを抽出して保存"""
    os.makedirs(output_dir, exist_ok=True)
    video = cv2.VideoCapture(video_path)
    if not video.isOpened():
        raise IOError(f"Could not open video file: {video_path}")

    frame_num = 0
    while True:
        success, frame = video.read()
        if not success:
            break
        frame_num += 1
        cv2.imwrite(os.path.join(output_dir, f"{frame_num}-frame.png"), frame)

    video.release()
    print(f"All frames saved to {output_dir}")
    return frame_num


def process_frames_with_posenet(frames_dir, output_dir):
    """PoseNetで骨格検出を行い、結果を保存"""
    ears_ai = EarsAI()
    os.makedirs(output_dir, exist_ok=True)

    # フレーム画像を番号順にソート
    png_files = sorted(
        [f for f in os.listdir(frames_dir) if f.lower().endswith(".png")],
        key=lambda x: int(re.search(r"(\d+)", x).group(1)),
    )

    processing_times = []
    fps_list = []

    for image_file_name in png_files:
        print(f"Processing image: {image_file_name}")

        # フレームを読み込み
        frame = cv2.imread(os.path.join(frames_dir, image_file_name))
        if frame is None:
            continue

        # 処理時間の計測開始
        start_time = time.time()

        # PoseNetで骨格検出
        pose_overlay_img, *_ = ears_ai.pose_detect(frame, None)

        # 処理時間の計測終了
        end_time = time.time()
        processing_time = end_time - start_time
        processing_times.append(processing_time)

        # FPSの計算
        fps = 1.0 / processing_time
        fps_list.append(fps)

        # FPSをフレームに描画
        fps_text = f"FPS: {fps:.1f}"
        cv2.putText(
            pose_overlay_img,
            fps_text,
            (10, 30),
            cv2.FONT_HERSHEY_SIMPLEX,
            1,
            (0, 255, 0),
            2,
        )

        # 結果を保存
        cv2.imwrite(
            os.path.join(output_dir, image_file_name),
            cv2.cvtColor(pose_overlay_img, cv2.COLOR_RGB2BGR),
        )

    # 統計情報の計算
    avg_fps = np.mean(fps_list)
    avg_processing_time = np.mean(processing_times)

    print("\n処理統計:")
    print(f"平均FPS: {avg_fps:.1f}")
    print(f"平均推論時間: {avg_processing_time*1000:.1f}ms")
    print(f"総フレーム数: {len(png_files)}")

    return avg_fps, avg_processing_time


def create_video_from_images(image_dir, output_path, fps=30):
    """画像シーケンスから動画を作成"""
    images = sorted(
        [img for img in os.listdir(image_dir) if img.endswith(".png")],
        key=lambda x: int(re.search(r"(\d+)", x).group()),
    )

    if not images:
        print(f"No images found in {image_dir}")
        return

    frame = cv2.imread(os.path.join(image_dir, images[0]))
    height, width, _ = frame.shape

    video = cv2.VideoWriter(
        output_path, cv2.VideoWriter_fourcc(*"mp4v"), fps, (width, height)
    )

    for image in images:
        img = cv2.imread(os.path.join(image_dir, image))
        video.write(img)

    video.release()
    print(f"Created video: {output_path}")


def main():
    # 設定
    video_path = "./video/Test3-1.mp4"  # 入力動画のパス
    output_base_dir = "output-posenet"  # 出力ディレクトリ
    frames_dir = os.path.join(output_base_dir, "frames")  # フレーム保存ディレクトリ
    pose_frames_dir = os.path.join(
        output_base_dir, "pose_frames"
    )  # ポーズ検出結果保存ディレクトリ
    output_video_path = os.path.join(
        output_base_dir, "pose_detection.mp4"
    )  # 出力動画パス

    # 出力ベースディレクトリの作成
    os.makedirs(output_base_dir, exist_ok=True)

    # メイン処理
    print("Starting video processing...")
    print("1. Extracting frames from video...")
    video_to_frames(video_path, frames_dir)

    print("2. Processing frames with PoseNet...")
    avg_fps, avg_processing_time = process_frames_with_posenet(
        frames_dir, pose_frames_dir
    )

    print("3. Creating output video...")
    create_video_from_images(pose_frames_dir, output_video_path)

    print("\n最終結果:")
    print(f"平均FPS: {avg_fps:.1f}")
    print(f"平均推論時間: {avg_processing_time*1000:.1f}ms")
    print("Processing complete!")


if __name__ == "__main__":
    main()