diff --git a/.gitignore b/.gitignore index b4966c7..1aad68b 100644 --- a/.gitignore +++ b/.gitignore @@ -8,6 +8,9 @@ pd_params.json params/ +# 学習データ +/data/ + # 旧コード(参照用,Git 管理外) src_old/ diff --git "a/docs/04_ENV/ENV_04_\343\203\207\343\202\243\343\203\254\343\202\257\343\203\210\343\203\252\346\247\213\346\210\220.txt" "b/docs/04_ENV/ENV_04_\343\203\207\343\202\243\343\203\254\343\202\257\343\203\210\343\203\252\346\247\213\346\210\220.txt" index 178535d..aef2aba 100644 --- "a/docs/04_ENV/ENV_04_\343\203\207\343\202\243\343\203\254\343\202\257\343\203\210\343\203\252\346\247\213\346\210\220.txt" +++ "b/docs/04_ENV/ENV_04_\343\203\207\343\202\243\343\203\254\343\202\257\343\203\210\343\203\252\346\247\213\346\210\220.txt" @@ -54,6 +54,8 @@ │ └── main_window.py メインウィンドウ ├── comm/ 通信関連 │ └── zmq_client.py ZMQ 送受信 + ├── data/ 学習データ収集 + │ └── collector.py 二値画像のラベル付き保存 ├── steering/ 操舵量計算(独立モジュール) │ ├── base.py 共通インターフェース │ ├── pd_control.py PD 制御の実装 diff --git a/src/pc/data/__init__.py b/src/pc/data/__init__.py new file mode 100644 index 0000000..bdec126 --- /dev/null +++ b/src/pc/data/__init__.py @@ -0,0 +1,8 @@ +""" +data +学習データの収集・管理 +""" + +from pc.data.collector import DataCollector + +__all__ = ["DataCollector"] diff --git a/src/pc/data/collector.py b/src/pc/data/collector.py new file mode 100644 index 0000000..eb48bbc --- /dev/null +++ b/src/pc/data/collector.py @@ -0,0 +1,112 @@ +""" +collector +二値画像をラベル別ディレクトリに保存するデータ収集モジュール + +録画中にフレームごとに save() を呼び出すと, +指定ラベルのサブディレクトリへ連番 PNG として保存する +""" + +from datetime import datetime +from pathlib import Path + +import cv2 +import numpy as np + +# デフォルトの保存先(プロジェクトルート / data) +_PROJECT_ROOT: Path = ( + Path(__file__).resolve().parent.parent.parent.parent +) +DEFAULT_DATA_DIR: Path = _PROJECT_ROOT / "data" + +# ラベル名 +LABEL_INTERSECTION: str = "intersection" +LABEL_NORMAL: str = "normal" + + +class DataCollector: + """二値画像をラベル付きで保存するコレクタ + + Attributes: + data_dir: 保存先ルートディレクトリ + session_dir: 現在の録画セッションのディレクトリ + is_recording: 録画中かどうか + """ + + def __init__( + self, + data_dir: Path = DEFAULT_DATA_DIR, + ) -> None: + self._data_dir = data_dir + self._session_dir: Path | None = None + self._is_recording: bool = False + self._count_intersection: int = 0 + self._count_normal: int = 0 + + @property + def is_recording(self) -> bool: + """録画中かどうかを返す""" + return self._is_recording + + @property + def count_intersection(self) -> int: + """現在のセッションで保存した intersection 画像の枚数""" + return self._count_intersection + + @property + def count_normal(self) -> int: + """現在のセッションで保存した normal 画像の枚数""" + return self._count_normal + + def start(self) -> Path: + """録画を開始する + + タイムスタンプ付きのセッションディレクトリを作成し, + その中に intersection/ と normal/ を用意する + + Returns: + 作成したセッションディレクトリのパス + """ + timestamp = datetime.now().strftime("%Y%m%d_%H%M%S") + self._session_dir = self._data_dir / timestamp + (self._session_dir / LABEL_INTERSECTION).mkdir( + parents=True, exist_ok=True, + ) + (self._session_dir / LABEL_NORMAL).mkdir( + parents=True, exist_ok=True, + ) + self._count_intersection = 0 + self._count_normal = 0 + self._is_recording = True + return self._session_dir + + def stop(self) -> None: + """録画を停止する""" + self._is_recording = False + + def save( + self, + binary_image: np.ndarray, + is_intersection: bool, + ) -> None: + """二値画像をラベル付きで保存する + + Args: + binary_image: 保存する二値画像(0/255) + is_intersection: True なら intersection,False なら normal + """ + if not self._is_recording or self._session_dir is None: + return + + if is_intersection: + label = LABEL_INTERSECTION + self._count_intersection += 1 + idx = self._count_intersection + else: + label = LABEL_NORMAL + self._count_normal += 1 + idx = self._count_normal + + path = ( + self._session_dir / label / f"{idx:06d}.png" + ) + cv2.imwrite(str(path), binary_image) diff --git a/src/pc/gui/main_window.py b/src/pc/gui/main_window.py index ec7c5cd..51ba5df 100644 --- a/src/pc/gui/main_window.py +++ b/src/pc/gui/main_window.py @@ -20,6 +20,7 @@ from common import config from pc.comm.zmq_client import PcZmqClient +from pc.data.collector import DataCollector from pc.gui.panels import ( ControlParamPanel, ImageParamPanel, @@ -128,6 +129,10 @@ None ) + # データ収集 + self._collector = DataCollector() + self._is_intersection_key: bool = False + self._setup_ui() self._setup_timers() @@ -255,6 +260,25 @@ ) control_layout.addWidget(self._recovery_panel) + # データ収集ボタン + self._record_btn = QPushButton("録画開始") + self._record_btn.setEnabled(False) + self._record_btn.clicked.connect( + self._toggle_recording, + ) + control_layout.addWidget(self._record_btn) + + # 録画ステータス表示 + self._record_label = QLabel("") + self._record_label.setAlignment( + Qt.AlignmentFlag.AlignLeft, + ) + self._record_label.setStyleSheet( + "font-size: 12px; font-family: monospace;" + " color: #f80; padding: 2px;" + ) + control_layout.addWidget(self._record_label) + # デバッグ表示パネル overlay_flags = load_overlay() self._overlay_panel = OverlayPanel(overlay_flags) @@ -269,7 +293,8 @@ "W/↑: 前進 S/↓: 後退\n" "A/←: 左 D/→: 右\n" "Space: 停止\n" - "Q: 自動操縦 切替" + "Q: 自動操縦 切替\n" + "I: 十字路ラベル(録画中,押下中)" ) guide.setAlignment(Qt.AlignmentFlag.AlignLeft) guide.setStyleSheet("font-size: 12px; color: #666;") @@ -372,6 +397,7 @@ self._is_connected = True self._connect_btn.setText("切断") self._auto_btn.setEnabled(True) + self._record_btn.setEnabled(True) self._status_label.setText("接続中 (手動操作)") self._frame_timer.start(FRAME_INTERVAL_MS) self._control_timer.start(CONTROL_INTERVAL_MS) @@ -380,12 +406,15 @@ """ZMQ 通信を停止する""" self._frame_timer.stop() self._control_timer.stop() + if self._collector.is_recording: + self._stop_recording() if self._is_auto: self._is_auto = False self._auto_btn.setText("自動操縦 ON") self._zmq_client.stop() self._is_connected = False self._auto_btn.setEnabled(False) + self._record_btn.setEnabled(False) self._pressed_keys.clear() self._throttle = 0.0 self._steer = 0.0 @@ -424,6 +453,50 @@ self._status_label.setText("接続中 (手動操作)") self._update_control_label() + # ── データ収集 ──────────────────────────────────────── + + def _toggle_recording(self) -> None: + """録画の開始 / 停止を切り替える""" + if self._collector.is_recording: + self._stop_recording() + else: + self._start_recording() + + def _start_recording(self) -> None: + """録画を開始する""" + session_dir = self._collector.start() + self._record_btn.setText("録画停止") + self._record_label.setText( + f"録画中: {session_dir.name}\n" + "intersection: 0 normal: 0" + ) + + def _stop_recording(self) -> None: + """録画を停止する""" + n_int = self._collector.count_intersection + n_norm = self._collector.count_normal + self._collector.stop() + self._is_intersection_key = False + self._record_btn.setText("録画開始") + self._record_label.setText( + f"録画完了: intersection={n_int}" + f" normal={n_norm}" + ) + + def _update_record_label(self) -> None: + """録画ステータスの枚数表示を更新する""" + if not self._collector.is_recording: + return + label = ( + "I押下中" if self._is_intersection_key + else "通常" + ) + self._record_label.setText( + f"録画中 [{label}]\n" + f"intersection: {self._collector.count_intersection}" + f" normal: {self._collector.count_normal}" + ) + # ── 映像更新 ────────────────────────────────────────── def _update_frame(self) -> None: @@ -464,6 +537,16 @@ self._pd_control.image_params, ) + # データ収集: 二値画像を保存 + if self._collector.is_recording: + det = self._last_detect_result + if det is not None and det.binary_image is not None: + self._collector.save( + det.binary_image, + self._is_intersection_key, + ) + self._update_record_label() + self._display_frame(frame) def _update_detect_info_label(self) -> None: @@ -620,6 +703,12 @@ self._toggle_auto() return + # I キーで十字路ラベル(録画中のみ) + if event.key() == Qt.Key.Key_I: + if self._collector.is_recording: + self._is_intersection_key = True + return + # 自動操縦中はキー操作を無視 if self._is_auto: return @@ -629,7 +718,15 @@ def keyReleaseEvent(self, event: QKeyEvent) -> None: """キー離上時に操舵量を更新する""" - if event.isAutoRepeat() or self._is_auto: + if event.isAutoRepeat(): + return + + # I キーの離上 + if event.key() == Qt.Key.Key_I: + self._is_intersection_key = False + return + + if self._is_auto: return self._pressed_keys.discard(event.key()) self._update_manual_control()