"""
collector
二値画像をラベル別ディレクトリに保存するデータ収集モジュール
録画中にフレームごとに save() を呼び出すと,
指定ラベルのサブディレクトリへ連番 PNG として保存する
"""
from datetime import datetime
from pathlib import Path
import cv2
import numpy as np
# プロジェクトルート
_PROJECT_ROOT: Path = (
Path(__file__).resolve().parent.parent.parent.parent
)
# データディレクトリ
DATA_DIR: Path = _PROJECT_ROOT / "data"
RAW_DIR: Path = DATA_DIR / "raw"
CONFIRMED_DIR: Path = DATA_DIR / "confirmed"
# ラベル名
LABEL_INTERSECTION: str = "intersection"
LABEL_NORMAL: str = "normal"
class DataCollector:
"""二値画像をラベル付きで保存するコレクタ
Attributes:
raw_dir: 未確定データの保存先ルートディレクトリ
session_dir: 現在の録画セッションのディレクトリ
is_recording: 録画中かどうか
"""
def __init__(
self,
raw_dir: Path = RAW_DIR,
) -> None:
self._raw_dir = raw_dir
self._session_dir: Path | None = None
self._is_recording: bool = False
self._count_intersection: int = 0
self._count_normal: int = 0
@property
def is_recording(self) -> bool:
"""録画中かどうかを返す"""
return self._is_recording
@property
def count_intersection(self) -> int:
"""現在のセッションで保存した intersection 画像の枚数"""
return self._count_intersection
@property
def count_normal(self) -> int:
"""現在のセッションで保存した normal 画像の枚数"""
return self._count_normal
def start(self) -> Path:
"""録画を開始する
タイムスタンプ付きのセッションディレクトリを作成し,
その中に intersection/ と normal/ を用意する
Returns:
作成したセッションディレクトリのパス
"""
timestamp = datetime.now().strftime("%Y%m%d_%H%M%S")
self._session_dir = self._raw_dir / timestamp
(self._session_dir / LABEL_INTERSECTION).mkdir(
parents=True, exist_ok=True,
)
(self._session_dir / LABEL_NORMAL).mkdir(
parents=True, exist_ok=True,
)
self._count_intersection = 0
self._count_normal = 0
self._is_recording = True
return self._session_dir
def stop(self) -> None:
"""録画を停止する"""
self._is_recording = False
def save(
self,
binary_image: np.ndarray,
is_intersection: bool,
) -> None:
"""二値画像をラベル付きで保存する
Args:
binary_image: 保存する二値画像(0/255)
is_intersection: True なら intersection,False なら normal
"""
if not self._is_recording or self._session_dir is None:
return
if is_intersection:
label = LABEL_INTERSECTION
self._count_intersection += 1
idx = self._count_intersection
else:
label = LABEL_NORMAL
self._count_normal += 1
idx = self._count_normal
path = (
self._session_dir / label / f"{idx:06d}.png"
)
cv2.imwrite(str(path), binary_image)