Newer
Older
RobotCar / src / pc / data / collector.py
"""
collector
二値画像をラベル別ディレクトリに保存するデータ収集モジュール

録画中にフレームごとに save() を呼び出すと,
指定ラベルのサブディレクトリへ連番 PNG として保存する
"""

from datetime import datetime
from pathlib import Path

import cv2
import numpy as np

# デフォルトの保存先(プロジェクトルート / data)
_PROJECT_ROOT: Path = (
    Path(__file__).resolve().parent.parent.parent.parent
)
DEFAULT_DATA_DIR: Path = _PROJECT_ROOT / "data"

# ラベル名
LABEL_INTERSECTION: str = "intersection"
LABEL_NORMAL: str = "normal"


class DataCollector:
    """二値画像をラベル付きで保存するコレクタ

    Attributes:
        data_dir: 保存先ルートディレクトリ
        session_dir: 現在の録画セッションのディレクトリ
        is_recording: 録画中かどうか
    """

    def __init__(
        self,
        data_dir: Path = DEFAULT_DATA_DIR,
    ) -> None:
        self._data_dir = data_dir
        self._session_dir: Path | None = None
        self._is_recording: bool = False
        self._count_intersection: int = 0
        self._count_normal: int = 0

    @property
    def is_recording(self) -> bool:
        """録画中かどうかを返す"""
        return self._is_recording

    @property
    def count_intersection(self) -> int:
        """現在のセッションで保存した intersection 画像の枚数"""
        return self._count_intersection

    @property
    def count_normal(self) -> int:
        """現在のセッションで保存した normal 画像の枚数"""
        return self._count_normal

    def start(self) -> Path:
        """録画を開始する

        タイムスタンプ付きのセッションディレクトリを作成し,
        その中に intersection/ と normal/ を用意する

        Returns:
            作成したセッションディレクトリのパス
        """
        timestamp = datetime.now().strftime("%Y%m%d_%H%M%S")
        self._session_dir = self._data_dir / timestamp
        (self._session_dir / LABEL_INTERSECTION).mkdir(
            parents=True, exist_ok=True,
        )
        (self._session_dir / LABEL_NORMAL).mkdir(
            parents=True, exist_ok=True,
        )
        self._count_intersection = 0
        self._count_normal = 0
        self._is_recording = True
        return self._session_dir

    def stop(self) -> None:
        """録画を停止する"""
        self._is_recording = False

    def save(
        self,
        binary_image: np.ndarray,
        is_intersection: bool,
    ) -> None:
        """二値画像をラベル付きで保存する

        Args:
            binary_image: 保存する二値画像(0/255)
            is_intersection: True なら intersection,False なら normal
        """
        if not self._is_recording or self._session_dir is None:
            return

        if is_intersection:
            label = LABEL_INTERSECTION
            self._count_intersection += 1
            idx = self._count_intersection
        else:
            label = LABEL_NORMAL
            self._count_normal += 1
            idx = self._count_normal

        path = (
            self._session_dir / label / f"{idx:06d}.png"
        )
        cv2.imwrite(str(path), binary_image)