diff --git "a/docs/03_TECH/TECH_01_\346\223\215\350\210\265\351\207\217\350\250\210\347\256\227\344\273\225\346\247\230.txt" "b/docs/03_TECH/TECH_01_\346\223\215\350\210\265\351\207\217\350\250\210\347\256\227\344\273\225\346\247\230.txt" index d3d77e5..439af03 100644 --- "a/docs/03_TECH/TECH_01_\346\223\215\350\210\265\351\207\217\350\250\210\347\256\227\344\273\225\346\247\230.txt" +++ "b/docs/03_TECH/TECH_01_\346\223\215\350\210\265\351\207\217\350\250\210\347\256\227\344\273\225\346\247\230.txt" @@ -197,6 +197,7 @@ 3. 案B: 二重正規化型 ・背景ブラーカーネルサイズ(bg_blur_ksize): 101 + ・固定閾値(global_thresh): 0(0 で無効,適応的閾値との AND) ・適応的閾値ブロックサイズ(adaptive_block): 51 ・適応的閾値定数 C(adaptive_c): 10 ・等方クロージングサイズ(iso_close_size): 15 @@ -230,3 +231,52 @@ タイトル・メモ付きで JSON ファイルに保存できる. ・GUI のコンボボックスで保存済みパラメータを選択・読み込み可能. ・保存ファイル: pd_params.json(.gitignore に登録済み) + + +7. 2点パシュート制御 (Two-Point Pursuit Control) +------------------------------------------------------------------------ + + 7-1. 概要 + + PD 制御の代替として,多項式曲線上の近い点と遠い点の2箇所を + 目標点として操舵量を計算する方式.微分・曲率の計算が不要で, + パラメータ調整が直感的である.GUI でPD 制御と切り替えて使用可能 + + 7-2. アルゴリズム + + (1) 二値画像の最大連結領域を抽出する + (2) 各行の左右端から中心 x 座標(row_centers)を求める + (3) 近い行と遠い行の y 座標を決定する + near_y = 画像高さ × near_ratio(手前側) + far_y = 画像高さ × far_ratio(奥側) + (4) row_centers から直接 x 座標を取得する + (該当行が NaN の場合は近傍の有効な行を探索) + (5) 各点の偏差を計算する + near_err = (画像中心x - near_x) / 画像中心x + far_err = (画像中心x - far_x) / 画像中心x + (4) 操舵量を算出する + steer = K_near × near_err + K_far × far_err + (5) レートリミッターで急な操舵変化を制限する + (6) 速度を算出する + curve = |near_x - far_x| / 画像中心x + throttle = max_throttle - speed_k × curve + + 7-3. PD 制御との比較 + + ・PD 制御: 画像下端の1点で微分・曲率を計算 → 不安定になりやすい + ・2点パシュート: 2点の位置だけで判断 → 直感的で安定 + + 7-4. パラメータ一覧(GUI で調整可能) + + ・near_ratio(デフォルト: 0.8): 近い目標点の位置(0.0=上端,1.0=下端) + ・far_ratio(デフォルト: 0.3): 遠い目標点の位置 + ・k_near(デフォルト: 0.5): 近い目標点の操舵ゲイン + ・k_far(デフォルト: 0.3): 遠い目標点の操舵ゲイン + ・max_steer_rate(デフォルト: 0.1): 1フレームあたりの最大操舵変化量 + ・max_throttle(デフォルト: 0.4): 直線での最大速度 + ・speed_k(デフォルト: 2.0): カーブ減速係数 + + 7-5. 実装ファイル + + ・src/pc/steering/pursuit_control.py: PursuitControl クラス + ・src/pc/gui/main_window.py: 制御手法の切替 UI diff --git "a/docs/03_TECH/TECH_03_\343\203\207\343\203\220\343\203\203\343\202\260\343\202\252\343\203\274\343\203\220\343\203\274\343\203\254\343\202\244\344\273\225\346\247\230.txt" "b/docs/03_TECH/TECH_03_\343\203\207\343\203\220\343\203\203\343\202\260\343\202\252\343\203\274\343\203\220\343\203\274\343\203\254\343\202\244\344\273\225\346\247\230.txt" index 3793ab1..4117323 100644 --- "a/docs/03_TECH/TECH_03_\343\203\207\343\203\220\343\203\203\343\202\260\343\202\252\343\203\274\343\203\220\343\203\274\343\203\254\343\202\244\344\273\225\346\247\230.txt" +++ "b/docs/03_TECH/TECH_03_\343\203\207\343\203\220\343\203\203\343\202\260\343\202\252\343\203\274\343\203\220\343\203\274\343\203\254\343\202\244\344\273\225\346\247\230.txt" @@ -15,27 +15,30 @@ 1-1. 基本方針 ・オーバーレイはカメラ映像に重ねて描画する. - ・手動操作中でもオーバーレイが有効なら線検出を実行する. + ・線検出は接続中は常に実行し,検出情報は映像下のラベルに表示する. ・自動操縦中は操舵量計算で実行済みの検出結果を再利用する. 2. 表示項目 (Overlay Items) ------------------------------------------------------------------------ - 2-1. 一覧 + 2-1. 画像オーバーレイ(チェックボックスで切替) ・二値化画像: 二値化結果を赤色の半透明で重ねる(不透明度 0.4) ・検出領域: 検出対象領域の枠を青色で表示 ・フィッティング曲線: 多項式の曲線を緑色で描画 ・中心線: 画像の中心 x に垂直線を描画(黄色) - ・検出情報: 位置偏差・傾き・曲率の数値を画像左上に表示 - 2-2. 描画色 (BGR) + 2-2. 検出情報ラベル(常時表示) + + 映像の下に配置されたラベルに,位置偏差・傾き・曲率の数値を + テキストで表示する.線が未検出の場合は「---」を表示する. + + 2-3. 描画色 (BGR) ・フィッティング曲線: (0, 255, 0) 緑 ・中心線: (0, 255, 255) 黄 ・検出領域: (255, 0, 0) 青 - ・テキスト: (255, 255, 255) 白 ・二値化オーバーレイ: 赤チャンネルに二値化画像を割り当て @@ -50,6 +53,6 @@ 3-2. 動作モードとの関係 - ・手動操作中: オーバーレイが 1 つでも ON なら線検出を実行する + ・手動操作中: 線検出を常に実行し,検出情報ラベルを更新する ・自動操縦中: 操舵量計算の線検出結果をそのまま使用する ・未接続時: オーバーレイは表示されない(映像がないため) diff --git "a/docs/03_TECH/TECH_04_\347\267\232\346\244\234\345\207\272\347\262\276\345\272\246\345\220\221\344\270\212\346\226\271\351\207\235.txt" "b/docs/03_TECH/TECH_04_\347\267\232\346\244\234\345\207\272\347\262\276\345\272\246\345\220\221\344\270\212\346\226\271\351\207\235.txt" index afea42c..f715361 100644 --- "a/docs/03_TECH/TECH_04_\347\267\232\346\244\234\345\207\272\347\262\276\345\272\246\345\220\221\344\270\212\346\226\271\351\207\235.txt" +++ "b/docs/03_TECH/TECH_04_\347\267\232\346\244\234\345\207\272\347\262\276\345\272\246\345\220\221\344\270\212\346\226\271\351\207\235.txt" @@ -591,6 +591,31 @@ ・備考: S 字カーブ等,2次多項式では表現できない 複雑な形状に対応できる.ただし過学習のリスクがある + 7-7. ロバストフィッティング前処理(全手法共通) + + 行ごと中心抽出の後,多項式フィッティングの前に適用する + 外れ値除去パイプライン.全検出手法(案A〜D)で共通に使用する + + (1) 移動メディアンフィルタ(median_ksize) + 中心点列の x 座標に1次元メディアンフィルタを適用し, + スパイク状の外れ値を平滑化する + (2) 近傍外れ値除去(neighbor_thresh) + 各点の x 座標を近傍(前後3行)の中央値と比較し, + 閾値を超えて逸脱する点を除去する. + グローバルなモデルに依存しないローカルな判定 + (3) 重み付き最小二乗 + 案Dでは谷の深度(コントラスト)を重みとして使用し, + 深度が浅い(=信頼度の低い)点の影響を低減する + + ・パラメータ: + - median_ksize: メディアンフィルタのカーネルサイズ + (デフォルト: 7,0 で無効) + - neighbor_thresh: 近傍除去の閾値 px + (デフォルト: 10.0,0 で無効) + ・実装: _clean_and_fit() 関数(line_detector.py) + ・備考: RANSAC と併用可能.RANSAC が有効な場合は + メディアン → 近傍除去 → RANSAC の順に適用される + 8. Stage 0: 撮影条件の最適化 (Camera Settings) ------------------------------------------------------------------------ @@ -785,3 +810,89 @@ ・評価環境: 通常照明,強照明,照明ムラありの3条件を推奨 ・計測項目: 周回数,コースアウト回数,position_error の分散 + + +11. 案D: 谷検出+追跡型 (Valley Detection + Tracking) +------------------------------------------------------------------------ + + 11-1. 概要 + + 案A〜Cはいずれも二値化(固定閾値または適応的閾値)を経由するが, + 光の反射や影によって二値化の閾値設定が困難になる場面がある. + 案Dは二値化を完全に排除し,各行の輝度信号から直接「谷」(暗い + 領域)を検出することで,照明変動に対する根本的な耐性を得る. + + さらに,時系列追跡(トラッキング)により検出の安定性を確保し, + 一時的な検出失敗にも対応できる構成とする. + + 11-2. アルゴリズム + + ■ 谷検出(行ごと,フレーム単位) + + (1) 画像全体を水平方向にガウシアンブラーする(カーネルサイズ: + valley_gauss_ksize) + (2) 各行の輝度信号から極小値(谷)を検出する + (3) 各谷について以下を計算する: + - 谷の中心 x 座標: 左右の肩の中点 + - 谷の深度: 左右の肩の平均輝度 - 谷底の輝度 + - 谷の幅: 左の肩から右の肩までのピクセル数 + (4) 以下のフィルタで不正な谷を除去する: + - 最小深度フィルタ(valley_min_depth) + - 透視補正付き幅フィルタ(width_near, width_far, width_tolerance) + - 予測偏差フィルタ(valley_max_deviation,後述の追跡と連動) + (5) 残った谷のうち最もスコアの高いものを採用する + スコア = 深度 + 予測位置への近さ + + ■ 時系列追跡(フレーム間) + + (1) 前フレームの多項式係数から各行の x 座標を予測する + (2) 予測から大きく外れた谷候補を棄却する(予測偏差フィルタ) + (3) 検出成功時: 多項式係数を指数移動平均(EMA)で平滑化する + smoothed = alpha × current + (1 - alpha) × prev + (4) 検出失敗時: 前フレームの予測で補間する(コースティング) + valley_coast_frames フレームまで継続し,超えると未検出を返す + + 11-3. 案A〜Cとの違い + + ・二値化が不要: 輝度の相対的な谷を検出するため,閾値設定の問題 + が根本的に解消される + ・時系列情報を活用: 案A〜Cは各フレームを独立に処理するが, + 案Dは前フレームの結果を予測・平滑化に活用する + ・検出失敗への耐性: コースティングにより数フレームの検出失敗を + 許容できる(案A〜Cは即座に未検出となる) + + 11-4. パラメータ一覧 + + ・valley_gauss_ksize (デフォルト: 15) + 行ごとのガウシアン平滑化カーネルサイズ. + 大きくするとノイズに強くなるが,細い線を検出しにくくなる + ・valley_min_depth (デフォルト: 15) + 谷として認識する最小深度(輝度差). + 小さくすると感度が上がるが誤検出も増える + ・valley_max_deviation (デフォルト: 40) + 追跡予測からの最大許容偏差(px). + 小さくすると追跡が安定するが急カーブへの追従が遅れる + ・valley_coast_frames (デフォルト: 3) + 検出失敗時に予測を継続するフレーム数. + 大きくすると一時的な見失いに強くなるが, + 誤った予測で走行し続けるリスクも増える + ・valley_ema_alpha (デフォルト: 0.7) + 多項式係数の EMA 係数.1.0 で平滑化なし, + 小さくすると安定するがラグが増える + ・ransac_thresh / ransac_iter(案Cと共通) + 有効にすると谷検出結果に RANSAC を適用する + ・width_near / width_far / width_tolerance(共通) + 透視補正付き幅フィルタ.谷の幅の検証に使用する + + 11-5. 実装ファイル + + ・src/pc/vision/line_detector.py + - ValleyTracker クラス: 時系列追跡の状態管理 + - _find_row_valley(): 1行の谷検出 + - _detect_valley(): 案Dのメイン処理 + - _build_valley_binary(): デバッグ用二値画像生成 + - reset_valley_tracker(): 追跡状態のリセット + ・src/pc/gui/main_window.py + - 案D用パラメータの GUI コントロール + ・src/pc/steering/pd_control.py + - reset() 時に追跡状態もリセット diff --git "a/docs/04_ENV/ENV_04_\343\203\207\343\202\243\343\203\254\343\202\257\343\203\210\343\203\252\346\247\213\346\210\220.txt" "b/docs/04_ENV/ENV_04_\343\203\207\343\202\243\343\203\254\343\202\257\343\203\210\343\203\252\346\247\213\346\210\220.txt" index 1aa7eaf..9f72519 100644 --- "a/docs/04_ENV/ENV_04_\343\203\207\343\202\243\343\203\254\343\202\257\343\203\210\343\203\252\346\247\213\346\210\220.txt" +++ "b/docs/04_ENV/ENV_04_\343\203\207\343\202\243\343\203\254\343\202\257\343\203\210\343\203\252\346\247\213\346\210\220.txt" @@ -39,10 +39,12 @@ 2-2. src/common/ common/ - └── config.py + ├── config.py + └── json_utils.py JSON 読み書き共通ユーティリティ - ・PC・Pi 間で共有する設定値を定義する. - ・内容: ネットワーク設定,画像フォーマット,通信設定等. + ・PC・Pi 間で共有する設定値・ユーティリティを定義する. + ・config.py: ネットワーク設定,画像フォーマット,通信設定等. + ・json_utils.py: JSON ファイル読み書きとパラメータディレクトリの定義. 2-3. src/pc/ @@ -55,10 +57,20 @@ ├── steering/ 操舵量計算(独立モジュール) │ ├── base.py 共通インターフェース │ ├── pd_control.py PD 制御の実装 - │ └── param_store.py パラメータ保存・読み込み + │ ├── pursuit_control.py 2点パシュート制御の実装 + │ ├── param_store.py プリセット保存・読み込み + │ └── auto_params.py パラメータ自動保存・復元 └── vision/ 画像処理 - ├── line_detector.py 線検出(多項式フィッティング) - └── overlay.py デバッグオーバーレイ描画 + ├── line_detector.py 線検出 API(データクラス・手法ディスパッチ) + ├── fitting.py 直線・曲線近似(Theil-Sen・RANSAC・外れ値除去) + ├── morphology.py 形態学的処理ユーティリティ + ├── overlay.py デバッグオーバーレイ描画 + └── detectors/ 検出手法の実装 + ├── current.py 現行(CLAHE + 固定閾値) + ├── blackhat.py 案A(Black-hat 中心) + ├── dual_norm.py 案B(二重正規化) + ├── robust.py 案C(最高ロバスト) + └── valley.py 案D(谷検出+追跡) 2-4. src/pi/ diff --git a/pytest.ini b/pytest.ini new file mode 100644 index 0000000..442bb94 --- /dev/null +++ b/pytest.ini @@ -0,0 +1,3 @@ +[pytest] +pythonpath = src +testpaths = tests diff --git a/requirements_pc.txt b/requirements_pc.txt index 506720a..89bfd92 100644 --- a/requirements_pc.txt +++ b/requirements_pc.txt @@ -3,3 +3,4 @@ pyzmq==27.1.0 numpy==2.4.3 python-dotenv==1.2.2 +pytest==9.0.2 diff --git a/src/common/config.py b/src/common/config.py index 8973b3b..4a9e106 100644 --- a/src/common/config.py +++ b/src/common/config.py @@ -31,11 +31,17 @@ # ── 画像設定 ────────────────────────────────────────────── -# カメラ画像の幅 (px) -FRAME_WIDTH: int = 320 +# カメラ撮影時の幅 (px) +CAPTURE_WIDTH: int = 320 -# カメラ画像の高さ (px) -FRAME_HEIGHT: int = 240 +# カメラ撮影時の高さ (px) +CAPTURE_HEIGHT: int = 240 + +# 処理・送信時の幅 (px)(撮影後に縮小) +FRAME_WIDTH: int = 40 + +# 処理・送信時の高さ (px)(撮影後に縮小) +FRAME_HEIGHT: int = 30 # JPEG 圧縮品質 (0-100) JPEG_QUALITY: int = 55 diff --git a/src/common/json_utils.py b/src/common/json_utils.py new file mode 100644 index 0000000..0ce705a --- /dev/null +++ b/src/common/json_utils.py @@ -0,0 +1,41 @@ +""" +json_utils +JSON ファイルの読み書きとパラメータディレクトリの共通定義 +""" + +import json +from pathlib import Path + +# プロジェクトルートの params/ ディレクトリ +PARAMS_DIR: Path = ( + Path(__file__).resolve().parent.parent.parent / "params" +) + + +def read_json(path: Path) -> dict: + """JSON ファイルを読み込む + + Args: + path: 読み込む JSON ファイルのパス + + Returns: + 読み込んだデータの辞書 + """ + with open(path, "r", encoding="utf-8") as f: + return json.load(f) + + +def write_json(path: Path, data: dict | list) -> None: + """JSON ファイルに書き込む + + 親ディレクトリが存在しない場合は自動作成する + + Args: + path: 書き込み先の JSON ファイルのパス + data: 書き込むデータ + """ + path.parent.mkdir(parents=True, exist_ok=True) + with open(path, "w", encoding="utf-8") as f: + json.dump( + data, f, ensure_ascii=False, indent=2, + ) diff --git a/src/pc/gui/main_window.py b/src/pc/gui/main_window.py index 10101f4..08de403 100644 --- a/src/pc/gui/main_window.py +++ b/src/pc/gui/main_window.py @@ -9,49 +9,39 @@ from PySide6.QtCore import Qt, QTimer from PySide6.QtGui import QImage, QKeyEvent, QPixmap from PySide6.QtWidgets import ( - QCheckBox, - QComboBox, - QDoubleSpinBox, - QFormLayout, - QGroupBox, QHBoxLayout, - QInputDialog, QLabel, QMainWindow, - QMessageBox, QPushButton, QScrollArea, - QSpinBox, QVBoxLayout, QWidget, ) from common import config from pc.comm.zmq_client import PcZmqClient +from pc.gui.panels import ( + ControlParamPanel, + ImageParamPanel, + OverlayPanel, +) from pc.steering.auto_params import ( load_control, load_detect_params, save_control, - save_detect_params, ) +from pc.steering.base import SteeringBase from pc.steering.pd_control import PdControl, PdParams -from pc.steering.param_store import ( - ImagePreset, - PdPreset, - add_image_preset, - add_pd_preset, - delete_image_preset, - delete_pd_preset, - load_image_presets, - load_pd_presets, +from pc.steering.pursuit_control import ( + PursuitControl, + PursuitParams, ) from pc.vision.line_detector import ( - DETECT_METHODS, ImageParams, LineDetectResult, detect_line, ) -from pc.vision.overlay import OverlayFlags, draw_overlay +from pc.vision.overlay import draw_overlay # 映像更新間隔 (ms) FRAME_INTERVAL_MS: int = 33 @@ -61,8 +51,8 @@ 1000 / config.CONTROL_PUBLISH_HZ ) -# 映像表示のスケール倍率 -DISPLAY_SCALE: float = 2.0 +# 映像表示のスケール倍率(40x30 → 640x480 相当) +DISPLAY_SCALE: float = 16.0 # 手動操作の throttle / steer 量 MANUAL_THROTTLE: float = 0.5 @@ -83,9 +73,6 @@ self._throttle: float = 0.0 self._steer: float = 0.0 - # 自動保存の制御フラグ - self._auto_save_enabled = False - # 前回のパラメータを復元 pd_params, last_method = load_control() image_params = load_detect_params(last_method) @@ -93,28 +80,35 @@ params=pd_params, image_params=image_params, ) + self._pursuit_control = PursuitControl( + image_params=image_params, + ) + + # 現在の制御手法("pd" or "pursuit") + self._steering_method: str = "pd" # 最新フレームの保持(自動操縦で使用) self._latest_frame: np.ndarray | None = None - # オーバーレイ - self._overlay_flags = OverlayFlags() - self._last_detect_result: LineDetectResult | None = None + # 検出結果の保持 + self._last_detect_result: LineDetectResult | None = ( + None + ) self._setup_ui() self._setup_timers() - self._auto_save_enabled = True def _setup_ui(self) -> None: """UI を構築する""" self.setWindowTitle("RobotCar Controller") - # 中央ウィジェット central = QWidget() self.setCentralWidget(central) root_layout = QHBoxLayout(central) - # 左側: 映像表示 + # 左側: 映像表示 + 検出情報 + left_layout = QVBoxLayout() + self._video_label = QLabel("カメラ映像待機中...") self._video_label.setAlignment( Qt.AlignmentFlag.AlignCenter, @@ -127,7 +121,21 @@ "background-color: #222;" " color: #aaa; font-size: 16px;" ) - root_layout.addWidget(self._video_label, stretch=3) + left_layout.addWidget(self._video_label) + + self._detect_info_label = QLabel( + "pos: --- head: --- curv: ---" + ) + self._detect_info_label.setAlignment( + Qt.AlignmentFlag.AlignLeft, + ) + self._detect_info_label.setStyleSheet( + "font-size: 14px; font-family: monospace;" + " color: #0f0; background-color: #222;" + " padding: 4px;" + ) + left_layout.addWidget(self._detect_info_label) + root_layout.addLayout(left_layout, stretch=3) # 右側: スクロール可能なコントロールパネル scroll = QScrollArea() @@ -150,9 +158,7 @@ # 自動操縦ボタン self._auto_btn = QPushButton("自動操縦 ON") self._auto_btn.setEnabled(False) - self._auto_btn.clicked.connect( - self._toggle_auto, - ) + self._auto_btn.clicked.connect(self._toggle_auto) control_layout.addWidget(self._auto_btn) # ステータス表示 @@ -169,17 +175,42 @@ self._control_label.setAlignment( Qt.AlignmentFlag.AlignCenter, ) - self._control_label.setStyleSheet("font-size: 14px;") + self._control_label.setStyleSheet( + "font-size: 14px;", + ) control_layout.addWidget(self._control_label) - # 画像処理パラメータ調整(Stage 1〜4) - self._setup_image_param_ui(control_layout) + # 画像処理パラメータパネル + self._image_panel = ImageParamPanel( + self._pd_control.image_params, + ) + self._image_panel.image_params_changed.connect( + self._on_image_params_changed, + ) + self._image_panel.method_changed.connect( + self._on_method_changed, + ) + control_layout.addWidget(self._image_panel) - # PD 制御パラメータ(操舵量計算) - self._setup_param_ui(control_layout) + # 制御パラメータパネル + self._control_panel = ControlParamPanel( + self._pd_control.params, + self._pursuit_control.params, + ) + self._control_panel.pd_params_changed.connect( + self._on_pd_params_changed, + ) + self._control_panel.pursuit_params_changed.connect( + self._on_pursuit_params_changed, + ) + self._control_panel.steering_method_changed.connect( + self._on_steering_method_changed, + ) + control_layout.addWidget(self._control_panel) - # デバッグ表示 - self._setup_overlay_ui(control_layout) + # デバッグ表示パネル + self._overlay_panel = OverlayPanel() + control_layout.addWidget(self._overlay_panel) # 操作ガイド guide = QLabel( @@ -193,728 +224,57 @@ guide.setStyleSheet("font-size: 12px; color: #666;") control_layout.addWidget(guide) - # 余白を下に詰める control_layout.addStretch() - def _setup_param_ui( - self, parent_layout: QVBoxLayout, - ) -> None: - """PD パラメータ調整 UI を構築する""" - group = QGroupBox("PD 制御パラメータ") - layout = QVBoxLayout() - group.setLayout(layout) - - form = QFormLayout() - layout.addLayout(form) - - params = self._pd_control.params - - self._spin_kp = self._create_spin( - params.kp, 0.0, 5.0, 0.05, - ) - form.addRow("Kp (位置):", self._spin_kp) - - self._spin_kh = self._create_spin( - params.kh, 0.0, 5.0, 0.05, - ) - form.addRow("Kh (傾き):", self._spin_kh) - - self._spin_kd = self._create_spin( - params.kd, 0.0, 5.0, 0.05, - ) - form.addRow("Kd (微分):", self._spin_kd) - - self._spin_max_steer_rate = self._create_spin( - params.max_steer_rate, 0.01, 1.0, 0.01, - ) - form.addRow("操舵制限:", self._spin_max_steer_rate) - - self._spin_max_throttle = self._create_spin( - params.max_throttle, 0.0, 1.0, 0.05, - ) - form.addRow("最大速度:", self._spin_max_throttle) - - self._spin_speed_k = self._create_spin( - params.speed_k, 0.0, 5.0, 0.05, - ) - form.addRow("減速係数:", self._spin_speed_k) - - # --- プリセット管理 --- - self._pd_preset_combo = QComboBox() - self._pd_preset_combo.setPlaceholderText( - "プリセット", - ) - layout.addWidget(self._pd_preset_combo) - - self._pd_preset_memo = QLabel("") - self._pd_preset_memo.setWordWrap(True) - self._pd_preset_memo.setStyleSheet( - "font-size: 11px; color: #888;", - ) - layout.addWidget(self._pd_preset_memo) - - btn_layout = QHBoxLayout() - load_btn = QPushButton("読込") - load_btn.clicked.connect( - self._on_load_pd_preset, - ) - btn_layout.addWidget(load_btn) - save_btn = QPushButton("保存") - save_btn.clicked.connect( - self._on_save_pd_preset, - ) - btn_layout.addWidget(save_btn) - del_btn = QPushButton("削除") - del_btn.clicked.connect( - self._on_delete_pd_preset, - ) - btn_layout.addWidget(del_btn) - layout.addLayout(btn_layout) - - # コールバック接続 - for spin in [ - self._spin_kp, self._spin_kh, - self._spin_kd, self._spin_max_steer_rate, - self._spin_max_throttle, self._spin_speed_k, - ]: - spin.valueChanged.connect( - self._on_param_changed, - ) - self._pd_preset_combo.currentIndexChanged \ - .connect(self._on_pd_preset_selected) - - self._pd_presets: list[PdPreset] = [] - self._refresh_pd_presets() - - parent_layout.addWidget(group) - - def _setup_image_param_ui( - self, parent_layout: QVBoxLayout, - ) -> None: - """画像処理パラメータ調整 UI を構築する""" - group = QGroupBox("画像処理パラメータ") - layout = QVBoxLayout() - group.setLayout(layout) - - ip = self._pd_control.image_params - - # 検出手法の選択コンボボックス - self._method_combo = QComboBox() - for key, label in DETECT_METHODS.items(): - self._method_combo.addItem(label, key) - idx = self._method_combo.findData(ip.method) - if idx >= 0: - self._method_combo.setCurrentIndex(idx) - layout.addWidget(self._method_combo) - - # パラメータフォーム - self._image_form = QFormLayout() - layout.addLayout(self._image_form) - - # 各パラメータの可視性マッピング - # (widget, 表示する手法の集合) - self._image_param_vis: list[ - tuple[QWidget, set[str]] - ] = [] - - # --- 現行手法パラメータ --- - self._spin_clahe_clip = self._create_spin( - ip.clahe_clip, 0.5, 10.0, 0.5, - ) - self._add_image_row( - "CLAHE強度:", self._spin_clahe_clip, - {"current"}, - ) - - self._spin_binary_thresh = QSpinBox() - self._spin_binary_thresh.setRange(10, 200) - self._spin_binary_thresh.setValue( - ip.binary_thresh, - ) - self._add_image_row( - "二値化閾値:", self._spin_binary_thresh, - {"current", "blackhat"}, - ) - - self._spin_open_size = QSpinBox() - self._spin_open_size.setRange(1, 31) - self._spin_open_size.setSingleStep(2) - self._spin_open_size.setValue(ip.open_size) - self._add_image_row( - "ノイズ除去:", self._spin_open_size, - {"current"}, - ) - - self._spin_close_width = QSpinBox() - self._spin_close_width.setRange(1, 51) - self._spin_close_width.setSingleStep(2) - self._spin_close_width.setValue(ip.close_width) - self._add_image_row( - "途切れ補間:", self._spin_close_width, - {"current"}, - ) - - # --- 案A/C: Black-hat --- - self._spin_blackhat_ksize = QSpinBox() - self._spin_blackhat_ksize.setRange(11, 101) - self._spin_blackhat_ksize.setSingleStep(2) - self._spin_blackhat_ksize.setValue( - ip.blackhat_ksize, - ) - self._add_image_row( - "BHカーネル:", self._spin_blackhat_ksize, - {"blackhat", "robust"}, - ) - - # --- 案B: 背景除算 --- - self._spin_bg_blur_ksize = QSpinBox() - self._spin_bg_blur_ksize.setRange(31, 201) - self._spin_bg_blur_ksize.setSingleStep(2) - self._spin_bg_blur_ksize.setValue( - ip.bg_blur_ksize, - ) - self._add_image_row( - "背景ブラー:", self._spin_bg_blur_ksize, - {"dual_norm"}, - ) - - # --- 案B/C: 適応的閾値 --- - self._spin_adaptive_block = QSpinBox() - self._spin_adaptive_block.setRange(11, 101) - self._spin_adaptive_block.setSingleStep(2) - self._spin_adaptive_block.setValue( - ip.adaptive_block, - ) - self._add_image_row( - "適応ブロック:", self._spin_adaptive_block, - {"dual_norm", "robust"}, - ) - - self._spin_adaptive_c = QSpinBox() - self._spin_adaptive_c.setRange(1, 30) - self._spin_adaptive_c.setValue(ip.adaptive_c) - self._add_image_row( - "適応定数C:", self._spin_adaptive_c, - {"dual_norm", "robust"}, - ) - - # --- 案A/B/C: 後処理 --- - self._spin_iso_close = QSpinBox() - self._spin_iso_close.setRange(1, 51) - self._spin_iso_close.setSingleStep(2) - self._spin_iso_close.setValue( - ip.iso_close_size, - ) - self._add_image_row( - "穴埋め:", self._spin_iso_close, - {"blackhat", "dual_norm", "robust"}, - ) - - self._spin_dist_thresh = self._create_spin( - ip.dist_thresh, 0.0, 10.0, 0.5, - ) - self._add_image_row( - "距離閾値:", self._spin_dist_thresh, - {"blackhat", "dual_norm", "robust"}, - ) - - self._spin_min_line_width = QSpinBox() - self._spin_min_line_width.setRange(1, 20) - self._spin_min_line_width.setValue( - ip.min_line_width, - ) - self._add_image_row( - "最小線幅:", self._spin_min_line_width, - {"blackhat", "dual_norm", "robust"}, - ) - - # --- 案C: RANSAC --- - self._spin_ransac_thresh = self._create_spin( - ip.ransac_thresh, 1.0, 30.0, 1.0, - ) - self._add_image_row( - "RANSAC閾値:", self._spin_ransac_thresh, - {"robust"}, - ) - - # --- 幅フィルタ(透視補正) --- - self._spin_width_near = QSpinBox() - self._spin_width_near.setRange(0, 200) - self._spin_width_near.setValue(ip.width_near) - self._spin_width_near.setSpecialValueText("無効") - self._add_image_row( - "線幅(近)px:", self._spin_width_near, - {"blackhat", "dual_norm", "robust"}, - ) - - self._spin_width_far = QSpinBox() - self._spin_width_far.setRange(0, 200) - self._spin_width_far.setValue(ip.width_far) - self._spin_width_far.setSpecialValueText("無効") - self._add_image_row( - "線幅(遠)px:", self._spin_width_far, - {"blackhat", "dual_norm", "robust"}, - ) - - self._spin_width_tolerance = self._create_spin( - ip.width_tolerance, 1.0, 5.0, 0.1, - ) - self._add_image_row( - "幅フィルタ倍率:", self._spin_width_tolerance, - {"blackhat", "dual_norm", "robust"}, - ) - - # --- プリセット管理 --- - self._image_preset_combo = QComboBox() - self._image_preset_combo.setPlaceholderText( - "プリセット", - ) - layout.addWidget(self._image_preset_combo) - - self._image_preset_memo = QLabel("") - self._image_preset_memo.setWordWrap(True) - self._image_preset_memo.setStyleSheet( - "font-size: 11px; color: #888;", - ) - layout.addWidget(self._image_preset_memo) - - btn_layout = QHBoxLayout() - load_btn = QPushButton("読込") - load_btn.clicked.connect( - self._on_load_image_preset, - ) - btn_layout.addWidget(load_btn) - save_btn = QPushButton("保存") - save_btn.clicked.connect( - self._on_save_image_preset, - ) - btn_layout.addWidget(save_btn) - del_btn = QPushButton("削除") - del_btn.clicked.connect( - self._on_delete_image_preset, - ) - btn_layout.addWidget(del_btn) - layout.addLayout(btn_layout) - - # コールバック接続 - self._method_combo.currentIndexChanged.connect( - self._on_method_changed, - ) - for widget, _ in self._image_param_vis: - widget.valueChanged.connect( - self._on_image_param_changed, - ) - self._image_preset_combo.currentIndexChanged \ - .connect(self._on_image_preset_selected) - - self._image_presets: list[ImagePreset] = [] - self._image_filtered: list[int] = [] - - parent_layout.addWidget(group) - - # 初期表示の更新 - self._on_method_changed() - - def _add_image_row( - self, - label: str, - widget: QWidget, - methods: set[str], - ) -> None: - """画像処理パラメータの行を追加する""" - self._image_form.addRow(label, widget) - self._image_param_vis.append( - (widget, methods), - ) - - def _on_method_changed(self) -> None: - """検出手法の変更を反映する""" - method = self._method_combo.currentData() - - # 旧手法のパラメータを保存 - if self._auto_save_enabled: - ip = self._pd_control.image_params - save_detect_params(ip.method, ip) - - # 新手法のパラメータを読み込み - if self._auto_save_enabled: - new_ip = load_detect_params(method) - self._pd_control.image_params = new_ip - self._sync_image_spinboxes() - save_control( - self._pd_control.params, method, - ) - else: - self._pd_control.image_params.method = ( - method - ) - - # パラメータの表示/非表示を更新 - for widget, methods in self._image_param_vis: - visible = method in methods - widget.setVisible(visible) - label = self._image_form.labelForField( - widget, - ) - if label: - label.setVisible(visible) - - # 保存済みプリセットを手法でフィルタ - if hasattr(self, "_image_preset_combo"): - self._refresh_image_presets() - - def _sync_image_spinboxes(self) -> None: - """画像処理パラメータの SpinBox を現在値に同期する""" - self._auto_save_enabled = False - try: - ip = self._pd_control.image_params - self._spin_clahe_clip.setValue( - ip.clahe_clip, - ) - self._spin_binary_thresh.setValue( - ip.binary_thresh, - ) - self._spin_open_size.setValue(ip.open_size) - self._spin_close_width.setValue( - ip.close_width, - ) - self._spin_blackhat_ksize.setValue( - ip.blackhat_ksize, - ) - self._spin_bg_blur_ksize.setValue( - ip.bg_blur_ksize, - ) - self._spin_adaptive_block.setValue( - ip.adaptive_block, - ) - self._spin_adaptive_c.setValue( - ip.adaptive_c, - ) - self._spin_iso_close.setValue( - ip.iso_close_size, - ) - self._spin_dist_thresh.setValue( - ip.dist_thresh, - ) - self._spin_min_line_width.setValue( - ip.min_line_width, - ) - self._spin_ransac_thresh.setValue( - ip.ransac_thresh, - ) - self._spin_width_near.setValue(ip.width_near) - self._spin_width_far.setValue(ip.width_far) - self._spin_width_tolerance.setValue( - ip.width_tolerance, - ) - finally: - self._auto_save_enabled = True - - def _on_image_param_changed(self) -> None: - """画像処理パラメータの変更を反映する""" - ip = self._pd_control.image_params - # 現行手法 - ip.clahe_clip = self._spin_clahe_clip.value() - ip.binary_thresh = ( - self._spin_binary_thresh.value() - ) - ip.open_size = self._spin_open_size.value() - ip.close_width = self._spin_close_width.value() - # 案A/C: Black-hat - ip.blackhat_ksize = ( - self._spin_blackhat_ksize.value() - ) - # 案B: 背景除算 - ip.bg_blur_ksize = ( - self._spin_bg_blur_ksize.value() - ) - # 案B/C: 適応的閾値 - ip.adaptive_block = ( - self._spin_adaptive_block.value() - ) - ip.adaptive_c = self._spin_adaptive_c.value() - # 案A/B/C: 後処理 - ip.iso_close_size = ( - self._spin_iso_close.value() - ) - ip.dist_thresh = ( - self._spin_dist_thresh.value() - ) - ip.min_line_width = ( - self._spin_min_line_width.value() - ) - # 案C: RANSAC - ip.ransac_thresh = ( - self._spin_ransac_thresh.value() - ) - # 幅フィルタ(透視補正) - ip.width_near = self._spin_width_near.value() - ip.width_far = self._spin_width_far.value() - ip.width_tolerance = ( - self._spin_width_tolerance.value() - ) - - if self._auto_save_enabled: - save_detect_params(ip.method, ip) - - # ── 画像処理プリセット ────────────────────────── - - def _refresh_image_presets(self) -> None: - """選択中の手法のプリセットだけ表示する""" - self._image_presets = load_image_presets() - method = self._method_combo.currentData() - self._image_preset_combo.clear() - self._image_filtered = [] - for i, p in enumerate(self._image_presets): - if p.image_params.method == method: - self._image_preset_combo.addItem( - p.title, - ) - self._image_filtered.append(i) - self._image_preset_memo.setText("") - - def _on_image_preset_selected(self) -> None: - """画像処理プリセット選択時にメモを表示""" - idx = self._get_image_preset_idx() - if idx >= 0: - self._image_preset_memo.setText( - self._image_presets[idx].memo, - ) - else: - self._image_preset_memo.setText("") - - def _get_image_preset_idx(self) -> int: - """コンボの選択を全体インデックスに変換""" - ci = self._image_preset_combo.currentIndex() - if ci < 0 or ci >= len(self._image_filtered): - return -1 - return self._image_filtered[ci] - - def _on_load_image_preset(self) -> None: - """画像処理プリセットを読み込む""" - idx = self._get_image_preset_idx() - if idx < 0: - return - self._auto_save_enabled = False - try: - ip = self._image_presets[idx].image_params - self._pd_control.image_params = ip - self._sync_image_spinboxes() - finally: - self._auto_save_enabled = True - save_detect_params(ip.method, ip) - - def _on_save_image_preset(self) -> None: - """画像処理プリセットを保存する""" - title, ok = QInputDialog.getText( - self, "画像処理プリセット保存", "タイトル:", - ) - if not ok or not title.strip(): - return - memo, ok = QInputDialog.getText( - self, "画像処理プリセット保存", "メモ:", - ) - if not ok: - return - ip = self._pd_control.image_params - add_image_preset(ImagePreset( - title=title.strip(), - memo=memo.strip(), - image_params=ImageParams(**{ - f.name: getattr(ip, f.name) - for f in ip.__dataclass_fields__.values() - }), - )) - self._refresh_image_presets() - self._image_preset_combo.setCurrentIndex( - self._image_preset_combo.count() - 1, - ) - - def _on_delete_image_preset(self) -> None: - """画像処理プリセットを削除する""" - idx = self._get_image_preset_idx() - if idx < 0: - return - title = self._image_presets[idx].title - reply = QMessageBox.question( - self, "削除確認", - f"「{title}」を削除しますか?", - QMessageBox.StandardButton.Yes - | QMessageBox.StandardButton.No, - QMessageBox.StandardButton.No, - ) - if reply == QMessageBox.StandardButton.Yes: - delete_image_preset(idx) - self._refresh_image_presets() - - # ── PD 制御プリセット ───────────────────────── - - def _refresh_pd_presets(self) -> None: - """PD 制御プリセット一覧を更新する""" - self._pd_presets = load_pd_presets() - self._pd_preset_combo.clear() - for p in self._pd_presets: - self._pd_preset_combo.addItem(p.title) - self._pd_preset_memo.setText("") - - def _on_pd_preset_selected(self) -> None: - """PD 制御プリセット選択時にメモを表示""" - idx = self._pd_preset_combo.currentIndex() - if 0 <= idx < len(self._pd_presets): - self._pd_preset_memo.setText( - self._pd_presets[idx].memo, - ) - else: - self._pd_preset_memo.setText("") - - def _on_load_pd_preset(self) -> None: - """PD 制御プリセットを読み込む""" - idx = self._pd_preset_combo.currentIndex() - if idx < 0 or idx >= len(self._pd_presets): - return - self._auto_save_enabled = False - try: - p = self._pd_presets[idx].params - self._spin_kp.setValue(p.kp) - self._spin_kh.setValue(p.kh) - self._spin_kd.setValue(p.kd) - self._spin_max_steer_rate.setValue( - p.max_steer_rate, - ) - self._spin_max_throttle.setValue( - p.max_throttle, - ) - self._spin_speed_k.setValue(p.speed_k) - self._pd_control.params = p - finally: - self._auto_save_enabled = True - save_control( - p, - self._pd_control.image_params.method, - ) - - def _on_save_pd_preset(self) -> None: - """PD 制御プリセットを保存する""" - title, ok = QInputDialog.getText( - self, "PD プリセット保存", "タイトル:", - ) - if not ok or not title.strip(): - return - memo, ok = QInputDialog.getText( - self, "PD プリセット保存", "メモ:", - ) - if not ok: - return - p = self._pd_control.params - add_pd_preset(PdPreset( - title=title.strip(), - memo=memo.strip(), - params=PdParams( - kp=p.kp, kh=p.kh, kd=p.kd, - max_steer_rate=p.max_steer_rate, - max_throttle=p.max_throttle, - speed_k=p.speed_k, - ), - )) - self._refresh_pd_presets() - self._pd_preset_combo.setCurrentIndex( - self._pd_preset_combo.count() - 1, - ) - - def _on_delete_pd_preset(self) -> None: - """PD 制御プリセットを削除する""" - idx = self._pd_preset_combo.currentIndex() - if idx < 0 or idx >= len(self._pd_presets): - return - title = self._pd_presets[idx].title - reply = QMessageBox.question( - self, "削除確認", - f"「{title}」を削除しますか?", - QMessageBox.StandardButton.Yes - | QMessageBox.StandardButton.No, - QMessageBox.StandardButton.No, - ) - if reply == QMessageBox.StandardButton.Yes: - delete_pd_preset(idx) - self._refresh_pd_presets() - - def _setup_overlay_ui( - self, parent_layout: QVBoxLayout, - ) -> None: - """デバッグ表示のチェックボックス UI を構築する""" - group = QGroupBox("デバッグ表示") - layout = QVBoxLayout() - group.setLayout(layout) - - items = [ - ("二値化画像", "binary"), - ("検出領域", "detect_region"), - ("フィッティング曲線", "poly_curve"), - ("中心線", "center_line"), - ("検出情報", "info_text"), - ] - for label, attr in items: - cb = QCheckBox(label) - cb.toggled.connect( - lambda checked, a=attr: - setattr(self._overlay_flags, a, checked), - ) - layout.addWidget(cb) - - parent_layout.addWidget(group) - - def _has_any_overlay(self) -> bool: - """いずれかのオーバーレイが有効かを返す""" - f = self._overlay_flags - return ( - f.binary or f.detect_region - or f.poly_curve or f.center_line - or f.info_text - ) - - @staticmethod - def _create_spin( - value: float, min_val: float, - max_val: float, step: float, - ) -> QDoubleSpinBox: - """パラメータ用の SpinBox を作成する""" - spin = QDoubleSpinBox() - spin.setRange(min_val, max_val) - spin.setSingleStep(step) - spin.setDecimals(3) - spin.setValue(value) - return spin - - def _on_param_changed(self) -> None: - """パラメータ SpinBox の値が変更されたときに反映する""" - p = self._pd_control.params - p.kp = self._spin_kp.value() - p.kh = self._spin_kh.value() - p.kd = self._spin_kd.value() - p.max_steer_rate = ( - self._spin_max_steer_rate.value() - ) - p.max_throttle = self._spin_max_throttle.value() - p.speed_k = self._spin_speed_k.value() - - if self._auto_save_enabled: - save_control( - p, - self._pd_control.image_params.method, - ) + @property + def _active_control(self) -> SteeringBase: + """現在選択中の制御クラスを返す""" + if self._steering_method == "pursuit": + return self._pursuit_control + return self._pd_control def _setup_timers(self) -> None: """タイマーを設定する""" - # 映像更新用 self._frame_timer = QTimer(self) self._frame_timer.timeout.connect(self._update_frame) - # 操舵量送信用 self._control_timer = QTimer(self) self._control_timer.timeout.connect( self._send_control, ) + # ── パネルシグナルのスロット ─────────────────────────── + + def _on_image_params_changed( + self, ip: ImageParams, + ) -> None: + """画像処理パラメータの変更を両制御クラスに反映する""" + self._pd_control.image_params = ip + self._pursuit_control.image_params = ip + + def _on_method_changed(self, method: str) -> None: + """検出手法の変更に合わせて制御設定を保存する""" + save_control(self._pd_control.params, method) + + def _on_pd_params_changed(self, p: PdParams) -> None: + """PD パラメータの変更を制御クラスに反映して保存する""" + self._pd_control.params = p + save_control( + p, self._pd_control.image_params.method, + ) + + def _on_pursuit_params_changed( + self, p: PursuitParams, + ) -> None: + """Pursuit パラメータの変更を制御クラスに反映する""" + self._pursuit_control.params = p + + def _on_steering_method_changed( + self, method: str, + ) -> None: + """制御手法の切替を反映する""" + self._steering_method = method + # ── 接続 ────────────────────────────────────────────── def _toggle_connection(self) -> None: @@ -967,7 +327,7 @@ def _enable_auto(self) -> None: """自動操縦を開始する""" self._is_auto = True - self._pd_control.reset() + self._active_control.reset() self._pressed_keys.clear() self._auto_btn.setText("自動操縦 OFF") self._status_label.setText("接続中 (自動操縦)") @@ -990,26 +350,38 @@ return self._latest_frame = frame - # 自動操縦時は操舵量を計算 + # 線検出は常に実行(検出情報ラベル表示のため) if self._is_auto: - output = self._pd_control.compute(frame) + ctrl = self._active_control + output = ctrl.compute(frame) self._throttle = output.throttle self._steer = output.steer self._update_control_label() self._last_detect_result = ( - self._pd_control.last_detect_result + ctrl.last_detect_result ) - elif self._has_any_overlay(): - # 手動操作中でもオーバーレイ用に線検出を実行 + else: self._last_detect_result = detect_line( frame, self._pd_control.image_params, ) - else: - self._last_detect_result = None self._display_frame(frame) + def _update_detect_info_label(self) -> None: + """検出情報ラベルを更新する""" + r = self._last_detect_result + if r is None or not r.detected: + self._detect_info_label.setText( + "pos: --- head: --- curv: ---" + ) + return + self._detect_info_label.setText( + f"pos: {r.position_error:+.3f}" + f" head: {r.heading:+.4f}" + f" curv: {r.curvature:+.6f}" + ) + def _display_frame(self, frame: np.ndarray) -> None: """NumPy 配列の画像を QLabel に表示する @@ -1022,9 +394,12 @@ # オーバーレイ描画 bgr = draw_overlay( bgr, self._last_detect_result, - self._overlay_flags, + self._overlay_panel.get_flags(), ) + # 検出情報をラベルに表示 + self._update_detect_info_label() + # BGR → RGB 変換 rgb = bgr[:, :, ::-1].copy() h, w, ch = rgb.shape @@ -1032,12 +407,11 @@ rgb.data, w, h, ch * w, QImage.Format.Format_RGB888, ) - # スケーリングして表示 - scaled_w = int(w * DISPLAY_SCALE) - scaled_h = int(h * DISPLAY_SCALE) + disp_w = int(config.FRAME_WIDTH * DISPLAY_SCALE) + disp_h = int(config.FRAME_HEIGHT * DISPLAY_SCALE) pixmap = QPixmap.fromImage(image).scaled( - scaled_w, - scaled_h, + disp_w, + disp_h, Qt.AspectRatioMode.KeepAspectRatio, Qt.TransformationMode.SmoothTransformation, ) @@ -1079,7 +453,6 @@ self._throttle = 0.0 self._steer = 0.0 self._pressed_keys.clear() - # 自動操縦中なら停止 if self._is_auto: self._disable_auto() self._update_control_label() @@ -1087,8 +460,7 @@ # throttle: W/↑ で前進,S/↓ で後退 forward = ( - Qt.Key.Key_W in keys - or Qt.Key.Key_Up in keys + Qt.Key.Key_W in keys or Qt.Key.Key_Up in keys ) backward = ( Qt.Key.Key_S in keys diff --git a/src/pc/gui/panels/__init__.py b/src/pc/gui/panels/__init__.py new file mode 100644 index 0000000..7988084 --- /dev/null +++ b/src/pc/gui/panels/__init__.py @@ -0,0 +1,14 @@ +""" +panels +GUI パネルウィジェット群 +""" + +from pc.gui.panels.control_param_panel import ControlParamPanel +from pc.gui.panels.image_param_panel import ImageParamPanel +from pc.gui.panels.overlay_panel import OverlayPanel + +__all__ = [ + "ControlParamPanel", + "ImageParamPanel", + "OverlayPanel", +] diff --git a/src/pc/gui/panels/control_param_panel.py b/src/pc/gui/panels/control_param_panel.py new file mode 100644 index 0000000..864288d --- /dev/null +++ b/src/pc/gui/panels/control_param_panel.py @@ -0,0 +1,382 @@ +""" +control_param_panel +PD / 2点パシュート制御パラメータ調整 UI パネル +""" + +from PySide6.QtCore import Signal +from PySide6.QtWidgets import ( + QComboBox, + QDoubleSpinBox, + QFormLayout, + QGroupBox, + QHBoxLayout, + QInputDialog, + QLabel, + QMessageBox, + QPushButton, + QVBoxLayout, + QWidget, +) + +from pc.gui.panels.image_param_panel import _create_preset_ui +from pc.steering.param_store import ( + PdPreset, + add_pd_preset, + delete_pd_preset, + load_pd_presets, +) +from pc.steering.pd_control import PdParams +from pc.steering.pursuit_control import PursuitParams + + +class ControlParamPanel(QGroupBox): + """PD / 2点パシュート制御パラメータ調整 UI""" + + # PD パラメータが変更されたときに emit する + pd_params_changed = Signal(object) + # Pursuit パラメータが変更されたときに emit する + pursuit_params_changed = Signal(object) + # 制御手法が変更されたときに emit する("pd" or "pursuit") + steering_method_changed = Signal(str) + + def __init__( + self, + pd_params: PdParams, + pursuit_params: PursuitParams, + ) -> None: + super().__init__("制御パラメータ") + self._pd_params = pd_params + self._pursuit_params = pursuit_params + self._auto_save_enabled = False + self._pd_presets: list[PdPreset] = [] + + self._setup_ui() + self._auto_save_enabled = True + + def get_pd_params(self) -> PdParams: + """現在の PD パラメータを返す""" + return self._pd_params + + def get_pursuit_params(self) -> PursuitParams: + """現在の Pursuit パラメータを返す""" + return self._pursuit_params + + def _setup_ui(self) -> None: + """UI を構築する""" + layout = QVBoxLayout() + self.setLayout(layout) + + # 制御手法の選択 + self._steering_combo = QComboBox() + self._steering_combo.addItem("PD 制御", "pd") + self._steering_combo.addItem( + "2点パシュート", "pursuit", + ) + layout.addWidget(self._steering_combo) + + # --- PD パラメータ --- + self._pd_param_form = QFormLayout() + layout.addLayout(self._pd_param_form) + + p = self._pd_params + + self._spin_kp = _create_spin(p.kp, 0.0, 0.05) + self._pd_param_form.addRow("Kp (位置):", self._spin_kp) + + self._spin_kh = _create_spin(p.kh, 0.0, 0.05) + self._pd_param_form.addRow("Kh (傾き):", self._spin_kh) + + self._spin_kd = _create_spin(p.kd, 0.0, 0.05) + self._pd_param_form.addRow("Kd (微分):", self._spin_kd) + + self._spin_max_steer_rate = _create_spin( + p.max_steer_rate, 0.01, 0.01, + ) + self._pd_param_form.addRow( + "操舵制限:", self._spin_max_steer_rate, + ) + + self._spin_max_throttle = _create_spin( + p.max_throttle, 0.0, 0.05, + ) + self._pd_param_form.addRow( + "最大速度:", self._spin_max_throttle, + ) + + self._spin_speed_k = _create_spin( + p.speed_k, 0.0, 0.05, + ) + self._pd_param_form.addRow( + "減速係数:", self._spin_speed_k, + ) + + # PD 固有ウィジェットリスト(表示切替用) + self._pd_widgets: list[QWidget] = [ + self._spin_kp, + self._spin_kh, + self._spin_kd, + ] + + # --- Pursuit パラメータ --- + self._pursuit_param_form = QFormLayout() + layout.addLayout(self._pursuit_param_form) + + pp = self._pursuit_params + + self._spin_near_ratio = _create_spin( + pp.near_ratio, 0.0, 0.05, + ) + self._pursuit_param_form.addRow( + "近目標(比率):", self._spin_near_ratio, + ) + + self._spin_far_ratio = _create_spin( + pp.far_ratio, 0.0, 0.05, + ) + self._pursuit_param_form.addRow( + "遠目標(比率):", self._spin_far_ratio, + ) + + self._spin_k_near = _create_spin( + pp.k_near, 0.0, 0.05, + ) + self._pursuit_param_form.addRow( + "K_near:", self._spin_k_near, + ) + + self._spin_k_far = _create_spin( + pp.k_far, 0.0, 0.05, + ) + self._pursuit_param_form.addRow( + "K_far:", self._spin_k_far, + ) + + self._spin_pursuit_steer_rate = _create_spin( + pp.max_steer_rate, 0.01, 0.01, + ) + self._pursuit_param_form.addRow( + "操舵制限:", self._spin_pursuit_steer_rate, + ) + + self._spin_pursuit_throttle = _create_spin( + pp.max_throttle, 0.0, 0.05, + ) + self._pursuit_param_form.addRow( + "最大速度:", self._spin_pursuit_throttle, + ) + + self._spin_pursuit_speed_k = _create_spin( + pp.speed_k, 0.0, 0.1, + ) + self._pursuit_param_form.addRow( + "減速係数:", self._spin_pursuit_speed_k, + ) + + # Pursuit 固有ウィジェットリスト(表示切替用) + self._pursuit_widgets: list[QWidget] = [ + self._spin_near_ratio, + self._spin_far_ratio, + self._spin_k_near, + self._spin_k_far, + self._spin_pursuit_steer_rate, + self._spin_pursuit_throttle, + self._spin_pursuit_speed_k, + ] + + # --- プリセット管理 --- + self._pd_preset_combo, self._pd_preset_memo = ( + _create_preset_ui( + layout, + self._on_load_pd_preset, + self._on_save_pd_preset, + self._on_delete_pd_preset, + self._on_pd_preset_selected, + ) + ) + + # コールバック接続 + self._steering_combo.currentIndexChanged.connect( + self._on_steering_method_changed, + ) + for spin in [ + self._spin_kp, self._spin_kh, + self._spin_kd, self._spin_max_steer_rate, + self._spin_max_throttle, self._spin_speed_k, + ]: + spin.valueChanged.connect(self._on_pd_changed) + for spin in self._pursuit_widgets: + spin.valueChanged.connect( + self._on_pursuit_changed, + ) + + self._refresh_pd_presets() + + # 初期表示の更新 + self._on_steering_method_changed() + + def _on_pd_changed(self) -> None: + """PD パラメータ SpinBox の値が変更されたときに反映する""" + p = self._pd_params + p.kp = self._spin_kp.value() + p.kh = self._spin_kh.value() + p.kd = self._spin_kd.value() + p.max_steer_rate = self._spin_max_steer_rate.value() + p.max_throttle = self._spin_max_throttle.value() + p.speed_k = self._spin_speed_k.value() + + if self._auto_save_enabled: + self.pd_params_changed.emit(p) + + def _on_pursuit_changed(self) -> None: + """Pursuit パラメータの変更を反映する""" + p = self._pursuit_params + p.near_ratio = self._spin_near_ratio.value() + p.far_ratio = self._spin_far_ratio.value() + p.k_near = self._spin_k_near.value() + p.k_far = self._spin_k_far.value() + p.max_steer_rate = ( + self._spin_pursuit_steer_rate.value() + ) + p.max_throttle = ( + self._spin_pursuit_throttle.value() + ) + p.speed_k = self._spin_pursuit_speed_k.value() + + if self._auto_save_enabled: + self.pursuit_params_changed.emit(p) + + def _on_steering_method_changed(self) -> None: + """制御手法の変更を反映する""" + method = self._steering_combo.currentData() + is_pd = method == "pd" + + # PD 固有ウィジェットの表示切替 + for w in self._pd_widgets: + w.setVisible(is_pd) + label = self._pd_param_form.labelForField(w) + if label: + label.setVisible(is_pd) + + # 共通ウィジェット(操舵制限/最大速度/減速係数) + for w in [ + self._spin_max_steer_rate, + self._spin_max_throttle, + self._spin_speed_k, + ]: + w.setVisible(is_pd) + label = self._pd_param_form.labelForField(w) + if label: + label.setVisible(is_pd) + + # Pursuit ウィジェットの表示切替 + for w in self._pursuit_widgets: + w.setVisible(not is_pd) + label = ( + self._pursuit_param_form.labelForField(w) + ) + if label: + label.setVisible(not is_pd) + + if self._auto_save_enabled: + self.steering_method_changed.emit(method) + + # ── PD プリセット管理 ────────────────────────────────── + + def _refresh_pd_presets(self) -> None: + """PD 制御プリセット一覧を更新する""" + self._pd_presets = load_pd_presets() + self._pd_preset_combo.clear() + for p in self._pd_presets: + self._pd_preset_combo.addItem(p.title) + self._pd_preset_memo.setText("") + + def _on_pd_preset_selected(self) -> None: + """PD 制御プリセット選択時にメモを表示する""" + idx = self._pd_preset_combo.currentIndex() + if 0 <= idx < len(self._pd_presets): + self._pd_preset_memo.setText( + self._pd_presets[idx].memo, + ) + else: + self._pd_preset_memo.setText("") + + def _on_load_pd_preset(self) -> None: + """PD 制御プリセットを読み込む""" + idx = self._pd_preset_combo.currentIndex() + if idx < 0 or idx >= len(self._pd_presets): + return + self._auto_save_enabled = False + try: + p = self._pd_presets[idx].params + self._spin_kp.setValue(p.kp) + self._spin_kh.setValue(p.kh) + self._spin_kd.setValue(p.kd) + self._spin_max_steer_rate.setValue( + p.max_steer_rate, + ) + self._spin_max_throttle.setValue(p.max_throttle) + self._spin_speed_k.setValue(p.speed_k) + self._pd_params = p + finally: + self._auto_save_enabled = True + self.pd_params_changed.emit(p) + + def _on_save_pd_preset(self) -> None: + """PD 制御プリセットを保存する""" + title, ok = QInputDialog.getText( + self, "PD プリセット保存", "タイトル:", + ) + if not ok or not title.strip(): + return + memo, ok = QInputDialog.getText( + self, "PD プリセット保存", "メモ:", + ) + if not ok: + return + p = self._pd_params + add_pd_preset(PdPreset( + title=title.strip(), + memo=memo.strip(), + params=PdParams( + kp=p.kp, kh=p.kh, kd=p.kd, + max_steer_rate=p.max_steer_rate, + max_throttle=p.max_throttle, + speed_k=p.speed_k, + ), + )) + self._refresh_pd_presets() + self._pd_preset_combo.setCurrentIndex( + self._pd_preset_combo.count() - 1, + ) + + def _on_delete_pd_preset(self) -> None: + """PD 制御プリセットを削除する""" + idx = self._pd_preset_combo.currentIndex() + if idx < 0 or idx >= len(self._pd_presets): + return + title = self._pd_presets[idx].title + reply = QMessageBox.question( + self, "削除確認", + f"「{title}」を削除しますか?", + QMessageBox.StandardButton.Yes + | QMessageBox.StandardButton.No, + QMessageBox.StandardButton.No, + ) + if reply == QMessageBox.StandardButton.Yes: + delete_pd_preset(idx) + self._refresh_pd_presets() + + +def _create_spin( + value: float, min_val: float, step: float, +) -> QDoubleSpinBox: + """パラメータ用の QDoubleSpinBox を作成する + + 直接入力にも対応するため,上限は広めに設定する + """ + spin = QDoubleSpinBox() + spin.setRange(min_val, 99999.0) + spin.setSingleStep(step) + spin.setDecimals(3) + spin.setValue(value) + return spin diff --git a/src/pc/gui/panels/image_param_panel.py b/src/pc/gui/panels/image_param_panel.py new file mode 100644 index 0000000..1e7a1b7 --- /dev/null +++ b/src/pc/gui/panels/image_param_panel.py @@ -0,0 +1,584 @@ +""" +image_param_panel +画像処理パラメータ調整 UI パネル +""" + +from PySide6.QtCore import Signal +from PySide6.QtWidgets import ( + QComboBox, + QDoubleSpinBox, + QFormLayout, + QGroupBox, + QHBoxLayout, + QInputDialog, + QLabel, + QMessageBox, + QPushButton, + QSpinBox, + QVBoxLayout, + QWidget, +) + +from pc.steering.auto_params import ( + load_detect_params, + save_detect_params, +) +from pc.steering.param_store import ( + ImagePreset, + add_image_preset, + delete_image_preset, + load_image_presets, +) +from pc.vision.line_detector import ( + DETECT_METHODS, + ImageParams, + reset_valley_tracker, +) + + +class ImageParamPanel(QGroupBox): + """画像処理パラメータ調整 UI""" + + # 画像処理パラメータが変更されたときに emit する + image_params_changed = Signal(object) + # 検出手法が変更されたときに emit する(新手法キー) + method_changed = Signal(str) + + def __init__(self, image_params: ImageParams) -> None: + super().__init__("画像処理パラメータ") + self._image_params = image_params + self._auto_save_enabled = False + self._image_presets: list[ImagePreset] = [] + self._image_filtered: list[int] = [] + self._image_param_vis: list[ + tuple[QWidget, set[str], str] + ] = [] + + self._setup_ui() + self._auto_save_enabled = True + + def get_image_params(self) -> ImageParams: + """現在の画像処理パラメータを返す""" + return self._image_params + + def _setup_ui(self) -> None: + """UI を構築する""" + layout = QVBoxLayout() + self.setLayout(layout) + + ip = self._image_params + + # 検出手法の選択コンボボックス + self._method_combo = QComboBox() + for key, label in DETECT_METHODS.items(): + self._method_combo.addItem(label, key) + idx = self._method_combo.findData(ip.method) + if idx >= 0: + self._method_combo.setCurrentIndex(idx) + layout.addWidget(self._method_combo) + + # パラメータフォーム + self._image_form = QFormLayout() + layout.addLayout(self._image_form) + + # --- 現行手法パラメータ --- + self._spin_clahe_clip = self._create_spin( + ip.clahe_clip, 0.5, 0.5, + ) + self._add_row( + "CLAHE強度:", self._spin_clahe_clip, + {"current"}, "clahe_clip", + ) + + self._spin_binary_thresh = QSpinBox() + self._spin_binary_thresh.setRange(0, 255) + self._spin_binary_thresh.setValue(ip.binary_thresh) + self._add_row( + "二値化閾値:", self._spin_binary_thresh, + {"current", "blackhat"}, "binary_thresh", + ) + + self._spin_open_size = QSpinBox() + self._spin_open_size.setRange(1, 999) + self._spin_open_size.setSingleStep(2) + self._spin_open_size.setValue(ip.open_size) + self._add_row( + "ノイズ除去:", self._spin_open_size, + {"current"}, "open_size", + ) + + self._spin_close_width = QSpinBox() + self._spin_close_width.setRange(1, 999) + self._spin_close_width.setSingleStep(2) + self._spin_close_width.setValue(ip.close_width) + self._add_row( + "途切れ補間:", self._spin_close_width, + {"current"}, "close_width", + ) + + # --- 案A/C: Black-hat --- + self._spin_blackhat_ksize = QSpinBox() + self._spin_blackhat_ksize.setRange(1, 999) + self._spin_blackhat_ksize.setSingleStep(2) + self._spin_blackhat_ksize.setValue(ip.blackhat_ksize) + self._add_row( + "BHカーネル:", self._spin_blackhat_ksize, + {"blackhat", "robust"}, "blackhat_ksize", + ) + + # --- 案B: 背景除算 --- + self._spin_bg_blur_ksize = QSpinBox() + self._spin_bg_blur_ksize.setRange(1, 999) + self._spin_bg_blur_ksize.setSingleStep(2) + self._spin_bg_blur_ksize.setValue(ip.bg_blur_ksize) + self._add_row( + "背景ブラー:", self._spin_bg_blur_ksize, + {"dual_norm"}, "bg_blur_ksize", + ) + + self._spin_global_thresh = QSpinBox() + self._spin_global_thresh.setRange(0, 255) + self._spin_global_thresh.setValue(ip.global_thresh) + self._spin_global_thresh.setSpecialValueText("無効") + self._add_row( + "固定閾値:", self._spin_global_thresh, + {"dual_norm"}, "global_thresh", + ) + + # --- 案B/C: 適応的閾値 --- + self._spin_adaptive_block = QSpinBox() + self._spin_adaptive_block.setRange(3, 999) + self._spin_adaptive_block.setSingleStep(2) + self._spin_adaptive_block.setValue(ip.adaptive_block) + self._add_row( + "適応ブロック:", self._spin_adaptive_block, + {"dual_norm", "robust"}, "adaptive_block", + ) + + self._spin_adaptive_c = QSpinBox() + self._spin_adaptive_c.setRange(0, 255) + self._spin_adaptive_c.setValue(ip.adaptive_c) + self._add_row( + "適応定数C:", self._spin_adaptive_c, + {"dual_norm", "robust"}, "adaptive_c", + ) + + # --- 案A/B/C: 後処理 --- + self._spin_iso_close = QSpinBox() + self._spin_iso_close.setRange(1, 999) + self._spin_iso_close.setSingleStep(2) + self._spin_iso_close.setValue(ip.iso_close_size) + self._add_row( + "穴埋め:", self._spin_iso_close, + {"blackhat", "dual_norm", "robust"}, + "iso_close_size", + ) + + self._spin_dist_thresh = self._create_spin( + ip.dist_thresh, 0.0, 0.5, + ) + self._add_row( + "距離閾値:", self._spin_dist_thresh, + {"blackhat", "dual_norm", "robust"}, + "dist_thresh", + ) + + self._spin_min_line_width = QSpinBox() + self._spin_min_line_width.setRange(1, 999) + self._spin_min_line_width.setValue(ip.min_line_width) + self._add_row( + "最小線幅:", self._spin_min_line_width, + {"blackhat", "dual_norm", "robust"}, + "min_line_width", + ) + + # --- 案B: 段階クロージング --- + self._spin_stage_close_small = QSpinBox() + self._spin_stage_close_small.setRange(1, 999) + self._spin_stage_close_small.setSingleStep(2) + self._spin_stage_close_small.setValue( + ip.stage_close_small, + ) + self._add_row( + "段階穴埋め(小):", + self._spin_stage_close_small, + {"dual_norm"}, "stage_close_small", + ) + + self._spin_stage_min_area = QSpinBox() + self._spin_stage_min_area.setRange(0, 99999) + self._spin_stage_min_area.setValue(ip.stage_min_area) + self._spin_stage_min_area.setSpecialValueText("無効") + self._add_row( + "孤立除去面積:", self._spin_stage_min_area, + {"dual_norm"}, "stage_min_area", + ) + + self._spin_stage_close_large = QSpinBox() + self._spin_stage_close_large.setRange(0, 999) + self._spin_stage_close_large.setSingleStep(2) + self._spin_stage_close_large.setValue( + ip.stage_close_large, + ) + self._spin_stage_close_large.setSpecialValueText( + "無効", + ) + self._add_row( + "段階穴埋め(大):", + self._spin_stage_close_large, + {"dual_norm"}, "stage_close_large", + ) + + # --- ロバストフィッティング(全手法共通) --- + all_methods = { + "blackhat", "dual_norm", "robust", "valley", + } + + self._spin_median_ksize = QSpinBox() + self._spin_median_ksize.setRange(0, 999) + self._spin_median_ksize.setSingleStep(2) + self._spin_median_ksize.setValue(ip.median_ksize) + self._spin_median_ksize.setSpecialValueText("無効") + self._add_row( + "メディアン:", self._spin_median_ksize, + all_methods, "median_ksize", + ) + + self._spin_neighbor_thresh = self._create_spin( + ip.neighbor_thresh, 0.0, 1.0, + ) + self._add_row( + "近傍除去:", self._spin_neighbor_thresh, + all_methods, "neighbor_thresh", + ) + + self._spin_residual_thresh = self._create_spin( + ip.residual_thresh, 0.0, 1.0, + ) + self._add_row( + "残差除去:", self._spin_residual_thresh, + all_methods, "residual_thresh", + ) + + # --- 案C/D: RANSAC --- + self._spin_ransac_thresh = self._create_spin( + ip.ransac_thresh, 1.0, 1.0, + ) + self._add_row( + "RANSAC閾値:", self._spin_ransac_thresh, + {"robust", "valley"}, "ransac_thresh", + ) + + # --- 幅フィルタ(透視補正) --- + _width_methods = { + "blackhat", "dual_norm", "robust", "valley", + } + + self._spin_width_near = QSpinBox() + self._spin_width_near.setRange(0, 9999) + self._spin_width_near.setValue(ip.width_near) + self._spin_width_near.setSpecialValueText("無効") + self._add_row( + "線幅(近)px:", self._spin_width_near, + _width_methods, "width_near", + ) + + self._spin_width_far = QSpinBox() + self._spin_width_far.setRange(0, 9999) + self._spin_width_far.setValue(ip.width_far) + self._spin_width_far.setSpecialValueText("無効") + self._add_row( + "線幅(遠)px:", self._spin_width_far, + _width_methods, "width_far", + ) + + self._spin_width_tolerance = self._create_spin( + ip.width_tolerance, 1.0, 0.1, + ) + self._add_row( + "幅フィルタ倍率:", self._spin_width_tolerance, + _width_methods, "width_tolerance", + ) + + # --- 案D: 谷検出+追跡 --- + self._spin_valley_gauss = QSpinBox() + self._spin_valley_gauss.setRange(3, 999) + self._spin_valley_gauss.setSingleStep(2) + self._spin_valley_gauss.setValue( + ip.valley_gauss_ksize, + ) + self._add_row( + "谷ガウス:", self._spin_valley_gauss, + {"valley"}, "valley_gauss_ksize", + ) + + self._spin_valley_min_depth = QSpinBox() + self._spin_valley_min_depth.setRange(1, 255) + self._spin_valley_min_depth.setValue( + ip.valley_min_depth, + ) + self._add_row( + "最小谷深度:", self._spin_valley_min_depth, + {"valley"}, "valley_min_depth", + ) + + self._spin_valley_max_dev = QSpinBox() + self._spin_valley_max_dev.setRange(1, 9999) + self._spin_valley_max_dev.setValue( + ip.valley_max_deviation, + ) + self._add_row( + "最大偏差:", self._spin_valley_max_dev, + {"valley"}, "valley_max_deviation", + ) + + self._spin_valley_coast = QSpinBox() + self._spin_valley_coast.setRange(0, 999) + self._spin_valley_coast.setValue( + ip.valley_coast_frames, + ) + self._add_row( + "予測継続:", self._spin_valley_coast, + {"valley"}, "valley_coast_frames", + ) + + self._spin_valley_ema = self._create_spin( + ip.valley_ema_alpha, 0.0, 0.05, + ) + self._add_row( + "EMA係数:", self._spin_valley_ema, + {"valley"}, "valley_ema_alpha", + ) + + # --- プリセット管理 --- + combo, memo = _create_preset_ui( + layout, + self._on_load_preset, + self._on_save_preset, + self._on_delete_preset, + self._on_preset_selected, + ) + self._preset_combo = combo + self._preset_memo = memo + + # コールバック接続 + self._method_combo.currentIndexChanged.connect( + self._on_method_changed, + ) + for widget, _, _ in self._image_param_vis: + widget.valueChanged.connect( + self._on_image_param_changed, + ) + + # 初期表示 + self._on_method_changed() + + def _add_row( + self, + label: str, + widget: QWidget, + methods: set[str], + field: str, + ) -> None: + """画像処理パラメータの行を追加する + + Args: + label: フォームラベル + widget: SpinBox ウィジェット + methods: 表示対象の検出手法集合 + field: ImageParams のフィールド名 + """ + self._image_form.addRow(label, widget) + self._image_param_vis.append((widget, methods, field)) + + def _on_method_changed(self) -> None: + """検出手法の変更を反映する""" + method = self._method_combo.currentData() + + # 谷検出の追跡状態をリセット + reset_valley_tracker() + + if self._auto_save_enabled: + # 旧手法のパラメータを保存 + ip = self._image_params + save_detect_params(ip.method, ip) + # 新手法のパラメータを読み込み + new_ip = load_detect_params(method) + self._image_params = new_ip + self._sync_spinboxes() + self.image_params_changed.emit(new_ip) + self.method_changed.emit(method) + else: + self._image_params.method = method + + # パラメータの表示/非表示を更新 + for widget, methods, _ in self._image_param_vis: + visible = method in methods + widget.setVisible(visible) + label = self._image_form.labelForField(widget) + if label: + label.setVisible(visible) + + # 保存済みプリセットを手法でフィルタ + if hasattr(self, "_preset_combo"): + self._refresh_presets() + + def _sync_spinboxes(self) -> None: + """SpinBox を現在の image_params に同期する""" + self._auto_save_enabled = False + try: + ip = self._image_params + for widget, _, field in self._image_param_vis: + widget.setValue(getattr(ip, field)) + finally: + self._auto_save_enabled = True + + def _on_image_param_changed(self) -> None: + """SpinBox の値が変更されたときに反映する""" + ip = self._image_params + for widget, _, field in self._image_param_vis: + setattr(ip, field, widget.value()) + + if self._auto_save_enabled: + save_detect_params(ip.method, ip) + self.image_params_changed.emit(ip) + + # ── プリセット管理 ───────────────────────────────────── + + def _refresh_presets(self) -> None: + """選択中の手法のプリセットだけ表示する""" + self._image_presets = load_image_presets() + method = self._method_combo.currentData() + self._preset_combo.clear() + self._image_filtered = [] + for i, p in enumerate(self._image_presets): + if p.image_params.method == method: + self._preset_combo.addItem(p.title) + self._image_filtered.append(i) + self._preset_memo.setText("") + + def _on_preset_selected(self) -> None: + """プリセット選択時にメモを表示する""" + idx = self._get_preset_idx() + if idx >= 0: + self._preset_memo.setText( + self._image_presets[idx].memo, + ) + else: + self._preset_memo.setText("") + + def _get_preset_idx(self) -> int: + """コンボの選択を全体インデックスに変換する""" + ci = self._preset_combo.currentIndex() + if ci < 0 or ci >= len(self._image_filtered): + return -1 + return self._image_filtered[ci] + + def _on_load_preset(self) -> None: + """プリセットを読み込む""" + idx = self._get_preset_idx() + if idx < 0: + return + self._auto_save_enabled = False + try: + ip = self._image_presets[idx].image_params + self._image_params = ip + self._sync_spinboxes() + finally: + self._auto_save_enabled = True + save_detect_params(ip.method, ip) + self.image_params_changed.emit(ip) + + def _on_save_preset(self) -> None: + """プリセットを保存する""" + title, ok = QInputDialog.getText( + self, "画像処理プリセット保存", "タイトル:", + ) + if not ok or not title.strip(): + return + memo, ok = QInputDialog.getText( + self, "画像処理プリセット保存", "メモ:", + ) + if not ok: + return + ip = self._image_params + add_image_preset(ImagePreset( + title=title.strip(), + memo=memo.strip(), + image_params=ImageParams(**{ + f.name: getattr(ip, f.name) + for f in ip.__dataclass_fields__.values() + }), + )) + self._refresh_presets() + self._preset_combo.setCurrentIndex( + self._preset_combo.count() - 1, + ) + + def _on_delete_preset(self) -> None: + """プリセットを削除する""" + idx = self._get_preset_idx() + if idx < 0: + return + title = self._image_presets[idx].title + reply = QMessageBox.question( + self, "削除確認", + f"「{title}」を削除しますか?", + QMessageBox.StandardButton.Yes + | QMessageBox.StandardButton.No, + QMessageBox.StandardButton.No, + ) + if reply == QMessageBox.StandardButton.Yes: + delete_image_preset(idx) + self._refresh_presets() + + @staticmethod + def _create_spin( + value: float, min_val: float, step: float, + ) -> QDoubleSpinBox: + """パラメータ用の QDoubleSpinBox を作成する + + 直接入力にも対応するため,上限は広めに設定する + """ + spin = QDoubleSpinBox() + spin.setRange(min_val, 99999.0) + spin.setSingleStep(step) + spin.setDecimals(3) + spin.setValue(value) + return spin + + +def _create_preset_ui( + layout: QVBoxLayout, + on_load, + on_save, + on_delete, + on_selected, +) -> tuple[QComboBox, QLabel]: + """プリセット管理 UI(ComboBox + メモ + ボタン)を作成する + + Returns: + (コンボボックス, メモラベル) のタプル + """ + combo = QComboBox() + combo.setPlaceholderText("プリセット") + layout.addWidget(combo) + + memo = QLabel("") + memo.setWordWrap(True) + memo.setStyleSheet("font-size: 11px; color: #888;") + layout.addWidget(memo) + + btn_layout = QHBoxLayout() + for text, callback in [ + ("読込", on_load), + ("保存", on_save), + ("削除", on_delete), + ]: + btn = QPushButton(text) + btn.clicked.connect(callback) + btn_layout.addWidget(btn) + layout.addLayout(btn_layout) + + combo.currentIndexChanged.connect(on_selected) + return combo, memo diff --git a/src/pc/gui/panels/overlay_panel.py b/src/pc/gui/panels/overlay_panel.py new file mode 100644 index 0000000..1a86ffa --- /dev/null +++ b/src/pc/gui/panels/overlay_panel.py @@ -0,0 +1,43 @@ +""" +overlay_panel +デバッグ表示の切替チェックボックスを提供するパネル +""" + +from PySide6.QtWidgets import QCheckBox, QGroupBox, QVBoxLayout + +from pc.vision.overlay import OverlayFlags + + +class OverlayPanel(QGroupBox): + """デバッグ表示の切替チェックボックス群""" + + def __init__(self) -> None: + super().__init__("デバッグ表示") + self._flags = OverlayFlags() + self._setup_ui() + + def _setup_ui(self) -> None: + """UI を構築する""" + layout = QVBoxLayout() + self.setLayout(layout) + + items = [ + ("二値化画像", "binary"), + ("検出領域", "detect_region"), + ("フィッティング曲線", "poly_curve"), + ("行中心点", "row_centers"), + ("Theil-Sen直線", "theil_sen"), + ("中心線", "center_line"), + ] + for label, attr in items: + cb = QCheckBox(label) + cb.toggled.connect( + lambda checked, a=attr: setattr( + self._flags, a, checked, + ), + ) + layout.addWidget(cb) + + def get_flags(self) -> OverlayFlags: + """現在のオーバーレイフラグを返す""" + return self._flags diff --git a/src/pc/steering/auto_params.py b/src/pc/steering/auto_params.py index fb3167d..24be64e 100644 --- a/src/pc/steering/auto_params.py +++ b/src/pc/steering/auto_params.py @@ -12,21 +12,14 @@ └── detect_robust.json 案C の画像処理パラメータ """ -import json from dataclasses import asdict -from pathlib import Path +from common.json_utils import PARAMS_DIR, read_json, write_json from pc.steering.pd_control import PdParams from pc.vision.line_detector import ImageParams -# パラメータ保存ディレクトリ -_PARAMS_DIR: Path = ( - Path(__file__).resolve().parent.parent.parent.parent - / "params" -) - # PD 制御パラメータファイル -_CONTROL_FILE: Path = _PARAMS_DIR / "control.json" +_CONTROL_FILE = PARAMS_DIR / "control.json" # 検出手法ごとのファイル名 _DETECT_FILES: dict[str, str] = { @@ -34,6 +27,7 @@ "blackhat": "detect_blackhat.json", "dual_norm": "detect_dual_norm.json", "robust": "detect_robust.json", + "valley": "detect_valley.json", } @@ -46,10 +40,9 @@ params: PD 制御パラメータ method: 最後に使用した検出手法の識別子 """ - _PARAMS_DIR.mkdir(exist_ok=True) data = asdict(params) data["last_method"] = method - _write_json(_CONTROL_FILE, data) + write_json(_CONTROL_FILE, data) def load_control() -> tuple[PdParams, str]: @@ -61,7 +54,7 @@ if not _CONTROL_FILE.exists(): return PdParams(), "current" - data = _read_json(_CONTROL_FILE) + data = read_json(_CONTROL_FILE) method = data.pop("last_method", "current") known = PdParams.__dataclass_fields__ filtered = { @@ -83,10 +76,9 @@ filename = _DETECT_FILES.get(method) if filename is None: return - _PARAMS_DIR.mkdir(exist_ok=True) data = asdict(params) data["method"] = method - _write_json(_PARAMS_DIR / filename, data) + write_json(PARAMS_DIR / filename, data) def load_detect_params(method: str) -> ImageParams: @@ -102,11 +94,11 @@ if filename is None: return ImageParams(method=method) - path = _PARAMS_DIR / filename + path = PARAMS_DIR / filename if not path.exists(): return ImageParams(method=method) - data = _read_json(path) + data = read_json(path) known = ImageParams.__dataclass_fields__ filtered = { k: v for k, v in data.items() @@ -116,15 +108,3 @@ return ImageParams(**filtered) -def _write_json(path: Path, data: dict) -> None: - """JSON ファイルに書き込む""" - with open(path, "w", encoding="utf-8") as f: - json.dump( - data, f, ensure_ascii=False, indent=2, - ) - - -def _read_json(path: Path) -> dict: - """JSON ファイルを読み込む""" - with open(path, "r", encoding="utf-8") as f: - return json.load(f) diff --git a/src/pc/steering/param_store.py b/src/pc/steering/param_store.py index c51ee3f..5a9cc58 100644 --- a/src/pc/steering/param_store.py +++ b/src/pc/steering/param_store.py @@ -4,21 +4,14 @@ 画像処理パラメータと PD 制御パラメータを独立して管理する """ -import json from dataclasses import asdict, dataclass -from pathlib import Path +from common.json_utils import PARAMS_DIR, read_json, write_json from pc.steering.pd_control import PdParams from pc.vision.line_detector import ImageParams -# プリセット保存ディレクトリ -_PARAMS_DIR: Path = ( - Path(__file__).resolve().parent.parent.parent.parent - / "params" -) - -_PD_FILE: Path = _PARAMS_DIR / "presets_pd.json" -_IMAGE_FILE: Path = _PARAMS_DIR / "presets_image.json" +_PD_FILE = PARAMS_DIR / "presets_pd.json" +_IMAGE_FILE = PARAMS_DIR / "presets_image.json" # ── PD 制御プリセット ───────────────────────── @@ -118,8 +111,7 @@ if not path.exists(): return [] - with open(path, "r", encoding="utf-8") as f: - data = json.load(f) + data = read_json(path) known = params_cls.__dataclass_fields__ presets = [] @@ -144,7 +136,6 @@ def _save_presets(path, presets, params_key): """プリセットファイルに保存する""" - path.parent.mkdir(exist_ok=True) data = [] for preset in presets: data.append({ @@ -155,5 +146,4 @@ ), }) - with open(path, "w", encoding="utf-8") as f: - json.dump(data, f, ensure_ascii=False, indent=2) + write_json(path, data) diff --git a/src/pc/steering/pd_control.py b/src/pc/steering/pd_control.py index 1f3a7fc..753c8e6 100644 --- a/src/pc/steering/pd_control.py +++ b/src/pc/steering/pd_control.py @@ -10,7 +10,11 @@ import numpy as np from pc.steering.base import SteeringBase, SteeringOutput -from pc.vision.line_detector import ImageParams, detect_line +from pc.vision.line_detector import ( + ImageParams, + detect_line, + reset_valley_tracker, +) @dataclass @@ -121,6 +125,7 @@ self._prev_time = 0.0 self._prev_steer = 0.0 self._last_result = None + reset_valley_tracker() @property def last_detect_result(self): diff --git a/src/pc/steering/pursuit_control.py b/src/pc/steering/pursuit_control.py new file mode 100644 index 0000000..833edd4 --- /dev/null +++ b/src/pc/steering/pursuit_control.py @@ -0,0 +1,144 @@ +""" +pursuit_control +2点パシュートによる操舵量計算モジュール +行中心点に Theil-Sen 直線近似を適用し,外れ値に強い操舵量を算出する +""" + +from dataclasses import dataclass + +import numpy as np + +from common import config +from pc.steering.base import SteeringBase, SteeringOutput +from pc.vision.fitting import theil_sen_fit +from pc.vision.line_detector import ( + ImageParams, + detect_line, + reset_valley_tracker, +) + + +@dataclass +class PursuitParams: + """2点パシュート制御のパラメータ + + Attributes: + near_ratio: 近い目標点の位置(0.0=上端,1.0=下端) + far_ratio: 遠い目標点の位置(0.0=上端,1.0=下端) + k_near: 近い目標点の操舵ゲイン + k_far: 遠い目標点の操舵ゲイン + max_steer_rate: 1フレームあたりの最大操舵変化量 + max_throttle: 直線での最大速度 + speed_k: カーブ減速係数(2点の差に対する係数) + """ + near_ratio: float = 0.8 + far_ratio: float = 0.3 + k_near: float = 0.5 + k_far: float = 0.3 + max_steer_rate: float = 0.1 + max_throttle: float = 0.4 + speed_k: float = 2.0 + + +class PursuitControl(SteeringBase): + """2点パシュートによる操舵量計算クラス + + 行中心点から Theil-Sen 直線近似を行い, + 直線上の近い点と遠い点の偏差から操舵量を計算する + """ + + def __init__( + self, + params: PursuitParams | None = None, + image_params: ImageParams | None = None, + ) -> None: + self.params: PursuitParams = ( + params or PursuitParams() + ) + self.image_params: ImageParams = ( + image_params or ImageParams() + ) + self._prev_steer: float = 0.0 + self._last_result = None + + def compute( + self, frame: np.ndarray, + ) -> SteeringOutput: + """カメラ画像から2点パシュートで操舵量を計算する + + Args: + frame: グレースケールのカメラ画像 + + Returns: + 計算された操舵量 + """ + p = self.params + + # 線検出 + result = detect_line(frame, self.image_params) + self._last_result = result + + if not result.detected or result.row_centers is None: + return SteeringOutput( + throttle=0.0, steer=0.0, + ) + + centers = result.row_centers + + # 有効な点(NaN でない行)を抽出 + valid = ~np.isnan(centers) + ys = np.where(valid)[0].astype(float) + xs = centers[valid] + + if len(ys) < 2: + return SteeringOutput( + throttle=0.0, steer=0.0, + ) + + # Theil-Sen 直線近似 + slope, intercept = theil_sen_fit(ys, xs) + + center_x = config.FRAME_WIDTH / 2.0 + h = len(centers) + + # 直線上の 2 点の x 座標を取得 + near_y = h * p.near_ratio + far_y = h * p.far_ratio + near_x = slope * near_y + intercept + far_x = slope * far_y + intercept + + # 各点の偏差(正: 線が左にある → 右に曲がる) + near_err = (center_x - near_x) / center_x + far_err = (center_x - far_x) / center_x + + # 操舵量 + steer = p.k_near * near_err + p.k_far * far_err + steer = max(-1.0, min(1.0, steer)) + + # レートリミッター + delta = steer - self._prev_steer + max_delta = p.max_steer_rate + delta = max(-max_delta, min(max_delta, delta)) + steer = self._prev_steer + delta + + # 速度制御(2点の x 差でカーブ度合いを判定) + curve = abs(near_x - far_x) / center_x + throttle = p.max_throttle - p.speed_k * curve + throttle = max(0.0, throttle) + + self._prev_steer = steer + + return SteeringOutput( + throttle=throttle, steer=steer, + ) + + def reset(self) -> None: + """内部状態をリセットする""" + self._prev_steer = 0.0 + self._last_result = None + reset_valley_tracker() + + @property + def last_detect_result(self): + """直近の線検出結果を取得する""" + return self._last_result diff --git a/src/pc/vision/detectors/__init__.py b/src/pc/vision/detectors/__init__.py new file mode 100644 index 0000000..e69de29 --- /dev/null +++ b/src/pc/vision/detectors/__init__.py diff --git a/src/pc/vision/detectors/blackhat.py b/src/pc/vision/detectors/blackhat.py new file mode 100644 index 0000000..bc80920 --- /dev/null +++ b/src/pc/vision/detectors/blackhat.py @@ -0,0 +1,66 @@ +""" +blackhat +案A: Black-hat 中心型の線検出 +Black-hat 変換で背景より暗い構造を直接抽出し, +固定閾値 + 距離変換 + 行ごと中心抽出で検出する +""" + +import cv2 +import numpy as np + +from pc.vision.line_detector import ImageParams, LineDetectResult +from pc.vision.line_detector import fit_row_centers +from pc.vision.morphology import ( + apply_dist_mask, + apply_iso_closing, + apply_width_filter, +) + + +def detect_blackhat( + frame: np.ndarray, params: ImageParams, +) -> LineDetectResult: + """案A: Black-hat 中心型""" + # Black-hat 変換(暗い構造の抽出) + bh_k = params.blackhat_ksize | 1 + bh_kernel = cv2.getStructuringElement( + cv2.MORPH_ELLIPSE, (bh_k, bh_k), + ) + blackhat = cv2.morphologyEx( + frame, cv2.MORPH_BLACKHAT, bh_kernel, + ) + + # ガウシアンブラー + blur_k = params.blur_size | 1 + blurred = cv2.GaussianBlur( + blackhat, (blur_k, blur_k), 0, + ) + + # 固定閾値(Black-hat 後は線が白) + _, binary = cv2.threshold( + blurred, params.binary_thresh, 255, + cv2.THRESH_BINARY, + ) + + # 等方クロージング + 距離変換マスク + 幅フィルタ + binary = apply_iso_closing( + binary, params.iso_close_size, + ) + binary = apply_dist_mask( + binary, params.dist_thresh, + ) + if params.width_near > 0 and params.width_far > 0: + binary = apply_width_filter( + binary, + params.width_near, + params.width_far, + params.width_tolerance, + ) + + # 行ごと中心抽出 + フィッティング + return fit_row_centers( + binary, params.min_line_width, + median_ksize=params.median_ksize, + neighbor_thresh=params.neighbor_thresh, + residual_thresh=params.residual_thresh, + ) diff --git a/src/pc/vision/detectors/current.py b/src/pc/vision/detectors/current.py new file mode 100644 index 0000000..12dacba --- /dev/null +++ b/src/pc/vision/detectors/current.py @@ -0,0 +1,76 @@ +""" +current +現行手法: CLAHE + 固定閾値 + 全ピクセルフィッティング +""" + +import cv2 +import numpy as np + +from pc.vision.line_detector import ( + DETECT_Y_END, + DETECT_Y_START, + MIN_FIT_PIXELS, + ImageParams, + LineDetectResult, + build_result, + no_detection, +) + + +def detect_current( + frame: np.ndarray, params: ImageParams, +) -> LineDetectResult: + """現行手法: CLAHE + 固定閾値 + 全ピクセルフィッティング""" + # CLAHE でコントラスト強調 + clahe = cv2.createCLAHE( + clipLimit=params.clahe_clip, + tileGridSize=( + params.clahe_grid, + params.clahe_grid, + ), + ) + enhanced = clahe.apply(frame) + + # ガウシアンブラー + blur_k = params.blur_size | 1 + blurred = cv2.GaussianBlur( + enhanced, (blur_k, blur_k), 0, + ) + + # 固定閾値で二値化(黒線を白に反転) + _, binary = cv2.threshold( + blurred, params.binary_thresh, 255, + cv2.THRESH_BINARY_INV, + ) + + # オープニング(孤立ノイズ除去) + if params.open_size >= 3: + open_k = params.open_size | 1 + open_kernel = cv2.getStructuringElement( + cv2.MORPH_ELLIPSE, (open_k, open_k), + ) + binary = cv2.morphologyEx( + binary, cv2.MORPH_OPEN, open_kernel, + ) + + # 横方向クロージング(途切れ補間) + if params.close_width >= 3: + close_h = max(params.close_height | 1, 1) + close_kernel = cv2.getStructuringElement( + cv2.MORPH_ELLIPSE, + (params.close_width, close_h), + ) + binary = cv2.morphologyEx( + binary, cv2.MORPH_CLOSE, close_kernel, + ) + + # 全ピクセルフィッティング + region = binary[DETECT_Y_START:DETECT_Y_END, :] + ys_local, xs = np.where(region > 0) + + if len(xs) < MIN_FIT_PIXELS: + return no_detection(binary) + + ys = ys_local + DETECT_Y_START + coeffs = np.polyfit(ys, xs, 2) + return build_result(coeffs, binary) diff --git a/src/pc/vision/detectors/dual_norm.py b/src/pc/vision/detectors/dual_norm.py new file mode 100644 index 0000000..46a9a6d --- /dev/null +++ b/src/pc/vision/detectors/dual_norm.py @@ -0,0 +1,86 @@ +""" +dual_norm +案B: 二重正規化型の線検出 +背景除算で照明勾配を除去し, +適応的閾値で局所ムラにも対応する二重防壁構成 +""" + +import cv2 +import numpy as np + +from pc.vision.line_detector import ImageParams, LineDetectResult +from pc.vision.line_detector import fit_row_centers +from pc.vision.morphology import ( + apply_dist_mask, + apply_iso_closing, + apply_staged_closing, + apply_width_filter, +) + + +def detect_dual_norm( + frame: np.ndarray, params: ImageParams, +) -> LineDetectResult: + """案B: 二重正規化型""" + # 背景除算正規化 + bg_k = params.bg_blur_ksize | 1 + bg = cv2.GaussianBlur( + frame, (bg_k, bg_k), 0, + ) + normalized = ( + frame.astype(np.float32) * 255.0 + / (bg.astype(np.float32) + 1.0) + ) + normalized = np.clip( + normalized, 0, 255, + ).astype(np.uint8) + + # 適応的閾値(ガウシアン,BINARY_INV) + block = max(params.adaptive_block | 1, 3) + binary = cv2.adaptiveThreshold( + normalized, 255, + cv2.ADAPTIVE_THRESH_GAUSSIAN_C, + cv2.THRESH_BINARY_INV, + block, params.adaptive_c, + ) + + # 固定閾値との AND(有効時のみ) + if params.global_thresh > 0: + _, global_mask = cv2.threshold( + normalized, params.global_thresh, + 255, cv2.THRESH_BINARY_INV, + ) + binary = cv2.bitwise_and(binary, global_mask) + + # 段階クロージング or 等方クロージング + if params.stage_min_area > 0: + binary = apply_staged_closing( + binary, + params.stage_close_small, + params.stage_min_area, + params.stage_close_large, + ) + else: + binary = apply_iso_closing( + binary, params.iso_close_size, + ) + + # 距離変換マスク + 幅フィルタ + binary = apply_dist_mask( + binary, params.dist_thresh, + ) + if params.width_near > 0 and params.width_far > 0: + binary = apply_width_filter( + binary, + params.width_near, + params.width_far, + params.width_tolerance, + ) + + # 行ごと中心抽出 + フィッティング + return fit_row_centers( + binary, params.min_line_width, + median_ksize=params.median_ksize, + neighbor_thresh=params.neighbor_thresh, + residual_thresh=params.residual_thresh, + ) diff --git a/src/pc/vision/detectors/robust.py b/src/pc/vision/detectors/robust.py new file mode 100644 index 0000000..ac55d35 --- /dev/null +++ b/src/pc/vision/detectors/robust.py @@ -0,0 +1,66 @@ +""" +robust +案C: 最高ロバスト型の線検出 +Black-hat + 適応的閾値の二重正規化に加え, +RANSAC で外れ値を除去する最もロバストな構成 +""" + +import cv2 +import numpy as np + +from pc.vision.line_detector import ImageParams, LineDetectResult +from pc.vision.line_detector import fit_row_centers +from pc.vision.morphology import ( + apply_dist_mask, + apply_iso_closing, + apply_width_filter, +) + + +def detect_robust( + frame: np.ndarray, params: ImageParams, +) -> LineDetectResult: + """案C: 最高ロバスト型""" + # Black-hat 変換 + bh_k = params.blackhat_ksize | 1 + bh_kernel = cv2.getStructuringElement( + cv2.MORPH_ELLIPSE, (bh_k, bh_k), + ) + blackhat = cv2.morphologyEx( + frame, cv2.MORPH_BLACKHAT, bh_kernel, + ) + + # 適応的閾値(BINARY: Black-hat 後は線が白) + block = max(params.adaptive_block | 1, 3) + binary = cv2.adaptiveThreshold( + blackhat, 255, + cv2.ADAPTIVE_THRESH_GAUSSIAN_C, + cv2.THRESH_BINARY, + block, -params.adaptive_c, + ) + + # 等方クロージング + 距離変換マスク + 幅フィルタ + binary = apply_iso_closing( + binary, params.iso_close_size, + ) + binary = apply_dist_mask( + binary, params.dist_thresh, + ) + if params.width_near > 0 and params.width_far > 0: + binary = apply_width_filter( + binary, + params.width_near, + params.width_far, + params.width_tolerance, + ) + + # 行ごと中央値抽出 + RANSAC フィッティング + return fit_row_centers( + binary, params.min_line_width, + use_median=True, + ransac_thresh=params.ransac_thresh, + ransac_iter=params.ransac_iter, + median_ksize=params.median_ksize, + neighbor_thresh=params.neighbor_thresh, + residual_thresh=params.residual_thresh, + ) diff --git a/src/pc/vision/detectors/valley.py b/src/pc/vision/detectors/valley.py new file mode 100644 index 0000000..ce14b52 --- /dev/null +++ b/src/pc/vision/detectors/valley.py @@ -0,0 +1,328 @@ +""" +valley +案D: 谷検出+追跡型の線検出 +各行の輝度信号から谷(暗い領域)を直接検出し, +時系列追跡で安定性を確保する.二値化を使用しない +""" + +import cv2 +import numpy as np + +from common import config +from pc.vision.fitting import clean_and_fit +from pc.vision.line_detector import ( + DETECT_Y_END, + DETECT_Y_START, + MIN_FIT_ROWS, + ImageParams, + LineDetectResult, + build_result, + no_detection, +) + + +class ValleyTracker: + """谷検出の時系列追跡を管理するクラス + + 前フレームの多項式係数を保持し,予測・平滑化・ + 検出失敗時のコースティングを提供する + """ + + def __init__(self) -> None: + self._prev_coeffs: np.ndarray | None = None + self._smoothed_coeffs: np.ndarray | None = None + self._frames_lost: int = 0 + + def predict_x(self, y: float) -> float | None: + """前フレームの多項式から x 座標を予測する + + Args: + y: 画像の y 座標 + + Returns: + 予測 x 座標(履歴なしの場合は None) + """ + if self._smoothed_coeffs is None: + return None + return float(np.poly1d(self._smoothed_coeffs)(y)) + + def update( + self, + coeffs: np.ndarray, + alpha: float, + ) -> np.ndarray: + """検出成功時に状態を更新する + + EMA で多項式係数を平滑化し,更新後の係数を返す + + Args: + coeffs: 今フレームのフィッティング係数 + alpha: EMA 係数(1.0 で平滑化なし) + + Returns: + 平滑化後の多項式係数 + """ + self._frames_lost = 0 + if self._smoothed_coeffs is None: + self._smoothed_coeffs = coeffs.copy() + else: + self._smoothed_coeffs = ( + alpha * coeffs + + (1.0 - alpha) * self._smoothed_coeffs + ) + self._prev_coeffs = self._smoothed_coeffs.copy() + return self._smoothed_coeffs + + def coast( + self, max_frames: int, + ) -> LineDetectResult | None: + """検出失敗時に予測結果を返す + + Args: + max_frames: 予測を継続する最大フレーム数 + + Returns: + 予測による結果(継続不可の場合は None) + """ + if self._smoothed_coeffs is None: + return None + self._frames_lost += 1 + if self._frames_lost > max_frames: + return None + # 予測でデバッグ用二値画像は空にする + h = config.FRAME_HEIGHT + w = config.FRAME_WIDTH + blank = np.zeros((h, w), dtype=np.uint8) + return build_result(self._smoothed_coeffs, blank) + + def reset(self) -> None: + """追跡状態をリセットする""" + self._prev_coeffs = None + self._smoothed_coeffs = None + self._frames_lost = 0 + + +_valley_tracker = ValleyTracker() + + +def reset_valley_tracker() -> None: + """谷検出の追跡状態をリセットする""" + _valley_tracker.reset() + + +def _find_row_valley( + row: np.ndarray, + min_depth: int, + expected_width: float, + width_tolerance: float, + predicted_x: float | None, + max_deviation: int, +) -> tuple[float, float] | None: + """1行の輝度信号から最適な谷を検出する + + Args: + row: スムージング済みの1行輝度信号 + min_depth: 最小谷深度 + expected_width: 期待線幅(px,0 で幅フィルタ無効) + width_tolerance: 幅フィルタの上限倍率 + predicted_x: 追跡による予測 x 座標(None で無効) + max_deviation: 予測からの最大許容偏差 + + Returns: + (谷の中心x, 谷の深度) または None + """ + n = len(row) + if n < 5: + return None + + signal = row.astype(np.float32) + + # 極小値を検出(前後より小さい点) + left = signal[:-2] + center = signal[1:-1] + right = signal[2:] + minima_mask = (center <= left) & (center <= right) + minima_indices = np.where(minima_mask)[0] + 1 + + if len(minima_indices) == 0: + return None + + best: tuple[float, float] | None = None + best_score = -1.0 + + for idx in minima_indices: + val = signal[idx] + + # 左の肩を探す + left_shoulder = idx + for i in range(idx - 1, -1, -1): + if signal[i] < signal[i + 1]: + break + left_shoulder = i + # 右の肩を探す + right_shoulder = idx + for i in range(idx + 1, n): + if signal[i] < signal[i - 1]: + break + right_shoulder = i + + # 谷の深度(肩の平均 - 谷底) + shoulder_avg = ( + signal[left_shoulder] + signal[right_shoulder] + ) / 2.0 + depth = shoulder_avg - val + if depth < min_depth: + continue + + # 谷の幅 + width = right_shoulder - left_shoulder + center_x = (left_shoulder + right_shoulder) / 2.0 + + # 幅フィルタ + if expected_width > 0: + max_w = expected_width * width_tolerance + min_w = expected_width / width_tolerance + if width > max_w or width < min_w: + continue + + # 予測との偏差チェック + if predicted_x is not None: + if abs(center_x - predicted_x) > max_deviation: + continue + + # スコア: 深度優先,予測がある場合は近さも考慮 + score = float(depth) + if predicted_x is not None: + dist = abs(center_x - predicted_x) + score += max(0.0, max_deviation - dist) + + if score > best_score: + best_score = score + best = (center_x, float(depth)) + + return best + + +def _build_valley_binary( + shape: tuple[int, int], + centers_y: list[int], + centers_x: list[float], +) -> np.ndarray: + """谷検出結果からデバッグ用二値画像を生成する + + Args: + shape: 出力画像の (高さ, 幅) + centers_y: 検出行の y 座標リスト + centers_x: 検出行の中心 x 座標リスト + + Returns: + デバッグ用二値画像 + """ + binary = np.zeros(shape, dtype=np.uint8) + half_w = 3 + w = shape[1] + for y, cx in zip(centers_y, centers_x): + x0 = max(0, int(cx) - half_w) + x1 = min(w, int(cx) + half_w + 1) + binary[y, x0:x1] = 255 + return binary + + +def detect_valley( + frame: np.ndarray, params: ImageParams, +) -> LineDetectResult: + """案D: 谷検出+追跡型""" + h, w = frame.shape[:2] + + # 行ごとにガウシアン平滑化するため画像全体をブラー + gauss_k = params.valley_gauss_ksize | 1 + blurred = cv2.GaussianBlur( + frame, (gauss_k, 1), 0, + ) + + # 透視補正の期待幅を計算するための準備 + use_width = ( + params.width_near > 0 and params.width_far > 0 + ) + detect_h = DETECT_Y_END - DETECT_Y_START + denom = max(detect_h - 1, 1) + + centers_y: list[int] = [] + centers_x: list[float] = [] + depths: list[float] = [] + + for y in range(DETECT_Y_START, DETECT_Y_END): + row = blurred[y] + + # 期待幅の計算 + if use_width: + t = (DETECT_Y_END - 1 - y) / denom + expected_w = float(params.width_far) + ( + float(params.width_near) + - float(params.width_far) + ) * t + else: + expected_w = 0.0 + + # 予測 x 座標 + predicted_x = _valley_tracker.predict_x( + float(y), + ) + + result = _find_row_valley( + row, + params.valley_min_depth, + expected_w, + params.width_tolerance, + predicted_x, + params.valley_max_deviation, + ) + if result is not None: + centers_y.append(y) + centers_x.append(result[0]) + depths.append(result[1]) + + # デバッグ用二値画像 + debug_binary = _build_valley_binary( + (h, w), centers_y, centers_x, + ) + + if len(centers_y) < MIN_FIT_ROWS: + # 検出失敗 → コースティングを試みる + coasted = _valley_tracker.coast( + params.valley_coast_frames, + ) + if coasted is not None: + coasted.binary_image = debug_binary + return coasted + return no_detection(debug_binary) + + cy = np.array(centers_y, dtype=np.float64) + cx = np.array(centers_x, dtype=np.float64) + w_arr = np.array(depths, dtype=np.float64) + + # ロバストフィッティング(深度を重みに使用) + coeffs = clean_and_fit( + cy, cx, + median_ksize=params.median_ksize, + neighbor_thresh=params.neighbor_thresh, + residual_thresh=params.residual_thresh, + weights=w_arr, + ransac_thresh=params.ransac_thresh, + ransac_iter=params.ransac_iter, + ) + if coeffs is None: + coasted = _valley_tracker.coast( + params.valley_coast_frames, + ) + if coasted is not None: + coasted.binary_image = debug_binary + return coasted + return no_detection(debug_binary) + + # EMA で平滑化 + smoothed = _valley_tracker.update( + coeffs, params.valley_ema_alpha, + ) + + return build_result(smoothed, debug_binary) diff --git a/src/pc/vision/fitting.py b/src/pc/vision/fitting.py new file mode 100644 index 0000000..34813d5 --- /dev/null +++ b/src/pc/vision/fitting.py @@ -0,0 +1,209 @@ +""" +fitting +直線・曲線近似の共通ユーティリティモジュール +Theil-Sen 推定,RANSAC,外れ値除去付きフィッティングを提供する +""" + +import numpy as np + +# フィッティングに必要な最小行数 +MIN_FIT_ROWS: int = 10 + +# 近傍外れ値除去の設定 +NEIGHBOR_HALF_WINDOW: int = 3 +NEIGHBOR_FILTER_PASSES: int = 3 + +# 残差ベース反復除去の最大回数 +RESIDUAL_REMOVAL_ITERATIONS: int = 5 + + +def theil_sen_fit( + y: np.ndarray, + x: np.ndarray, +) -> tuple[float, float]: + """Theil-Sen 推定で直線 x = slope * y + intercept を求める + + 全ペアの傾きの中央値を使い,外れ値に強い直線近似を行う + + Args: + y: y 座標の配列(行番号) + x: x 座標の配列(各行の中心) + + Returns: + (slope, intercept) のタプル + """ + n = len(y) + slopes = [] + for i in range(n): + for j in range(i + 1, n): + dy = y[j] - y[i] + if dy != 0: + slopes.append((x[j] - x[i]) / dy) + + if len(slopes) == 0: + return 0.0, float(np.median(x)) + + slope = float(np.median(slopes)) + intercept = float(np.median(x - slope * y)) + return slope, intercept + + +def ransac_polyfit( + ys: np.ndarray, xs: np.ndarray, + degree: int, n_iter: int, thresh: float, +) -> np.ndarray | None: + """RANSAC で外れ値を除去して多項式フィッティング + + Args: + ys: y 座標配列 + xs: x 座標配列 + degree: 多項式の次数 + n_iter: 反復回数 + thresh: 外れ値判定閾値(ピクセル) + + Returns: + 多項式係数(フィッティング失敗時は None) + """ + n = len(ys) + sample_size = degree + 1 + if n < sample_size: + return None + + best_coeffs: np.ndarray | None = None + best_inliers = 0 + rng = np.random.default_rng() + + for _ in range(n_iter): + idx = rng.choice(n, sample_size, replace=False) + coeffs = np.polyfit(ys[idx], xs[idx], degree) + poly = np.poly1d(coeffs) + residuals = np.abs(xs - poly(ys)) + n_inliers = int(np.sum(residuals < thresh)) + if n_inliers > best_inliers: + best_inliers = n_inliers + best_coeffs = coeffs + + # インライアで再フィッティング + if best_coeffs is not None: + poly = np.poly1d(best_coeffs) + inlier_mask = np.abs(xs - poly(ys)) < thresh + if np.sum(inlier_mask) >= sample_size: + best_coeffs = np.polyfit( + ys[inlier_mask], + xs[inlier_mask], + degree, + ) + + return best_coeffs + + +def clean_and_fit( + cy: np.ndarray, + cx: np.ndarray, + median_ksize: int, + neighbor_thresh: float, + residual_thresh: float = 0.0, + weights: np.ndarray | None = None, + ransac_thresh: float = 0.0, + ransac_iter: int = 0, +) -> np.ndarray | None: + """外れ値除去+重み付きフィッティングを行う + + 全検出手法で共通に使えるロバストなフィッティング + (1) 移動メディアンフィルタでスパイクを平滑化 + (2) 近傍中央値からの偏差で外れ値を除去(複数パス) + (3) 重み付き最小二乗(または RANSAC)でフィッティング + (4) 残差ベースの反復除去で外れ値を最終除去 + + Args: + cy: 中心点の y 座標配列 + cx: 中心点の x 座標配列 + median_ksize: 移動メディアンのカーネルサイズ(0 で無効) + neighbor_thresh: 近傍外れ値除去の閾値 px(0 で無効) + residual_thresh: 残差除去の閾値 px(0 で無効) + weights: 各点の信頼度(None で均等) + ransac_thresh: RANSAC 閾値(0 以下で無効) + ransac_iter: RANSAC 反復回数 + + Returns: + 多項式係数(フィッティング失敗時は None) + """ + if len(cy) < MIN_FIT_ROWS: + return None + + cx_clean = cx.copy() + mask = np.ones(len(cy), dtype=bool) + + # (1) 移動メディアンフィルタ + if median_ksize >= 3: + k = median_ksize | 1 + half = k // 2 + for i in range(len(cx_clean)): + lo = max(0, i - half) + hi = min(len(cx_clean), i + half + 1) + cx_clean[i] = float(np.median(cx[lo:hi])) + + # (2) 近傍外れ値除去(複数パス) + if neighbor_thresh > 0: + half_n = NEIGHBOR_HALF_WINDOW + for _ in range(NEIGHBOR_FILTER_PASSES): + new_mask = np.ones(len(cx_clean), dtype=bool) + for i in range(len(cx_clean)): + if not mask[i]: + continue + lo = max(0, i - half_n) + hi = min(len(cx_clean), i + half_n + 1) + neighbors = cx_clean[lo:hi][mask[lo:hi]] + if len(neighbors) == 0: + new_mask[i] = False + continue + local_med = float(np.median(neighbors)) + if abs(cx_clean[i] - local_med) > neighbor_thresh: + new_mask[i] = False + if np.array_equal(mask, mask & new_mask): + break + mask = mask & new_mask + + cy = cy[mask] + cx_clean = cx_clean[mask] + if weights is not None: + weights = weights[mask] + + if len(cy) < MIN_FIT_ROWS: + return None + + # (3) フィッティング + if ransac_thresh > 0 and ransac_iter > 0: + coeffs = ransac_polyfit( + cy, cx_clean, 2, ransac_iter, ransac_thresh, + ) + elif weights is not None: + coeffs = np.polyfit(cy, cx_clean, 2, w=weights) + else: + coeffs = np.polyfit(cy, cx_clean, 2) + + if coeffs is None: + return None + + # (4) 残差ベースの反復除去 + if residual_thresh > 0: + for _ in range(RESIDUAL_REMOVAL_ITERATIONS): + poly = np.poly1d(coeffs) + residuals = np.abs(cx_clean - poly(cy)) + inlier = residuals < residual_thresh + if np.all(inlier): + break + if np.sum(inlier) < MIN_FIT_ROWS: + break + cy = cy[inlier] + cx_clean = cx_clean[inlier] + if weights is not None: + weights = weights[inlier] + if weights is not None: + coeffs = np.polyfit( + cy, cx_clean, 2, w=weights, + ) + else: + coeffs = np.polyfit(cy, cx_clean, 2) + + return coeffs diff --git a/src/pc/vision/line_detector.py b/src/pc/vision/line_detector.py index dde4275..09115bf 100644 --- a/src/pc/vision/line_detector.py +++ b/src/pc/vision/line_detector.py @@ -2,6 +2,10 @@ line_detector カメラ画像から黒線の位置を検出するモジュール 複数の検出手法を切り替えて使用できる + +公開 API: + ImageParams, LineDetectResult, detect_line, + reset_valley_tracker, DETECT_METHODS """ from dataclasses import dataclass @@ -10,6 +14,7 @@ import numpy as np from common import config +from pc.vision.fitting import clean_and_fit # 検出領域の y 範囲(画像全体) DETECT_Y_START: int = 0 @@ -25,6 +30,7 @@ "blackhat": "案A(Black-hat 中心)", "dual_norm": "案B(二重正規化)", "robust": "案C(最高ロバスト)", + "valley": "案D(谷検出+追跡)", } @@ -43,16 +49,28 @@ close_height: クロージングの高さ blackhat_ksize: Black-hat のカーネルサイズ bg_blur_ksize: 背景除算のブラーカーネルサイズ + global_thresh: 固定閾値(0 で無効,適応的閾値との AND) adaptive_block: 適応的閾値のブロックサイズ adaptive_c: 適応的閾値の定数 C iso_close_size: 等方クロージングのカーネルサイズ dist_thresh: 距離変換の閾値 min_line_width: 行ごと中心抽出の最小線幅 + stage_close_small: 段階クロージング第1段のサイズ + stage_min_area: 孤立除去の最小面積(0 で無効) + stage_close_large: 段階クロージング第2段のサイズ(0 で無効) ransac_thresh: RANSAC の外れ値判定閾値 ransac_iter: RANSAC の反復回数 width_near: 画像下端での期待線幅(px,0 で無効) width_far: 画像上端での期待線幅(px,0 で無効) width_tolerance: 幅フィルタの上限倍率 + median_ksize: 中心点列の移動メディアンフィルタサイズ(0 で無効) + neighbor_thresh: 近傍外れ値除去の閾値(px,0 で無効) + residual_thresh: 残差反復除去の閾値(px,0 で無効) + valley_gauss_ksize: 谷検出の行ごとガウシアンカーネルサイズ + valley_min_depth: 谷として認識する最小深度 + valley_max_deviation: 追跡予測からの最大許容偏差(px) + valley_coast_frames: 検出失敗時の予測継続フレーム数 + valley_ema_alpha: 多項式係数の指数移動平均係数 """ # 検出手法 @@ -72,6 +90,7 @@ # 案B: 背景除算 bg_blur_ksize: int = 101 + global_thresh: int = 0 # 固定閾値(0 で無効) # 案B/C: 適応的閾値 adaptive_block: int = 51 @@ -82,15 +101,32 @@ dist_thresh: float = 3.0 min_line_width: int = 3 + # 案B: 段階クロージング + stage_close_small: int = 5 # 第1段: 小クロージングサイズ + stage_min_area: int = 0 # 孤立除去の最小面積(0 で無効) + stage_close_large: int = 0 # 第2段: 大クロージングサイズ(0 で無効) + # 案C: RANSAC ransac_thresh: float = 5.0 ransac_iter: int = 50 + # ロバストフィッティング(全手法共通) + median_ksize: int = 7 + neighbor_thresh: float = 10.0 + residual_thresh: float = 8.0 + # 透視補正付き幅フィルタ(0 で無効) width_near: int = 0 width_far: int = 0 width_tolerance: float = 1.8 + # 案D: 谷検出+追跡 + valley_gauss_ksize: int = 15 + valley_min_depth: int = 15 + valley_max_deviation: int = 40 + valley_coast_frames: int = 3 + valley_ema_alpha: float = 0.7 + @dataclass class LineDetectResult: @@ -102,6 +138,8 @@ heading: 線の傾き(dx/dy,画像下端での値) curvature: 線の曲率(d²x/dy²) poly_coeffs: 多項式の係数(描画用,未検出時は None) + row_centers: 各行の線中心 x 座標(index=行番号, + NaN=その行に線なし,未検出時は None) binary_image: 二値化後の画像(デバッグ用) """ @@ -110,9 +148,13 @@ heading: float curvature: float poly_coeffs: np.ndarray | None + row_centers: np.ndarray | None binary_image: np.ndarray | None +# ── 公開 API ────────────────────────────────────── + + def detect_line( frame: np.ndarray, params: ImageParams | None = None, @@ -133,345 +175,149 @@ method = params.method if method == "blackhat": - return _detect_blackhat(frame, params) + from pc.vision.detectors.blackhat import ( + detect_blackhat, + ) + return detect_blackhat(frame, params) if method == "dual_norm": - return _detect_dual_norm(frame, params) + from pc.vision.detectors.dual_norm import ( + detect_dual_norm, + ) + return detect_dual_norm(frame, params) if method == "robust": - return _detect_robust(frame, params) - return _detect_current(frame, params) - - -# ── 検出手法の実装 ───────────────────────────── - - -def _detect_current( - frame: np.ndarray, params: ImageParams, -) -> LineDetectResult: - """現行手法: CLAHE + 固定閾値 + 全ピクセルフィッティング""" - # CLAHE でコントラスト強調 - clahe = cv2.createCLAHE( - clipLimit=params.clahe_clip, - tileGridSize=( - params.clahe_grid, - params.clahe_grid, - ), - ) - enhanced = clahe.apply(frame) - - # ガウシアンブラー - blur_k = params.blur_size | 1 - blurred = cv2.GaussianBlur( - enhanced, (blur_k, blur_k), 0, - ) - - # 固定閾値で二値化(黒線を白に反転) - _, binary = cv2.threshold( - blurred, params.binary_thresh, 255, - cv2.THRESH_BINARY_INV, - ) - - # オープニング(孤立ノイズ除去) - if params.open_size >= 3: - open_k = params.open_size | 1 - open_kernel = cv2.getStructuringElement( - cv2.MORPH_ELLIPSE, (open_k, open_k), + from pc.vision.detectors.robust import ( + detect_robust, ) - binary = cv2.morphologyEx( - binary, cv2.MORPH_OPEN, open_kernel, + return detect_robust(frame, params) + if method == "valley": + from pc.vision.detectors.valley import ( + detect_valley, ) + return detect_valley(frame, params) - # 横方向クロージング(途切れ補間) - if params.close_width >= 3: - close_h = max(params.close_height | 1, 1) - close_kernel = cv2.getStructuringElement( - cv2.MORPH_ELLIPSE, - (params.close_width, close_h), - ) - binary = cv2.morphologyEx( - binary, cv2.MORPH_CLOSE, close_kernel, - ) - - # 全ピクセルフィッティング(従来方式) - return _fit_all_pixels(binary) - - -def _detect_blackhat( - frame: np.ndarray, params: ImageParams, -) -> LineDetectResult: - """案A: Black-hat 中心型 - - Black-hat 変換で背景より暗い構造を直接抽出し, - 固定閾値 + 距離変換 + 行ごと中心抽出で検出する - """ - # Black-hat 変換(暗い構造の抽出) - bh_k = params.blackhat_ksize | 1 - bh_kernel = cv2.getStructuringElement( - cv2.MORPH_ELLIPSE, (bh_k, bh_k), + from pc.vision.detectors.current import ( + detect_current, ) - blackhat = cv2.morphologyEx( - frame, cv2.MORPH_BLACKHAT, bh_kernel, + return detect_current(frame, params) + + +def reset_valley_tracker() -> None: + """谷検出の追跡状態をリセットする""" + from pc.vision.detectors.valley import ( + reset_valley_tracker as _reset, ) - - # ガウシアンブラー - blur_k = params.blur_size | 1 - blurred = cv2.GaussianBlur( - blackhat, (blur_k, blur_k), 0, - ) - - # 固定閾値(Black-hat 後は線が白) - _, binary = cv2.threshold( - blurred, params.binary_thresh, 255, - cv2.THRESH_BINARY, - ) - - # 等方クロージング + 距離変換マスク + 幅フィルタ - binary = _apply_iso_closing( - binary, params.iso_close_size, - ) - binary = _apply_dist_mask( - binary, params.dist_thresh, - ) - if params.width_near > 0 and params.width_far > 0: - binary = _apply_width_filter( - binary, - params.width_near, - params.width_far, - params.width_tolerance, - ) - - # 行ごと中心抽出 + フィッティング - return _fit_row_centers( - binary, params.min_line_width, - ) + _reset() -def _detect_dual_norm( - frame: np.ndarray, params: ImageParams, -) -> LineDetectResult: - """案B: 二重正規化型 - - 背景除算で照明勾配を除去し, - 適応的閾値で局所ムラにも対応する二重防壁構成 - """ - # 背景除算正規化 - bg_k = params.bg_blur_ksize | 1 - bg = cv2.GaussianBlur( - frame, (bg_k, bg_k), 0, - ) - normalized = ( - frame.astype(np.float32) * 255.0 - / (bg.astype(np.float32) + 1.0) - ) - normalized = np.clip( - normalized, 0, 255, - ).astype(np.uint8) - - # 適応的閾値(ガウシアン,BINARY_INV) - block = max(params.adaptive_block | 1, 3) - binary = cv2.adaptiveThreshold( - normalized, 255, - cv2.ADAPTIVE_THRESH_GAUSSIAN_C, - cv2.THRESH_BINARY_INV, - block, params.adaptive_c, - ) - - # 等方クロージング + 距離変換マスク + 幅フィルタ - binary = _apply_iso_closing( - binary, params.iso_close_size, - ) - binary = _apply_dist_mask( - binary, params.dist_thresh, - ) - if params.width_near > 0 and params.width_far > 0: - binary = _apply_width_filter( - binary, - params.width_near, - params.width_far, - params.width_tolerance, - ) - - # 行ごと中心抽出 + フィッティング - return _fit_row_centers( - binary, params.min_line_width, - ) +# ── 共通結果構築(各検出器から使用) ────────────── -def _detect_robust( - frame: np.ndarray, params: ImageParams, -) -> LineDetectResult: - """案C: 最高ロバスト型 - - Black-hat + 適応的閾値の二重正規化に加え, - RANSAC で外れ値を除去する最もロバストな構成 - """ - # Black-hat 変換 - bh_k = params.blackhat_ksize | 1 - bh_kernel = cv2.getStructuringElement( - cv2.MORPH_ELLIPSE, (bh_k, bh_k), - ) - blackhat = cv2.morphologyEx( - frame, cv2.MORPH_BLACKHAT, bh_kernel, - ) - - # 適応的閾値(BINARY: Black-hat 後は線が白) - block = max(params.adaptive_block | 1, 3) - binary = cv2.adaptiveThreshold( - blackhat, 255, - cv2.ADAPTIVE_THRESH_GAUSSIAN_C, - cv2.THRESH_BINARY, - block, -params.adaptive_c, - ) - - # 等方クロージング + 距離変換マスク + 幅フィルタ - binary = _apply_iso_closing( - binary, params.iso_close_size, - ) - binary = _apply_dist_mask( - binary, params.dist_thresh, - ) - if params.width_near > 0 and params.width_far > 0: - binary = _apply_width_filter( - binary, - params.width_near, - params.width_far, - params.width_tolerance, - ) - - # 行ごと中央値抽出 + RANSAC フィッティング - return _fit_row_centers( - binary, params.min_line_width, - use_median=True, - ransac_thresh=params.ransac_thresh, - ransac_iter=params.ransac_iter, - ) - - -# ── 共通処理 ─────────────────────────────────── - - -def _apply_iso_closing( - binary: np.ndarray, size: int, -) -> np.ndarray: - """等方クロージングで穴を埋める - - Args: - binary: 二値画像 - size: カーネルサイズ - - Returns: - クロージング後の二値画像 - """ - if size < 3: - return binary - k = size | 1 - kernel = cv2.getStructuringElement( - cv2.MORPH_ELLIPSE, (k, k), - ) - return cv2.morphologyEx( - binary, cv2.MORPH_CLOSE, kernel, - ) - - -def _apply_width_filter( - binary: np.ndarray, - width_near: int, - width_far: int, - tolerance: float, -) -> np.ndarray: - """透視補正付き幅フィルタで広がりすぎた行を除外する - - 各行の期待線幅を線形補間で算出し, - 実際の幅が上限(期待幅 × tolerance)を超える行をマスクする - - Args: - binary: 二値画像 - width_near: 画像下端での期待線幅(px) - width_far: 画像上端での期待線幅(px) - tolerance: 上限倍率 - - Returns: - 幅フィルタ適用後の二値画像 - """ - result = binary.copy() - h = binary.shape[0] - denom = max(h - 1, 1) - - for y_local in range(h): - xs = np.where(binary[y_local] > 0)[0] - if len(xs) == 0: - continue - # 画像下端(近距離)ほど t=1,上端(遠距離)ほど t=0 - t = (h - 1 - y_local) / denom - expected = float(width_far) + ( - float(width_near) - float(width_far) - ) * t - max_w = expected * tolerance - actual_w = int(xs[-1]) - int(xs[0]) + 1 - if actual_w > max_w: - result[y_local] = 0 - - return result - - -def _apply_dist_mask( - binary: np.ndarray, thresh: float, -) -> np.ndarray: - """距離変換で中心部のみを残す - - Args: - binary: 二値画像 - thresh: 距離の閾値(ピクセル) - - Returns: - 中心部のみの二値画像 - """ - if thresh <= 0: - return binary - dist = cv2.distanceTransform( - binary, cv2.DIST_L2, 5, - ) - _, mask = cv2.threshold( - dist, thresh, 255, cv2.THRESH_BINARY, - ) - return mask.astype(np.uint8) - - -def _fit_all_pixels( +def no_detection( binary: np.ndarray, ) -> LineDetectResult: - """全白ピクセルに多項式をフィッティングする + """未検出の結果を返す""" + return LineDetectResult( + detected=False, + position_error=0.0, + heading=0.0, + curvature=0.0, + poly_coeffs=None, + row_centers=None, + binary_image=binary, + ) - 従来方式.全ピクセルを等しく扱うため, - 陰で幅が広がった行がフィッティングを支配する弱点がある + +def _extract_row_centers( + binary: np.ndarray, +) -> np.ndarray | None: + """二値画像の最大連結領域から各行の線中心を求める Args: binary: 二値画像 Returns: - 線検出の結果 + 各行の中心 x 座標(NaN=その行に線なし), + 最大領域が見つからない場合は None """ - region = binary[DETECT_Y_START:DETECT_Y_END, :] - ys_local, xs = np.where(region > 0) + h, w = binary.shape[:2] + num_labels, labels, stats, _ = ( + cv2.connectedComponentsWithStats(binary) + ) - if len(xs) < MIN_FIT_PIXELS: - return _no_detection(binary) + if num_labels <= 1: + return None - ys = ys_local + DETECT_Y_START - coeffs = np.polyfit(ys, xs, 2) - return _build_result(coeffs, binary) + # 背景(ラベル 0)を除いた最大領域を取得 + areas = stats[1:, cv2.CC_STAT_AREA] + largest_label = int(np.argmax(areas)) + 1 + + # 最大領域のマスク + mask = (labels == largest_label).astype(np.uint8) + + # 各行の左右端から中心を計算 + centers = np.full(h, np.nan) + for y in range(h): + row = mask[y] + cols = np.where(row > 0)[0] + if len(cols) > 0: + centers[y] = (cols[0] + cols[-1]) / 2.0 + + return centers -def _fit_row_centers( +def build_result( + coeffs: np.ndarray, + binary: np.ndarray, + row_centers: np.ndarray | None = None, +) -> LineDetectResult: + """多項式係数から LineDetectResult を構築する + + row_centers が None の場合は binary から自動抽出する + """ + poly = np.poly1d(coeffs) + center_x = config.FRAME_WIDTH / 2.0 + + # 画像下端での位置偏差 + x_bottom = poly(DETECT_Y_END) + position_error = (center_x - x_bottom) / center_x + + # 傾き: dx/dy(画像下端での値) + poly_deriv = poly.deriv() + heading = float(poly_deriv(DETECT_Y_END)) + + # 曲率: d²x/dy² + poly_deriv2 = poly_deriv.deriv() + curvature = float(poly_deriv2(DETECT_Y_END)) + + # row_centers が未提供なら binary から抽出 + if row_centers is None: + row_centers = _extract_row_centers(binary) + + return LineDetectResult( + detected=True, + position_error=position_error, + heading=heading, + curvature=curvature, + poly_coeffs=coeffs, + row_centers=row_centers, + binary_image=binary, + ) + + +def fit_row_centers( binary: np.ndarray, min_width: int, use_median: bool = False, ransac_thresh: float = 0.0, ransac_iter: int = 0, + median_ksize: int = 0, + neighbor_thresh: float = 0.0, + residual_thresh: float = 0.0, ) -> LineDetectResult: """行ごとの中心点に多項式をフィッティングする 各行の白ピクセルの中心(平均または中央値)を1点抽出し, - 中心点列に対してフィッティングする. + ロバスト前処理の後にフィッティングする. 幅の変動に強く,各行が等しく寄与する Args: @@ -480,6 +326,9 @@ use_median: True の場合は中央値を使用 ransac_thresh: RANSAC 閾値(0 以下で無効) ransac_iter: RANSAC 反復回数 + median_ksize: 移動メディアンのカーネルサイズ + neighbor_thresh: 近傍外れ値除去の閾値 px + residual_thresh: 残差反復除去の閾値 px Returns: 線検出の結果 @@ -500,111 +349,20 @@ centers_x.append(float(np.mean(xs))) if len(centers_y) < MIN_FIT_ROWS: - return _no_detection(binary) + return no_detection(binary) cy = np.array(centers_y) cx = np.array(centers_x) - if ransac_thresh > 0 and ransac_iter > 0: - coeffs = _ransac_polyfit( - cy, cx, 2, ransac_iter, ransac_thresh, - ) - if coeffs is None: - return _no_detection(binary) - else: - coeffs = np.polyfit(cy, cx, 2) - - return _build_result(coeffs, binary) - - -def _ransac_polyfit( - ys: np.ndarray, xs: np.ndarray, - degree: int, n_iter: int, thresh: float, -) -> np.ndarray | None: - """RANSAC で外れ値を除去して多項式フィッティング - - Args: - ys: y 座標配列 - xs: x 座標配列 - degree: 多項式の次数 - n_iter: 反復回数 - thresh: 外れ値判定閾値(ピクセル) - - Returns: - 多項式係数(フィッティング失敗時は None) - """ - n = len(ys) - sample_size = degree + 1 - if n < sample_size: - return None - - best_coeffs: np.ndarray | None = None - best_inliers = 0 - rng = np.random.default_rng() - - for _ in range(n_iter): - idx = rng.choice(n, sample_size, replace=False) - coeffs = np.polyfit(ys[idx], xs[idx], degree) - poly = np.poly1d(coeffs) - residuals = np.abs(xs - poly(ys)) - n_inliers = int(np.sum(residuals < thresh)) - if n_inliers > best_inliers: - best_inliers = n_inliers - best_coeffs = coeffs - - # インライアで再フィッティング - if best_coeffs is not None: - poly = np.poly1d(best_coeffs) - inlier_mask = np.abs(xs - poly(ys)) < thresh - if np.sum(inlier_mask) >= sample_size: - best_coeffs = np.polyfit( - ys[inlier_mask], - xs[inlier_mask], - degree, - ) - - return best_coeffs - - -def _no_detection( - binary: np.ndarray, -) -> LineDetectResult: - """未検出の結果を返す""" - return LineDetectResult( - detected=False, - position_error=0.0, - heading=0.0, - curvature=0.0, - poly_coeffs=None, - binary_image=binary, + coeffs = clean_and_fit( + cy, cx, + median_ksize=median_ksize, + neighbor_thresh=neighbor_thresh, + residual_thresh=residual_thresh, + ransac_thresh=ransac_thresh, + ransac_iter=ransac_iter, ) + if coeffs is None: + return no_detection(binary) - -def _build_result( - coeffs: np.ndarray, - binary: np.ndarray, -) -> LineDetectResult: - """多項式係数から LineDetectResult を構築する""" - poly = np.poly1d(coeffs) - center_x = config.FRAME_WIDTH / 2.0 - - # 画像下端での位置偏差 - x_bottom = poly(DETECT_Y_END) - position_error = (center_x - x_bottom) / center_x - - # 傾き: dx/dy(画像下端での値) - poly_deriv = poly.deriv() - heading = float(poly_deriv(DETECT_Y_END)) - - # 曲率: d²x/dy² - poly_deriv2 = poly_deriv.deriv() - curvature = float(poly_deriv2(DETECT_Y_END)) - - return LineDetectResult( - detected=True, - position_error=position_error, - heading=heading, - curvature=curvature, - poly_coeffs=coeffs, - binary_image=binary, - ) + return build_result(coeffs, binary) diff --git a/src/pc/vision/morphology.py b/src/pc/vision/morphology.py new file mode 100644 index 0000000..9b7edd3 --- /dev/null +++ b/src/pc/vision/morphology.py @@ -0,0 +1,134 @@ +""" +morphology +二値画像の形態学的処理ユーティリティモジュール +""" + +import cv2 +import numpy as np + + +def apply_iso_closing( + binary: np.ndarray, size: int, +) -> np.ndarray: + """等方クロージングで穴を埋める + + Args: + binary: 二値画像 + size: カーネルサイズ + + Returns: + クロージング後の二値画像 + """ + if size < 3: + return binary + k = size | 1 + kernel = cv2.getStructuringElement( + cv2.MORPH_ELLIPSE, (k, k), + ) + return cv2.morphologyEx( + binary, cv2.MORPH_CLOSE, kernel, + ) + + +def apply_staged_closing( + binary: np.ndarray, + small_size: int, + min_area: int, + large_size: int, +) -> np.ndarray: + """段階クロージング: 小穴埋め → 孤立除去 → 大穴埋め + + Args: + binary: 二値画像 + small_size: 第1段クロージングのカーネルサイズ + min_area: 孤立領域除去の最小面積(0 で無効) + large_size: 第2段クロージングのカーネルサイズ(0 で無効) + + Returns: + 処理後の二値画像 + """ + # 第1段: 小さいクロージングで近接ピクセルをつなぐ + result = apply_iso_closing(binary, small_size) + + # 孤立領域の除去 + if min_area > 0: + contours, _ = cv2.findContours( + result, cv2.RETR_EXTERNAL, + cv2.CHAIN_APPROX_SIMPLE, + ) + mask = np.zeros_like(result) + for cnt in contours: + if cv2.contourArea(cnt) >= min_area: + cv2.drawContours( + mask, [cnt], -1, 255, -1, + ) + result = mask + + # 第2段: 大きいクロージングで中抜けを埋める + result = apply_iso_closing(result, large_size) + + return result + + +def apply_width_filter( + binary: np.ndarray, + width_near: int, + width_far: int, + tolerance: float, +) -> np.ndarray: + """透視補正付き幅フィルタで広がりすぎた行を除外する + + 各行の期待線幅を線形補間で算出し, + 実際の幅が上限(期待幅 × tolerance)を超える行をマスクする + + Args: + binary: 二値画像 + width_near: 画像下端での期待線幅(px) + width_far: 画像上端での期待線幅(px) + tolerance: 上限倍率 + + Returns: + 幅フィルタ適用後の二値画像 + """ + result = binary.copy() + h = binary.shape[0] + denom = max(h - 1, 1) + + for y_local in range(h): + xs = np.where(binary[y_local] > 0)[0] + if len(xs) == 0: + continue + # 画像下端(近距離)ほど t=1,上端(遠距離)ほど t=0 + t = (h - 1 - y_local) / denom + expected = float(width_far) + ( + float(width_near) - float(width_far) + ) * t + max_w = expected * tolerance + actual_w = int(xs[-1]) - int(xs[0]) + 1 + if actual_w > max_w: + result[y_local] = 0 + + return result + + +def apply_dist_mask( + binary: np.ndarray, thresh: float, +) -> np.ndarray: + """距離変換で中心部のみを残す + + Args: + binary: 二値画像 + thresh: 距離の閾値(ピクセル) + + Returns: + 中心部のみの二値画像 + """ + if thresh <= 0: + return binary + dist = cv2.distanceTransform( + binary, cv2.DIST_L2, 5, + ) + _, mask = cv2.threshold( + dist, thresh, 255, cv2.THRESH_BINARY, + ) + return mask.astype(np.uint8) diff --git a/src/pc/vision/overlay.py b/src/pc/vision/overlay.py index 188a1e8..27436f2 100644 --- a/src/pc/vision/overlay.py +++ b/src/pc/vision/overlay.py @@ -9,15 +9,15 @@ import cv2 import numpy as np -from common import config -from pc.vision import line_detector +from pc.vision.fitting import theil_sen_fit from pc.vision.line_detector import LineDetectResult # 描画色の定義 (BGR) COLOR_LINE: tuple = (0, 255, 0) COLOR_CENTER: tuple = (0, 255, 255) -COLOR_TEXT: tuple = (255, 255, 255) COLOR_REGION: tuple = (255, 0, 0) +COLOR_ROW_CENTER: tuple = (0, 165, 255) +COLOR_THEIL_SEN: tuple = (255, 0, 255) # 二値化オーバーレイの不透明度 BINARY_OPACITY: float = 0.4 @@ -31,14 +31,16 @@ binary: 二値化画像の半透明表示 detect_region: 検出領域の枠 poly_curve: フィッティング曲線 + row_centers: 各行の線中心点 + theil_sen: Theil-Sen 近似直線 center_line: 画像中心線 - info_text: 検出情報の数値表示 """ binary: bool = False detect_region: bool = False poly_curve: bool = False + row_centers: bool = False + theil_sen: bool = False center_line: bool = False - info_text: bool = False def draw_overlay( @@ -57,6 +59,7 @@ オーバーレイ描画済みの画像 """ display = frame.copy() + h, w = display.shape[:2] if result is None: return display @@ -71,21 +74,18 @@ if flags.detect_region: cv2.rectangle( display, - (0, line_detector.DETECT_Y_START), - ( - config.FRAME_WIDTH - 1, - line_detector.DETECT_Y_END - 1, - ), + (0, 0), + (w - 1, h - 1), COLOR_REGION, 1, ) # 画像中心線 if flags.center_line: - center_x = config.FRAME_WIDTH // 2 + center_x = w // 2 cv2.line( display, (center_x, 0), - (center_x, config.FRAME_HEIGHT), + (center_x, h), COLOR_CENTER, 1, ) @@ -93,9 +93,15 @@ if flags.poly_curve and result.poly_coeffs is not None: _draw_poly_curve(display, result.poly_coeffs) - # 検出情報の数値表示 - if flags.info_text: - _draw_info_text(display, result) + # 各行の線中心点 + if flags.row_centers and result.row_centers is not None: + _draw_row_centers(display, result.row_centers) + + # Theil-Sen 近似直線 + if flags.theil_sen and result.row_centers is not None: + _draw_theil_sen_line( + display, result.row_centers, + ) return display @@ -122,6 +128,55 @@ ) +def _draw_row_centers( + frame: np.ndarray, + centers: np.ndarray, +) -> None: + """各行の線中心点を描画する + + Args: + frame: 描画先の画像 + centers: 各行の中心 x 座標(NaN=線なし) + """ + w = frame.shape[1] + for y, cx in enumerate(centers): + if np.isnan(cx): + continue + ix = int(round(cx)) + if 0 <= ix < w: + frame[y, ix] = COLOR_ROW_CENTER + + +def _draw_theil_sen_line( + frame: np.ndarray, + centers: np.ndarray, +) -> None: + """行中心点から Theil-Sen 近似直線を描画する + + Args: + frame: 描画先の画像 + centers: 各行の中心 x 座標(NaN=線なし) + """ + h, w = frame.shape[:2] + valid = ~np.isnan(centers) + ys = np.where(valid)[0].astype(float) + xs = centers[valid] + + if len(ys) < 2: + return + + slope, intercept = theil_sen_fit(ys, xs) + + # 直線の両端を計算して描画 + x0 = int(round(intercept)) + x1 = int(round(slope * (h - 1) + intercept)) + cv2.line( + frame, + (x0, 0), (x1, h - 1), + COLOR_THEIL_SEN, 1, + ) + + def _draw_poly_curve( frame: np.ndarray, coeffs: np.ndarray, @@ -132,55 +187,23 @@ frame: 描画先の画像 coeffs: 多項式の係数 """ + h, w = frame.shape[:2] poly = np.poly1d(coeffs) - y_start = line_detector.DETECT_Y_START - y_end = line_detector.DETECT_Y_END # 曲線上の点を生成 - ys = np.arange(y_start, y_end) + ys = np.arange(0, h) xs = poly(ys) # 画像範囲内の点のみ描画 points = [] for x, y in zip(xs, ys): ix = int(round(x)) - if 0 <= ix < config.FRAME_WIDTH: + if 0 <= ix < w: points.append([ix, int(y)]) if len(points) >= 2: pts = np.array(points, dtype=np.int32) cv2.polylines( frame, [pts], False, - COLOR_LINE, 2, - ) - - -def _draw_info_text( - frame: np.ndarray, - result: LineDetectResult, -) -> None: - """検出情報の数値を画像に描画する - - Args: - frame: 描画先の画像 - result: 線検出の結果 - """ - if not result.detected: - cv2.putText( - frame, "LINE: N/A", (5, 15), - cv2.FONT_HERSHEY_SIMPLEX, 0.4, - COLOR_TEXT, 1, - ) - return - - lines = [ - f"pos: {result.position_error:+.3f}", - f"head: {result.heading:+.4f}", - f"curv: {result.curvature:+.6f}", - ] - for i, text in enumerate(lines): - cv2.putText( - frame, text, (5, 15 + i * 15), - cv2.FONT_HERSHEY_SIMPLEX, 0.35, - COLOR_TEXT, 1, + COLOR_LINE, 1, ) diff --git a/src/pi/camera/capture.py b/src/pi/camera/capture.py index 712e7c4..9b89163 100644 --- a/src/pi/camera/capture.py +++ b/src/pi/camera/capture.py @@ -3,6 +3,7 @@ Picamera2 を使用してカメラ画像を取得するモジュール """ +import cv2 import numpy as np from picamera2 import Picamera2 @@ -20,8 +21,11 @@ self._camera = Picamera2() camera_config = self._camera.create_preview_configuration( main={ - "size": (config.FRAME_WIDTH, config.FRAME_HEIGHT), - "format": "Y8", + "size": ( + config.CAPTURE_WIDTH, + config.CAPTURE_HEIGHT, + ), + "format": "YUV420", }, ) self._camera.configure(camera_config) @@ -30,10 +34,21 @@ def capture(self) -> np.ndarray: """1フレームを取得する + 撮影後に INTER_AREA で縮小して返す + Returns: グレースケールの画像(NumPy 配列) """ - return self._camera.capture_array() + yuv = self._camera.capture_array() + gray = yuv[ + :config.CAPTURE_HEIGHT, + :config.CAPTURE_WIDTH, + ] + return cv2.resize( + gray, + (config.FRAME_WIDTH, config.FRAME_HEIGHT), + interpolation=cv2.INTER_AREA, + ) def stop(self) -> None: """カメラを停止する""" diff --git a/tests/__init__.py b/tests/__init__.py new file mode 100644 index 0000000..e69de29 --- /dev/null +++ b/tests/__init__.py diff --git a/tests/conftest.py b/tests/conftest.py new file mode 100644 index 0000000..5d06702 --- /dev/null +++ b/tests/conftest.py @@ -0,0 +1,44 @@ +"""テスト共通フィクスチャ""" + +import numpy as np +import pytest + +from common import config + + +@pytest.fixture() +def straight_line_image() -> np.ndarray: + """中央に暗い縦線があるグレースケール画像""" + h, w = config.FRAME_HEIGHT, config.FRAME_WIDTH + img = np.full((h, w), 200, dtype=np.uint8) + cx = w // 2 + img[:, cx - 1 : cx + 2] = 30 + return img + + +@pytest.fixture() +def blank_image() -> np.ndarray: + """線のない均一なグレースケール画像""" + h, w = config.FRAME_HEIGHT, config.FRAME_WIDTH + return np.full((h, w), 180, dtype=np.uint8) + + +@pytest.fixture() +def binary_with_hole() -> np.ndarray: + """中央に穴がある二値画像(クロージングテスト用)""" + h, w = config.FRAME_HEIGHT, config.FRAME_WIDTH + binary = np.zeros((h, w), dtype=np.uint8) + binary[3:h - 3, w // 2 - 2 : w // 2 + 3] = 255 + # 中央に 2px の穴をあける + binary[h // 2 : h // 2 + 2, :] = 0 + return binary + + +@pytest.fixture() +def binary_line() -> np.ndarray: + """中央に太い白線がある二値画像""" + h, w = config.FRAME_HEIGHT, config.FRAME_WIDTH + binary = np.zeros((h, w), dtype=np.uint8) + cx = w // 2 + binary[:, cx - 3 : cx + 4] = 255 + return binary diff --git a/tests/test_fitting.py b/tests/test_fitting.py new file mode 100644 index 0000000..5791fbf --- /dev/null +++ b/tests/test_fitting.py @@ -0,0 +1,147 @@ +"""fitting モジュールのテスト""" + +import numpy as np +import pytest + +from pc.vision.fitting import ( + MIN_FIT_ROWS, + clean_and_fit, + ransac_polyfit, + theil_sen_fit, +) + + +class TestTheilSenFit: + """theil_sen_fit のテスト""" + + def test_linear_data(self) -> None: + """直線データから正しい slope と intercept を復元できる""" + y = np.arange(10, dtype=float) + # x = 2.0 * y + 5.0 + x = 2.0 * y + 5.0 + slope, intercept = theil_sen_fit(y, x) + assert slope == pytest.approx(2.0, abs=1e-6) + assert intercept == pytest.approx(5.0, abs=1e-6) + + def test_with_outlier(self) -> None: + """外れ値が1つあっても正しい傾きを推定できる""" + y = np.arange(11, dtype=float) + x = 1.0 * y + 3.0 + # 1点を大きく外す + x[5] = 100.0 + slope, intercept = theil_sen_fit(y, x) + assert slope == pytest.approx(1.0, abs=0.2) + assert intercept == pytest.approx(3.0, abs=1.0) + + def test_two_points(self) -> None: + """2点でも傾きを計算できる""" + y = np.array([0.0, 10.0]) + x = np.array([5.0, 15.0]) + slope, intercept = theil_sen_fit(y, x) + assert slope == pytest.approx(1.0, abs=1e-6) + assert intercept == pytest.approx(5.0, abs=1e-6) + + def test_single_point(self) -> None: + """1点しかない場合は slope=0, intercept=median(x)""" + y = np.array([5.0]) + x = np.array([10.0]) + slope, intercept = theil_sen_fit(y, x) + assert slope == 0.0 + assert intercept == pytest.approx(10.0) + + def test_horizontal_line(self) -> None: + """水平な線(slope=0)を正しく推定できる""" + y = np.arange(10, dtype=float) + x = np.full(10, 7.0) + slope, intercept = theil_sen_fit(y, x) + assert slope == pytest.approx(0.0, abs=1e-6) + assert intercept == pytest.approx(7.0, abs=1e-6) + + +class TestRansacPolyfit: + """ransac_polyfit のテスト""" + + def test_clean_quadratic(self) -> None: + """ノイズなしの2次曲線を正しくフィットできる""" + ys = np.arange(20, dtype=float) + # x = 0.1 * y^2 - 2.0 * y + 10.0 + xs = 0.1 * ys**2 - 2.0 * ys + 10.0 + coeffs = ransac_polyfit(ys, xs, 2, 50, 5.0) + assert coeffs is not None + assert coeffs[0] == pytest.approx(0.1, abs=0.01) + assert coeffs[1] == pytest.approx(-2.0, abs=0.1) + + def test_with_outliers(self) -> None: + """30% の外れ値があっても正しくフィットできる""" + rng = np.random.default_rng(42) + ys = np.arange(30, dtype=float) + xs = 0.05 * ys**2 + 3.0 + # 30% を大きく外す + outlier_idx = rng.choice(30, 9, replace=False) + xs[outlier_idx] += rng.uniform(50, 100, 9) + coeffs = ransac_polyfit(ys, xs, 2, 100, 5.0) + assert coeffs is not None + assert coeffs[0] == pytest.approx(0.05, abs=0.02) + + def test_too_few_points(self) -> None: + """点が不足している場合は None を返す""" + ys = np.array([1.0, 2.0]) + xs = np.array([3.0, 4.0]) + assert ransac_polyfit(ys, xs, 2, 50, 5.0) is None + + +class TestCleanAndFit: + """clean_and_fit のテスト""" + + def test_basic_fit(self) -> None: + """正常なデータでフィッティングできる""" + ys = np.arange(15, dtype=float) + xs = 0.5 * ys + 10.0 + coeffs = clean_and_fit( + ys, xs, + median_ksize=0, + neighbor_thresh=0.0, + ) + assert coeffs is not None + # 2次の係数はほぼ 0,1次はほぼ 0.5 + assert coeffs[-2] == pytest.approx(0.5, abs=0.1) + + def test_too_few_points(self) -> None: + """MIN_FIT_ROWS 未満のデータは None を返す""" + n = MIN_FIT_ROWS - 1 + ys = np.arange(n, dtype=float) + xs = np.arange(n, dtype=float) + assert clean_and_fit( + ys, xs, median_ksize=0, neighbor_thresh=0.0, + ) is None + + def test_neighbor_filter_removes_outlier(self) -> None: + """近傍フィルタが外れ値を除去できる""" + ys = np.arange(20, dtype=float) + xs = np.full(20, 15.0) + xs[10] = 100.0 # 大きな外れ値 + coeffs = clean_and_fit( + ys, xs, + median_ksize=0, + neighbor_thresh=5.0, + ) + assert coeffs is not None + # 外れ値除去後,x ≈ 15.0 の直線になる + poly = np.poly1d(coeffs) + assert poly(10) == pytest.approx(15.0, abs=2.0) + + def test_residual_removal(self) -> None: + """残差除去が外れ値を取り除ける""" + ys = np.arange(20, dtype=float) + xs = 1.0 * ys + 5.0 + xs[3] = 80.0 + xs[17] = -50.0 + coeffs = clean_and_fit( + ys, xs, + median_ksize=0, + neighbor_thresh=0.0, + residual_thresh=10.0, + ) + assert coeffs is not None + poly = np.poly1d(coeffs) + assert poly(10) == pytest.approx(15.0, abs=3.0) diff --git a/tests/test_json_utils.py b/tests/test_json_utils.py new file mode 100644 index 0000000..03020fc --- /dev/null +++ b/tests/test_json_utils.py @@ -0,0 +1,47 @@ +"""json_utils モジュールのテスト""" + +from pathlib import Path + +from common.json_utils import read_json, write_json + + +class TestWriteReadRoundtrip: + """write_json / read_json の往復テスト""" + + def test_dict_roundtrip(self, tmp_path: Path) -> None: + """dict を書き込んで読み込むと一致する""" + path = tmp_path / "test.json" + data = {"key": "value", "number": 42} + write_json(path, data) + assert read_json(path) == data + + def test_list_roundtrip(self, tmp_path: Path) -> None: + """list を書き込んで読み込むと一致する""" + path = tmp_path / "test.json" + data = [{"a": 1}, {"b": 2}] + write_json(path, data) + assert read_json(path) == data + + def test_creates_parent_dir( + self, tmp_path: Path, + ) -> None: + """親ディレクトリが存在しなくても自動作成される""" + path = tmp_path / "sub" / "dir" / "test.json" + data = {"created": True} + write_json(path, data) + assert path.exists() + assert read_json(path) == data + + def test_japanese_text(self, tmp_path: Path) -> None: + """日本語テキストが正しく保存・復元される""" + path = tmp_path / "test.json" + data = {"title": "テスト", "memo": "日本語メモ"} + write_json(path, data) + assert read_json(path) == data + + def test_overwrite(self, tmp_path: Path) -> None: + """既存ファイルを上書きできる""" + path = tmp_path / "test.json" + write_json(path, {"v": 1}) + write_json(path, {"v": 2}) + assert read_json(path) == {"v": 2} diff --git a/tests/test_line_detector.py b/tests/test_line_detector.py new file mode 100644 index 0000000..ce186c7 --- /dev/null +++ b/tests/test_line_detector.py @@ -0,0 +1,204 @@ +"""line_detector モジュールのテスト""" + +import numpy as np +import pytest + +from common import config +from pc.vision.line_detector import ( + ImageParams, + LineDetectResult, + build_result, + detect_line, + fit_row_centers, + no_detection, +) + + +class TestNoDetection: + """no_detection のテスト""" + + def test_returns_not_detected( + self, blank_image: np.ndarray, + ) -> None: + """detected=False で全フィールドがデフォルト値""" + result = no_detection(blank_image) + assert result.detected is False + assert result.position_error == 0.0 + assert result.heading == 0.0 + assert result.curvature == 0.0 + assert result.poly_coeffs is None + assert result.row_centers is None + assert result.binary_image is not None + + +class TestBuildResult: + """build_result のテスト""" + + def test_straight_center_line(self) -> None: + """画像中央の直線は position_error ≈ 0""" + h = config.FRAME_HEIGHT + w = config.FRAME_WIDTH + center_x = w / 2.0 + # x = center_x (定数) → coeffs = [0, 0, center_x] + coeffs = np.array([0.0, 0.0, center_x]) + binary = np.zeros((h, w), dtype=np.uint8) + result = build_result(coeffs, binary) + assert result.detected is True + assert result.position_error == pytest.approx( + 0.0, abs=0.01, + ) + assert result.heading == pytest.approx( + 0.0, abs=0.01, + ) + assert result.curvature == pytest.approx( + 0.0, abs=0.01, + ) + + def test_offset_line(self) -> None: + """左にオフセットした直線は position_error > 0""" + h = config.FRAME_HEIGHT + w = config.FRAME_WIDTH + # 左寄りの直線 + offset_x = w / 4.0 + coeffs = np.array([0.0, 0.0, offset_x]) + binary = np.zeros((h, w), dtype=np.uint8) + result = build_result(coeffs, binary) + assert result.position_error > 0 # 中心より左 + + +class TestDetectLine: + """detect_line のテスト""" + + def test_current_detects_straight_line( + self, straight_line_image: np.ndarray, + ) -> None: + """現行手法で中央の直線を検出できる""" + # 小さいテスト画像用にパラメータを調整 + params = ImageParams( + method="current", + clahe_grid=2, blur_size=3, + open_size=1, close_width=3, + close_height=1, + ) + result = detect_line( + straight_line_image, params, + ) + assert result.detected is True + assert abs(result.position_error) < 0.5 + + def test_current_no_line( + self, blank_image: np.ndarray, + ) -> None: + """均一画像では線を検出しない""" + params = ImageParams(method="current") + result = detect_line(blank_image, params) + assert result.detected is False + + def test_blackhat_detects_straight_line( + self, straight_line_image: np.ndarray, + ) -> None: + """案A で中央の直線を検出できる""" + params = ImageParams( + method="blackhat", + blackhat_ksize=15, + binary_thresh=30, + blur_size=3, + iso_close_size=1, + dist_thresh=0.0, + min_line_width=1, + median_ksize=0, + neighbor_thresh=0.0, + residual_thresh=0.0, + ) + result = detect_line( + straight_line_image, params, + ) + assert result.detected is True + + def test_dual_norm_detects_straight_line( + self, straight_line_image: np.ndarray, + ) -> None: + """案B で中央の直線を検出できる""" + params = ImageParams( + method="dual_norm", + bg_blur_ksize=21, + adaptive_block=11, adaptive_c=5, + iso_close_size=1, + dist_thresh=0.0, + min_line_width=1, + median_ksize=0, + neighbor_thresh=0.0, + residual_thresh=0.0, + ) + result = detect_line( + straight_line_image, params, + ) + assert result.detected is True + + def test_robust_detects_straight_line( + self, straight_line_image: np.ndarray, + ) -> None: + """案C で中央の直線を検出できる""" + params = ImageParams( + method="robust", + blackhat_ksize=15, + adaptive_block=11, adaptive_c=5, + iso_close_size=1, + dist_thresh=0.0, + min_line_width=1, + median_ksize=0, + neighbor_thresh=0.0, + residual_thresh=0.0, + ) + result = detect_line( + straight_line_image, params, + ) + assert result.detected is True + + def test_valley_detects_straight_line( + self, straight_line_image: np.ndarray, + ) -> None: + """案D で中央の直線を検出できる""" + params = ImageParams(method="valley") + result = detect_line( + straight_line_image, params, + ) + assert result.detected is True + + def test_default_params( + self, straight_line_image: np.ndarray, + ) -> None: + """params=None でもデフォルトで動作する""" + result = detect_line(straight_line_image) + assert isinstance(result, LineDetectResult) + + def test_result_has_binary_image( + self, straight_line_image: np.ndarray, + ) -> None: + """結果に二値化画像が含まれる""" + result = detect_line(straight_line_image) + assert result.binary_image is not None + assert result.binary_image.shape == ( + config.FRAME_HEIGHT, config.FRAME_WIDTH, + ) + + +class TestFitRowCenters: + """fit_row_centers のテスト""" + + def test_detects_binary_line( + self, binary_line: np.ndarray, + ) -> None: + """二値画像の白線から中心をフィッティングできる""" + result = fit_row_centers( + binary_line, min_width=1, + ) + assert result.detected is True + assert abs(result.position_error) < 0.3 + + def test_empty_binary(self) -> None: + """白ピクセルがない二値画像では検出しない""" + h, w = config.FRAME_HEIGHT, config.FRAME_WIDTH + empty = np.zeros((h, w), dtype=np.uint8) + result = fit_row_centers(empty, min_width=1) + assert result.detected is False diff --git a/tests/test_morphology.py b/tests/test_morphology.py new file mode 100644 index 0000000..1cff783 --- /dev/null +++ b/tests/test_morphology.py @@ -0,0 +1,142 @@ +"""morphology モジュールのテスト""" + +import numpy as np +import pytest + +from common import config +from pc.vision.morphology import ( + apply_dist_mask, + apply_iso_closing, + apply_staged_closing, + apply_width_filter, +) + + +class TestApplyIsoClosing: + """apply_iso_closing のテスト""" + + def test_fills_small_hole( + self, binary_with_hole: np.ndarray, + ) -> None: + """小さい穴をクロージングで埋められる""" + result = apply_iso_closing(binary_with_hole, 5) + h = config.FRAME_HEIGHT + w = config.FRAME_WIDTH + # 穴があった中央付近が白になっている + assert result[h // 2, w // 2] == 255 + + def test_small_size_noop(self) -> None: + """size < 3 の場合は何もしない""" + h, w = config.FRAME_HEIGHT, config.FRAME_WIDTH + binary = np.zeros((h, w), dtype=np.uint8) + binary[5, 5] = 255 + result = apply_iso_closing(binary, 1) + assert np.array_equal(result, binary) + + def test_preserves_shape( + self, binary_with_hole: np.ndarray, + ) -> None: + """出力画像のサイズが入力と同じ""" + result = apply_iso_closing(binary_with_hole, 7) + assert result.shape == binary_with_hole.shape + + +class TestApplyStagedClosing: + """apply_staged_closing のテスト""" + + def test_removes_small_regions(self) -> None: + """min_area で小さい領域を除去できる""" + h, w = config.FRAME_HEIGHT, config.FRAME_WIDTH + binary = np.zeros((h, w), dtype=np.uint8) + # 大きな領域 + binary[2:h - 2, w // 2 - 3 : w // 2 + 4] = 255 + # 小さな孤立点 + binary[0, 0] = 255 + result = apply_staged_closing( + binary, small_size=3, min_area=5, + large_size=0, + ) + assert result[0, 0] == 0 # 孤立点が除去されている + assert result[h // 2, w // 2] == 255 # 大領域は残る + + def test_no_min_area(self) -> None: + """min_area=0 なら孤立除去をスキップする""" + h, w = config.FRAME_HEIGHT, config.FRAME_WIDTH + binary = np.zeros((h, w), dtype=np.uint8) + binary[0, 0] = 255 + result = apply_staged_closing( + binary, small_size=1, min_area=0, + large_size=0, + ) + assert result[0, 0] == 255 + + +class TestApplyWidthFilter: + """apply_width_filter のテスト""" + + def test_keeps_narrow_line(self) -> None: + """期待幅以内の線は残す""" + h, w = config.FRAME_HEIGHT, config.FRAME_WIDTH + binary = np.zeros((h, w), dtype=np.uint8) + cx = w // 2 + binary[:, cx - 1 : cx + 2] = 255 # 幅 3 + result = apply_width_filter( + binary, width_near=10, width_far=5, + tolerance=2.0, + ) + # 幅 3 は期待幅内なので残る + assert np.any(result[h // 2] > 0) + + def test_removes_wide_row(self) -> None: + """期待幅を超える行を除去する""" + h, w = config.FRAME_HEIGHT, config.FRAME_WIDTH + binary = np.zeros((h, w), dtype=np.uint8) + # 下端に幅いっぱいの白(幅=w) + binary[h - 1, :] = 255 + result = apply_width_filter( + binary, width_near=3, width_far=2, + tolerance=1.5, + ) + # 幅 w >> 3*1.5 なので除去される + assert np.all(result[h - 1] == 0) + + def test_empty_rows_unchanged(self) -> None: + """白ピクセルがない行はそのまま""" + h, w = config.FRAME_HEIGHT, config.FRAME_WIDTH + binary = np.zeros((h, w), dtype=np.uint8) + result = apply_width_filter( + binary, width_near=5, width_far=3, + tolerance=2.0, + ) + assert np.array_equal(result, binary) + + +class TestApplyDistMask: + """apply_dist_mask のテスト""" + + def test_keeps_center_of_thick_line(self) -> None: + """太い線の中心部を残す""" + h, w = config.FRAME_HEIGHT, config.FRAME_WIDTH + binary = np.zeros((h, w), dtype=np.uint8) + cx = w // 2 + binary[:, cx - 5 : cx + 6] = 255 # 幅 11 + result = apply_dist_mask(binary, thresh=2.0) + # 中心は残る + assert result[h // 2, cx] > 0 + + def test_thresh_zero_noop(self) -> None: + """thresh <= 0 の場合は何もしない""" + h, w = config.FRAME_HEIGHT, config.FRAME_WIDTH + binary = np.zeros((h, w), dtype=np.uint8) + binary[5, 5] = 255 + result = apply_dist_mask(binary, thresh=0.0) + assert np.array_equal(result, binary) + + def test_removes_thin_line(self) -> None: + """細い線(幅1px)は距離変換で除去される""" + h, w = config.FRAME_HEIGHT, config.FRAME_WIDTH + binary = np.zeros((h, w), dtype=np.uint8) + binary[:, w // 2] = 255 # 幅 1px + result = apply_dist_mask(binary, thresh=1.0) + # 幅 1px の距離変換最大値は 1.0 未満 + assert np.all(result == 0) diff --git a/tests/test_params.py b/tests/test_params.py new file mode 100644 index 0000000..d7bf6fa --- /dev/null +++ b/tests/test_params.py @@ -0,0 +1,145 @@ +"""auto_params / param_store モジュールのテスト""" + +from pathlib import Path +from unittest.mock import patch + +import pytest + +from pc.steering.pd_control import PdParams +from pc.vision.line_detector import ImageParams + + +class TestAutoParams: + """auto_params の保存・読み込みテスト""" + + @pytest.fixture(autouse=True) + def _use_tmp_dir(self, tmp_path: Path) -> None: + """PARAMS_DIR を一時ディレクトリに差し替える""" + with patch( + "pc.steering.auto_params.PARAMS_DIR", + tmp_path, + ): + with patch( + "pc.steering.auto_params._CONTROL_FILE", + tmp_path / "control.json", + ): + yield + + def test_save_load_control_roundtrip(self) -> None: + """PD パラメータと手法を保存→読込で復元できる""" + from pc.steering.auto_params import ( + load_control, + save_control, + ) + params = PdParams( + kp=1.0, kh=0.5, kd=0.2, + max_steer_rate=0.05, + max_throttle=0.6, speed_k=0.4, + ) + save_control(params, "blackhat") + loaded, method = load_control() + assert method == "blackhat" + assert loaded.kp == pytest.approx(1.0) + assert loaded.kh == pytest.approx(0.5) + assert loaded.max_throttle == pytest.approx(0.6) + + def test_load_control_missing_file(self) -> None: + """ファイルがない場合はデフォルト値を返す""" + from pc.steering.auto_params import load_control + params, method = load_control() + assert method == "current" + assert params.kp == PdParams().kp + + def test_save_load_detect_params( + self, tmp_path: Path, + ) -> None: + """検出パラメータを手法別に保存→復元できる""" + from pc.steering.auto_params import ( + load_detect_params, + save_detect_params, + ) + ip = ImageParams( + method="dual_norm", + adaptive_block=31, + adaptive_c=15, + ) + save_detect_params("dual_norm", ip) + loaded = load_detect_params("dual_norm") + assert loaded.method == "dual_norm" + assert loaded.adaptive_block == 31 + assert loaded.adaptive_c == 15 + + def test_load_detect_unknown_method(self) -> None: + """未知の手法は指定手法のデフォルト値を返す""" + from pc.steering.auto_params import ( + load_detect_params, + ) + loaded = load_detect_params("unknown") + assert loaded.method == "unknown" + + +class TestParamStore: + """param_store の保存・読み込みテスト""" + + @pytest.fixture(autouse=True) + def _use_tmp_dir(self, tmp_path: Path) -> None: + """プリセットファイルを一時ディレクトリに差し替える""" + with patch( + "pc.steering.param_store._PD_FILE", + tmp_path / "presets_pd.json", + ): + with patch( + "pc.steering.param_store._IMAGE_FILE", + tmp_path / "presets_image.json", + ): + yield + + def test_pd_preset_add_load_delete(self) -> None: + """PD プリセットの追加・読込・削除""" + from pc.steering.param_store import ( + PdPreset, + add_pd_preset, + delete_pd_preset, + load_pd_presets, + ) + # 追加 + preset = PdPreset( + title="テスト", + memo="メモ", + params=PdParams(kp=2.0), + ) + add_pd_preset(preset) + presets = load_pd_presets() + assert len(presets) == 1 + assert presets[0].title == "テスト" + assert presets[0].params.kp == pytest.approx(2.0) + + # 削除 + delete_pd_preset(0) + assert len(load_pd_presets()) == 0 + + def test_image_preset_add_load(self) -> None: + """画像処理プリセットの追加・読込""" + from pc.steering.param_store import ( + ImagePreset, + add_image_preset, + load_image_presets, + ) + ip = ImageParams( + method="blackhat", blackhat_ksize=51, + ) + add_image_preset(ImagePreset( + title="BH51", memo="テスト", image_params=ip, + )) + presets = load_image_presets() + assert len(presets) == 1 + assert presets[0].image_params.blackhat_ksize == 51 + + def test_load_empty(self) -> None: + """ファイルがない場合は空リストを返す""" + from pc.steering.param_store import ( + load_image_presets, + load_pd_presets, + ) + assert load_pd_presets() == [] + assert load_image_presets() == []