diff --git a/.gitignore b/.gitignore index b4966c7..1aad68b 100644 --- a/.gitignore +++ b/.gitignore @@ -8,6 +8,9 @@ pd_params.json params/ +# 学習データ +/data/ + # 旧コード(参照用,Git 管理外) src_old/ diff --git a/CLAUDE.md b/CLAUDE.md index 9325a4a..57de90b 100644 --- a/CLAUDE.md +++ b/CLAUDE.md @@ -22,9 +22,13 @@ - `docs/03_TECH/TECH_03_デバッグオーバーレイ仕様.txt` — オーバーレイ表示項目、描画色、GUI 操作 - `docs/03_TECH/TECH_04_線検出精度向上方針.txt` — 線検出が最重要ファクターである理由、照明・影の課題、改善の方向性 - `docs/03_TECH/TECH_05_コースアウト復帰仕様.txt` — 復帰判定ロジック、復帰動作、パラメータ一覧、GUI 仕様 +- `docs/03_TECH/TECH_06_十字路分類モデル評価.txt` — モデル比較結果、採用モデル、F1 スコア(学習実行時に自動生成) + +### テスト(テスト作成・実行時に参照) +- `docs/05_TEST/TEST_01_テスト方針.txt` — テスト方針、実行方法、構成、追加ルール ### 環境(セットアップ時に参照) -- `docs/04_ENV/ENV_01_技術スタック選定.txt` — ZMQ, PySide6, OpenCV, Picamera2, RPi.GPIO, python-dotenv +- `docs/04_ENV/ENV_01_技術スタック選定.txt` — ZMQ, PySide6, OpenCV, Picamera2, RPi.GPIO, python-dotenv, scikit-learn - `docs/04_ENV/ENV_02_PC環境構築手順.txt` — venv 作成、ライブラリインストール - `docs/04_ENV/ENV_03_RaspPi環境構築手順.txt` — SSH 接続、deploy.sh による転送、venv 構築、動作確認 - `docs/04_ENV/ENV_04_ディレクトリ構成.txt` — src/ の構成と実装状態 diff --git a/LICENSE b/LICENSE new file mode 100644 index 0000000..cc57693 --- /dev/null +++ b/LICENSE @@ -0,0 +1,21 @@ +MIT License + +Copyright (c) 2026 Rinto Hasegawa + +Permission is hereby granted, free of charge, to any person obtaining a copy +of this software and associated documentation files (the "Software"), to deal +in the Software without restriction, including without limitation the rights +to use, copy, modify, merge, publish, distribute, sublicense, and/or sell +copies of the Software, and to permit persons to whom the Software is +furnished to do so, subject to the following conditions: + +The above copyright notice and this permission notice shall be included in all +copies or substantial portions of the Software. + +THE SOFTWARE IS PROVIDED "AS IS", WITHOUT WARRANTY OF ANY KIND, EXPRESS OR +IMPLIED, INCLUDING BUT NOT LIMITED TO THE WARRANTIES OF MERCHANTABILITY, +FITNESS FOR A PARTICULAR PURPOSE AND NONINFRINGEMENT. IN NO EVENT SHALL THE +AUTHORS OR COPYRIGHT HOLDERS BE LIABLE FOR ANY CLAIM, DAMAGES OR OTHER +LIABILITY, WHETHER IN AN ACTION OF CONTRACT, TORT OR OTHERWISE, ARISING FROM, +OUT OF OR IN CONNECTION WITH THE SOFTWARE OR THE USE OR OTHER DEALINGS IN THE +SOFTWARE. diff --git a/README.md b/README.md new file mode 100644 index 0000000..a779d61 --- /dev/null +++ b/README.md @@ -0,0 +1,164 @@ +# RobotCar — カメラ搭載ライントレース自律走行システム + +カメラ画像のみを用いて黒線コースを自律走行するロボットカーの制御システム. +Raspberry Pi がリアルタイム制御を担当し,PC が GUI モニタリング・パラメータ調整を担当する. + +## システム概要 + +``` +Raspberry Pi PC +──────────── ── +カメラ撮影 テレメトリ受信 + │ │ +画像処理・線検出 映像・状態表示 + │ │ +操舵量計算(3手法切替) パラメータ調整 + │ │ +モーター制御 コマンド送信 + │ (モード切替・手動操作) + └──── テレメトリ(ZMQ) ──────> │ + ┌──── コマンド受信 <────── │ +``` + +- **制御ループ**(カメラ取得 → 画像処理 → 操舵計算 → モーター制御)は Pi 内で完結し,通信遅延の影響を受けない +- **PC** はモニタリング GUI・パラメータ調整・手動操作のみを担当 + +## 主な機能 + +| 機能 | 概要 | +|------|------| +| 線検出 | 5 種類の検出手法を GUI から切替可能(CLAHE,Black-hat,二重正規化,ロバスト,谷検出) | +| 操舵制御 | PD 制御・2点パシュート制御・Theil-Sen PD 制御の 3 手法を切替可能 | +| 速度制御 | カーブの度合いに応じて動的に減速・加速 | +| 十字路判定 | SVM 分類器で十字路を検出し直進に切替 | +| コースアウト復帰 | 線を見失った場合に自動で復帰動作を実行 | +| パラメータ調整 | GUI スライダーでリアルタイムに変更・プリセット保存 | +| 手動操作 | キーボード入力による手動走行モード | + +## ハードウェア構成 + +- **メインコンピュータ**: Raspberry Pi +- **カメラ**: Raspberry Pi カメラモジュール(車体前方,真下向き) +- **モータードライバ**: TB6612FNG +- **駆動方式**: 差動 2 輪駆動(左右モーターの回転速度差で旋回) + +## 技術スタック + +| 用途 | ライブラリ | +|------|-----------| +| 通信 | ZeroMQ (pyzmq) — PUB/SUB + CONFLATE | +| GUI | PySide6 | +| 画像処理 | OpenCV | +| 数値計算 | NumPy | +| カメラ制御 | Picamera2 | +| モーター制御 | RPi.GPIO | +| 十字路分類 | scikit-learn | +| 環境変数 | python-dotenv | + +## ディレクトリ構成 + +``` +RobotCar/ +├── src/ +│ ├── common/ 共通設定・画像処理・操舵量計算 +│ │ ├── vision/ 線検出パイプライン・検出手法・十字路分類 +│ │ └── steering/ PD制御・パシュート制御・Theil-Sen PD制御・復帰制御 +│ ├── pc/ PC側(GUI・テレメトリ受信・コマンド送信) +│ └── pi/ Pi側(カメラ・モーター・自律制御ループ) +├── docs/ ドキュメント +├── tests/ ユニットテスト +├── params/ パラメータ・モデル +├── data/ 学習データ(十字路分類用) +├── deploy.sh Pi への転送スクリプト +├── requirements_pc.txt PC 用依存パッケージ +└── requirements_pi.txt Pi 用依存パッケージ +``` + +## セットアップ + +### PC + +```bash +python -m venv .venv +source .venv/bin/activate +pip install -r requirements_pc.txt +cp .env.example .env # IP アドレス・ポート番号を設定 +``` + +### Raspberry Pi + +```bash +# PC から Pi へファイルを転送 +./deploy.sh + +# Pi 上で仮想環境を構築 +python -m venv .venv +source .venv/bin/activate +pip install -r requirements_pi.txt +``` + +詳細は [docs/04_ENV/](docs/04_ENV/) を参照. + +## 実行方法 + +```bash +# Pi 側(自律制御ループ) +python -m src.pi.main + +# PC 側(GUI) +python -m src.pc.main +``` + +## ドキュメント + +詳細な仕様は `docs/` 以下を参照. + +### ガイドライン(01_GUIDE) + +| ファイル | 内容 | +|---------|------| +| [GUIDE_01_ドキュメント作成ガイド](docs/01_GUIDE/GUIDE_01_ドキュメント作成ガイド.txt) | 句読点は「,」「.」,見出し書式等 | +| [GUIDE_02_ドキュメント命名規則](docs/01_GUIDE/GUIDE_02_ドキュメント命名規則.txt) | `[カテゴリ]_[連番]_[ファイル名].txt` | +| [GUIDE_03_Git運用ルール](docs/01_GUIDE/GUIDE_03_Git運用ルール.txt) | ブランチ命名,コミットメッセージ規約 | +| [GUIDE_04_コーディング規則](docs/01_GUIDE/GUIDE_04_コーディング規則.txt) | PEP 8 ベース,snake_case,型ヒント必須 | +| [GUIDE_05_コードコメント規則](docs/01_GUIDE/GUIDE_05_コードコメント規則.txt) | docstring は日本語,Google スタイル | + +### プロジェクト計画(02_PLAN) + +| ファイル | 内容 | +|---------|------| +| [PLAN_01_プロジェクト概要](docs/02_PLAN/PLAN_01_プロジェクト概要.txt) | 目的・最終目標・ハードウェア構成・システム概要 | + +### 技術仕様(03_TECH) + +| ファイル | 内容 | +|---------|------| +| [TECH_01_操舵量計算仕様](docs/03_TECH/TECH_01_操舵量計算仕様.txt) | PD 制御・パシュート制御・Theil-Sen PD 制御,速度制御,レートリミッター | +| [TECH_02_システム構成仕様](docs/03_TECH/TECH_02_システム構成仕様.txt) | Pi/PC の役割分担,通信フロー,ZMQ プロトコル詳細 | +| [TECH_03_デバッグオーバーレイ仕様](docs/03_TECH/TECH_03_デバッグオーバーレイ仕様.txt) | オーバーレイ表示項目,描画色,GUI 操作 | +| [TECH_04_線検出精度向上方針](docs/03_TECH/TECH_04_線検出精度向上方針.txt) | 線検出の課題(照明・影),5 種類の検出手法の方針 | +| [TECH_05_コースアウト復帰仕様](docs/03_TECH/TECH_05_コースアウト復帰仕様.txt) | 復帰判定ロジック,復帰動作,パラメータ一覧 | +| [TECH_06_十字路分類モデル評価](docs/03_TECH/TECH_06_十字路分類モデル評価.txt) | モデル比較結果,採用モデル,F1 スコア(学習時に自動生成) | + +### 環境構築(04_ENV) + +| ファイル | 内容 | +|---------|------| +| [ENV_01_技術スタック選定](docs/04_ENV/ENV_01_技術スタック選定.txt) | ZMQ,PySide6,OpenCV,Picamera2 等の選定理由 | +| [ENV_02_PC環境構築手順](docs/04_ENV/ENV_02_PC環境構築手順.txt) | venv 作成,ライブラリインストール | +| [ENV_03_RaspPi環境構築手順](docs/04_ENV/ENV_03_RaspPi環境構築手順.txt) | SSH 接続,deploy.sh による転送,venv 構築 | +| [ENV_04_ディレクトリ構成](docs/04_ENV/ENV_04_ディレクトリ構成.txt) | src/ の詳細構成と各ファイルの役割 | + +### テスト(05_TEST) + +| ファイル | 内容 | +|---------|------| +| [TEST_01_テスト方針](docs/05_TEST/TEST_01_テスト方針.txt) | テスト方針,実行方法,構成,追加ルール | + +## 作成者 + +Rinto Hasegawa + +## ライセンス + +[MIT License](LICENSE) diff --git a/deploy.sh b/deploy.sh index ba1ed4d..4cb5aa2 100644 --- a/deploy.sh +++ b/deploy.sh @@ -14,7 +14,7 @@ # ── Pi 側の既存フォルダを削除 ───────────────────────────── echo "Pi 側のフォルダを初期化中..." -ssh "${PI_HOST}" "rm -rf ${PI_DIR}/common ${PI_DIR}/pi" +ssh "${PI_HOST}" "rm -rf ${PI_DIR}/common ${PI_DIR}/pi ${PI_DIR}/params" # ── ファイル転送 ────────────────────────────────────────── echo "common/ を転送中..." @@ -23,6 +23,10 @@ echo "pi/ を転送中..." scp -r "${SRC_DIR}/pi" "${PI_HOST}:${PI_DIR}/" +# ── モデルファイルの転送 ──────────────────────────────────── +echo "params/ を転送中..." +scp -r "${SCRIPT_DIR}/params" "${PI_HOST}:${PI_DIR}/" + # ── 設定ファイルの転送 ──────────────────────────────────── echo ".env を転送中..." scp "${SCRIPT_DIR}/.env" "${PI_HOST}:${PI_DIR}/.env" diff --git "a/docs/02_PLAN/PLAN_01_\343\203\227\343\203\255\343\202\270\343\202\247\343\202\257\343\203\210\346\246\202\350\246\201.txt" "b/docs/02_PLAN/PLAN_01_\343\203\227\343\203\255\343\202\270\343\202\247\343\202\257\343\203\210\346\246\202\350\246\201.txt" index d071520..0d2d747 100644 --- "a/docs/02_PLAN/PLAN_01_\343\203\227\343\203\255\343\202\270\343\202\247\343\202\257\343\203\210\346\246\202\350\246\201.txt" +++ "b/docs/02_PLAN/PLAN_01_\343\203\227\343\203\255\343\202\270\343\202\247\343\202\257\343\203\210\346\246\202\350\246\201.txt" @@ -59,16 +59,19 @@ 3-2. 操舵量計算の方針 - ・制御方式: PD 制御(ルールベース)を基本とする. ・偏差の算出: カメラ画像から黒線の位置を検出し, 画像中心からのずれを偏差として用いる. - ・操舵: 偏差に基づく PD 制御で操舵量を決定する. - ・速度: 操舵量が大きいほど減速し,直線では加速する. - ※ 詳細な計算方法・パラメータは別ドキュメントにて定義する. + ・操舵: 偏差に基づく制御で操舵量を決定する. + 複数の制御手法を GUI から切り替えて使用できる. + ・速度: カーブの度合いに応じて動的に調整する(カーブ減速,直線加速). + ※ 制御手法・計算方法・パラメータの詳細は + `TECH_01_操舵量計算仕様.txt` を参照する. 3-3. 実行環境 - ・Raspberry Pi を使用する. - ・Raspberry Pi 単体で完結させるか,PC と連携させるかは未定. - ※ 現状は Raspberry Pi をメインの実行環境として想定する. + ・制御ループ(カメラ取得→画像処理→操舵計算→モーター制御)は + Pi 内で完結させ,通信遅延の影響を排除する. + ・PC はモニタリング GUI・パラメータ調整・手動操作を担当する. + ※ Pi と PC の役割分担・通信の詳細は + `TECH_02_システム構成仕様.txt` を参照する. diff --git "a/docs/03_TECH/TECH_01_\346\223\215\350\210\265\351\207\217\350\250\210\347\256\227\344\273\225\346\247\230.txt" "b/docs/03_TECH/TECH_01_\346\223\215\350\210\265\351\207\217\350\250\210\347\256\227\344\273\225\346\247\230.txt" index b1730a3..505df76 100644 --- "a/docs/03_TECH/TECH_01_\346\223\215\350\210\265\351\207\217\350\250\210\347\256\227\344\273\225\346\247\230.txt" +++ "b/docs/03_TECH/TECH_01_\346\223\215\350\210\265\351\207\217\350\250\210\347\256\227\344\273\225\346\247\230.txt" @@ -176,10 +176,9 @@ GUI のコンボボックスで検出手法を切り替えられる. 手法ごとに使用するパラメータが異なり, GUI では選択中の手法に関連するパラメータのみ表示される. - 詳細は `TECH_04_線検出精度向上方針.txt` を参照する. 1. 現行手法(CLAHE + 固定閾値) - ・画像解像度: 320x240 + ・撮影解像度: 320×240(処理時は 40×30 に縮小) ・CLAHE 強度(clahe_clip): 2.0 ・CLAHE 分割数(clahe_grid): 8 ・ブラーカーネルサイズ(blur_size): 5 @@ -188,30 +187,9 @@ ・クロージング横幅(close_width): 25 ・クロージング高さ(close_height): 3 - 2. 案A: Black-hat 中心型 - ・Black-hat カーネルサイズ(blackhat_ksize): 45 - ・二値化閾値(binary_thresh): 80 - ・等方クロージングサイズ(iso_close_size): 15 - ・距離変換閾値(dist_thresh): 3.0 - ・最小線幅(min_line_width): 3 - - 3. 案B: 二重正規化型 - ・背景ブラーカーネルサイズ(bg_blur_ksize): 101 - ・固定閾値(global_thresh): 0(0 で無効,適応的閾値との AND) - ・適応的閾値ブロックサイズ(adaptive_block): 51 - ・適応的閾値定数 C(adaptive_c): 10 - ・等方クロージングサイズ(iso_close_size): 15 - ・距離変換閾値(dist_thresh): 3.0 - ・最小線幅(min_line_width): 3 - - 4. 案C: 最高ロバスト型 - ・Black-hat カーネルサイズ(blackhat_ksize): 45 - ・適応的閾値ブロックサイズ(adaptive_block): 51 - ・適応的閾値定数 C(adaptive_c): 10 - ・等方クロージングサイズ(iso_close_size): 15 - ・距離変換閾値(dist_thresh): 3.0 - ・最小線幅(min_line_width): 3 - ・RANSAC 閾値(ransac_thresh): 5.0 + 2. 案A〜D および共通後処理パラメータ + 各手法のパイプライン・パラメータ・デフォルト値は + `TECH_04_線検出精度向上方針.txt` を参照する. ■ PD 制御パラメータ(GUI で調整可能) @@ -230,7 +208,8 @@ ・全パラメータ(画像処理 + PD 制御 + 速度制御)を タイトル・メモ付きで JSON ファイルに保存できる. ・GUI のコンボボックスで保存済みパラメータを選択・読み込み可能. - ・保存ファイル: pd_params.json(.gitignore に登録済み) + ・保存ファイル: params/ ディレクトリ配下の各 JSON ファイル + (control.json,detect_*.json,pursuit.json,ts_pd.json 等) 7. 2点パシュート制御 (Two-Point Pursuit Control) @@ -278,7 +257,8 @@ 7-5. 実装ファイル - ・src/pc/steering/pursuit_control.py: PursuitControl クラス + ・src/common/steering/pursuit_control.py: PursuitControl クラス + ・src/common/steering/base.py: SteeringBase(線検出・レートリミッター・reset の共通処理) ・src/pc/gui/main_window.py: 制御手法の切替 UI @@ -336,29 +316,16 @@ 8-6. 実装ファイル - ・src/pc/steering/ts_pd_control.py: TsPdControl クラス + ・src/common/steering/ts_pd_control.py: TsPdControl クラス + ・src/common/steering/base.py: SteeringBase(線検出・レートリミッター・reset の共通処理) ・src/pc/gui/main_window.py: 制御手法の切替 UI 9. コースアウト復帰 (Course-Out Recovery) ------------------------------------------------------------------------ - 9-1. 概要 - 自動操縦中に黒線を一定時間検出できなかった場合に, 最後に検出した方向へ旋回しながら走行して復帰を試みる機能. 全制御手法に共通で適用される. + 判定ロジック・復帰動作・パラメータ・GUI 仕様・実装ファイルの 詳細は `TECH_05_コースアウト復帰仕様.txt` を参照する. - - 9-2. パラメータ一覧(GUI で調整可能) - - ・enabled(有効/無効): True - ・timeout_sec(判定時間): 0.5 秒 - ・steer_amount(操舵量): 0.5 - ・throttle(速度): -0.3(負で後退,正で前進) - - 9-3. 実装ファイル - - ・src/pc/steering/recovery.py: RecoveryParams,RecoveryController - ・src/pc/gui/panels/recovery_panel.py: RecoveryPanel - ・src/pc/gui/main_window.py: 復帰ロジックの統合 diff --git "a/docs/03_TECH/TECH_02_\343\202\267\343\202\271\343\203\206\343\203\240\346\247\213\346\210\220\344\273\225\346\247\230.txt" "b/docs/03_TECH/TECH_02_\343\202\267\343\202\271\343\203\206\343\203\240\346\247\213\346\210\220\344\273\225\346\247\230.txt" index e00c748..4e477d1 100644 --- "a/docs/03_TECH/TECH_02_\343\202\267\343\202\271\343\203\206\343\203\240\346\247\213\346\210\220\344\273\225\346\247\230.txt" +++ "b/docs/03_TECH/TECH_02_\343\202\267\343\202\271\343\203\206\343\203\240\346\247\213\346\210\220\344\273\225\346\247\230.txt" @@ -14,21 +14,22 @@ 1-1. 全体構成 - Raspberry Pi はカメラ画像の取得とモーター制御を担当し, - PC は画像処理・操舵量計算・GUI 表示を担当する. + Raspberry Pi はカメラ画像の取得・画像処理・操舵量計算・モーター制御を + すべて担当し,PC はモニタリング GUI・パラメータ調整・手動操作を担当する. + 制御ループが Pi 内で完結するため,通信遅延の影響を受けない. - Raspberry Pi PC - ──────────── ── - カメラ撮影 画像受信 + Raspberry Pi PC + ──────────── ── + カメラ撮影 テレメトリ受信 │ │ - └──── 画像送信 ──────> │ - 画像処理・線検出 - │ - 操舵量計算(PD 制御) - │ - ┌──── 操舵量受信 <────── 操舵量送信 - │ - モーター制御 + 画像処理・線検出 映像・状態表示 + │ │ + 操舵量計算(3手法切替) パラメータ調整 + │ │ + モーター制御 コマンド送信 + │ (モード切替・手動操作) + └──── テレメトリ ──────> │ + ┌──── コマンド受信 <────── │ 2. Raspberry Pi 側の処理 (Raspberry Pi Processing) @@ -37,63 +38,71 @@ 2-1. カメラ画像の取得 ・Picamera2 を使用してグレースケール(Y8 フォーマット)でフレームを取得する. - ・取得した画像を JPEG 圧縮して PC に送信する. ※ グレースケールで取得することで,転送データ量を BGR 比で 1/3 に削減する. - 2-2. 操舵量の受信 + 2-2. 画像処理・線検出 - ・PC から送信された操舵量(throttle,steer)を受信する. + ・取得した画像を 40×30 に縮小し,黒線の位置を検出する. + ・5 種類の検出手法を GUI から切り替えて使用できる: + - 現行(CLAHE + 固定閾値) + - 案A(Black-hat 中心) + - 案B(二重正規化) + - 案C(最高ロバスト) + - 案D(谷検出+追跡) + ・各手法の処理パイプラインは異なるが, + 出力は共通のデータ構造(位置偏差・傾き・二値画像等)に統一される. + ・詳細は `TECH_01_操舵量計算仕様.txt` および + `TECH_04_線検出精度向上方針.txt` を参照する. - 2-3. モーター制御 + 2-3. 操舵量計算 - ・受信した throttle,steer を既存の `MotorDriver.set_drive()` に渡し, + ・3 種類の制御手法を GUI から切り替えて使用できる: + - PD 制御: 位置偏差・傾き・微分項で操舵,曲率で速度を調整 + - 2点パシュート制御: 近い点と遠い点の偏差で操舵,曲がり度合いで速度を調整 + - Theil-Sen PD 制御: Theil-Sen 直線近似と PD 制御のハイブリッド, + 傾きで速度を調整 + ・レートリミッターで急激な操舵変化を抑制する. + ・十字路判定: SVM 分類器で十字路を検出し,直進に切り替える. + ・コースアウト復帰: 一定時間線を検出できない場合に復帰動作を行う. + ・詳細は `TECH_01_操舵量計算仕様.txt` を参照する. + + 2-4. モーター制御 + + ・計算した throttle,steer を `MotorDriver.set_drive()` に渡し, 左右モーターを制御する. - ・差動2輪駆動の計算は既存コードを流用する. + ・差動2輪駆動の計算: - left = throttle + steer - right = throttle - steer ・極性補正・PWM 出力も既存コードに従う. - 2-4. フェイルセーフ + 2-5. テレメトリ送信 - ・一定時間(例: 0.5秒)操舵量を受信しなかった場合, - モーターを自動停止する. - ・通信切断時の暴走を防止する. + ・毎フレーム,以下のテレメトリを PC に送信する: + - カメラ画像(JPEG 圧縮) + - 操舵量(throttle,steer) + - 検出結果(検出成否,位置偏差,傾き) + - 十字路判定結果,復帰状態,処理 FPS + - 二値画像(デバッグ用) + ・ZMQ PUB/SUB + CONFLATE で,古いフレームを自動破棄する. + + 2-6. コマンド受信 + + ・PC からのコマンドを非ブロッキングで受信する: + - mode: "auto"(自律走行),"manual"(手動操作),"stop"(停止) + - 手動モード時の throttle,steer + - パラメータ更新(画像処理・操舵・復帰パラメータ) + - 十字路判定の有効/無効 + ・PC が切断しても,auto モードなら自律走行を継続する. 3. PC 側の処理 (PC Processing) ------------------------------------------------------------------------ - 3-1. 画像の受信 + 3-1. テレメトリの受信 - ・Raspberry Pi から送信されたカメラ画像を受信する. + ・Pi から送信されたテレメトリ(画像+検出結果+操舵量)を受信する. - 3-2. 画像処理・線検出 - - ・受信した画像から黒線の位置を検出する. - ・処理手順: - 1. CLAHE によるコントラスト強調 - 2. ガウシアンブラー(ノイズ除去) - 3. 固定閾値で二値化 - 4. オープニングで孤立ノイズ除去 - 5. 横方向クロージングで途切れ補間 - 6. 白ピクセルに2次多項式フィッティング - ※ グレースケール変換は Pi 側(撮影時)で完了しているため不要 - ・位置偏差・傾き・曲率を算出する. - ・詳細は `TECH_01_操舵量計算仕様.txt` を参照する. - - 3-3. 操舵量計算 - - ・多項式フィッティングから得た位置偏差と傾きで - PD 制御により操舵量を計算する. - ・速度は曲率に応じて動的に調整する. - ・レートリミッターで急激な操舵変化を抑制する. - ・詳細は `TECH_01_操舵量計算仕様.txt` を参照する. - - 3-4. 操舵量の送信 - - ・計算した throttle,steer を Raspberry Pi に送信する. - - 3-5. GUI 表示 + 3-2. GUI 表示 ■ カメラ映像表示 ・受信した画像をリアルタイムで表示する. @@ -102,15 +111,15 @@ ■ 自動操縦の切り替え ・自動操縦の ON / OFF を切り替えるボタンを設ける. - ・ON: 画像処理の結果に基づいて操舵量を自動計算・送信する. - ・OFF: 手動操作モードに切り替わる. + ・ON: Pi に "auto" コマンドを送信し,自律走行を開始する. + ・OFF: Pi に "manual" コマンドを送信し,手動操作に切り替える. ■ パラメータ調整 ・PD 制御パラメータ(Kp,Kh,Kd 等)をリアルタイムに変更できる UI を設ける. ・二値化パラメータ(二値化閾値,CLAHE 強度等)も リアルタイムに変更できる. - ・変更したパラメータは即座に処理に反映される. + ・変更したパラメータはコマンドとして Pi に送信し,即座に反映される. ■ パラメータ保存・読み込み ・調整したパラメータをタイトル・メモ付きで JSON に保存できる. @@ -119,7 +128,7 @@ ■ 手動操作 ・自動操縦 OFF 時に,ユーザーが手動で車体を操作できる. - ・操作方式は操作性を重視して設計する. + ・キー入力を throttle,steer に変換し,コマンドとして Pi に送信する. 4. 設計方針 (Design Policy) @@ -138,26 +147,70 @@ この入出力を維持する限り,計算の中身(偏差の取り方,制御式, 速度調整の方法等)を自由に変更できる. + 4-2. 制御ループの完結性 + + 画像取得からモーター制御までの制御ループを Pi 内で完結させ, + 通信遅延の影響を排除する.PC はモニタリングとパラメータ調整のみを + 担当し,制御のクリティカルパスには含まれない. + 5. 通信の流れ (Communication Flow) ------------------------------------------------------------------------ 5-1. 全体のループ - 以下のサイクルを毎フレーム繰り返す. - + Pi 側の制御ループ(毎フレーム): 1. Pi: カメラでフレームを取得する. - 2. Pi → PC: 画像を送信する. - 3. PC: 画像処理・線検出を行う. - 4. PC: 操舵量を計算する(自動時)またはユーザー入力を取得する(手動時). - 5. PC → Pi: 操舵量(throttle,steer)を送信する. - 6. Pi: 受信した操舵量でモーターを制御する. + 2. Pi: 画像処理・線検出を行う. + 3. Pi: 操舵量を計算する(自動時). + 4. Pi: モーターを制御する. + 5. Pi → PC: テレメトリを送信する. + 6. Pi: PC からのコマンドを確認する. - 5-2. 通信要件 + PC 側のループ(タイマー駆動): + 1. PC: テレメトリを受信して映像・状態を表示する. + 2. PC → Pi: コマンドを送信する(モード切替・手動操作・パラメータ更新). - ・双方向通信: Pi → PC(画像),PC → Pi(操舵量). - ・低遅延: 画像取得からモーター反映までの遅延を最小限にする. - ※ 遅延が大きいとコースアウトのリスクが増加する. - ・信頼性: パケットロス時の振る舞いを定義する. - - 画像が届かない場合: 前フレームの操舵量を維持する. - - 操舵量が届かない場合: フェイルセーフでモーター停止する. + 5-2. 通信プロトコル + + ■ テレメトリ(Pi → PC,ZMQ PUB/SUB) + ・メッセージ形式: + - 4 バイト: JSON ヘッダ長(uint32 LE) + - N バイト: JSON テレメトリ(操舵量・検出結果・状態) + - 4 バイト: カメラ画像長(uint32 LE) + - M バイト: JPEG 圧縮カメラ画像 + - 残り: JPEG 圧縮二値画像(デバッグ用,省略可) + ・JSON テレメトリのフィールド: + - v: プロトコルバージョン(config.TELEMETRY_VERSION と一致必須) + - ts: 送信時刻(Unix タイムスタンプ) + - throttle, steer: 現在の操舵量 + - detected: 線検出の成否 + - pos_error: 位置偏差 + - heading: 線の傾き + - is_intersection: 十字路判定結果 + - is_recovering: 復帰動作中か + - fps: Pi 側の処理 FPS + - intersection_available: 十字路分類器の読み込み済み状態 + - compute_ms: 操舵計算の平均処理時間(ミリ秒) + ・PC 側はバージョン不一致のメッセージを破棄する. + + ■ コマンド(PC → Pi,ZMQ PUB/SUB) + ・JSON 形式: + - mode: "auto" | "manual" | "stop" + - throttle, steer: 手動モード時の操舵量 + - steering_method: "pd" | "pursuit" | "ts_pd" + - image_params: 画像処理パラメータの辞書 + - pd_params: PD 制御パラメータの辞書 + - pursuit_params: Pursuit 制御パラメータの辞書 + - steering_params: Theil-Sen PD 制御パラメータの辞書 + - recovery_params: 復帰パラメータの辞書 + - intersection_enabled: 十字路判定の有効/無効 + - intersection_throttle: 十字路通過時の速度 + + 5-3. 通信要件 + + ・双方向通信: Pi → PC(テレメトリ),PC → Pi(コマンド). + ・テレメトリ: CONFLATE により最新フレームのみ保持(古いデータを自動破棄). + ・コマンド: 非ブロッキング受信で制御ループを阻害しない. + ・PC 切断時: Pi は最後のモードで動作を継続する. + auto モードなら自律走行を続け,stop モードなら停止を維持する. diff --git "a/docs/03_TECH/TECH_03_\343\203\207\343\203\220\343\203\203\343\202\260\343\202\252\343\203\274\343\203\220\343\203\274\343\203\254\343\202\244\344\273\225\346\247\230.txt" "b/docs/03_TECH/TECH_03_\343\203\207\343\203\220\343\203\203\343\202\260\343\202\252\343\203\274\343\203\220\343\203\274\343\203\254\343\202\244\344\273\225\346\247\230.txt" index 23d8b3e..4259891 100644 --- "a/docs/03_TECH/TECH_03_\343\203\207\343\203\220\343\203\203\343\202\260\343\202\252\343\203\274\343\203\220\343\203\274\343\203\254\343\202\244\344\273\225\346\247\230.txt" +++ "b/docs/03_TECH/TECH_03_\343\203\207\343\203\220\343\203\203\343\202\260\343\202\252\343\203\274\343\203\220\343\203\274\343\203\254\343\202\244\344\273\225\346\247\230.txt" @@ -15,8 +15,8 @@ 1-1. 基本方針 ・オーバーレイはカメラ映像に重ねて描画する. - ・線検出は接続中は常に実行し,検出情報は映像下のラベルに表示する. - ・自動操縦中は操舵量計算で実行済みの検出結果を再利用する. + ・線検出は Pi 側で実行し,結果をテレメトリで PC に送信する. + ・検出情報は映像下のラベルに表示する. 2. 表示項目 (Overlay Items) @@ -33,6 +33,8 @@ ・パシュート目標点: 2点パシュートの near/far 目標点をシアンの 円(半径 2px)で描画.制御手法がパシュートのとき有効 (手動操作中は検出結果からプレビュー表示) + ・十字路判定枠: 十字路と判定された場合に画像全体の枠を赤色で + 描画する(常時有効,チェックボックスによる切替なし) 2-2. 検出情報ラベル(常時表示) @@ -47,6 +49,7 @@ ・行中心点: (0, 165, 255) オレンジ ・Theil-Sen 直線: (255, 0, 255) マゼンタ ・パシュート目標点: (255, 255, 0) シアン + ・十字路判定枠: (0, 0, 255) 赤 ・二値化オーバーレイ: 赤チャンネルに二値化画像を割り当て @@ -61,6 +64,6 @@ 3-2. 動作モードとの関係 - ・手動操作中: 線検出を常に実行し,検出情報ラベルを更新する - ・自動操縦中: 操舵量計算の線検出結果をそのまま使用する + ・手動操作中・自動操縦中: Pi からのテレメトリに含まれる + 検出結果・二値画像を使用してオーバーレイを描画する ・未接続時: オーバーレイは表示されない(映像がないため) diff --git "a/docs/03_TECH/TECH_04_\347\267\232\346\244\234\345\207\272\347\262\276\345\272\246\345\220\221\344\270\212\346\226\271\351\207\235.txt" "b/docs/03_TECH/TECH_04_\347\267\232\346\244\234\345\207\272\347\262\276\345\272\246\345\220\221\344\270\212\346\226\271\351\207\235.txt" index f715361..a8dd2ce 100644 --- "a/docs/03_TECH/TECH_04_\347\267\232\346\244\234\345\207\272\347\262\276\345\272\246\345\220\221\344\270\212\346\226\271\351\207\235.txt" +++ "b/docs/03_TECH/TECH_04_\347\267\232\346\244\234\345\207\272\347\262\276\345\272\246\345\220\221\344\270\212\346\226\271\351\207\235.txt" @@ -612,7 +612,7 @@ (デフォルト: 7,0 で無効) - neighbor_thresh: 近傍除去の閾値 px (デフォルト: 10.0,0 で無効) - ・実装: _clean_and_fit() 関数(line_detector.py) + ・実装: clean_and_fit() 関数(fitting.py) ・備考: RANSAC と併用可能.RANSAC が有効な場合は メディアン → 近傍除去 → RANSAC の順に適用される @@ -886,13 +886,18 @@ 11-5. 実装ファイル - ・src/pc/vision/line_detector.py + ・src/pc/vision/line_detector.py(PC 側) + ・src/pi/vision/line_detector.py(Pi 側) - ValleyTracker クラス: 時系列追跡の状態管理 - _find_row_valley(): 1行の谷検出 - _detect_valley(): 案Dのメイン処理 - _build_valley_binary(): デバッグ用二値画像生成 - reset_valley_tracker(): 追跡状態のリセット - ・src/pc/gui/main_window.py + ・src/pc/vision/detectors/valley.py(PC 側) + ・src/pi/vision/detectors/valley.py(Pi 側) + - 谷検出の実装 + ・src/pc/gui/panels/image_param_panel.py - 案D用パラメータの GUI コントロール - ・src/pc/steering/pd_control.py + ・src/pc/steering/pd_control.py(PC 側) + ・src/pi/steering/pd_control.py(Pi 側) - reset() 時に追跡状態もリセット diff --git "a/docs/03_TECH/TECH_05_\343\202\263\343\203\274\343\202\271\343\202\242\343\202\246\343\203\210\345\276\251\345\270\260\344\273\225\346\247\230.txt" "b/docs/03_TECH/TECH_05_\343\202\263\343\203\274\343\202\271\343\202\242\343\202\246\343\203\210\345\276\251\345\270\260\344\273\225\346\247\230.txt" index b7c3e89..6b9902b 100644 --- "a/docs/03_TECH/TECH_05_\343\202\263\343\203\274\343\202\271\343\202\242\343\202\246\343\203\210\345\276\251\345\270\260\344\273\225\346\247\230.txt" +++ "b/docs/03_TECH/TECH_05_\343\202\263\343\203\274\343\202\271\343\202\242\343\202\246\343\203\210\345\276\251\345\270\260\344\273\225\346\247\230.txt" @@ -15,9 +15,10 @@ 1-1. 基本方針 ・全制御手法(PD / 2点パシュート / Theil-Sen PD)に共通で適用する - ・制御手法の外側(MainWindow)で復帰判定と操舵量の上書きを行う + ・Pi 側のメインループ(pi/main.py)で復帰判定と操舵量の上書きを行う ・復帰パラメータは GUI の折りたたみパネルでリアルタイムに調整可能 ・パラメータは JSON ファイルに自動保存・復元される + ・PC から復帰パラメータを更新すると Pi に送信され,即座に反映される 2. 復帰判定 (Recovery Trigger) @@ -64,13 +65,13 @@ 用途に応じて GUI で調整する. - 3-3. 処理フロー + 3-3. 処理フロー(Pi 側メインループ) 1. 各制御手法の compute() を通常通り呼び出す 2. 検出結果を RecoveryController.update() に渡す 3. update() が SteeringOutput を返した場合, 制御手法の出力を復帰用の出力で上書きする - 4. 上書きされた操舵量を Pi に送信する + 4. 上書きされた操舵量でモーターを制御する 4. パラメータ一覧 (Parameters) @@ -116,7 +117,8 @@ 6. 実装ファイル (Implementation Files) ------------------------------------------------------------------------ - ・src/pc/steering/recovery.py: RecoveryParams,RecoveryController + ・src/common/steering/recovery.py: RecoveryParams,RecoveryController + ・src/pi/main.py: Pi 側メインループでの復帰ロジック統合 ・src/pc/gui/panels/recovery_panel.py: RecoveryPanel - ・src/pc/gui/main_window.py: 復帰ロジックの統合 + ・src/pc/gui/main_window.py: 復帰パラメータ管理・Pi への送信 ・src/pc/steering/auto_params.py: save_recovery / load_recovery diff --git "a/docs/03_TECH/TECH_06_\345\215\201\345\255\227\350\267\257\345\210\206\351\241\236\343\203\242\343\203\207\343\203\253\350\251\225\344\276\241.txt" "b/docs/03_TECH/TECH_06_\345\215\201\345\255\227\350\267\257\345\210\206\351\241\236\343\203\242\343\203\207\343\203\253\350\251\225\344\276\241.txt" new file mode 100644 index 0000000..19fc01e --- /dev/null +++ "b/docs/03_TECH/TECH_06_\345\215\201\345\255\227\350\267\257\345\210\206\351\241\236\343\203\242\343\203\207\343\203\253\350\251\225\344\276\241.txt" @@ -0,0 +1,76 @@ +======================================================================== +十字路分類モデル評価 (Intersection Classifier Evaluation) +======================================================================== + + +1. 概要 (Overview) +------------------------------------------------------------------------ + + 1-0. 目的 + + 十字路(intersection)と通常区間(normal)を分類する + 二値画像分類モデルの比較評価結果を記録する. + + 1-1. 評価日時 + + ・実施日時: 2026-03-26 00:17 + + 1-2. データセット + + ・入力: 40×30 二値画像(1200 特徴量,0.0/1.0) + ・全サンプル数: 633 + - intersection: 127 + - normal: 506 + ・クラス比率: intersection:normal = 127:506 + + +2. 評価方法 (Evaluation Method) +------------------------------------------------------------------------ + + 2-1. 交差検証 + + ・手法: Stratified 5-Fold Cross-Validation + ・指標: マクロ平均 F1 スコア + ・前処理: StandardScaler(fold ごとに fit) + ・乱数シード: 42 + + +3. 評価結果 (Results) +------------------------------------------------------------------------ + + 3-1. モデル比較(F1 スコア降順) + + モデル F1(平均) F1(標準偏差) + ─────────────────────────────────────────────────────── + SVM_RBF 0.9745 0.0186 ← 採用 + MLP_2layer 0.9665 0.0272 + MLP_1layer 0.9609 0.0307 + RandomForest 0.9356 0.0400 + LogisticRegression 0.9096 0.0197 + SVM_Linear 0.9067 0.0223 + + 3-2. 各 Fold の F1 スコア + + ・SVM_RBF: [0.9881, 0.9632, 1.0000, 0.9743, 0.9468] + ・MLP_2layer: [0.9758, 0.9368, 1.0000, 0.9873, 0.9324] + ・MLP_1layer: [0.9632, 0.9368, 1.0000, 0.9873, 0.9174] + ・RandomForest: [0.9758, 0.9486, 0.8740, 0.9743, 0.9053] + ・LogisticRegression: [0.8972, 0.8938, 0.9325, 0.9346, 0.8899] + ・SVM_Linear: [0.8825, 0.8938, 0.9325, 0.9346, 0.8899] + + 3-3. 採用モデル + + ・モデル: SVM_RBF + ・保存先: params/intersection_model.pkl + ・スケーラ: params/intersection_scaler.pkl + + 3-4. 全データでの分類レポート(再学習後) + + precision recall f1-score support + + normal 1.00 1.00 1.00 506 + intersection 1.00 0.99 1.00 127 + + accuracy 1.00 633 + macro avg 1.00 1.00 1.00 633 + weighted avg 1.00 1.00 1.00 633 diff --git "a/docs/04_ENV/ENV_01_\346\212\200\350\241\223\343\202\271\343\202\277\343\203\203\343\202\257\351\201\270\345\256\232.txt" "b/docs/04_ENV/ENV_01_\346\212\200\350\241\223\343\202\271\343\202\277\343\203\203\343\202\257\351\201\270\345\256\232.txt" index 924416d..74e4f3b 100644 --- "a/docs/04_ENV/ENV_01_\346\212\200\350\241\223\343\202\271\343\202\277\343\203\203\343\202\257\351\201\270\345\256\232.txt" +++ "b/docs/04_ENV/ENV_01_\346\212\200\350\241\223\343\202\271\343\202\277\343\203\203\343\202\257\351\201\270\345\256\232.txt" @@ -15,7 +15,7 @@ 1-1. 全体方針 ・言語: Python で統一する(Pi 側・PC 側ともに). - ・既存資産: モーター制御コード(`src_old/pi/motor.py`)を参考にする. + ・既存資産: モーター制御コード(`src/pi/motor/driver.py`)を参考にする. ・選定基準: 低遅延・軽量・Python との親和性を重視する. @@ -28,8 +28,8 @@ ・通信パターン: PUB/SUB ・オプション: ZMQ_CONFLATE=1(受信側で最新メッセージのみ保持) ・用途: - - Pi → PC: カメラ画像の送信 - - PC → Pi: 操舵量(throttle,steer)の送信 + - Pi → PC: テレメトリ(カメラ画像・検出結果・操舵量)の送信 + - PC → Pi: コマンド(モード切替・パラメータ更新・手動操作)の送信 ■ 選定理由 ・ブローカー不要で軽量,低遅延に適している. @@ -63,18 +63,29 @@ 2-3. 画像処理: OpenCV(cv2) - ・用途: PC 側での画像処理・線検出 - - グレースケール変換 - - ガウシアンブラー - - 二値化 - - 重心算出 + ・用途: Pi 側・PC 側での画像処理・線検出 + - グレースケール変換・ガウシアンブラー・二値化(固定・適応的閾値) + - CLAHE(コントラスト制限付き適応ヒストグラム均等化) + - モルフォロジー処理(Opening・Closing・Black-hat・距離変換) + - 画像リサイズ・JPEG 圧縮 ■ 選定理由 ・ライントレースに必要な処理がすべて揃っている. ・NumPy ベースで PySide6 への画像受け渡しが容易である. ・既存コードでの使用実績がある. - 2-4. カメラ制御: Picamera2 + 2-4. 数値計算: NumPy + + ・用途: Pi 側・PC 側での数値計算・配列操作 + - 多項式フィッティング(numpy.polyfit) + - Theil-Sen 直線近似・RANSAC 外れ値除去 + - 画像の配列操作・統計処理 + + ■ 選定理由 + ・OpenCV と同一のデータ形式(ndarray)で連携が容易である. + ・多項式フィッティングや統計処理の関数が揃っている. + + 2-5. カメラ制御: Picamera2 ・用途: Pi 側でのカメラフレーム取得 @@ -82,16 +93,32 @@ ・Raspberry Pi カメラの標準ライブラリである. ・既存コードでの使用実績がある. - 2-5. モーター制御: RPi.GPIO + 2-6. モーター制御: RPi.GPIO ・用途: Pi 側での TB6612FNG モータードライバ制御 ■ 選定理由 - ・既存コード(`src/pi/motor.py`)をそのまま流用できる. + ・既存コード(`src/pi/motor/driver.py`)をそのまま流用できる. ・PWM 制御・GPIO 出力に必要な機能が揃っている. - 2-6. 環境変数管理: python-dotenv + 2-7. 機械学習: scikit-learn + + ・ライブラリ: scikit-learn, joblib + ・用途: 十字路分類モデルの学習(PC 側)・推論(Pi 側・PC 側) + - 複数モデルの交差検証による比較評価 + - 最良モデルの保存・読み込み + + ■ 選定理由 + ・40×30 二値画像の 2 クラス分類であり,軽量なモデルで十分である. + ・SVM,ロジスティック回帰,MLP 等を統一的な API で比較できる. + ・NumPy ベースで既存コードとの統合が容易である. + + ■ 不採用とした候補 + ・PyTorch/TensorFlow: 入力が 1200 特徴量と小さく, + 深層学習フレームワークは過剰である. + + 2-8. 環境変数管理: python-dotenv ・用途: .env ファイルから環境変数を読み込む - PC の IP アドレス @@ -110,6 +137,10 @@ ・Picamera2: カメラフレーム取得 ・RPi.GPIO: モーター制御 + ・OpenCV: 画像処理・線検出 + ・NumPy: 多項式フィッティング・Theil-Sen 近似・配列操作 + ・scikit-learn: 十字路分類モデルの推論 + ・joblib: モデルとスケーラの読み込み ・pyzmq: PC との通信 ・python-dotenv: 環境変数管理 @@ -117,5 +148,8 @@ ・PySide6: GUI アプリケーション ・OpenCV: 画像処理・線検出 + ・NumPy: 多項式フィッティング・Theil-Sen 近似・配列操作 + ・scikit-learn: 十字路分類モデルの学習・推論 + ・joblib: モデルとスケーラの保存・読み込み ・pyzmq: Pi との通信 ・python-dotenv: 環境変数管理 diff --git "a/docs/04_ENV/ENV_02_PC\347\222\260\345\242\203\346\247\213\347\257\211\346\211\213\351\240\206.txt" "b/docs/04_ENV/ENV_02_PC\347\222\260\345\242\203\346\247\213\347\257\211\346\211\213\351\240\206.txt" index 3c8a5eb..50997a0 100644 --- "a/docs/04_ENV/ENV_02_PC\347\222\260\345\242\203\346\247\213\347\257\211\346\211\213\351\240\206.txt" +++ "b/docs/04_ENV/ENV_02_PC\347\222\260\345\242\203\346\247\213\347\257\211\346\211\213\351\240\206.txt" @@ -30,7 +30,7 @@ プロジェクトルートで以下を実行する. - $ cd c:\Users\rinto\source\RobotCar + $ cd <プロジェクトルート> $ python -m venv .venv ・`.venv/` ディレクトリが作成される. @@ -73,6 +73,9 @@ ・pyzmq (27.1.0): ZMQ 通信 ・numpy (2.4.3): 数値計算(OpenCV の依存ライブラリ) ・python-dotenv (1.2.2): .env ファイルからの環境変数読み込み + ・scikit-learn (1.6.1): 十字路分類モデルの学習・推論 + ・joblib (1.4.2): モデルのシリアライズ + ・pytest (9.0.2): ユニットテスト 3-3. インストール確認 diff --git "a/docs/04_ENV/ENV_03_RaspPi\347\222\260\345\242\203\346\247\213\347\257\211\346\211\213\351\240\206.txt" "b/docs/04_ENV/ENV_03_RaspPi\347\222\260\345\242\203\346\247\213\347\257\211\346\211\213\351\240\206.txt" index 5b4f0ce..0a72ae5 100644 --- "a/docs/04_ENV/ENV_03_RaspPi\347\222\260\345\242\203\346\247\213\347\257\211\346\211\213\351\240\206.txt" +++ "b/docs/04_ENV/ENV_03_RaspPi\347\222\260\345\242\203\346\247\213\347\257\211\346\211\213\351\240\206.txt" @@ -59,8 +59,8 @@ $ bash deploy.sh 処理内容: - 1. Pi 側の common/,pi/ を削除 - 2. src/common/,src/pi/ を転送 + 1. Pi 側の common/,pi/,params/ を削除 + 2. src/common/,src/pi/,params/ を転送 3. .env,requirements_pi.txt を転送 ※ パスワードを複数回入力する必要がある. @@ -133,7 +133,7 @@ $ source .venv/bin/activate (.venv) $ python -m pi.main - 「Pi: カメラ・通信を開始」と表示されれば成功. + 「Pi: カメラ・通信・モーターを開始(自律モード)」と表示されれば成功. 4-2. PC 側の起動 diff --git "a/docs/04_ENV/ENV_04_\343\203\207\343\202\243\343\203\254\343\202\257\343\203\210\343\203\252\346\247\213\346\210\220.txt" "b/docs/04_ENV/ENV_04_\343\203\207\343\202\243\343\203\254\343\202\257\343\203\210\343\203\252\346\247\213\346\210\220.txt" index 178535d..2e61175 100644 --- "a/docs/04_ENV/ENV_04_\343\203\207\343\202\243\343\203\254\343\202\257\343\203\210\343\203\252\346\247\213\346\210\220.txt" +++ "b/docs/04_ENV/ENV_04_\343\203\207\343\202\243\343\203\254\343\202\257\343\203\210\343\203\252\346\247\213\346\210\220.txt" @@ -26,59 +26,90 @@ ├── CLAUDE.md ├── requirements_pc.txt ├── requirements_pi.txt + ├── pytest.ini テスト設定(pythonpath・testpaths) ├── deploy.sh Pi への転送スクリプト ├── .env.example 環境変数テンプレート - ├── pd_params.json パラメータ保存ファイル(.gitignore) ├── docs/ ドキュメント + ├── params/ パラメータ・モデル(.gitignore) + ├── data/ 学習データ(.gitignore) + │ ├── raw/ 収集した二値画像(未仕分け) + │ └── confirmed/ 仕分け済み画像 + │ ├── intersection/ 十字路画像 + │ └── normal/ 通常画像 + ├── tests/ ユニットテスト ├── src/ 自律走行用ソースコード │ ├── common/ 共通設定(PC・Pi 両方で使用) │ ├── pc/ PC 側 │ └── pi/ Pi 側 - └── src_old/ 旧コード(参照用) + └── .venv/ 仮想環境(.gitignore) 2-2. src/common/ common/ - ├── config.py - └── json_utils.py JSON 読み書き共通ユーティリティ + ├── config.py ネットワーク・画像・通信設定 + ├── json_utils.py JSON 読み書き共通ユーティリティ + ├── vision/ 画像処理(PC・Pi 共通) + │ ├── line_detector.py 線検出 API(データクラス・手法ディスパッチ) + │ ├── fitting.py 直線・曲線近似(Theil-Sen・RANSAC・外れ値除去) + │ ├── morphology.py 形態学的処理ユーティリティ + │ ├── intersection.py 十字路分類モデルの推論 + │ └── detectors/ 検出手法の実装 + │ ├── current.py 現行(CLAHE + 固定閾値) + │ ├── blackhat.py 案A(Black-hat 中心) + │ ├── dual_norm.py 案B(二重正規化) + │ ├── robust.py 案C(最高ロバスト) + │ └── valley.py 案D(谷検出+追跡) + └── steering/ 操舵量計算(PC・Pi 共通) + ├── base.py 共通基底クラス(線検出・レートリミッター・reset) + ├── pd_control.py PD 制御の実装 + ├── pursuit_control.py 2点パシュート制御の実装 + ├── ts_pd_control.py Theil-Sen PD 制御の実装 + └── recovery.py コースアウト復帰 - ・PC・Pi 間で共有する設定値・ユーティリティを定義する. - ・config.py: ネットワーク設定,画像フォーマット,通信設定等. + ・PC・Pi 間で共有する設定値・画像処理・操舵量計算を定義する. + ・config.py: ネットワーク設定,画像フォーマット,通信設定, + テレメトリバージョン,表示倍率,ログ間隔等. ・json_utils.py: JSON ファイル読み書きとパラメータディレクトリの定義. + ・vision/: 線検出パイプライン・検出手法・十字路分類を共通化. + ・steering/: PD 制御・パシュート制御・Theil-Sen PD 制御・復帰制御を共通化. 2-3. src/pc/ pc/ ├── main.py エントリーポイント + ├── review.py データ仕分け GUI エントリーポイント ├── gui/ GUI 関連 - │ └── main_window.py メインウィンドウ + │ ├── main_window.py メインウィンドウ(レイアウト・ライフサイクル管理) + │ ├── telemetry_display.py テレメトリ受信・映像表示・オーバーレイ + │ ├── command_sender.py コマンド構築・dirty 管理・ZMQ 送信 + │ ├── manual_controller.py キー入力 → throttle/steer 変換 + │ └── panels/ パラメータ調整パネル群 + │ ├── collapsible_group_box.py 折りたたみ GroupBox + │ ├── control_param_panel.py 制御パラメータ + │ ├── image_param_panel.py 二値化パラメータ + │ ├── intersection_panel.py 十字路判定 + │ ├── overlay_panel.py デバッグ表示 + │ └── recovery_panel.py コースアウト復帰 ├── comm/ 通信関連 - │ └── zmq_client.py ZMQ 送受信 - ├── steering/ 操舵量計算(独立モジュール) - │ ├── base.py 共通インターフェース - │ ├── pd_control.py PD 制御の実装 - │ ├── pursuit_control.py 2点パシュート制御の実装 - │ ├── ts_pd_control.py Theil-Sen PD 制御の実装 + │ └── zmq_client.py ZMQ テレメトリ受信・コマンド送信 + ├── data/ 学習データ収集・仕分け・学習 + │ ├── __main__.py 学習スクリプトエントリーポイント + │ ├── collector.py 二値画像のラベル付き保存 + │ ├── reviewer.py 仕分けレビュー GUI + │ ├── dataset.py データ読み込み + │ └── train.py モデル学習・評価・保存 + ├── steering/ PC 固有の操舵パラメータ管理 │ ├── param_store.py プリセット保存・読み込み │ └── auto_params.py パラメータ自動保存・復元 - └── vision/ 画像処理 - ├── line_detector.py 線検出 API(データクラス・手法ディスパッチ) - ├── fitting.py 直線・曲線近似(Theil-Sen・RANSAC・外れ値除去) - ├── morphology.py 形態学的処理ユーティリティ - ├── overlay.py デバッグオーバーレイ描画 - └── detectors/ 検出手法の実装 - ├── current.py 現行(CLAHE + 固定閾値) - ├── blackhat.py 案A(Black-hat 中心) - ├── dual_norm.py 案B(二重正規化) - ├── robust.py 案C(最高ロバスト) - └── valley.py 案D(谷検出+追跡) + └── vision/ PC 固有の画像処理 + └── overlay.py デバッグオーバーレイ描画 2-4. src/pi/ pi/ - ├── main.py エントリーポイント + ├── main.py エントリーポイント(自律制御ループ) ├── comm/ 通信関連 - │ └── zmq_client.py ZMQ 送受信 + │ └── zmq_client.py ZMQ テレメトリ送信・コマンド受信 ├── camera/ カメラ関連 │ └── capture.py フレーム取得 └── motor/ モーター関連 diff --git "a/docs/05_TEST/TEST_01_\343\203\206\343\202\271\343\203\210\346\226\271\351\207\235.txt" "b/docs/05_TEST/TEST_01_\343\203\206\343\202\271\343\203\210\346\226\271\351\207\235.txt" new file mode 100644 index 0000000..b1f6c53 --- /dev/null +++ "b/docs/05_TEST/TEST_01_\343\203\206\343\202\271\343\203\210\346\226\271\351\207\235.txt" @@ -0,0 +1,150 @@ +======================================================================== +テスト方針 (Testing Policy) +======================================================================== + + +1. 概要 (Overview) +------------------------------------------------------------------------ + + 1-0. 目的 + + 本プロジェクトにおけるテストの方針・実行方法・構成を定義する. + テストファイルを追加・変更した場合は本ドキュメントも更新すること. + + 1-1. テストの位置づけ + + ライントレース自律走行の性能は,線検出精度と操舵量計算の品質に + 直結する.これらの数値ロジックは試行錯誤で頻繁に変更されるため, + 回帰を防止するユニットテストを整備し,変更の安全性を担保する. + + +2. テスト方針 (Policy) +------------------------------------------------------------------------ + + 2-1. テスト対象の優先順位 + + 以下の優先順位でテストを整備する: + + 1. 操舵量計算(steering/): 制御ロジックの正しさを担保する + 2. 線検出(vision/): 検出手法ごとの検出成否・偏差の妥当性を検証する + 3. 近似・フィッティング(vision/fitting.py): 外れ値処理・境界条件を検証する + 4. 形態学的処理(vision/morphology.py): 画像フィルタの挙動を検証する + 5. ユーティリティ(json_utils.py 等): データ永続化の正しさを検証する + + 2-2. テストの種類 + + ■ ユニットテスト + ・対象: `src/common/` 配下の画像処理・操舵量計算・ユーティリティ + ・方針: 合成画像(フィクスチャ)を入力に使用し, + 外部デバイス(カメラ・モーター・ネットワーク)に依存しない + ・配置: `tests/` ディレクトリに集約する + + ■ 実機テスト + ・対象: Pi 上でのカメラ取得・モーター制御・通信 + ・方針: 自動化の対象外とし,手動で実施する + ・手順: 本ドキュメントの範囲外とする + + 2-3. テスト設計の原則 + + ・合成画像を使用する: 実画像に依存せず,テストの再現性を確保する + ・境界条件を含める: 最小データ点数,空画像,未検出時の挙動を検証する + ・外れ値耐性を検証する: ロバスト推定手法の外れ値処理を確認する + ・浮動小数点の比較: `pytest.approx()` で許容誤差を明示する + + +3. テスト環境 (Environment) +------------------------------------------------------------------------ + + 3-1. フレームワーク + + ・pytest を使用する + ・設定ファイル: `pytest.ini`(pythonpath・testpaths を定義) + + 3-2. テスト実行 + + ・全テスト実行: + + $ pytest + + ・特定ファイルの実行: + + $ pytest tests/test_steering.py + + ・特定テストクラスの実行: + + $ pytest tests/test_steering.py::TestPdControl + + ・詳細出力: + + $ pytest -v + + 3-3. 前提条件 + + ・PC 環境(`requirements_pc.txt`)がインストール済みであること + ・Pi 固有のライブラリ(Picamera2,RPi.GPIO)は不要 + + +4. テスト構成 (Structure) +------------------------------------------------------------------------ + + 4-1. ディレクトリ構成 + + tests/ + ├── conftest.py 共通フィクスチャ + ├── test_steering.py 操舵量計算(PD・Pursuit・TsPd) + ├── test_line_detector.py 線検出(5手法 + フィッティング検出) + ├── test_fitting.py 直線・曲線近似(Theil-Sen・RANSAC) + ├── test_morphology.py 形態学的処理 + ├── test_params.py パラメータ保存・復元・プリセット + └── test_json_utils.py JSON ユーティリティ + + 4-2. 共通フィクスチャ(conftest.py) + + テスト用の合成画像を `conftest.py` で定義し,複数のテストファイルで + 共有する: + + ・`straight_line_image`: 中央に暗い縦線があるグレースケール画像 + ・`blank_image`: 線のない均一なグレースケール画像 + ・`binary_with_hole`: 中央に穴がある二値画像(クロージングテスト用) + ・`binary_line`: 中央に太い白線がある二値画像 + + ※ 画像サイズは `config.FRAME_WIDTH` x `config.FRAME_HEIGHT` に従う. + + 4-3. テストファイルと対象モジュールの対応 + + テストファイル 対象モジュール + ───────────────────── ─────────────────────────────── + test_steering.py common.steering.pd_control + common.steering.pursuit_control + common.steering.ts_pd_control + test_line_detector.py common.vision.line_detector + test_fitting.py common.vision.fitting + test_morphology.py common.vision.morphology + test_params.py pc.steering.auto_params + pc.steering.param_store + test_json_utils.py common.json_utils + + +5. テスト追加時のルール (Rules for Adding Tests) +------------------------------------------------------------------------ + + 5-1. ファイル命名 + + ・`test_<対象モジュール名>.py` とする + ・新規モジュールを追加した場合,対応するテストファイルも作成する + + 5-2. クラス・関数命名 + + ・テストクラス: `Test<テスト対象>` (PascalCase) + ・テスト関数: `test_<テスト内容>` (snake_case) + ・テスト内容が一読で分かる命名にする + + 5-3. フィクスチャの追加 + + ・複数テストファイルで共有するフィクスチャは `conftest.py` に定義する + ・単一ファイル内でのみ使用するフィクスチャはそのファイル内に定義する + + 5-4. コーディング規則 + + ・`GUIDE_04_コーディング規則.txt` に従う + ・docstring は日本語,Google スタイル,句点なし diff --git a/requirements_pc.txt b/requirements_pc.txt index 89bfd92..ec15433 100644 --- a/requirements_pc.txt +++ b/requirements_pc.txt @@ -3,4 +3,6 @@ pyzmq==27.1.0 numpy==2.4.3 python-dotenv==1.2.2 +scikit-learn==1.6.1 +joblib==1.4.2 pytest==9.0.2 diff --git a/requirements_pi.txt b/requirements_pi.txt index 8b753d4..2b73096 100644 --- a/requirements_pi.txt +++ b/requirements_pi.txt @@ -2,3 +2,7 @@ picamera2 RPi.GPIO python-dotenv==1.2.2 +opencv-python-headless +numpy +scikit-learn +joblib diff --git a/src/common/config.py b/src/common/config.py index 8596713..db73bd7 100644 --- a/src/common/config.py +++ b/src/common/config.py @@ -18,6 +18,12 @@ _env_path = _search_dir / ".env" load_dotenv(_env_path) +# ── テレメトリプロトコル ────────────────────────────────────── + +# テレメトリメッセージのプロトコルバージョン +# Pi/PC 間でこの値が一致しないと正しくパースできない +TELEMETRY_VERSION: int = 3 + # ── ネットワーク設定(.env から読み込み) ────────────────────── # PC の IP アドレス @@ -46,6 +52,19 @@ # JPEG 圧縮品質 (0-100) JPEG_QUALITY: int = 55 +# 二値画像の JPEG 圧縮品質 (0-100) +JPEG_QUALITY_BINARY: int = 80 + +# ── 表示設定 ────────────────────────────────────────────── + +# GUI 表示倍率(FRAME_WIDTH/HEIGHT → 表示サイズ) +DISPLAY_SCALE: float = 16.0 + +# ── Pi 側ログ設定 ───────────────────────────────────────── + +# FPS ログの出力間隔(秒) +LOG_INTERVAL_SEC: float = 3.0 + # ── 通信設定 ────────────────────────────────────────────── # 操舵量の送信頻度 (Hz) diff --git a/src/common/json_utils.py b/src/common/json_utils.py index 0ce705a..969c5f0 100644 --- a/src/common/json_utils.py +++ b/src/common/json_utils.py @@ -7,9 +7,21 @@ from pathlib import Path # プロジェクトルートの params/ ディレクトリ -PARAMS_DIR: Path = ( - Path(__file__).resolve().parent.parent.parent / "params" -) +# PC では src/common/ の3階層上,Pi では common/ の2階層上になるため +# 上方向に探索して params/ を見つける +def _find_params_dir() -> Path: + """params/ ディレクトリを上方向に探索して返す""" + d = Path(__file__).resolve().parent + while d != d.parent: + candidate = d / "params" + if candidate.is_dir(): + return candidate + d = d.parent + # 見つからない場合はデフォルト + return Path(__file__).resolve().parent.parent / "params" + + +PARAMS_DIR: Path = _find_params_dir() def read_json(path: Path) -> dict: diff --git a/src/common/steering/__init__.py b/src/common/steering/__init__.py new file mode 100644 index 0000000..e69de29 --- /dev/null +++ b/src/common/steering/__init__.py diff --git a/src/common/steering/base.py b/src/common/steering/base.py new file mode 100644 index 0000000..cbb365f --- /dev/null +++ b/src/common/steering/base.py @@ -0,0 +1,109 @@ +""" +base +操舵量計算の共通インターフェースを定義するモジュール +全ての操舵量計算クラスはこのインターフェースに従う +""" + +from abc import ABC, abstractmethod +from dataclasses import dataclass + +import numpy as np + +from common.vision.line_detector import ( + ImageParams, + LineDetectResult, + detect_line, + reset_valley_tracker, +) + + +@dataclass +class SteeringOutput: + """操舵量計算の出力を格納するデータクラス + + Attributes: + throttle: 前後方向の出力 (-1.0 ~ +1.0) + steer: 左右方向の出力 (-1.0 ~ +1.0) + """ + throttle: float + steer: float + + +class SteeringBase(ABC): + """操舵量計算の基底クラス + + 線検出・レートリミッター・状態管理の共通ロジックを提供し, + サブクラスは _compute_from_result で操舵計算のみ実装する + """ + + def __init__( + self, + image_params: ImageParams | None = None, + ) -> None: + self.image_params: ImageParams = ( + image_params or ImageParams() + ) + self._prev_steer: float = 0.0 + self._last_result: LineDetectResult | None = None + + def compute( + self, frame: np.ndarray, + ) -> SteeringOutput: + """カメラ画像から操舵量を計算する + + 線検出 → サブクラスの操舵計算 → レートリミッター + の共通フローを実行する + + Args: + frame: グレースケールのカメラ画像 + + Returns: + 計算された操舵量 + """ + result = detect_line(frame, self.image_params) + self._last_result = result + + output = self._compute_from_result(result) + + # レートリミッター + max_rate = self._max_steer_rate() + delta = output.steer - self._prev_steer + delta = max(-max_rate, min(max_rate, delta)) + output.steer = self._prev_steer + delta + self._prev_steer = output.steer + + return output + + @abstractmethod + def _compute_from_result( + self, result: LineDetectResult, + ) -> SteeringOutput: + """線検出結果から操舵量を計算する + + サブクラスで操舵アルゴリズムを実装する. + レートリミッターは基底クラスが適用するため, + ここでは素の操舵量を返せばよい + + Args: + result: 線検出の結果 + + Returns: + 計算された操舵量(レートリミッター適用前) + """ + + @abstractmethod + def _max_steer_rate(self) -> float: + """1フレームあたりの最大操舵変化量を返す""" + + def reset(self) -> None: + """内部状態をリセットする""" + self._prev_steer = 0.0 + self._last_result = None + reset_valley_tracker() + + @property + def last_detect_result( + self, + ) -> LineDetectResult | None: + """直近の線検出結果を取得する""" + return self._last_result diff --git a/src/common/steering/pd_control.py b/src/common/steering/pd_control.py new file mode 100644 index 0000000..2579b30 --- /dev/null +++ b/src/common/steering/pd_control.py @@ -0,0 +1,114 @@ +""" +pd_control +PD 制御による操舵量計算モジュール +多項式フィッティングの位置・傾き・曲率から操舵量と速度を算出する +""" + +import time +from dataclasses import dataclass + +import numpy as np + +from common.steering.base import SteeringBase, SteeringOutput +from common.vision.line_detector import ( + ImageParams, + LineDetectResult, +) + + +@dataclass +class PdParams: + """PD 制御のパラメータ + + Attributes: + kp: 位置偏差ゲイン + kh: 傾き(ヘディング)ゲイン + kd: 微分ゲイン + max_steer_rate: 1フレームあたりの最大操舵変化量 + max_throttle: 直線での最大速度 + speed_k: 曲率ベースの減速係数 + """ + kp: float = 0.5 + kh: float = 0.3 + kd: float = 0.1 + max_steer_rate: float = 0.1 + max_throttle: float = 0.4 + speed_k: float = 0.3 + + +class PdControl(SteeringBase): + """PD 制御による操舵量計算クラス""" + + def __init__( + self, + params: PdParams | None = None, + image_params: ImageParams | None = None, + ) -> None: + super().__init__(image_params) + self.params: PdParams = params or PdParams() + self._prev_error: float = 0.0 + self._prev_time: float = 0.0 + + def _compute_from_result( + self, result: LineDetectResult, + ) -> SteeringOutput: + """PD 制御で操舵量を計算する + + Args: + result: 線検出の結果 + + Returns: + 計算された操舵量 + """ + if not result.detected: + return SteeringOutput( + throttle=0.0, steer=0.0, + ) + + p = self.params + + # 位置偏差 + 傾きによる操舵量 + error = ( + p.kp * result.position_error + + p.kh * result.heading + ) + + # 時間差分の計算 + now = time.time() + dt = ( + now - self._prev_time + if self._prev_time > 0 + else 0.033 + ) + dt = max(dt, 0.001) + + # 微分項 + derivative = (error - self._prev_error) / dt + steer = error + p.kd * derivative + + # 操舵量のクランプ + steer = max(-1.0, min(1.0, steer)) + + # 速度制御(曲率連動) + throttle = ( + p.max_throttle + - p.speed_k * abs(result.curvature) + ) + throttle = max(0.0, throttle) + + # 状態の更新 + self._prev_error = error + self._prev_time = now + + return SteeringOutput( + throttle=throttle, steer=steer, + ) + + def _max_steer_rate(self) -> float: + return self.params.max_steer_rate + + def reset(self) -> None: + """内部状態をリセットする""" + super().reset() + self._prev_error = 0.0 + self._prev_time = 0.0 diff --git a/src/common/steering/pursuit_control.py b/src/common/steering/pursuit_control.py new file mode 100644 index 0000000..4e00ac9 --- /dev/null +++ b/src/common/steering/pursuit_control.py @@ -0,0 +1,118 @@ +""" +pursuit_control +2点パシュートによる操舵量計算モジュール +行中心点に Theil-Sen 直線近似を適用し,外れ値に強い操舵量を算出する +""" + +from dataclasses import dataclass + +import numpy as np + +from common import config +from common.steering.base import SteeringBase, SteeringOutput +from common.vision.fitting import theil_sen_fit +from common.vision.line_detector import ( + ImageParams, + LineDetectResult, +) + + +@dataclass +class PursuitParams: + """2点パシュート制御のパラメータ + + Attributes: + near_ratio: 近い目標点の位置(0.0=上端,1.0=下端) + far_ratio: 遠い目標点の位置(0.0=上端,1.0=下端) + k_near: 近い目標点の操舵ゲイン + k_far: 遠い目標点の操舵ゲイン + max_steer_rate: 1フレームあたりの最大操舵変化量 + max_throttle: 直線での最大速度 + speed_k: カーブ減速係数(2点の差に対する係数) + """ + near_ratio: float = 0.8 + far_ratio: float = 0.3 + k_near: float = 0.5 + k_far: float = 0.3 + max_steer_rate: float = 0.1 + max_throttle: float = 0.4 + speed_k: float = 2.0 + + +class PursuitControl(SteeringBase): + """2点パシュートによる操舵量計算クラス + + 行中心点から Theil-Sen 直線近似を行い, + 直線上の近い点と遠い点の偏差から操舵量を計算する + """ + + def __init__( + self, + params: PursuitParams | None = None, + image_params: ImageParams | None = None, + ) -> None: + super().__init__(image_params) + self.params: PursuitParams = ( + params or PursuitParams() + ) + + def _compute_from_result( + self, result: LineDetectResult, + ) -> SteeringOutput: + """2点パシュートで操舵量を計算する + + Args: + result: 線検出の結果 + + Returns: + 計算された操舵量 + """ + if not result.detected or result.row_centers is None: + return SteeringOutput( + throttle=0.0, steer=0.0, + ) + + p = self.params + centers = result.row_centers + + # 有効な点(NaN でない行)を抽出 + valid = ~np.isnan(centers) + ys = np.where(valid)[0].astype(float) + xs = centers[valid] + + if len(ys) < 2: + return SteeringOutput( + throttle=0.0, steer=0.0, + ) + + # Theil-Sen 直線近似 + slope, intercept = theil_sen_fit(ys, xs) + + center_x = config.FRAME_WIDTH / 2.0 + h = len(centers) + + # 直線上の 2 点の x 座標を取得 + near_y = h * p.near_ratio + far_y = h * p.far_ratio + near_x = slope * near_y + intercept + far_x = slope * far_y + intercept + + # 各点の偏差(正: 線が左にある → 右に曲がる) + near_err = (center_x - near_x) / center_x + far_err = (center_x - far_x) / center_x + + # 操舵量 + steer = p.k_near * near_err + p.k_far * far_err + steer = max(-1.0, min(1.0, steer)) + + # 速度制御(2点の x 差でカーブ度合いを判定) + curve = abs(near_x - far_x) / center_x + throttle = p.max_throttle - p.speed_k * curve + throttle = max(0.0, throttle) + + return SteeringOutput( + throttle=throttle, steer=steer, + ) + + def _max_steer_rate(self) -> float: + return self.params.max_steer_rate diff --git a/src/common/steering/recovery.py b/src/common/steering/recovery.py new file mode 100644 index 0000000..57087bf --- /dev/null +++ b/src/common/steering/recovery.py @@ -0,0 +1,104 @@ +""" +recovery +コースアウト復帰のパラメータと判定ロジックを定義するモジュール +黒線を一定時間検出できなかった場合に, +最後に検出した方向へ旋回しながら走行して復帰する +""" + +import time +from dataclasses import dataclass + +from common.steering.base import SteeringOutput + + +@dataclass +class RecoveryParams: + """コースアウト復帰のパラメータ + + Attributes: + enabled: 復帰機能の有効/無効 + timeout_sec: 線を見失ってから復帰動作を開始するまでの時間 + steer_amount: 復帰時の操舵量(0.0 ~ 1.0) + throttle: 復帰時の速度(負: 後退,正: 前進) + """ + enabled: bool = True + timeout_sec: float = 0.5 + steer_amount: float = 0.5 + throttle: float = -0.3 + + +class RecoveryController: + """コースアウト復帰の判定と操舵量算出を行うクラス + + 自動操縦中にフレームごとに呼び出し, + 線検出の成否を記録する.一定時間検出できなかった場合に + 復帰用の操舵量を返す + """ + + def __init__( + self, + params: RecoveryParams | None = None, + ) -> None: + self.params: RecoveryParams = ( + params or RecoveryParams() + ) + self._last_detected_time: float = 0.0 + self._last_error_sign: float = 0.0 + self._is_recovering: bool = False + + def reset(self) -> None: + """内部状態をリセットする + + 自動操縦の開始時に呼び出す + """ + self._last_detected_time = time.time() + self._last_error_sign = 0.0 + self._is_recovering = False + + def update( + self, + detected: bool, + position_error: float = 0.0, + ) -> SteeringOutput | None: + """検出結果を記録し,復帰が必要なら操舵量を返す + + Args: + detected: 線が検出できたか + position_error: 検出時の位置偏差(正: 線が左) + + Returns: + 復帰操舵量,または None(通常走行を継続) + """ + if not self.params.enabled: + return None + + now = time.time() + + if detected: + self._last_detected_time = now + if position_error != 0.0: + self._last_error_sign = ( + 1.0 if position_error > 0 else -1.0 + ) + self._is_recovering = False + return None + + elapsed = now - self._last_detected_time + if elapsed < self.params.timeout_sec: + return None + + # 復帰モード: 最後に検出した方向へ旋回 + self._is_recovering = True + steer = ( + self._last_error_sign + * self.params.steer_amount + ) + return SteeringOutput( + throttle=self.params.throttle, + steer=steer, + ) + + @property + def is_recovering(self) -> bool: + """現在復帰動作中かどうかを返す""" + return self._is_recovering diff --git a/src/common/steering/ts_pd_control.py b/src/common/steering/ts_pd_control.py new file mode 100644 index 0000000..e4f2f41 --- /dev/null +++ b/src/common/steering/ts_pd_control.py @@ -0,0 +1,138 @@ +""" +ts_pd_control +Theil-Sen 直線近似による PD 制御モジュール +行中心点に Theil-Sen 直線をフィッティングし, +位置偏差・傾き・微分項から操舵量を算出する +""" + +import time +from dataclasses import dataclass + +import numpy as np + +from common import config +from common.steering.base import SteeringBase, SteeringOutput +from common.vision.fitting import theil_sen_fit +from common.vision.line_detector import ( + ImageParams, + LineDetectResult, +) + + +@dataclass +class TsPdParams: + """Theil-Sen PD 制御のパラメータ + + Attributes: + kp: 位置偏差ゲイン + kh: 傾き(Theil-Sen slope)ゲイン + kd: 微分ゲイン + max_steer_rate: 1フレームあたりの最大操舵変化量 + max_throttle: 直線での最大速度 + speed_k: 傾きベースの減速係数 + """ + kp: float = 0.5 + kh: float = 0.3 + kd: float = 0.1 + max_steer_rate: float = 0.1 + max_throttle: float = 0.4 + speed_k: float = 2.0 + + +class TsPdControl(SteeringBase): + """Theil-Sen 直線近似による PD 制御クラス + + 行中心点から Theil-Sen 直線近似を行い, + 画像下端での位置偏差と直線の傾きから PD 制御で操舵量を計算する + """ + + def __init__( + self, + params: TsPdParams | None = None, + image_params: ImageParams | None = None, + ) -> None: + super().__init__(image_params) + self.params: TsPdParams = ( + params or TsPdParams() + ) + self._prev_error: float = 0.0 + self._prev_time: float = 0.0 + + def _compute_from_result( + self, result: LineDetectResult, + ) -> SteeringOutput: + """Theil-Sen PD 制御で操舵量を計算する + + Args: + result: 線検出の結果 + + Returns: + 計算された操舵量 + """ + if not result.detected or result.row_centers is None: + return SteeringOutput( + throttle=0.0, steer=0.0, + ) + + p = self.params + centers = result.row_centers + + # 有効な点(NaN でない行)を抽出 + valid = ~np.isnan(centers) + ys = np.where(valid)[0].astype(float) + xs = centers[valid] + + if len(ys) < 2: + return SteeringOutput( + throttle=0.0, steer=0.0, + ) + + # Theil-Sen 直線近似: x = slope * y + intercept + slope, intercept = theil_sen_fit(ys, xs) + + center_x = config.FRAME_WIDTH / 2.0 + h = len(centers) + + # 画像下端での位置偏差 + bottom_x = slope * (h - 1) + intercept + position_error = (center_x - bottom_x) / center_x + + # 操舵量: P 項(位置偏差)+ Heading 項(傾き) + error = p.kp * position_error + p.kh * slope + + # 時間差分の計算 + now = time.time() + dt = ( + now - self._prev_time + if self._prev_time > 0 + else 0.033 + ) + dt = max(dt, 0.001) + + # D 項(微分項) + derivative = (error - self._prev_error) / dt + steer = error + p.kd * derivative + + # 操舵量のクランプ + steer = max(-1.0, min(1.0, steer)) + + # 速度制御(傾きベース) + throttle = p.max_throttle - p.speed_k * abs(slope) + throttle = max(0.0, throttle) + + # 状態の更新 + self._prev_error = error + self._prev_time = now + + return SteeringOutput( + throttle=throttle, steer=steer, + ) + + def _max_steer_rate(self) -> float: + return self.params.max_steer_rate + + def reset(self) -> None: + """内部状態をリセットする""" + super().reset() + self._prev_error = 0.0 + self._prev_time = 0.0 diff --git a/src/common/vision/__init__.py b/src/common/vision/__init__.py new file mode 100644 index 0000000..e69de29 --- /dev/null +++ b/src/common/vision/__init__.py diff --git a/src/common/vision/detectors/__init__.py b/src/common/vision/detectors/__init__.py new file mode 100644 index 0000000..e69de29 --- /dev/null +++ b/src/common/vision/detectors/__init__.py diff --git a/src/common/vision/detectors/blackhat.py b/src/common/vision/detectors/blackhat.py new file mode 100644 index 0000000..0860cd9 --- /dev/null +++ b/src/common/vision/detectors/blackhat.py @@ -0,0 +1,69 @@ +""" +blackhat +案A: Black-hat 中心型の線検出 +Black-hat 変換で背景より暗い構造を直接抽出し, +固定閾値 + 距離変換 + 行ごと中心抽出で検出する +""" + +import cv2 +import numpy as np + +from common.vision.line_detector import ( + ImageParams, + LineDetectResult, + fit_row_centers, +) +from common.vision.morphology import ( + apply_dist_mask, + apply_iso_closing, + apply_width_filter, +) + + +def detect_blackhat( + frame: np.ndarray, params: ImageParams, +) -> LineDetectResult: + """案A: Black-hat 中心型""" + # Black-hat 変換(暗い構造の抽出) + bh_k = params.blackhat_ksize | 1 + bh_kernel = cv2.getStructuringElement( + cv2.MORPH_ELLIPSE, (bh_k, bh_k), + ) + blackhat = cv2.morphologyEx( + frame, cv2.MORPH_BLACKHAT, bh_kernel, + ) + + # ガウシアンブラー + blur_k = params.blur_size | 1 + blurred = cv2.GaussianBlur( + blackhat, (blur_k, blur_k), 0, + ) + + # 固定閾値(Black-hat 後は線が白) + _, binary = cv2.threshold( + blurred, params.binary_thresh, 255, + cv2.THRESH_BINARY, + ) + + # 等方クロージング + 距離変換マスク + 幅フィルタ + binary = apply_iso_closing( + binary, params.iso_close_size, + ) + binary = apply_dist_mask( + binary, params.dist_thresh, + ) + if params.width_near > 0 and params.width_far > 0: + binary = apply_width_filter( + binary, + params.width_near, + params.width_far, + params.width_tolerance, + ) + + # 行ごと中心抽出 + フィッティング + return fit_row_centers( + binary, params.min_line_width, + median_ksize=params.median_ksize, + neighbor_thresh=params.neighbor_thresh, + residual_thresh=params.residual_thresh, + ) diff --git a/src/common/vision/detectors/current.py b/src/common/vision/detectors/current.py new file mode 100644 index 0000000..98c4310 --- /dev/null +++ b/src/common/vision/detectors/current.py @@ -0,0 +1,76 @@ +""" +current +現行手法: CLAHE + 固定閾値 + 全ピクセルフィッティング +""" + +import cv2 +import numpy as np + +from common.vision.line_detector import ( + DETECT_Y_END, + DETECT_Y_START, + MIN_FIT_PIXELS, + ImageParams, + LineDetectResult, + build_result, + no_detection, +) + + +def detect_current( + frame: np.ndarray, params: ImageParams, +) -> LineDetectResult: + """現行手法: CLAHE + 固定閾値 + 全ピクセルフィッティング""" + # CLAHE でコントラスト強調 + clahe = cv2.createCLAHE( + clipLimit=params.clahe_clip, + tileGridSize=( + params.clahe_grid, + params.clahe_grid, + ), + ) + enhanced = clahe.apply(frame) + + # ガウシアンブラー + blur_k = params.blur_size | 1 + blurred = cv2.GaussianBlur( + enhanced, (blur_k, blur_k), 0, + ) + + # 固定閾値で二値化(黒線を白に反転) + _, binary = cv2.threshold( + blurred, params.binary_thresh, 255, + cv2.THRESH_BINARY_INV, + ) + + # オープニング(孤立ノイズ除去) + if params.open_size >= 3: + open_k = params.open_size | 1 + open_kernel = cv2.getStructuringElement( + cv2.MORPH_ELLIPSE, (open_k, open_k), + ) + binary = cv2.morphologyEx( + binary, cv2.MORPH_OPEN, open_kernel, + ) + + # 横方向クロージング(途切れ補間) + if params.close_width >= 3: + close_h = max(params.close_height | 1, 1) + close_kernel = cv2.getStructuringElement( + cv2.MORPH_ELLIPSE, + (params.close_width, close_h), + ) + binary = cv2.morphologyEx( + binary, cv2.MORPH_CLOSE, close_kernel, + ) + + # 全ピクセルフィッティング + region = binary[DETECT_Y_START:DETECT_Y_END, :] + ys_local, xs = np.where(region > 0) + + if len(xs) < MIN_FIT_PIXELS: + return no_detection(binary) + + ys = ys_local + DETECT_Y_START + coeffs = np.polyfit(ys, xs, 2) + return build_result(coeffs, binary) diff --git a/src/common/vision/detectors/dual_norm.py b/src/common/vision/detectors/dual_norm.py new file mode 100644 index 0000000..9bd92af --- /dev/null +++ b/src/common/vision/detectors/dual_norm.py @@ -0,0 +1,148 @@ +""" +dual_norm +案B: 二重正規化型の線検出 +背景除算で照明勾配を除去し, +適応的閾値で局所ムラにも対応する二重防壁構成 +""" + +import time + +import cv2 +import numpy as np + +from common.vision.line_detector import ( + ImageParams, + LineDetectResult, + fit_row_centers, +) +from common.vision.morphology import ( + apply_dist_mask, + apply_iso_closing, + apply_staged_closing, + apply_width_filter, +) + +# 内訳計測用の累積値 +_profile_count: int = 0 +_profile_sums: dict[str, float] = {} +_profile_start: float = 0.0 +_PROFILE_INTERVAL: float = 3.0 + + +def _profile_reset() -> None: + """計測値をリセットする""" + global _profile_count, _profile_sums, _profile_start + _profile_count = 0 + _profile_sums = {} + _profile_start = time.time() + + +def _profile_record(label: str, elapsed: float) -> None: + """計測値を記録する""" + _profile_sums[label] = ( + _profile_sums.get(label, 0.0) + elapsed + ) + + +def _profile_print() -> None: + """計測結果を出力する""" + global _profile_count, _profile_start + if _profile_count == 0: + return + elapsed = time.time() - _profile_start + if elapsed < _PROFILE_INTERVAL: + return + parts = [] + for label, total in _profile_sums.items(): + avg = total / _profile_count * 1000.0 + parts.append(f"{label}={avg:.1f}ms") + print(f"Pi: dual_norm内訳({_profile_count}f) " + + " ".join(parts)) + _profile_reset() + + +def detect_dual_norm( + frame: np.ndarray, params: ImageParams, +) -> LineDetectResult: + """案B: 二重正規化型""" + global _profile_count + + if _profile_count == 0 and not _profile_sums: + _profile_reset() + + # 背景除算正規化 + t0 = time.time() + bg_k = params.bg_blur_ksize | 1 + bg = cv2.GaussianBlur( + frame, (bg_k, bg_k), 0, + ) + normalized = ( + frame.astype(np.float32) * 255.0 + / (bg.astype(np.float32) + 1.0) + ) + normalized = np.clip( + normalized, 0, 255, + ).astype(np.uint8) + t1 = time.time() + _profile_record("背景除算", t1 - t0) + + # 適応的閾値(ガウシアン,BINARY_INV) + block = max(params.adaptive_block | 1, 3) + binary = cv2.adaptiveThreshold( + normalized, 255, + cv2.ADAPTIVE_THRESH_GAUSSIAN_C, + cv2.THRESH_BINARY_INV, + block, params.adaptive_c, + ) + + # 固定閾値との AND(有効時のみ) + if params.global_thresh > 0: + _, global_mask = cv2.threshold( + normalized, params.global_thresh, + 255, cv2.THRESH_BINARY_INV, + ) + binary = cv2.bitwise_and(binary, global_mask) + t2 = time.time() + _profile_record("閾値処理", t2 - t1) + + # 段階クロージング or 等方クロージング + if params.stage_min_area > 0: + binary = apply_staged_closing( + binary, + params.stage_close_small, + params.stage_min_area, + params.stage_close_large, + ) + else: + binary = apply_iso_closing( + binary, params.iso_close_size, + ) + + # 距離変換マスク + 幅フィルタ + binary = apply_dist_mask( + binary, params.dist_thresh, + ) + if params.width_near > 0 and params.width_far > 0: + binary = apply_width_filter( + binary, + params.width_near, + params.width_far, + params.width_tolerance, + ) + t3 = time.time() + _profile_record("後処理", t3 - t2) + + # 行ごと中心抽出 + フィッティング + result = fit_row_centers( + binary, params.min_line_width, + median_ksize=params.median_ksize, + neighbor_thresh=params.neighbor_thresh, + residual_thresh=params.residual_thresh, + ) + t4 = time.time() + _profile_record("fitting", t4 - t3) + + _profile_count += 1 + _profile_print() + + return result diff --git a/src/common/vision/detectors/robust.py b/src/common/vision/detectors/robust.py new file mode 100644 index 0000000..04ebf58 --- /dev/null +++ b/src/common/vision/detectors/robust.py @@ -0,0 +1,69 @@ +""" +robust +案C: 最高ロバスト型の線検出 +Black-hat + 適応的閾値の二重正規化に加え, +RANSAC で外れ値を除去する最もロバストな構成 +""" + +import cv2 +import numpy as np + +from common.vision.line_detector import ( + ImageParams, + LineDetectResult, + fit_row_centers, +) +from common.vision.morphology import ( + apply_dist_mask, + apply_iso_closing, + apply_width_filter, +) + + +def detect_robust( + frame: np.ndarray, params: ImageParams, +) -> LineDetectResult: + """案C: 最高ロバスト型""" + # Black-hat 変換 + bh_k = params.blackhat_ksize | 1 + bh_kernel = cv2.getStructuringElement( + cv2.MORPH_ELLIPSE, (bh_k, bh_k), + ) + blackhat = cv2.morphologyEx( + frame, cv2.MORPH_BLACKHAT, bh_kernel, + ) + + # 適応的閾値(BINARY: Black-hat 後は線が白) + block = max(params.adaptive_block | 1, 3) + binary = cv2.adaptiveThreshold( + blackhat, 255, + cv2.ADAPTIVE_THRESH_GAUSSIAN_C, + cv2.THRESH_BINARY, + block, -params.adaptive_c, + ) + + # 等方クロージング + 距離変換マスク + 幅フィルタ + binary = apply_iso_closing( + binary, params.iso_close_size, + ) + binary = apply_dist_mask( + binary, params.dist_thresh, + ) + if params.width_near > 0 and params.width_far > 0: + binary = apply_width_filter( + binary, + params.width_near, + params.width_far, + params.width_tolerance, + ) + + # 行ごと中央値抽出 + RANSAC フィッティング + return fit_row_centers( + binary, params.min_line_width, + use_median=True, + ransac_thresh=params.ransac_thresh, + ransac_iter=params.ransac_iter, + median_ksize=params.median_ksize, + neighbor_thresh=params.neighbor_thresh, + residual_thresh=params.residual_thresh, + ) diff --git a/src/common/vision/detectors/valley.py b/src/common/vision/detectors/valley.py new file mode 100644 index 0000000..e29aa60 --- /dev/null +++ b/src/common/vision/detectors/valley.py @@ -0,0 +1,349 @@ +""" +valley +案D: 谷検出+追跡型の線検出 +各行の輝度信号から谷(暗い領域)を直接検出し, +時系列追跡で安定性を確保する.二値化を使用しない +""" + +import cv2 +import numpy as np + +from common import config +from common.vision.fitting import clean_and_fit +from common.vision.line_detector import ( + DETECT_Y_END, + DETECT_Y_START, + MIN_FIT_ROWS, + ImageParams, + LineDetectResult, + build_result, + no_detection, +) + + +class ValleyTracker: + """谷検出の時系列追跡を管理するクラス + + 前フレームの多項式係数を保持し,予測・平滑化・ + 検出失敗時のコースティングを提供する + """ + + def __init__(self) -> None: + self._prev_coeffs: np.ndarray | None = None + self._smoothed_coeffs: np.ndarray | None = None + self._frames_lost: int = 0 + + def predict_x(self, y: float) -> float | None: + """前フレームの多項式から x 座標を予測する + + Args: + y: 画像の y 座標 + + Returns: + 予測 x 座標(履歴なしの場合は None) + """ + if self._smoothed_coeffs is None: + return None + return float(np.poly1d(self._smoothed_coeffs)(y)) + + def update( + self, + coeffs: np.ndarray, + alpha: float, + ) -> np.ndarray: + """検出成功時に状態を更新する + + Args: + coeffs: 今フレームのフィッティング係数 + alpha: EMA 係数(1.0 で平滑化なし) + + Returns: + 平滑化後の多項式係数 + """ + self._frames_lost = 0 + if self._smoothed_coeffs is None: + self._smoothed_coeffs = coeffs.copy() + else: + self._smoothed_coeffs = ( + alpha * coeffs + + (1.0 - alpha) * self._smoothed_coeffs + ) + self._prev_coeffs = self._smoothed_coeffs.copy() + return self._smoothed_coeffs + + def coast( + self, max_frames: int, + ) -> LineDetectResult | None: + """検出失敗時に予測結果を返す + + Args: + max_frames: 予測を継続する最大フレーム数 + + Returns: + 予測による結果(継続不可の場合は None) + """ + if self._smoothed_coeffs is None: + return None + self._frames_lost += 1 + if self._frames_lost > max_frames: + return None + h = config.FRAME_HEIGHT + w = config.FRAME_WIDTH + blank = np.zeros((h, w), dtype=np.uint8) + return build_result(self._smoothed_coeffs, blank) + + def reset(self) -> None: + """追跡状態をリセットする""" + self._prev_coeffs = None + self._smoothed_coeffs = None + self._frames_lost = 0 + + +_valley_tracker: ValleyTracker | None = None + + +def get_valley_tracker() -> ValleyTracker: + """モジュール内のデフォルト ValleyTracker を取得する + + 初回呼び出し時にインスタンスを生成する + + Returns: + ValleyTracker インスタンス + """ + global _valley_tracker + if _valley_tracker is None: + _valley_tracker = ValleyTracker() + return _valley_tracker + + +def reset_valley_tracker() -> None: + """谷検出の追跡状態をリセットする""" + if _valley_tracker is not None: + _valley_tracker.reset() + + +def _find_row_valley( + row: np.ndarray, + min_depth: int, + expected_width: float, + width_tolerance: float, + predicted_x: float | None, + max_deviation: int, +) -> tuple[float, float] | None: + """1行の輝度信号から最適な谷を検出する + + Args: + row: スムージング済みの1行輝度信号 + min_depth: 最小谷深度 + expected_width: 期待線幅(px,0 で幅フィルタ無効) + width_tolerance: 幅フィルタの上限倍率 + predicted_x: 追跡による予測 x 座標(None で無効) + max_deviation: 予測からの最大許容偏差 + + Returns: + (谷の中心x, 谷の深度) または None + """ + n = len(row) + if n < 5: + return None + + signal = row.astype(np.float32) + + # 極小値を検出(前後より小さい点) + left = signal[:-2] + center = signal[1:-1] + right = signal[2:] + minima_mask = (center <= left) & (center <= right) + minima_indices = np.where(minima_mask)[0] + 1 + + if len(minima_indices) == 0: + return None + + best: tuple[float, float] | None = None + best_score = -1.0 + + for idx in minima_indices: + val = signal[idx] + + # 左の肩を探す + left_shoulder = idx + for i in range(idx - 1, -1, -1): + if signal[i] < signal[i + 1]: + break + left_shoulder = i + # 右の肩を探す + right_shoulder = idx + for i in range(idx + 1, n): + if signal[i] < signal[i - 1]: + break + right_shoulder = i + + # 谷の深度(肩の平均 - 谷底) + shoulder_avg = ( + signal[left_shoulder] + signal[right_shoulder] + ) / 2.0 + depth = shoulder_avg - val + if depth < min_depth: + continue + + # 谷の幅 + width = right_shoulder - left_shoulder + center_x = (left_shoulder + right_shoulder) / 2.0 + + # 幅フィルタ + if expected_width > 0: + max_w = expected_width * width_tolerance + min_w = expected_width / width_tolerance + if width > max_w or width < min_w: + continue + + # 予測との偏差チェック + if predicted_x is not None: + if abs(center_x - predicted_x) > max_deviation: + continue + + # スコア: 深度優先,予測がある場合は近さも考慮 + score = float(depth) + if predicted_x is not None: + dist = abs(center_x - predicted_x) + score += max(0.0, max_deviation - dist) + + if score > best_score: + best_score = score + best = (center_x, float(depth)) + + return best + + +def _build_valley_binary( + shape: tuple[int, int], + centers_y: list[int], + centers_x: list[float], +) -> np.ndarray: + """谷検出結果からデバッグ用二値画像を生成する + + Args: + shape: 出力画像の (高さ, 幅) + centers_y: 検出行の y 座標リスト + centers_x: 検出行の中心 x 座標リスト + + Returns: + デバッグ用二値画像 + """ + binary = np.zeros(shape, dtype=np.uint8) + half_w = 3 + w = shape[1] + for y, cx in zip(centers_y, centers_x): + x0 = max(0, int(cx) - half_w) + x1 = min(w, int(cx) + half_w + 1) + binary[y, x0:x1] = 255 + return binary + + +def detect_valley( + frame: np.ndarray, + params: ImageParams, + tracker: ValleyTracker | None = None, +) -> LineDetectResult: + """案D: 谷検出+追跡型 + + Args: + frame: グレースケールのカメラ画像 + params: 二値化パラメータ + tracker: 追跡状態(None でモジュール内デフォルトを使用) + """ + if tracker is None: + tracker = get_valley_tracker() + h, w = frame.shape[:2] + + # 行ごとにガウシアン平滑化するため画像全体をブラー + gauss_k = params.valley_gauss_ksize | 1 + blurred = cv2.GaussianBlur( + frame, (gauss_k, 1), 0, + ) + + # 透視補正の期待幅を計算するための準備 + use_width = ( + params.width_near > 0 and params.width_far > 0 + ) + detect_h = DETECT_Y_END - DETECT_Y_START + denom = max(detect_h - 1, 1) + + centers_y: list[int] = [] + centers_x: list[float] = [] + depths: list[float] = [] + + for y in range(DETECT_Y_START, DETECT_Y_END): + row = blurred[y] + + # 期待幅の計算 + if use_width: + t = (DETECT_Y_END - 1 - y) / denom + expected_w = float(params.width_far) + ( + float(params.width_near) + - float(params.width_far) + ) * t + else: + expected_w = 0.0 + + # 予測 x 座標 + predicted_x = tracker.predict_x( + float(y), + ) + + result = _find_row_valley( + row, + params.valley_min_depth, + expected_w, + params.width_tolerance, + predicted_x, + params.valley_max_deviation, + ) + if result is not None: + centers_y.append(y) + centers_x.append(result[0]) + depths.append(result[1]) + + # デバッグ用二値画像 + debug_binary = _build_valley_binary( + (h, w), centers_y, centers_x, + ) + + if len(centers_y) < MIN_FIT_ROWS: + coasted = tracker.coast( + params.valley_coast_frames, + ) + if coasted is not None: + coasted.binary_image = debug_binary + return coasted + return no_detection(debug_binary) + + cy = np.array(centers_y, dtype=np.float64) + cx = np.array(centers_x, dtype=np.float64) + w_arr = np.array(depths, dtype=np.float64) + + # ロバストフィッティング(深度を重みに使用) + coeffs = clean_and_fit( + cy, cx, + median_ksize=params.median_ksize, + neighbor_thresh=params.neighbor_thresh, + residual_thresh=params.residual_thresh, + weights=w_arr, + ransac_thresh=params.ransac_thresh, + ransac_iter=params.ransac_iter, + ) + if coeffs is None: + coasted = tracker.coast( + params.valley_coast_frames, + ) + if coasted is not None: + coasted.binary_image = debug_binary + return coasted + return no_detection(debug_binary) + + # EMA で平滑化 + smoothed = tracker.update( + coeffs, params.valley_ema_alpha, + ) + + return build_result(smoothed, debug_binary) diff --git a/src/common/vision/fitting.py b/src/common/vision/fitting.py new file mode 100644 index 0000000..c55139d --- /dev/null +++ b/src/common/vision/fitting.py @@ -0,0 +1,212 @@ +""" +fitting +直線・曲線近似の共通ユーティリティモジュール +Theil-Sen 推定,RANSAC,外れ値除去付きフィッティングを提供する +""" + +import numpy as np + +# フィッティングに必要な最小行数 +MIN_FIT_ROWS: int = 10 + +# 近傍外れ値除去の設定 +NEIGHBOR_HALF_WINDOW: int = 3 +NEIGHBOR_FILTER_PASSES: int = 3 + +# 残差ベース反復除去の最大回数 +RESIDUAL_REMOVAL_ITERATIONS: int = 5 + + +def theil_sen_fit( + y: np.ndarray, + x: np.ndarray, +) -> tuple[float, float]: + """Theil-Sen 推定で直線 x = slope * y + intercept を求める + + 全ペアの傾きの中央値を使い,外れ値に強い直線近似を行う + + Args: + y: y 座標の配列(行番号) + x: x 座標の配列(各行の中心) + + Returns: + (slope, intercept) のタプル + """ + n = len(y) + slopes = [] + for i in range(n): + for j in range(i + 1, n): + dy = y[j] - y[i] + if dy != 0: + slopes.append((x[j] - x[i]) / dy) + + if len(slopes) == 0: + return 0.0, float(np.median(x)) + + slope = float(np.median(slopes)) + intercept = float(np.median(x - slope * y)) + return slope, intercept + + +def ransac_polyfit( + ys: np.ndarray, xs: np.ndarray, + degree: int, n_iter: int, thresh: float, +) -> np.ndarray | None: + """RANSAC で外れ値を除去して多項式フィッティング + + Args: + ys: y 座標配列 + xs: x 座標配列 + degree: 多項式の次数 + n_iter: 反復回数 + thresh: 外れ値判定閾値(ピクセル) + + Returns: + 多項式係数(フィッティング失敗時は None) + """ + n = len(ys) + sample_size = degree + 1 + if n < sample_size: + return None + + best_coeffs: np.ndarray | None = None + best_inliers = 0 + rng = np.random.default_rng() + + for _ in range(n_iter): + idx = rng.choice(n, sample_size, replace=False) + coeffs = np.polyfit(ys[idx], xs[idx], degree) + poly = np.poly1d(coeffs) + residuals = np.abs(xs - poly(ys)) + n_inliers = int(np.sum(residuals < thresh)) + if n_inliers > best_inliers: + best_inliers = n_inliers + best_coeffs = coeffs + + # インライアで再フィッティング + if best_coeffs is not None: + poly = np.poly1d(best_coeffs) + inlier_mask = np.abs(xs - poly(ys)) < thresh + if np.sum(inlier_mask) >= sample_size: + best_coeffs = np.polyfit( + ys[inlier_mask], + xs[inlier_mask], + degree, + ) + + return best_coeffs + + +def clean_and_fit( + cy: np.ndarray, + cx: np.ndarray, + median_ksize: int, + neighbor_thresh: float, + residual_thresh: float = 0.0, + weights: np.ndarray | None = None, + ransac_thresh: float = 0.0, + ransac_iter: int = 0, +) -> np.ndarray | None: + """外れ値除去+重み付きフィッティングを行う + + 全検出手法で共通に使えるロバストなフィッティング + (1) 移動メディアンフィルタでスパイクを平滑化 + (2) 近傍中央値からの偏差で外れ値を除去(複数パス) + (3) 重み付き最小二乗(または RANSAC)でフィッティング + (4) 残差ベースの反復除去で外れ値を最終除去 + + Args: + cy: 中心点の y 座標配列 + cx: 中心点の x 座標配列 + median_ksize: 移動メディアンのカーネルサイズ(0 で無効) + neighbor_thresh: 近傍外れ値除去の閾値 px(0 で無効) + residual_thresh: 残差除去の閾値 px(0 で無効) + weights: 各点の信頼度(None で均等) + ransac_thresh: RANSAC 閾値(0 以下で無効) + ransac_iter: RANSAC 反復回数 + + Returns: + 多項式係数(フィッティング失敗時は None) + """ + if len(cy) < MIN_FIT_ROWS: + return None + + cx_clean = cx.copy() + mask = np.ones(len(cy), dtype=bool) + + # (1) 移動メディアンフィルタ + if median_ksize >= 3: + k = median_ksize | 1 + half = k // 2 + for i in range(len(cx_clean)): + lo = max(0, i - half) + hi = min(len(cx_clean), i + half + 1) + cx_clean[i] = float(np.median(cx[lo:hi])) + + # (2) 近傍外れ値除去(複数パス) + if neighbor_thresh > 0: + half_n = NEIGHBOR_HALF_WINDOW + for _ in range(NEIGHBOR_FILTER_PASSES): + new_mask = np.ones(len(cx_clean), dtype=bool) + for i in range(len(cx_clean)): + if not mask[i]: + continue + lo = max(0, i - half_n) + hi = min(len(cx_clean), i + half_n + 1) + neighbors = cx_clean[lo:hi][mask[lo:hi]] + if len(neighbors) == 0: + new_mask[i] = False + continue + local_med = float(np.median(neighbors)) + if ( + abs(cx_clean[i] - local_med) + > neighbor_thresh + ): + new_mask[i] = False + if np.array_equal(mask, mask & new_mask): + break + mask = mask & new_mask + + cy = cy[mask] + cx_clean = cx_clean[mask] + if weights is not None: + weights = weights[mask] + + if len(cy) < MIN_FIT_ROWS: + return None + + # (3) フィッティング + if ransac_thresh > 0 and ransac_iter > 0: + coeffs = ransac_polyfit( + cy, cx_clean, 2, ransac_iter, ransac_thresh, + ) + elif weights is not None: + coeffs = np.polyfit(cy, cx_clean, 2, w=weights) + else: + coeffs = np.polyfit(cy, cx_clean, 2) + + if coeffs is None: + return None + + # (4) 残差ベースの反復除去 + if residual_thresh > 0: + for _ in range(RESIDUAL_REMOVAL_ITERATIONS): + poly = np.poly1d(coeffs) + residuals = np.abs(cx_clean - poly(cy)) + inlier = residuals < residual_thresh + if np.all(inlier): + break + if np.sum(inlier) < MIN_FIT_ROWS: + break + cy = cy[inlier] + cx_clean = cx_clean[inlier] + if weights is not None: + weights = weights[inlier] + if weights is not None: + coeffs = np.polyfit( + cy, cx_clean, 2, w=weights, + ) + else: + coeffs = np.polyfit(cy, cx_clean, 2) + + return coeffs diff --git a/src/common/vision/intersection.py b/src/common/vision/intersection.py new file mode 100644 index 0000000..20a94e3 --- /dev/null +++ b/src/common/vision/intersection.py @@ -0,0 +1,75 @@ +""" +intersection +十字路分類モデルの読み込みと推論を行うモジュール + +学習済みモデルとスケーラを読み込み, +二値画像から十字路かどうかを判定する +""" + +from pathlib import Path + +import numpy as np + +from common.json_utils import PARAMS_DIR + +# モデル・スケーラの保存先 +_MODEL_PATH: Path = PARAMS_DIR / "intersection_model.pkl" +_SCALER_PATH: Path = PARAMS_DIR / "intersection_scaler.pkl" + + +class IntersectionClassifier: + """十字路分類器 + + 学習済みモデルを読み込み,二値画像から + 十字路かどうかを判定する + scikit-learn を遅延インポートして起動時間を短縮する + """ + + def __init__(self) -> None: + self._model: object | None = None + self._scaler: object | None = None + self._available: bool = False + + def load(self) -> None: + """モデルとスケーラを読み込む(遅延呼び出し用)""" + if not _MODEL_PATH.exists(): + print( + f" モデルが見つかりません: {_MODEL_PATH}" + ) + return + if not _SCALER_PATH.exists(): + print( + f" スケーラが見つかりません: {_SCALER_PATH}" + ) + return + try: + import joblib + + self._model = joblib.load(_MODEL_PATH) + self._scaler = joblib.load(_SCALER_PATH) + self._available = True + except Exception as e: + print(f" モデル読み込みエラー: {e}") + + @property + def available(self) -> bool: + """モデルが利用可能かどうか""" + return self._available + + def predict(self, binary_image: np.ndarray) -> bool: + """二値画像が十字路かどうかを判定する + + Args: + binary_image: 40×30 の二値画像(0/255) + + Returns: + 十字路なら True + """ + if not self._available: + return False + flat = (binary_image.flatten() / 255.0).astype( + np.float32, + ) + x = self._scaler.transform(flat.reshape(1, -1)) + pred = self._model.predict(x) + return bool(pred[0] == 1) diff --git a/src/common/vision/line_detector.py b/src/common/vision/line_detector.py new file mode 100644 index 0000000..018c10d --- /dev/null +++ b/src/common/vision/line_detector.py @@ -0,0 +1,389 @@ +""" +line_detector +カメラ画像から黒線の位置を検出するモジュール +複数の検出手法を切り替えて使用できる + +公開 API: + ImageParams, LineDetectResult, detect_line, + reset_valley_tracker, DETECT_METHODS +""" + +from __future__ import annotations + +from dataclasses import dataclass +from typing import TYPE_CHECKING + +import cv2 +import numpy as np + +if TYPE_CHECKING: + from collections.abc import Callable + +from common import config +from common.vision.fitting import clean_and_fit + +# 検出領域の y 範囲(画像全体) +DETECT_Y_START: int = 0 +DETECT_Y_END: int = config.FRAME_HEIGHT + +# フィッティングに必要な最小数 +MIN_FIT_PIXELS: int = 50 +MIN_FIT_ROWS: int = 10 + +# 検出手法の定義(キー: 識別子,値: 表示名) +DETECT_METHODS: dict[str, str] = { + "current": "現行(CLAHE + 固定閾値)", + "blackhat": "案A(Black-hat 中心)", + "dual_norm": "案B(二重正規化)", + "robust": "案C(最高ロバスト)", + "valley": "案D(谷検出+追跡)", +} + + +@dataclass +class ImageParams: + """二値化パラメータ + + Attributes: + method: 検出手法の識別子 + clahe_clip: CLAHE のコントラスト増幅上限 + clahe_grid: CLAHE の局所領域分割数 + blur_size: ガウシアンブラーのカーネルサイズ(奇数) + binary_thresh: 二値化の閾値 + open_size: オープニングのカーネルサイズ + close_width: クロージングの横幅 + close_height: クロージングの高さ + blackhat_ksize: Black-hat のカーネルサイズ + bg_blur_ksize: 背景除算のブラーカーネルサイズ + global_thresh: 固定閾値(0 で無効,適応的閾値との AND) + adaptive_block: 適応的閾値のブロックサイズ + adaptive_c: 適応的閾値の定数 C + iso_close_size: 等方クロージングのカーネルサイズ + dist_thresh: 距離変換の閾値 + min_line_width: 行ごと中心抽出の最小線幅 + stage_close_small: 段階クロージング第1段のサイズ + stage_min_area: 孤立除去の最小面積(0 で無効) + stage_close_large: 段階クロージング第2段のサイズ(0 で無効) + ransac_thresh: RANSAC の外れ値判定閾値 + ransac_iter: RANSAC の反復回数 + width_near: 画像下端での期待線幅(px,0 で無効) + width_far: 画像上端での期待線幅(px,0 で無効) + width_tolerance: 幅フィルタの上限倍率 + median_ksize: 中心点列の移動メディアンフィルタサイズ(0 で無効) + neighbor_thresh: 近傍外れ値除去の閾値(px,0 で無効) + residual_thresh: 残差反復除去の閾値(px,0 で無効) + valley_gauss_ksize: 谷検出の行ごとガウシアンカーネルサイズ + valley_min_depth: 谷として認識する最小深度 + valley_max_deviation: 追跡予測からの最大許容偏差(px) + valley_coast_frames: 検出失敗時の予測継続フレーム数 + valley_ema_alpha: 多項式係数の指数移動平均係数 + """ + + # 検出手法 + method: str = "current" + + # 現行手法パラメータ + clahe_clip: float = 2.0 + clahe_grid: int = 8 + blur_size: int = 5 + binary_thresh: int = 80 + open_size: int = 5 + close_width: int = 25 + close_height: int = 3 + + # 案A/C: Black-hat + blackhat_ksize: int = 45 + + # 案B: 背景除算 + bg_blur_ksize: int = 101 + global_thresh: int = 0 + + # 案B/C: 適応的閾値 + adaptive_block: int = 51 + adaptive_c: int = 10 + + # 案A/B/C: 後処理 + iso_close_size: int = 15 + dist_thresh: float = 3.0 + min_line_width: int = 3 + + # 案B: 段階クロージング + stage_close_small: int = 5 + stage_min_area: int = 0 + stage_close_large: int = 0 + + # 案C: RANSAC + ransac_thresh: float = 5.0 + ransac_iter: int = 50 + + # ロバストフィッティング(全手法共通) + median_ksize: int = 7 + neighbor_thresh: float = 10.0 + residual_thresh: float = 8.0 + + # 透視補正付き幅フィルタ(0 で無効) + width_near: int = 0 + width_far: int = 0 + width_tolerance: float = 1.8 + + # 案D: 谷検出+追跡 + valley_gauss_ksize: int = 15 + valley_min_depth: int = 15 + valley_max_deviation: int = 40 + valley_coast_frames: int = 3 + valley_ema_alpha: float = 0.7 + + +@dataclass +class LineDetectResult: + """線検出の結果を格納するデータクラス + + Attributes: + detected: 線が検出できたか + position_error: 画像下端での位置偏差(-1.0~+1.0) + heading: 線の傾き(dx/dy,画像下端での値) + curvature: 線の曲率(d²x/dy²) + poly_coeffs: 多項式の係数(描画用,未検出時は None) + row_centers: 各行の線中心 x 座標(index=行番号, + NaN=その行に線なし,未検出時は None) + binary_image: 二値化後の画像(デバッグ用) + """ + + detected: bool + position_error: float + heading: float + curvature: float + poly_coeffs: np.ndarray | None + row_centers: np.ndarray | None + binary_image: np.ndarray | None + + +# ── 公開 API ────────────────────────────────────── + + +def _get_detector_registry() -> dict[ + str, + "Callable[[np.ndarray, ImageParams], LineDetectResult]", +]: + """検出手法の辞書を遅延構築して返す + + Returns: + 手法識別子と検出関数の辞書 + """ + from common.vision.detectors.blackhat import ( + detect_blackhat, + ) + from common.vision.detectors.current import ( + detect_current, + ) + from common.vision.detectors.dual_norm import ( + detect_dual_norm, + ) + from common.vision.detectors.robust import ( + detect_robust, + ) + from common.vision.detectors.valley import ( + detect_valley, + ) + return { + "current": detect_current, + "blackhat": detect_blackhat, + "dual_norm": detect_dual_norm, + "robust": detect_robust, + "valley": detect_valley, + } + + +_detector_registry: dict | None = None + + +def detect_line( + frame: np.ndarray, + params: ImageParams | None = None, +) -> LineDetectResult: + """画像から黒線の位置を検出する + + params.method に応じて検出手法を切り替える + + Args: + frame: グレースケールのカメラ画像 + params: 二値化パラメータ(None でデフォルト) + + Returns: + 線検出の結果 + """ + global _detector_registry + if _detector_registry is None: + _detector_registry = _get_detector_registry() + + if params is None: + params = ImageParams() + + detector = _detector_registry.get( + params.method, + _detector_registry["current"], + ) + return detector(frame, params) + + +def reset_valley_tracker() -> None: + """谷検出の追跡状態をリセットする""" + from common.vision.detectors.valley import ( + reset_valley_tracker as _reset, + ) + _reset() + + +# ── 共通結果構築(各検出器から使用) ────────────── + + +def no_detection( + binary: np.ndarray, +) -> LineDetectResult: + """未検出の結果を返す""" + return LineDetectResult( + detected=False, + position_error=0.0, + heading=0.0, + curvature=0.0, + poly_coeffs=None, + row_centers=None, + binary_image=binary, + ) + + +def _extract_row_centers( + binary: np.ndarray, +) -> np.ndarray | None: + """二値画像の最大連結領域から各行の線中心を求める + + Args: + binary: 二値画像 + + Returns: + 各行の中心 x 座標(NaN=その行に線なし), + 最大領域が見つからない場合は None + """ + h, w = binary.shape[:2] + num_labels, labels, stats, _ = ( + cv2.connectedComponentsWithStats(binary) + ) + + if num_labels <= 1: + return None + + # 背景(ラベル 0)を除いた最大領域を取得 + areas = stats[1:, cv2.CC_STAT_AREA] + largest_label = int(np.argmax(areas)) + 1 + + # 最大領域のマスク + mask = (labels == largest_label).astype(np.uint8) + + # 各行の左右端から中心を計算 + centers = np.full(h, np.nan) + for y in range(h): + row = mask[y] + cols = np.where(row > 0)[0] + if len(cols) > 0: + centers[y] = (cols[0] + cols[-1]) / 2.0 + + return centers + + +def build_result( + coeffs: np.ndarray, + binary: np.ndarray, + row_centers: np.ndarray | None = None, +) -> LineDetectResult: + """多項式係数から LineDetectResult を構築する + + row_centers が None の場合は binary から自動抽出する + """ + poly = np.poly1d(coeffs) + center_x = config.FRAME_WIDTH / 2.0 + + # 画像下端での位置偏差 + x_bottom = poly(DETECT_Y_END) + position_error = (center_x - x_bottom) / center_x + + # 傾き: dx/dy(画像下端での値) + poly_deriv = poly.deriv() + heading = float(poly_deriv(DETECT_Y_END)) + + # 曲率: d²x/dy² + poly_deriv2 = poly_deriv.deriv() + curvature = float(poly_deriv2(DETECT_Y_END)) + + # row_centers が未提供なら binary から抽出 + if row_centers is None: + row_centers = _extract_row_centers(binary) + + return LineDetectResult( + detected=True, + position_error=position_error, + heading=heading, + curvature=curvature, + poly_coeffs=coeffs, + row_centers=row_centers, + binary_image=binary, + ) + + +def fit_row_centers( + binary: np.ndarray, + min_width: int, + use_median: bool = False, + ransac_thresh: float = 0.0, + ransac_iter: int = 0, + median_ksize: int = 0, + neighbor_thresh: float = 0.0, + residual_thresh: float = 0.0, +) -> LineDetectResult: + """行ごとの中心点に多項式をフィッティングする + + Args: + binary: 二値画像 + min_width: 線として認識する最小ピクセル数 + use_median: True の場合は中央値を使用 + ransac_thresh: RANSAC 閾値(0 以下で無効) + ransac_iter: RANSAC 反復回数 + median_ksize: 移動メディアンのカーネルサイズ + neighbor_thresh: 近傍外れ値除去の閾値 px + residual_thresh: 残差反復除去の閾値 px + + Returns: + 線検出の結果 + """ + region = binary[DETECT_Y_START:DETECT_Y_END, :] + centers_y: list[float] = [] + centers_x: list[float] = [] + + for y_local in range(region.shape[0]): + xs = np.where(region[y_local] > 0)[0] + if len(xs) < min_width: + continue + y = float(y_local + DETECT_Y_START) + centers_y.append(y) + if use_median: + centers_x.append(float(np.median(xs))) + else: + centers_x.append(float(np.mean(xs))) + + if len(centers_y) < MIN_FIT_ROWS: + return no_detection(binary) + + cy = np.array(centers_y) + cx = np.array(centers_x) + + coeffs = clean_and_fit( + cy, cx, + median_ksize=median_ksize, + neighbor_thresh=neighbor_thresh, + residual_thresh=residual_thresh, + ransac_thresh=ransac_thresh, + ransac_iter=ransac_iter, + ) + if coeffs is None: + return no_detection(binary) + + return build_result(coeffs, binary) diff --git a/src/common/vision/morphology.py b/src/common/vision/morphology.py new file mode 100644 index 0000000..578dcb6 --- /dev/null +++ b/src/common/vision/morphology.py @@ -0,0 +1,127 @@ +""" +morphology +二値画像の形態学的処理ユーティリティモジュール +""" + +import cv2 +import numpy as np + + +def apply_iso_closing( + binary: np.ndarray, size: int, +) -> np.ndarray: + """等方クロージングで穴を埋める + + Args: + binary: 二値画像 + size: カーネルサイズ + + Returns: + クロージング後の二値画像 + """ + if size < 3: + return binary + k = size | 1 + kernel = cv2.getStructuringElement( + cv2.MORPH_ELLIPSE, (k, k), + ) + return cv2.morphologyEx( + binary, cv2.MORPH_CLOSE, kernel, + ) + + +def apply_staged_closing( + binary: np.ndarray, + small_size: int, + min_area: int, + large_size: int, +) -> np.ndarray: + """段階クロージング: 小穴埋め → 孤立除去 → 大穴埋め + + Args: + binary: 二値画像 + small_size: 第1段クロージングのカーネルサイズ + min_area: 孤立領域除去の最小面積(0 で無効) + large_size: 第2段クロージングのカーネルサイズ(0 で無効) + + Returns: + 処理後の二値画像 + """ + result = apply_iso_closing(binary, small_size) + + if min_area > 0: + contours, _ = cv2.findContours( + result, cv2.RETR_EXTERNAL, + cv2.CHAIN_APPROX_SIMPLE, + ) + mask = np.zeros_like(result) + for cnt in contours: + if cv2.contourArea(cnt) >= min_area: + cv2.drawContours( + mask, [cnt], -1, 255, -1, + ) + result = mask + + result = apply_iso_closing(result, large_size) + + return result + + +def apply_width_filter( + binary: np.ndarray, + width_near: int, + width_far: int, + tolerance: float, +) -> np.ndarray: + """透視補正付き幅フィルタで広がりすぎた行を除外する + + Args: + binary: 二値画像 + width_near: 画像下端での期待線幅(px) + width_far: 画像上端での期待線幅(px) + tolerance: 上限倍率 + + Returns: + 幅フィルタ適用後の二値画像 + """ + result = binary.copy() + h = binary.shape[0] + denom = max(h - 1, 1) + + for y_local in range(h): + xs = np.where(binary[y_local] > 0)[0] + if len(xs) == 0: + continue + t = (h - 1 - y_local) / denom + expected = float(width_far) + ( + float(width_near) - float(width_far) + ) * t + max_w = expected * tolerance + actual_w = int(xs[-1]) - int(xs[0]) + 1 + if actual_w > max_w: + result[y_local] = 0 + + return result + + +def apply_dist_mask( + binary: np.ndarray, thresh: float, +) -> np.ndarray: + """距離変換で中心部のみを残す + + Args: + binary: 二値画像 + thresh: 距離の閾値(ピクセル) + + Returns: + 中心部のみの二値画像 + """ + if thresh <= 0: + return binary + dist = cv2.distanceTransform( + binary, cv2.DIST_L2, 5, + ) + _, mask = cv2.threshold( + dist, thresh, 255, cv2.THRESH_BINARY, + ) + return mask.astype(np.uint8) diff --git a/src/pc/comm/zmq_client.py b/src/pc/comm/zmq_client.py index d373c71..d9e56e4 100644 --- a/src/pc/comm/zmq_client.py +++ b/src/pc/comm/zmq_client.py @@ -1,10 +1,12 @@ """ zmq_client PC 側の ZMQ 通信を担当するモジュール -画像の受信と操舵量の送信を行う +テレメトリ受信(画像+検出結果+操舵量)と +コマンド送信(モード切替・パラメータ更新・手動操作)を行う """ import json +import struct import cv2 import numpy as np @@ -16,71 +18,127 @@ class PcZmqClient: """PC 側の ZMQ 通信クライアント - 画像受信(SUB)と操舵量送信(PUB)の2チャネルを管理する + テレメトリ受信(SUB)とコマンド送信(PUB)の2チャネルを管理する """ def __init__(self) -> None: self._context: zmq.Context | None = None - self._image_socket: zmq.Socket | None = None - self._control_socket: zmq.Socket | None = None + self._telemetry_socket: zmq.Socket | None = None + self._command_socket: zmq.Socket | None = None def start(self) -> None: """通信ソケットを初期化してバインドする""" self._context = zmq.Context() - # 画像受信ソケット(SUB,Pi からの画像を受信) - self._image_socket = self._context.socket(zmq.SUB) - self._image_socket.setsockopt(zmq.CONFLATE, 1) - self._image_socket.setsockopt_string(zmq.SUBSCRIBE, "") - self._image_socket.bind(config.image_bind_address()) + # テレメトリ受信ソケット(SUB,Pi からの画像+状態を受信) + self._telemetry_socket = self._context.socket( + zmq.SUB, + ) + self._telemetry_socket.setsockopt( + zmq.CONFLATE, 1, + ) + self._telemetry_socket.setsockopt_string( + zmq.SUBSCRIBE, "", + ) + self._telemetry_socket.bind( + config.image_bind_address(), + ) - # 操舵量送信ソケット(PUB,Pi へ操舵量を送信) - self._control_socket = self._context.socket(zmq.PUB) - self._control_socket.bind(config.control_bind_address()) + # コマンド送信ソケット(PUB,Pi へコマンドを送信) + self._command_socket = self._context.socket( + zmq.PUB, + ) + self._command_socket.bind( + config.control_bind_address(), + ) - def receive_image(self) -> np.ndarray | None: - """画像を非ブロッキングで受信する + def receive_telemetry( + self, + ) -> tuple[dict, np.ndarray, np.ndarray | None] | None: + """テレメトリを非ブロッキングで受信する Returns: - 受信したグレースケール画像の NumPy 配列,受信データがない場合は None + (telemetry_dict, camera_frame, binary_image) のタプル, + 受信データがない場合は None. + binary_image はデータがない場合 None """ - if self._image_socket is None: + if self._telemetry_socket is None: return None try: - data = self._image_socket.recv(zmq.NOBLOCK) + raw = self._telemetry_socket.recv( + zmq.NOBLOCK, + ) + offset = 0 + + # JSON ヘッダを読み取り + json_len = struct.unpack_from( + " None: - """操舵量を送信する + def send_command(self, command: dict) -> None: + """コマンドを Pi に送信する Args: - throttle: 前後方向の出力 (-1.0 ~ +1.0) - steer: 左右方向の出力 (-1.0 ~ +1.0) + command: コマンド辞書 """ - if self._control_socket is None: + if self._command_socket is None: return - payload = json.dumps({ - "throttle": throttle, - "steer": steer, - }).encode("utf-8") - self._control_socket.send(payload, zmq.NOBLOCK) + payload = json.dumps(command).encode("utf-8") + self._command_socket.send( + payload, zmq.NOBLOCK, + ) def stop(self) -> None: """通信ソケットを閉じる""" - if self._image_socket is not None: - self._image_socket.close() - self._image_socket = None - if self._control_socket is not None: - self._control_socket.close() - self._control_socket = None + if self._telemetry_socket is not None: + self._telemetry_socket.close() + self._telemetry_socket = None + if self._command_socket is not None: + self._command_socket.close() + self._command_socket = None if self._context is not None: self._context.term() self._context = None diff --git a/src/pc/data/__init__.py b/src/pc/data/__init__.py new file mode 100644 index 0000000..e7d79a9 --- /dev/null +++ b/src/pc/data/__init__.py @@ -0,0 +1,9 @@ +""" +data +学習データの収集・管理 +""" + +from pc.data.collector import DataCollector +from pc.data.reviewer import ReviewWindow + +__all__ = ["DataCollector", "ReviewWindow"] diff --git a/src/pc/data/__main__.py b/src/pc/data/__main__.py new file mode 100644 index 0000000..915bc7e --- /dev/null +++ b/src/pc/data/__main__.py @@ -0,0 +1,5 @@ +"""pc.data パッケージの直接実行で学習スクリプトを起動する""" + +from pc.data.train import main + +main() diff --git a/src/pc/data/collector.py b/src/pc/data/collector.py new file mode 100644 index 0000000..ca2fd80 --- /dev/null +++ b/src/pc/data/collector.py @@ -0,0 +1,116 @@ +""" +collector +二値画像をラベル別ディレクトリに保存するデータ収集モジュール + +録画中にフレームごとに save() を呼び出すと, +指定ラベルのサブディレクトリへ連番 PNG として保存する +""" + +from datetime import datetime +from pathlib import Path + +import cv2 +import numpy as np + +# プロジェクトルート +_PROJECT_ROOT: Path = ( + Path(__file__).resolve().parent.parent.parent.parent +) + +# データディレクトリ +DATA_DIR: Path = _PROJECT_ROOT / "data" +RAW_DIR: Path = DATA_DIR / "raw" +CONFIRMED_DIR: Path = DATA_DIR / "confirmed" + +# ラベル名 +LABEL_INTERSECTION: str = "intersection" +LABEL_NORMAL: str = "normal" + + +class DataCollector: + """二値画像をラベル付きで保存するコレクタ + + Attributes: + raw_dir: 未確定データの保存先ルートディレクトリ + session_dir: 現在の録画セッションのディレクトリ + is_recording: 録画中かどうか + """ + + def __init__( + self, + raw_dir: Path = RAW_DIR, + ) -> None: + self._raw_dir = raw_dir + self._session_dir: Path | None = None + self._is_recording: bool = False + self._count_intersection: int = 0 + self._count_normal: int = 0 + + @property + def is_recording(self) -> bool: + """録画中かどうかを返す""" + return self._is_recording + + @property + def count_intersection(self) -> int: + """現在のセッションで保存した intersection 画像の枚数""" + return self._count_intersection + + @property + def count_normal(self) -> int: + """現在のセッションで保存した normal 画像の枚数""" + return self._count_normal + + def start(self) -> Path: + """録画を開始する + + タイムスタンプ付きのセッションディレクトリを作成し, + その中に intersection/ と normal/ を用意する + + Returns: + 作成したセッションディレクトリのパス + """ + timestamp = datetime.now().strftime("%Y%m%d_%H%M%S") + self._session_dir = self._raw_dir / timestamp + (self._session_dir / LABEL_INTERSECTION).mkdir( + parents=True, exist_ok=True, + ) + (self._session_dir / LABEL_NORMAL).mkdir( + parents=True, exist_ok=True, + ) + self._count_intersection = 0 + self._count_normal = 0 + self._is_recording = True + return self._session_dir + + def stop(self) -> None: + """録画を停止する""" + self._is_recording = False + + def save( + self, + binary_image: np.ndarray, + is_intersection: bool, + ) -> None: + """二値画像をラベル付きで保存する + + Args: + binary_image: 保存する二値画像(0/255) + is_intersection: True なら intersection,False なら normal + """ + if not self._is_recording or self._session_dir is None: + return + + if is_intersection: + label = LABEL_INTERSECTION + self._count_intersection += 1 + idx = self._count_intersection + else: + label = LABEL_NORMAL + self._count_normal += 1 + idx = self._count_normal + + path = ( + self._session_dir / label / f"{idx:06d}.png" + ) + cv2.imwrite(str(path), binary_image) diff --git a/src/pc/data/dataset.py b/src/pc/data/dataset.py new file mode 100644 index 0000000..1e1980c --- /dev/null +++ b/src/pc/data/dataset.py @@ -0,0 +1,66 @@ +""" +dataset +confirmed/ から学習データを読み込むモジュール + +画像を flatten した特徴量ベクトルとラベルを返す +""" + +from pathlib import Path + +import cv2 +import numpy as np + +from pc.data.collector import ( + CONFIRMED_DIR, + LABEL_INTERSECTION, + LABEL_NORMAL, +) + + +def load_dataset( + confirmed_dir: Path = CONFIRMED_DIR, +) -> tuple[np.ndarray, np.ndarray]: + """confirmed/ から画像とラベルを読み込む + + Args: + confirmed_dir: 確定済みデータのディレクトリ + + Returns: + (X, y) のタプル + X: (n_samples, 1200) の特徴量行列(0.0/1.0) + y: (n_samples,) のラベル配列(1=intersection, 0=normal) + + Raises: + FileNotFoundError: 画像が見つからない場合 + """ + images: list[np.ndarray] = [] + labels: list[int] = [] + + for label_name, label_val in ( + (LABEL_INTERSECTION, 1), + (LABEL_NORMAL, 0), + ): + label_dir = confirmed_dir / label_name + if not label_dir.is_dir(): + continue + for img_path in sorted(label_dir.glob("*.png")): + img = cv2.imread( + str(img_path), cv2.IMREAD_GRAYSCALE, + ) + if img is None: + continue + # 0/255 → 0.0/1.0 に正規化して flatten + flat = (img.flatten() / 255.0).astype( + np.float32, + ) + images.append(flat) + labels.append(label_val) + + if len(images) == 0: + raise FileNotFoundError( + f"画像が見つかりません: {confirmed_dir}" + ) + + x = np.array(images) + y = np.array(labels) + return x, y diff --git a/src/pc/data/reviewer.py b/src/pc/data/reviewer.py new file mode 100644 index 0000000..670bdd8 --- /dev/null +++ b/src/pc/data/reviewer.py @@ -0,0 +1,413 @@ +""" +reviewer +収集した二値画像を閲覧・仕分けするレビュー GUI + +raw/ のセッションを開き,画像を1枚ずつ確認して +confirmed/ に確定 or 削除する + +キー操作: + ←: 前の画像に戻る + →/Enter: 現在のラベルで確定して次へ + M: ラベルを反転(intersection ↔ normal) + Delete/Backspace: 画像を削除して次へ + Escape: 終了 +""" + +from pathlib import Path + +import cv2 +import numpy as np +from PySide6.QtCore import Qt +from PySide6.QtGui import QImage, QKeyEvent, QPixmap +from PySide6.QtWidgets import ( + QFileDialog, + QHBoxLayout, + QLabel, + QMainWindow, + QPushButton, + QVBoxLayout, + QWidget, +) + +from pc.data.collector import ( + CONFIRMED_DIR, + LABEL_INTERSECTION, + LABEL_NORMAL, + RAW_DIR, +) + +# 画像表示の拡大倍率 +REVIEW_SCALE: int = 16 + + +def _load_image(path: Path) -> np.ndarray: + """画像ファイルを読み込む + + Args: + path: 画像ファイルのパス + + Returns: + グレースケール画像 + """ + img = cv2.imread(str(path), cv2.IMREAD_GRAYSCALE) + if img is None: + return np.zeros((30, 40), dtype=np.uint8) + return img + + +def _count_confirmed() -> tuple[int, int]: + """confirmed/ 内の画像数を集計する + + Returns: + (intersection の枚数, normal の枚数) + """ + n_int = len( + list( + (CONFIRMED_DIR / LABEL_INTERSECTION).glob( + "*.png", + ), + ) + ) if (CONFIRMED_DIR / LABEL_INTERSECTION).is_dir() else 0 + n_norm = len( + list( + (CONFIRMED_DIR / LABEL_NORMAL).glob("*.png"), + ) + ) if (CONFIRMED_DIR / LABEL_NORMAL).is_dir() else 0 + return n_int, n_norm + + +def _next_confirmed_index(label: str) -> int: + """confirmed/ の次の連番を返す + + Args: + label: ラベル名 + + Returns: + 次の連番(1始まり) + """ + label_dir = CONFIRMED_DIR / label + if not label_dir.is_dir(): + return 1 + existing = list(label_dir.glob("*.png")) + if not existing: + return 1 + max_num = 0 + for p in existing: + try: + num = int(p.stem.split("_")[0]) + max_num = max(max_num, num) + except ValueError: + continue + return max_num + 1 + + +class _ImageEntry: + """画像1枚の情報を保持する + + Attributes: + path: ファイルパス + label: 現在のラベル(intersection / normal) + """ + + def __init__(self, path: Path, label: str) -> None: + self.path = path + self.label = label + + +class ReviewWindow(QMainWindow): + """データ仕分けウィンドウ""" + + def __init__( + self, + session_dir: Path | None = None, + ) -> None: + super().__init__() + self._entries: list[_ImageEntry] = [] + self._index: int = 0 + + self._setup_ui() + + if session_dir is not None: + self._load_session(session_dir) + + def _setup_ui(self) -> None: + """UI を構築する""" + self.setWindowTitle("データ仕分け") + + central = QWidget() + self.setCentralWidget(central) + root = QVBoxLayout(central) + + # セッション選択 + top_bar = QHBoxLayout() + self._open_btn = QPushButton("セッションを開く") + self._open_btn.clicked.connect(self._on_open) + top_bar.addWidget(self._open_btn) + + self._session_label = QLabel("未選択") + self._session_label.setStyleSheet( + "font-size: 13px; color: #888;" + ) + top_bar.addWidget(self._session_label, stretch=1) + root.addLayout(top_bar) + + # 画像表示 + self._image_label = QLabel( + "セッションを選択してください", + ) + self._image_label.setAlignment( + Qt.AlignmentFlag.AlignCenter, + ) + self._image_label.setMinimumSize( + 40 * REVIEW_SCALE, 30 * REVIEW_SCALE, + ) + self._image_label.setStyleSheet( + "background-color: #222;" + " color: #aaa; font-size: 16px;" + ) + root.addWidget(self._image_label) + + # ラベル・進捗表示 + self._info_label = QLabel("") + self._info_label.setAlignment( + Qt.AlignmentFlag.AlignCenter, + ) + self._info_label.setStyleSheet( + "font-size: 16px; font-family: monospace;" + " padding: 6px;" + ) + root.addWidget(self._info_label) + + # 操作ボタン + btn_bar = QHBoxLayout() + + self._prev_btn = QPushButton("← 戻る (←)") + self._prev_btn.clicked.connect(self._go_prev) + btn_bar.addWidget(self._prev_btn) + + self._flip_btn = QPushButton("ラベル反転 (M)") + self._flip_btn.clicked.connect(self._flip_label) + btn_bar.addWidget(self._flip_btn) + + self._delete_btn = QPushButton("削除 (Del)") + self._delete_btn.clicked.connect( + self._delete_current, + ) + btn_bar.addWidget(self._delete_btn) + + self._confirm_btn = QPushButton("確定 → (→)") + self._confirm_btn.clicked.connect( + self._confirm_current, + ) + btn_bar.addWidget(self._confirm_btn) + + root.addLayout(btn_bar) + + # 集計表示(raw 残り + confirmed 累計) + self._summary_label = QLabel("") + self._summary_label.setAlignment( + Qt.AlignmentFlag.AlignCenter, + ) + self._summary_label.setStyleSheet( + "font-size: 13px; font-family: monospace;" + " color: #888; padding: 4px;" + ) + root.addWidget(self._summary_label) + + # 操作ガイド + guide = QLabel( + "←: 戻る →/Enter: 確定 " + "M: ラベル反転 Del/BS: 削除 Esc: 終了" + ) + guide.setAlignment(Qt.AlignmentFlag.AlignCenter) + guide.setStyleSheet( + "font-size: 12px; color: #666;" + ) + root.addWidget(guide) + + # ── セッション読み込み ──────────────────────────────── + + def _on_open(self) -> None: + """セッション選択ダイアログを開く""" + raw_dir = str(RAW_DIR) + dir_path = QFileDialog.getExistingDirectory( + self, "セッションディレクトリを選択", + raw_dir, + ) + if dir_path: + self._load_session(Path(dir_path)) + + def _load_session(self, session_dir: Path) -> None: + """セッションの画像一覧を読み込む + + Args: + session_dir: セッションディレクトリのパス + """ + self._entries.clear() + self._index = 0 + + for label in (LABEL_INTERSECTION, LABEL_NORMAL): + label_dir = session_dir / label + if not label_dir.is_dir(): + continue + for img_path in sorted(label_dir.glob("*.png")): + self._entries.append( + _ImageEntry(img_path, label), + ) + + self._session_label.setText( + f"セッション: {session_dir.name}" + ) + self._update_display() + + # ── ナビゲーション ──────────────────────────────────── + + def _go_prev(self) -> None: + """前の画像に戻る""" + if len(self._entries) == 0: + return + self._index = max(0, self._index - 1) + self._update_display() + + # ── ラベル操作 ──────────────────────────────────────── + + def _flip_label(self) -> None: + """現在の画像のラベルを反転する(raw 内で移動)""" + if len(self._entries) == 0: + return + + entry = self._entries[self._index] + new_label = ( + LABEL_NORMAL + if entry.label == LABEL_INTERSECTION + else LABEL_INTERSECTION + ) + entry.label = new_label + self._update_display() + + def _confirm_current(self) -> None: + """現在の画像を confirmed/ に移動して次へ""" + if len(self._entries) == 0: + return + + entry = self._entries[self._index] + label = entry.label + + # confirmed/ 内の保存先を確保 + dest_dir = CONFIRMED_DIR / label + dest_dir.mkdir(parents=True, exist_ok=True) + idx = _next_confirmed_index(label) + dest_path = dest_dir / f"{idx:06d}.png" + + # raw → confirmed へ移動 + entry.path.rename(dest_path) + self._entries.pop(self._index) + + if len(self._entries) == 0: + self._index = 0 + elif self._index >= len(self._entries): + self._index = len(self._entries) - 1 + + self._update_display() + + def _delete_current(self) -> None: + """現在の画像を削除して次へ""" + if len(self._entries) == 0: + return + + entry = self._entries[self._index] + entry.path.unlink(missing_ok=True) + self._entries.pop(self._index) + + if len(self._entries) == 0: + self._index = 0 + elif self._index >= len(self._entries): + self._index = len(self._entries) - 1 + + self._update_display() + + # ── 表示更新 ────────────────────────────────────────── + + def _update_display(self) -> None: + """画像・ラベル・進捗の表示を更新する""" + n = len(self._entries) + if n == 0: + self._image_label.setText("画像がありません") + self._info_label.setText("") + self._update_summary() + return + + entry = self._entries[self._index] + + # 画像表示 + img = _load_image(entry.path) + h, w = img.shape[:2] + qimg = QImage( + img.data, w, h, w, + QImage.Format.Format_Grayscale8, + ) + pixmap = QPixmap.fromImage(qimg).scaled( + w * REVIEW_SCALE, + h * REVIEW_SCALE, + Qt.AspectRatioMode.KeepAspectRatio, + Qt.TransformationMode.FastTransformation, + ) + self._image_label.setPixmap(pixmap) + + # ラベル色分け + if entry.label == LABEL_INTERSECTION: + color = "#f44" + label_text = "intersection" + else: + color = "#4a4" + label_text = "normal" + + self._info_label.setText( + f'' + f"{label_text}" + f" {self._index + 1} / {n}" + f" [{entry.path.name}]" + ) + + self._update_summary() + + def _update_summary(self) -> None: + """集計表示を更新する""" + # raw 残り + n_raw_int = sum( + 1 for e in self._entries + if e.label == LABEL_INTERSECTION + ) + n_raw_norm = len(self._entries) - n_raw_int + + # confirmed 累計 + c_int, c_norm = _count_confirmed() + + self._summary_label.setText( + f"[raw] int: {n_raw_int} norm: {n_raw_norm}" + f" [confirmed] int: {c_int}" + f" norm: {c_norm}" + ) + + # ── キー操作 ────────────────────────────────────────── + + def keyPressEvent(self, event: QKeyEvent) -> None: + """キー押下時の処理""" + key = event.key() + if key == Qt.Key.Key_Left: + self._go_prev() + elif key in ( + Qt.Key.Key_Right, Qt.Key.Key_Return, + ): + self._confirm_current() + elif key == Qt.Key.Key_M: + self._flip_label() + elif key in ( + Qt.Key.Key_Delete, Qt.Key.Key_Backspace, + ): + self._delete_current() + elif key == Qt.Key.Key_Escape: + self.close() + else: + super().keyPressEvent(event) diff --git a/src/pc/data/train.py b/src/pc/data/train.py new file mode 100644 index 0000000..1ac811d --- /dev/null +++ b/src/pc/data/train.py @@ -0,0 +1,393 @@ +""" +train +十字路分類モデルの学習・評価スクリプト + +複数モデルを 5-fold CV で比較し,結果をドキュメントに出力する +最良モデルを全データで再学習して params/ に保存する + +使い方: + $ cd src && python -m pc.data.train +""" + +import sys +from datetime import datetime +from pathlib import Path + +import joblib +import numpy as np +from sklearn.ensemble import RandomForestClassifier +from sklearn.linear_model import LogisticRegression +from sklearn.metrics import ( + classification_report, + f1_score, +) +from sklearn.model_selection import StratifiedKFold +from sklearn.neural_network import MLPClassifier +from sklearn.preprocessing import StandardScaler +from sklearn.svm import SVC + +from common.json_utils import PARAMS_DIR +from pc.data.dataset import load_dataset + +# ── 定数 ────────────────────────────────────────────── + +N_FOLDS: int = 5 +RANDOM_STATE: int = 42 + +# モデル保存先 +MODEL_PATH: Path = PARAMS_DIR / "intersection_model.pkl" +SCALER_PATH: Path = PARAMS_DIR / "intersection_scaler.pkl" + +# ドキュメント出力先 +_PROJECT_ROOT: Path = ( + Path(__file__).resolve().parent.parent.parent.parent +) +REPORT_PATH: Path = ( + _PROJECT_ROOT / "docs" / "03_TECH" + / "TECH_06_十字路分類モデル評価.txt" +) + + +# ── モデル定義 ──────────────────────────────────────── + +def _build_models() -> list[tuple[str, object]]: + """比較対象のモデル一覧を返す + + Returns: + (名前, モデルインスタンス) のリスト + """ + return [ + ( + "LogisticRegression", + LogisticRegression( + max_iter=1000, + random_state=RANDOM_STATE, + ), + ), + ( + "SVM_RBF", + SVC( + kernel="rbf", + random_state=RANDOM_STATE, + ), + ), + ( + "SVM_Linear", + SVC( + kernel="linear", + random_state=RANDOM_STATE, + ), + ), + ( + "RandomForest", + RandomForestClassifier( + n_estimators=100, + random_state=RANDOM_STATE, + ), + ), + ( + "MLP_1layer", + MLPClassifier( + hidden_layer_sizes=(64,), + max_iter=500, + random_state=RANDOM_STATE, + ), + ), + ( + "MLP_2layer", + MLPClassifier( + hidden_layer_sizes=(128, 64), + max_iter=500, + random_state=RANDOM_STATE, + ), + ), + ] + + +# ── 評価 ────────────────────────────────────────────── + +def _evaluate_models( + x: np.ndarray, + y: np.ndarray, +) -> list[dict]: + """全モデルを StratifiedKFold CV で評価する + + Args: + x: 特徴量行列 + y: ラベル配列 + + Returns: + 各モデルの評価結果辞書のリスト + """ + skf = StratifiedKFold( + n_splits=N_FOLDS, + shuffle=True, + random_state=RANDOM_STATE, + ) + results: list[dict] = [] + + for name, model in _build_models(): + f1_scores: list[float] = [] + + for train_idx, test_idx in skf.split(x, y): + x_train, x_test = x[train_idx], x[test_idx] + y_train, y_test = y[train_idx], y[test_idx] + + # スケーリング + scaler = StandardScaler() + x_train_s = scaler.fit_transform(x_train) + x_test_s = scaler.transform(x_test) + + model_clone = _clone_model(name) + model_clone.fit(x_train_s, y_train) + y_pred = model_clone.predict(x_test_s) + + f1 = f1_score(y_test, y_pred, average="macro") + f1_scores.append(f1) + + results.append({ + "name": name, + "f1_mean": float(np.mean(f1_scores)), + "f1_std": float(np.std(f1_scores)), + "f1_scores": f1_scores, + }) + print( + f" {name:25s} " + f"F1={np.mean(f1_scores):.4f} " + f"(±{np.std(f1_scores):.4f})" + ) + + return results + + +def _clone_model(name: str) -> object: + """モデル名から新しいインスタンスを生成する + + Args: + name: モデル名 + + Returns: + モデルインスタンス + """ + for n, m in _build_models(): + if n == name: + return m + raise ValueError(f"不明なモデル: {name}") + + +# ── 最良モデルの再学習・保存 ────────────────────────── + +def _train_best_and_save( + best_name: str, + x: np.ndarray, + y: np.ndarray, +) -> str: + """最良モデルを全データで再学習して保存する + + Args: + best_name: 最良モデルの名前 + x: 全特徴量行列 + y: 全ラベル配列 + + Returns: + 全データでの classification_report 文字列 + """ + scaler = StandardScaler() + x_scaled = scaler.fit_transform(x) + + model = _clone_model(best_name) + model.fit(x_scaled, y) + y_pred = model.predict(x_scaled) + report = classification_report( + y, y_pred, + target_names=["normal", "intersection"], + ) + + # 保存 + PARAMS_DIR.mkdir(parents=True, exist_ok=True) + joblib.dump(model, MODEL_PATH) + joblib.dump(scaler, SCALER_PATH) + + return report + + +# ── ドキュメント出力 ────────────────────────────────── + +def _write_report( + results: list[dict], + best_name: str, + n_samples: int, + n_intersection: int, + n_normal: int, + full_report: str, +) -> None: + """評価結果をドキュメントとして出力する + + Args: + results: 各モデルの評価結果 + best_name: 最良モデルの名前 + n_samples: 全サンプル数 + n_intersection: intersection のサンプル数 + n_normal: normal のサンプル数 + full_report: 全データでの classification_report + """ + timestamp = datetime.now().strftime("%Y-%m-%d %H:%M") + + # 結果テーブル + sorted_results = sorted( + results, key=lambda r: r["f1_mean"], reverse=True, + ) + table_lines: list[str] = [] + table_lines.append( + f" {'モデル':25s} {'F1(平均)':>10s}" + f" {'F1(標準偏差)':>14s}" + ) + table_lines.append(f" {'─' * 55}") + for r in sorted_results: + marker = " ← 採用" if r["name"] == best_name else "" + table_lines.append( + f" {r['name']:25s}" + f" {r['f1_mean']:10.4f}" + f" {r['f1_std']:14.4f}{marker}" + ) + table_str = "\n".join(table_lines) + + # fold 詳細 + fold_lines: list[str] = [] + for r in sorted_results: + scores_str = ", ".join( + f"{s:.4f}" for s in r["f1_scores"] + ) + fold_lines.append( + f" ・{r['name']}: [{scores_str}]" + ) + fold_str = "\n".join(fold_lines) + + # classification_report のインデント + report_indented = "\n".join( + f" {line}" for line in full_report.splitlines() + ) + + content = f"""\ +======================================================================== +十字路分類モデル評価 (Intersection Classifier Evaluation) +======================================================================== + + +1. 概要 (Overview) +------------------------------------------------------------------------ + + 1-0. 目的 + + 十字路(intersection)と通常区間(normal)を分類する + 二値画像分類モデルの比較評価結果を記録する. + + 1-1. 評価日時 + + ・実施日時: {timestamp} + + 1-2. データセット + + ・入力: 40×30 二値画像(1200 特徴量,0.0/1.0) + ・全サンプル数: {n_samples} + - intersection: {n_intersection} + - normal: {n_normal} + ・クラス比率: intersection:normal\ + = {n_intersection}:{n_normal} + + +2. 評価方法 (Evaluation Method) +------------------------------------------------------------------------ + + 2-1. 交差検証 + + ・手法: Stratified {N_FOLDS}-Fold Cross-Validation + ・指標: マクロ平均 F1 スコア + ・前処理: StandardScaler(fold ごとに fit) + ・乱数シード: {RANDOM_STATE} + + +3. 評価結果 (Results) +------------------------------------------------------------------------ + + 3-1. モデル比較(F1 スコア降順) + +{table_str} + + 3-2. 各 Fold の F1 スコア + +{fold_str} + + 3-3. 採用モデル + + ・モデル: {best_name} + ・保存先: params/intersection_model.pkl + ・スケーラ: params/intersection_scaler.pkl + + 3-4. 全データでの分類レポート(再学習後) + +{report_indented} +""" + + REPORT_PATH.parent.mkdir(parents=True, exist_ok=True) + with open( + REPORT_PATH, "w", encoding="utf-8", + ) as f: + f.write(content) + + print(f"\nドキュメント出力: {REPORT_PATH}") + + +# ── メイン ──────────────────────────────────────────── + +def main() -> None: + """学習・評価のメインフロー""" + print("=== 十字路分類モデルの学習・評価 ===\n") + + # データ読み込み + print("データ読み込み中...") + try: + x, y = load_dataset() + except FileNotFoundError as e: + print(f"エラー: {e}") + sys.exit(1) + + n_samples = len(y) + n_intersection = int(np.sum(y == 1)) + n_normal = int(np.sum(y == 0)) + print( + f" サンプル数: {n_samples}" + f" (intersection={n_intersection}," + f" normal={n_normal})\n" + ) + + # CV 評価 + print(f"{N_FOLDS}-Fold CV 評価中...") + results = _evaluate_models(x, y) + + # 最良モデル選定 + best = max(results, key=lambda r: r["f1_mean"]) + best_name = best["name"] + print( + f"\n最良モデル: {best_name}" + f" (F1={best['f1_mean']:.4f})\n" + ) + + # 全データで再学習・保存 + print("全データで再学習・保存中...") + full_report = _train_best_and_save(best_name, x, y) + print(f" モデル保存: {MODEL_PATH}") + print(f" スケーラ保存: {SCALER_PATH}") + + # ドキュメント出力 + _write_report( + results, best_name, + n_samples, n_intersection, n_normal, + full_report, + ) + + print("\n完了") + + +if __name__ == "__main__": + main() diff --git a/src/pc/gui/command_sender.py b/src/pc/gui/command_sender.py new file mode 100644 index 0000000..dc48b87 --- /dev/null +++ b/src/pc/gui/command_sender.py @@ -0,0 +1,110 @@ +""" +command_sender +パラメータの dirty 管理・コマンド辞書構築・ZMQ 送信を担当するモジュール +""" + +import dataclasses + +from common.steering.pd_control import PdParams +from common.steering.pursuit_control import PursuitParams +from common.steering.recovery import RecoveryParams +from common.steering.ts_pd_control import TsPdParams +from common.vision.line_detector import ImageParams +from pc.comm.zmq_client import PcZmqClient + + +class CommandSender: + """コマンドを構築して Pi に送信する + + パラメータの変更フラグを管理し, + モード・操舵量・パラメータをまとめて送信する + """ + + def __init__( + self, + zmq_client: PcZmqClient, + ) -> None: + self._zmq_client = zmq_client + self.params_dirty: bool = True + + def send( + self, + *, + is_auto: bool, + throttle: float, + steer: float, + intersection_enabled: bool, + intersection_throttle: float, + steering_method: str, + image_params: ImageParams, + pd_params: PdParams, + pursuit_params: PursuitParams, + ts_pd_params: TsPdParams, + recovery_params: RecoveryParams, + ) -> None: + """コマンドを構築して送信する + + Args: + is_auto: 自動操縦中か + throttle: 手動操作の throttle + steer: 手動操作の steer + intersection_enabled: 十字路判定の有効化 + intersection_throttle: 十字路通過時の throttle + steering_method: 制御手法名 + image_params: 二値化パラメータ + pd_params: PD 制御パラメータ + pursuit_params: Pursuit 制御パラメータ + ts_pd_params: Theil-Sen PD パラメータ + recovery_params: 復帰パラメータ + """ + cmd: dict = {} + + # モード + if is_auto: + cmd["mode"] = "auto" + elif throttle != 0.0 or steer != 0.0: + cmd["mode"] = "manual" + cmd["throttle"] = throttle + cmd["steer"] = steer + else: + cmd["mode"] = "stop" + + # 十字路設定 + cmd["intersection_enabled"] = ( + intersection_enabled + ) + cmd["intersection_throttle"] = ( + intersection_throttle + ) + + # 制御手法 + cmd["steering_method"] = steering_method + + # パラメータ更新(変更があった場合のみ) + if self.params_dirty: + cmd["image_params"] = dataclasses.asdict( + image_params, + ) + cmd["pd_params"] = dataclasses.asdict( + pd_params, + ) + cmd["pursuit_params"] = dataclasses.asdict( + pursuit_params, + ) + cmd["steering_params"] = dataclasses.asdict( + ts_pd_params, + ) + cmd["recovery_params"] = dataclasses.asdict( + recovery_params, + ) + self.params_dirty = False + + self._zmq_client.send_command(cmd) + + def send_stop(self) -> None: + """停止コマンドを送信する""" + self._zmq_client.send_command({"mode": "stop"}) + + def mark_dirty(self) -> None: + """パラメータ変更フラグを立てる""" + self.params_dirty = True diff --git a/src/pc/gui/main_window.py b/src/pc/gui/main_window.py index ec7c5cd..69f69d5 100644 --- a/src/pc/gui/main_window.py +++ b/src/pc/gui/main_window.py @@ -1,13 +1,12 @@ """ main_window PC 側のメインウィンドウを定義するモジュール -カメラ映像のリアルタイム表示と操作 UI を提供する +Pi からのテレメトリをリアルタイム表示し, +モード切替・パラメータ調整・手動操作のコマンドを送信する """ -import cv2 -import numpy as np from PySide6.QtCore import Qt, QTimer -from PySide6.QtGui import QImage, QKeyEvent, QPixmap +from PySide6.QtGui import QKeyEvent from PySide6.QtWidgets import ( QHBoxLayout, QLabel, @@ -20,12 +19,16 @@ from common import config from pc.comm.zmq_client import PcZmqClient +from pc.gui.command_sender import CommandSender +from pc.gui.manual_controller import ManualController from pc.gui.panels import ( ControlParamPanel, ImageParamPanel, + IntersectionPanel, OverlayPanel, RecoveryPanel, ) +from pc.gui.telemetry_display import TelemetryDisplay from pc.steering.auto_params import ( load_control, load_detect_params, @@ -39,46 +42,30 @@ save_recovery, save_ts_pd, ) -from pc.steering.base import SteeringBase -from pc.steering.pd_control import PdControl, PdParams -from pc.steering.pursuit_control import ( - PursuitControl, - PursuitParams, -) -from pc.steering.recovery import ( - RecoveryController, - RecoveryParams, -) -from pc.steering.ts_pd_control import ( - TsPdControl, - TsPdParams, -) -from pc.vision.fitting import theil_sen_fit -from pc.vision.line_detector import ( - ImageParams, - LineDetectResult, - detect_line, -) -from pc.vision.overlay import draw_overlay +from common.steering.pd_control import PdParams +from common.steering.pursuit_control import PursuitParams +from common.steering.recovery import RecoveryParams +from common.steering.ts_pd_control import TsPdParams +from common.vision.line_detector import ImageParams # 映像更新間隔 (ms) FRAME_INTERVAL_MS: int = 33 -# 操舵量送信間隔 (ms) -CONTROL_INTERVAL_MS: int = int( +# コマンド送信間隔 (ms) +COMMAND_INTERVAL_MS: int = int( 1000 / config.CONTROL_PUBLISH_HZ ) # 映像表示のスケール倍率(40x30 → 640x480 相当) -DISPLAY_SCALE: float = 16.0 - -# 手動操作の throttle / steer 量 -MANUAL_THROTTLE: float = 0.5 -MANUAL_STEER: float = 0.4 +DISPLAY_SCALE: float = config.DISPLAY_SCALE class MainWindow(QMainWindow): - """PC 側のメインウィンドウ""" + """PC 側のメインウィンドウ + + Pi からテレメトリを受信して映像・状態を表示し, + コマンドを Pi に送信する + """ def __init__(self) -> None: super().__init__() @@ -86,49 +73,39 @@ self._is_connected = False self._is_auto = False - # 手動操作の状態 - self._pressed_keys: set[int] = set() - self._throttle: float = 0.0 - self._steer: float = 0.0 + # 分割されたコンポーネント + self._manual = ManualController() + self._cmd_sender = CommandSender( + self._zmq_client, + ) # 前回のパラメータを復元 pd_params, last_method, last_steering = ( load_control() ) image_params = load_detect_params(last_method) - self._pd_control = PdControl( - params=pd_params, - image_params=image_params, - ) + self._image_params = image_params + self._pd_params = pd_params pursuit_params = load_pursuit() - self._pursuit_control = PursuitControl( - params=pursuit_params, - image_params=image_params, - ) + self._pursuit_params = pursuit_params ts_pd_params = load_ts_pd() - self._ts_pd_control = TsPdControl( - params=ts_pd_params, - image_params=image_params, - ) - - # 現在の制御手法("pd", "pursuit", "ts_pd") + self._ts_pd_params = ts_pd_params self._steering_method: str = last_steering - # コースアウト復帰 + # 復帰パラメータ recovery_params = load_recovery() - self._recovery = RecoveryController( - params=recovery_params, - ) - - # 最新フレームの保持(自動操縦で使用) - self._latest_frame: np.ndarray | None = None - - # 検出結果の保持 - self._last_detect_result: LineDetectResult | None = ( - None - ) + self._recovery_params = recovery_params self._setup_ui() + + # TelemetryDisplay はウィジェット生成後に初期化 + self._telemetry = TelemetryDisplay( + self._zmq_client, + self._video_label, + self._detect_info_label, + self._perf_label, + ) + self._setup_timers() def _setup_ui(self) -> None: @@ -157,7 +134,7 @@ left_layout.addWidget(self._video_label) self._detect_info_label = QLabel( - "pos: --- head: --- curv: ---" + "pos: --- head: ---" ) self._detect_info_label.setAlignment( Qt.AlignmentFlag.AlignLeft, @@ -168,6 +145,19 @@ " padding: 4px;" ) left_layout.addWidget(self._detect_info_label) + + self._perf_label = QLabel( + "recv FPS: --- Pi FPS: ---" + ) + self._perf_label.setAlignment( + Qt.AlignmentFlag.AlignLeft, + ) + self._perf_label.setStyleSheet( + "font-size: 14px; font-family: monospace;" + " color: #ff0; background-color: #222;" + " padding: 4px;" + ) + left_layout.addWidget(self._perf_label) root_layout.addLayout(left_layout, stretch=3) # 右側: スクロール可能なコントロールパネル @@ -215,7 +205,7 @@ # 二値化パラメータパネル self._image_panel = ImageParamPanel( - self._pd_control.image_params, + self._image_params, ) self._image_panel.image_params_changed.connect( self._on_image_params_changed, @@ -225,11 +215,17 @@ ) control_layout.addWidget(self._image_panel) + # 十字路判定パネル + self._intersection_panel = IntersectionPanel( + available=True, + ) + control_layout.addWidget(self._intersection_panel) + # 制御パラメータパネル self._control_panel = ControlParamPanel( - self._pd_control.params, - self._pursuit_control.params, - self._ts_pd_control.params, + self._pd_params, + self._pursuit_params, + self._ts_pd_params, self._steering_method, ) self._control_panel.pd_params_changed.connect( @@ -248,7 +244,7 @@ # コースアウト復帰パネル self._recovery_panel = RecoveryPanel( - self._recovery.params, + self._recovery_params, ) self._recovery_panel.recovery_params_changed.connect( self._on_recovery_params_changed, @@ -277,23 +273,16 @@ control_layout.addStretch() - @property - def _active_control(self) -> SteeringBase: - """現在選択中の制御クラスを返す""" - if self._steering_method == "pursuit": - return self._pursuit_control - if self._steering_method == "ts_pd": - return self._ts_pd_control - return self._pd_control - def _setup_timers(self) -> None: """タイマーを設定する""" self._frame_timer = QTimer(self) - self._frame_timer.timeout.connect(self._update_frame) + self._frame_timer.timeout.connect( + self._update_frame, + ) - self._control_timer = QTimer(self) - self._control_timer.timeout.connect( - self._send_control, + self._command_timer = QTimer(self) + self._command_timer.timeout.connect( + self._send_command, ) # ── パネルシグナルのスロット ─────────────────────────── @@ -301,46 +290,49 @@ def _on_image_params_changed( self, ip: ImageParams, ) -> None: - """二値化パラメータの変更を全制御クラスに反映する""" - self._pd_control.image_params = ip - self._pursuit_control.image_params = ip - self._ts_pd_control.image_params = ip + """二値化パラメータの変更をマークする""" + self._image_params = ip + self._cmd_sender.mark_dirty() def _on_method_changed(self, method: str) -> None: - """検出手法の変更に合わせて制御設定を保存する""" + """検出手法の変更を保存する""" save_control( - self._pd_control.params, method, + self._pd_params, method, self._steering_method, ) def _on_pd_params_changed(self, p: PdParams) -> None: - """PD パラメータの変更を制御クラスに反映して保存する""" - self._pd_control.params = p + """PD パラメータの変更を保存・マークする""" + self._pd_params = p save_control( - p, self._pd_control.image_params.method, + p, self._image_params.method, self._steering_method, ) + self._cmd_sender.mark_dirty() def _on_pursuit_params_changed( self, p: PursuitParams, ) -> None: - """Pursuit パラメータの変更を制御クラスに反映して保存する""" - self._pursuit_control.params = p + """Pursuit パラメータの変更を保存・マークする""" + self._pursuit_params = p save_pursuit(p) + self._cmd_sender.mark_dirty() def _on_ts_pd_params_changed( self, p: TsPdParams, ) -> None: - """Theil-Sen PD パラメータの変更を反映して保存する""" - self._ts_pd_control.params = p + """Theil-Sen PD パラメータの変更を保存・マークする""" + self._ts_pd_params = p save_ts_pd(p) + self._cmd_sender.mark_dirty() def _on_recovery_params_changed( self, p: RecoveryParams, ) -> None: - """復帰パラメータの変更を反映して保存する""" - self._recovery.params = p + """復帰パラメータの変更を保存・マークする""" + self._recovery_params = p save_recovery(p) + self._cmd_sender.mark_dirty() def _on_overlay_flags_changed(self) -> None: """オーバーレイフラグの変更を保存する""" @@ -352,10 +344,11 @@ """制御手法の切替を反映して保存する""" self._steering_method = method save_control( - self._pd_control.params, - self._pd_control.image_params.method, + self._pd_params, + self._image_params.method, method, ) + self._cmd_sender.mark_dirty() # ── 接続 ────────────────────────────────────────────── @@ -367,29 +360,32 @@ self._connect() def _connect(self) -> None: - """ZMQ 通信を開始して映像受信を始める""" + """ZMQ 通信を開始してテレメトリ受信を始める""" self._zmq_client.start() self._is_connected = True + self._cmd_sender.mark_dirty() self._connect_btn.setText("切断") self._auto_btn.setEnabled(True) self._status_label.setText("接続中 (手動操作)") self._frame_timer.start(FRAME_INTERVAL_MS) - self._control_timer.start(CONTROL_INTERVAL_MS) + self._command_timer.start(COMMAND_INTERVAL_MS) + # 初回接続時に stop コマンドを送信 + self._cmd_sender.send_stop() def _disconnect(self) -> None: """ZMQ 通信を停止する""" self._frame_timer.stop() - self._control_timer.stop() + self._command_timer.stop() + # Pi を停止 + if self._is_connected: + self._cmd_sender.send_stop() if self._is_auto: self._is_auto = False self._auto_btn.setText("自動操縦 ON") self._zmq_client.stop() self._is_connected = False self._auto_btn.setEnabled(False) - self._pressed_keys.clear() - self._throttle = 0.0 - self._steer = 0.0 - self._latest_frame = None + self._manual.reset() self._connect_btn.setText("接続開始") self._status_label.setText("未接続") self._video_label.setText("カメラ映像待機中...") @@ -409,203 +405,50 @@ def _enable_auto(self) -> None: """自動操縦を開始する""" self._is_auto = True - self._active_control.reset() - self._recovery.reset() - self._pressed_keys.clear() + self._manual.reset() self._auto_btn.setText("自動操縦 OFF") self._status_label.setText("接続中 (自動操縦)") + self._cmd_sender.mark_dirty() def _disable_auto(self) -> None: """自動操縦を停止して手動に戻る""" self._is_auto = False - self._throttle = 0.0 - self._steer = 0.0 + self._manual.reset() self._auto_btn.setText("自動操縦 ON") self._status_label.setText("接続中 (手動操作)") self._update_control_label() - # ── 映像更新 ────────────────────────────────────────── + # ── 映像更新(テレメトリ受信) ──────────────────────── def _update_frame(self) -> None: - """タイマーから呼び出され,最新フレームを表示する""" - frame = self._zmq_client.receive_image() - if frame is None: - return - self._latest_frame = frame - - # 線検出は常に実行(検出情報ラベル表示のため) - if self._is_auto: - ctrl = self._active_control - output = ctrl.compute(frame) - self._last_detect_result = ( - ctrl.last_detect_result - ) - - # コースアウト復帰判定 - det = self._last_detect_result - detected = det is not None and det.detected - pos_err = ( - det.position_error - if detected and det is not None - else 0.0 - ) - recovery_output = self._recovery.update( - detected, pos_err, - ) - if recovery_output is not None: - output = recovery_output - - self._throttle = output.throttle - self._steer = output.steer - self._update_control_label() - else: - self._last_detect_result = detect_line( - frame, - self._pd_control.image_params, - ) - - self._display_frame(frame) - - def _update_detect_info_label(self) -> None: - """検出情報ラベルを更新する""" - r = self._last_detect_result - if r is None or not r.detected: - self._detect_info_label.setText( - "pos: --- head: --- curv: ---" - ) - return - self._detect_info_label.setText( - f"pos: {r.position_error:+.3f}" - f" head: {r.heading:+.4f}" - f" curv: {r.curvature:+.6f}" - ) - - def _calc_pursuit_points_preview( - self, - ) -> ( - tuple[tuple[float, float], tuple[float, float]] - | None - ): - """手動操作中にパシュート目標点を算出する - - Returns: - ((near_x, near_y), (far_x, far_y)) または None - """ - r = self._last_detect_result - if r is None or not r.detected: - return None - if r.row_centers is None: - return None - - centers = r.row_centers - valid = ~np.isnan(centers) - ys = np.where(valid)[0].astype(float) - xs = centers[valid] - if len(ys) < 2: - return None - - slope, intercept = theil_sen_fit(ys, xs) - h = len(centers) - p = self._pursuit_control.params - near_y = h * p.near_ratio - far_y = h * p.far_ratio - near_x = slope * near_y + intercept - far_x = slope * far_y + intercept - return ((near_x, near_y), (far_x, far_y)) - - def _calc_ts_pd_line_preview( - self, - ) -> ( - tuple[tuple[float, float], tuple[float, float]] - | None - ): - """Theil-Sen PD の近似直線を表示用に算出する - - 直線の上端と下端の 2 点を返す - - Returns: - ((bottom_x, bottom_y), (top_x, top_y)) または None - """ - if self._is_auto: - fit = self._ts_pd_control.last_fit_line - if fit is None: - return None - slope, intercept = fit - r = self._last_detect_result - if r is None or r.row_centers is None: - return None - h = len(r.row_centers) - else: - r = self._last_detect_result - if r is None or not r.detected: - return None - if r.row_centers is None: - return None - centers = r.row_centers - valid = ~np.isnan(centers) - ys = np.where(valid)[0].astype(float) - xs = centers[valid] - if len(ys) < 2: - return None - slope, intercept = theil_sen_fit(ys, xs) - h = len(centers) - - bottom_y = float(h - 1) - top_y = 0.0 - bottom_x = slope * bottom_y + intercept - top_x = slope * top_y + intercept - return ((bottom_x, bottom_y), (top_x, top_y)) - - def _display_frame(self, frame: np.ndarray) -> None: - """NumPy 配列の画像を QLabel に表示する - - Args: - frame: グレースケールの画像 - """ - # グレースケール → BGR 変換(カラーオーバーレイ描画のため) - bgr = cv2.cvtColor(frame, cv2.COLOR_GRAY2BGR) - - # オーバーレイ描画 - pursuit_pts = None - if self._steering_method == "pursuit": - if self._is_auto: - pursuit_pts = ( - self._pursuit_control - .last_pursuit_points - ) - else: - pursuit_pts = ( - self._calc_pursuit_points_preview() - ) - elif self._steering_method == "ts_pd": - pursuit_pts = ( - self._calc_ts_pd_line_preview() - ) - bgr = draw_overlay( - bgr, self._last_detect_result, + """タイマーから呼び出され,テレメトリを受信して表示する""" + received = self._telemetry.update( self._overlay_panel.get_flags(), - pursuit_points=pursuit_pts, ) + if not received: + return - # 検出情報をラベルに表示 - self._update_detect_info_label() + state = self._telemetry.state - # BGR → RGB 変換 - rgb = bgr[:, :, ::-1].copy() - h, w, ch = rgb.shape - image = QImage( - rgb.data, w, h, ch * w, - QImage.Format.Format_RGB888, - ) - disp_w = int(config.FRAME_WIDTH * DISPLAY_SCALE) - disp_h = int(config.FRAME_HEIGHT * DISPLAY_SCALE) - pixmap = QPixmap.fromImage(image).scaled( - disp_w, - disp_h, - Qt.AspectRatioMode.KeepAspectRatio, - Qt.TransformationMode.SmoothTransformation, - ) - self._video_label.setPixmap(pixmap) + # 自動時は Pi の操舵量を表示 + if self._is_auto: + self._update_control_label() + + # 十字路パネルの表示更新 + if self._intersection_panel.enabled: + self._intersection_panel.update_result( + state.is_intersection, + ) + else: + self._intersection_panel.clear_result() + + # 十字路有効時は分類器準備完了まで自動操縦を無効化 + if not self._is_auto: + can_auto = ( + not self._intersection_panel.enabled + or state.intersection_available + ) + self._auto_btn.setEnabled(can_auto) # ── 手動操作 ────────────────────────────────────────── @@ -624,79 +467,73 @@ if self._is_auto: return - self._pressed_keys.add(event.key()) - self._update_manual_control() + if self._manual.handle_key_press(event): + if self._manual.is_emergency_stop(): + if self._is_auto: + self._disable_auto() + self._update_control_label() def keyReleaseEvent(self, event: QKeyEvent) -> None: """キー離上時に操舵量を更新する""" - if event.isAutoRepeat() or self._is_auto: + if event.isAutoRepeat(): return - self._pressed_keys.discard(event.key()) - self._update_manual_control() - - def _update_manual_control(self) -> None: - """押下中のキーから throttle と steer を計算する""" - keys = self._pressed_keys - - # Space で緊急停止 - if Qt.Key.Key_Space in keys: - self._throttle = 0.0 - self._steer = 0.0 - self._pressed_keys.clear() - if self._is_auto: - self._disable_auto() + if self._is_auto: + return + if self._manual.handle_key_release(event): self._update_control_label() - return - - # throttle: W/↑ で前進,S/↓ で後退 - forward = ( - Qt.Key.Key_W in keys or Qt.Key.Key_Up in keys - ) - backward = ( - Qt.Key.Key_S in keys - or Qt.Key.Key_Down in keys - ) - if forward and not backward: - self._throttle = MANUAL_THROTTLE - elif backward and not forward: - self._throttle = -MANUAL_THROTTLE - else: - self._throttle = 0.0 - - # steer: A/← で左,D/→ で右 - left = ( - Qt.Key.Key_A in keys - or Qt.Key.Key_Left in keys - ) - right = ( - Qt.Key.Key_D in keys - or Qt.Key.Key_Right in keys - ) - if left and not right: - self._steer = -MANUAL_STEER - elif right and not left: - self._steer = MANUAL_STEER - else: - self._steer = 0.0 - - self._update_control_label() def _update_control_label(self) -> None: """操舵量の表示を更新する""" + if self._is_auto: + state = self._telemetry.state + throttle = state.throttle + steer = state.steer + else: + throttle = self._manual.throttle + steer = self._manual.steer + text = ( - f"throttle: {self._throttle:+.2f}\n" - f"steer: {self._steer:+.2f}" + f"throttle: {throttle:+.2f}\n" + f"steer: {steer:+.2f}" ) - if self._is_auto and self._recovery.is_recovering: - text += "\n[復帰中]" + if self._is_auto: + state = self._telemetry.state + if state.is_intersection: + text += "\n[十字路]" + if state.is_recovering: + text += "\n[復帰中]" self._control_label.setText(text) - def _send_control(self) -> None: - """操舵量を Pi に送信する""" + # ── コマンド送信 ────────────────────────────────────── + + def _send_command(self) -> None: + """コマンドを Pi に送信する""" if not self._is_connected: return - self._zmq_client.send_control( - self._throttle, self._steer, + + if self._is_auto: + throttle = 0.0 + steer = 0.0 + else: + throttle = self._manual.throttle + steer = self._manual.steer + + self._cmd_sender.send( + is_auto=self._is_auto, + throttle=throttle, + steer=steer, + intersection_enabled=( + self._intersection_panel.enabled + ), + intersection_throttle=( + self._intersection_panel.throttle + ), + steering_method=self._steering_method, + image_params=self._image_params, + pd_params=self._pd_params, + pursuit_params=self._pursuit_params, + ts_pd_params=self._ts_pd_params, + recovery_params=self._recovery_params, ) def closeEvent(self, event) -> None: diff --git a/src/pc/gui/manual_controller.py b/src/pc/gui/manual_controller.py new file mode 100644 index 0000000..cdeeed8 --- /dev/null +++ b/src/pc/gui/manual_controller.py @@ -0,0 +1,113 @@ +""" +manual_controller +キー入力から手動操作の throttle / steer を計算するモジュール +""" + +from PySide6.QtCore import Qt +from PySide6.QtGui import QKeyEvent + +# 手動操作の throttle / steer 量 +MANUAL_THROTTLE: float = 0.5 +MANUAL_STEER: float = 0.4 + + +class ManualController: + """キー入力を throttle / steer に変換する + + Qt のキーイベントを受け取り, + 押下中のキーセットから操舵量を計算する + """ + + def __init__(self) -> None: + self._pressed_keys: set[int] = set() + self.throttle: float = 0.0 + self.steer: float = 0.0 + + def reset(self) -> None: + """状態をリセットする""" + self._pressed_keys.clear() + self.throttle = 0.0 + self.steer = 0.0 + + def handle_key_press(self, event: QKeyEvent) -> bool: + """キー押下を処理する + + Args: + event: Qt のキーイベント + + Returns: + 操舵量が更新された場合 True + """ + if event.isAutoRepeat(): + return False + self._pressed_keys.add(event.key()) + self._update() + return True + + def handle_key_release(self, event: QKeyEvent) -> bool: + """キー離上を処理する + + Args: + event: Qt のキーイベント + + Returns: + 操舵量が更新された場合 True + """ + if event.isAutoRepeat(): + return False + self._pressed_keys.discard(event.key()) + self._update() + return True + + def is_emergency_stop(self) -> bool: + """Space キーによる緊急停止が発生したか判定する + + 緊急停止の場合は状態をリセットする + + Returns: + 緊急停止が発生した場合 True + """ + if Qt.Key.Key_Space in self._pressed_keys: + self.reset() + return True + return False + + def _update(self) -> None: + """押下中のキーから throttle と steer を計算する""" + keys = self._pressed_keys + + # Space で緊急停止 + if Qt.Key.Key_Space in keys: + self.reset() + return + + # throttle: W/↑ で前進,S/↓ で後退 + forward = ( + Qt.Key.Key_W in keys or Qt.Key.Key_Up in keys + ) + backward = ( + Qt.Key.Key_S in keys + or Qt.Key.Key_Down in keys + ) + if forward and not backward: + self.throttle = MANUAL_THROTTLE + elif backward and not forward: + self.throttle = -MANUAL_THROTTLE + else: + self.throttle = 0.0 + + # steer: A/← で左,D/→ で右 + left = ( + Qt.Key.Key_A in keys + or Qt.Key.Key_Left in keys + ) + right = ( + Qt.Key.Key_D in keys + or Qt.Key.Key_Right in keys + ) + if left and not right: + self.steer = -MANUAL_STEER + elif right and not left: + self.steer = MANUAL_STEER + else: + self.steer = 0.0 diff --git a/src/pc/gui/panels/__init__.py b/src/pc/gui/panels/__init__.py index 9814afe..f9c8691 100644 --- a/src/pc/gui/panels/__init__.py +++ b/src/pc/gui/panels/__init__.py @@ -8,6 +8,7 @@ ) from pc.gui.panels.control_param_panel import ControlParamPanel from pc.gui.panels.image_param_panel import ImageParamPanel +from pc.gui.panels.intersection_panel import IntersectionPanel from pc.gui.panels.overlay_panel import OverlayPanel from pc.gui.panels.recovery_panel import RecoveryPanel @@ -15,6 +16,7 @@ "CollapsibleGroupBox", "ControlParamPanel", "ImageParamPanel", + "IntersectionPanel", "OverlayPanel", "RecoveryPanel", ] diff --git a/src/pc/gui/panels/control_param_panel.py b/src/pc/gui/panels/control_param_panel.py index bddedfe..11572b5 100644 --- a/src/pc/gui/panels/control_param_panel.py +++ b/src/pc/gui/panels/control_param_panel.py @@ -32,9 +32,9 @@ load_pd_presets, load_ts_pd_presets, ) -from pc.steering.pd_control import PdParams -from pc.steering.pursuit_control import PursuitParams -from pc.steering.ts_pd_control import TsPdParams +from common.steering.pd_control import PdParams +from common.steering.pursuit_control import PursuitParams +from common.steering.ts_pd_control import TsPdParams class ControlParamPanel(CollapsibleGroupBox): diff --git a/src/pc/gui/panels/image_param_panel.py b/src/pc/gui/panels/image_param_panel.py index 570daa3..17f2e19 100644 --- a/src/pc/gui/panels/image_param_panel.py +++ b/src/pc/gui/panels/image_param_panel.py @@ -32,7 +32,7 @@ delete_image_preset, load_image_presets, ) -from pc.vision.line_detector import ( +from common.vision.line_detector import ( DETECT_METHODS, ImageParams, reset_valley_tracker, @@ -376,6 +376,10 @@ # 初期表示 self._on_method_changed() + def _on_expanded(self) -> None: + """展開後にパラメータの表示/非表示を再適用する""" + self._on_method_changed() + def _add_row( self, label: str, diff --git a/src/pc/gui/panels/intersection_panel.py b/src/pc/gui/panels/intersection_panel.py new file mode 100644 index 0000000..2363746 --- /dev/null +++ b/src/pc/gui/panels/intersection_panel.py @@ -0,0 +1,109 @@ +""" +intersection_panel +十字路判定の有効/無効を切り替える UI パネル +""" + +from PySide6.QtCore import Signal +from PySide6.QtWidgets import ( + QCheckBox, + QDoubleSpinBox, + QFormLayout, + QLabel, + QVBoxLayout, +) + +from pc.gui.panels.collapsible_group_box import ( + CollapsibleGroupBox, +) + + +class IntersectionPanel(CollapsibleGroupBox): + """十字路判定の切替 UI""" + + enabled_changed = Signal(bool) + + def __init__(self, available: bool = False) -> None: + super().__init__("十字路判定") + self._available = available + self._setup_ui() + + def _setup_ui(self) -> None: + """UI を構築する""" + layout = QVBoxLayout() + self.setLayout(layout) + + # 有効/無効チェックボックス + self._enabled_cb = QCheckBox("十字路判定を有効にする") + self._enabled_cb.setChecked(self._available) + self._enabled_cb.setEnabled(self._available) + self._enabled_cb.toggled.connect( + self._on_toggled, + ) + layout.addWidget(self._enabled_cb) + + # モデル状態の表示 + if self._available: + status_text = "モデル: 読み込み済み" + else: + status_text = "モデル: 未学習(params/ にありません)" + self._status_label = QLabel(status_text) + self._status_label.setStyleSheet( + "font-size: 11px; color: #888;" + ) + layout.addWidget(self._status_label) + + # 十字路時の速度 + form = QFormLayout() + layout.addLayout(form) + + self._spin_throttle = QDoubleSpinBox() + self._spin_throttle.setRange(0.0, 1.0) + self._spin_throttle.setSingleStep(0.05) + self._spin_throttle.setDecimals(2) + self._spin_throttle.setValue(0.4) + form.addRow("十字路速度:", self._spin_throttle) + + # 判定結果の表示 + self._result_label = QLabel("") + self._result_label.setStyleSheet( + "font-size: 13px; font-weight: bold;" + " padding: 2px;" + ) + layout.addWidget(self._result_label) + + @property + def enabled(self) -> bool: + """十字路判定が有効かどうか""" + return self._enabled_cb.isChecked() + + @property + def throttle(self) -> float: + """十字路時の速度""" + return self._spin_throttle.value() + + def update_result(self, is_intersection: bool) -> None: + """判定結果の表示を更新する + + Args: + is_intersection: 十字路と判定されたか + """ + if is_intersection: + self._result_label.setText("判定: 十字路(直進)") + self._result_label.setStyleSheet( + "font-size: 13px; font-weight: bold;" + " color: #f44; padding: 2px;" + ) + else: + self._result_label.setText("判定: 通常") + self._result_label.setStyleSheet( + "font-size: 13px; font-weight: bold;" + " color: #4f4; padding: 2px;" + ) + + def clear_result(self) -> None: + """判定結果の表示をクリアする""" + self._result_label.setText("") + + def _on_toggled(self, checked: bool) -> None: + """チェックボックスの切替をシグナルで通知する""" + self.enabled_changed.emit(checked) diff --git a/src/pc/gui/panels/recovery_panel.py b/src/pc/gui/panels/recovery_panel.py index a56bb25..bf07380 100644 --- a/src/pc/gui/panels/recovery_panel.py +++ b/src/pc/gui/panels/recovery_panel.py @@ -14,7 +14,7 @@ from pc.gui.panels.collapsible_group_box import ( CollapsibleGroupBox, ) -from pc.steering.recovery import RecoveryParams +from common.steering.recovery import RecoveryParams class RecoveryPanel(CollapsibleGroupBox): diff --git a/src/pc/gui/telemetry_display.py b/src/pc/gui/telemetry_display.py new file mode 100644 index 0000000..8cca816 --- /dev/null +++ b/src/pc/gui/telemetry_display.py @@ -0,0 +1,209 @@ +""" +telemetry_display +テレメトリ受信・状態抽出・映像表示・オーバーレイ描画を担当するモジュール +""" + +import time +from dataclasses import dataclass + +import cv2 +import numpy as np +from PySide6.QtCore import Qt +from PySide6.QtGui import QImage, QPixmap +from PySide6.QtWidgets import QLabel + +from common import config +from common.vision.line_detector import LineDetectResult +from pc.comm.zmq_client import PcZmqClient +from pc.vision.overlay import OverlayFlags, draw_overlay + +# 映像表示のスケール倍率 +DISPLAY_SCALE: float = config.DISPLAY_SCALE + + +@dataclass +class TelemetryState: + """テレメトリから受信した Pi 側の状態""" + + detected: bool = False + pos_error: float = 0.0 + heading: float = 0.0 + is_intersection: bool = False + is_recovering: bool = False + intersection_available: bool = False + compute_ms: float = 0.0 + fps: float = 0.0 + throttle: float = 0.0 + steer: float = 0.0 + + +class TelemetryDisplay: + """テレメトリを受信して映像・状態を表示する + + PcZmqClient からテレメトリを受信し, + 状態を抽出してオーバーレイ付き映像を QLabel に表示する + """ + + def __init__( + self, + zmq_client: PcZmqClient, + video_label: QLabel, + detect_info_label: QLabel, + perf_label: QLabel, + ) -> None: + self._zmq_client = zmq_client + self._video_label = video_label + self._detect_info_label = detect_info_label + self._perf_label = perf_label + + self.state = TelemetryState() + self._latest_binary: np.ndarray | None = None + + # 受信 FPS 計測 + self._recv_frame_count: int = 0 + self._recv_fps_start: float = time.time() + self._recv_fps: float = 0.0 + + def update( + self, + overlay_flags: OverlayFlags, + ) -> bool: + """テレメトリを受信して表示を更新する + + Args: + overlay_flags: オーバーレイ表示フラグ + + Returns: + テレメトリを受信できた場合 True + """ + result = self._zmq_client.receive_telemetry() + if result is None: + return False + + telemetry, frame, binary = result + + # テレメトリから状態を取得 + self.state.detected = telemetry.get( + "detected", False, + ) + self.state.pos_error = telemetry.get( + "pos_error", 0.0, + ) + self.state.heading = telemetry.get( + "heading", 0.0, + ) + self.state.is_intersection = telemetry.get( + "is_intersection", False, + ) + self.state.is_recovering = telemetry.get( + "is_recovering", False, + ) + self.state.intersection_available = telemetry.get( + "intersection_available", False, + ) + self.state.compute_ms = telemetry.get( + "compute_ms", 0.0, + ) + self.state.fps = telemetry.get("fps", 0.0) + self.state.throttle = telemetry.get( + "throttle", 0.0, + ) + self.state.steer = telemetry.get("steer", 0.0) + + self._latest_binary = binary + + # 受信 FPS 計測 + self._recv_frame_count += 1 + elapsed = time.time() - self._recv_fps_start + if elapsed >= 1.0: + self._recv_fps = ( + self._recv_frame_count / elapsed + ) + self._recv_frame_count = 0 + self._recv_fps_start = time.time() + + # 検出情報表示 + self._update_detect_info_label() + + # パフォーマンス表示 + perf_text = ( + f"recv FPS: {self._recv_fps:.1f}" + f" Pi FPS: {self.state.fps:.1f}" + ) + if self.state.compute_ms > 0.0: + perf_text += ( + f" 計算: {self.state.compute_ms:.1f}ms" + ) + self._perf_label.setText(perf_text) + + self._display_frame(frame, overlay_flags) + return True + + def _update_detect_info_label(self) -> None: + """検出情報ラベルを更新する""" + if not self.state.detected: + self._detect_info_label.setText( + "pos: --- head: ---" + ) + return + self._detect_info_label.setText( + f"pos: {self.state.pos_error:+.3f}" + f" head: {self.state.heading:+.4f}" + ) + + def _display_frame( + self, + frame: np.ndarray, + overlay_flags: OverlayFlags, + ) -> None: + """NumPy 配列の画像を QLabel に表示する + + Args: + frame: グレースケールの画像 + overlay_flags: オーバーレイ表示フラグ + """ + # グレースケール → BGR 変換 + bgr = cv2.cvtColor(frame, cv2.COLOR_GRAY2BGR) + + # テレメトリから LineDetectResult を構築 + detect_result = None + if ( + self.state.detected + and self._latest_binary is not None + ): + detect_result = LineDetectResult( + detected=True, + position_error=self.state.pos_error, + heading=self.state.heading, + curvature=0.0, + poly_coeffs=None, + row_centers=None, + binary_image=self._latest_binary, + ) + + bgr = draw_overlay( + bgr, detect_result, + overlay_flags, + is_intersection=( + self.state.is_intersection + ), + ) + + # BGR → RGB 変換 + rgb = bgr[:, :, ::-1].copy() + h, w, ch = rgb.shape + image = QImage( + rgb.data, w, h, ch * w, + QImage.Format.Format_RGB888, + ) + disp_w = int(config.FRAME_WIDTH * DISPLAY_SCALE) + disp_h = int( + config.FRAME_HEIGHT * DISPLAY_SCALE + ) + pixmap = QPixmap.fromImage(image).scaled( + disp_w, + disp_h, + Qt.AspectRatioMode.KeepAspectRatio, + Qt.TransformationMode.SmoothTransformation, + ) + self._video_label.setPixmap(pixmap) diff --git a/src/pc/review.py b/src/pc/review.py new file mode 100644 index 0000000..b62364a --- /dev/null +++ b/src/pc/review.py @@ -0,0 +1,22 @@ +""" +review +データ仕分け GUI のエントリーポイント +""" + +import sys + +from PySide6.QtWidgets import QApplication + +from pc.data.reviewer import ReviewWindow + + +def main() -> None: + """データ仕分け GUI を起動する""" + app = QApplication(sys.argv) + window = ReviewWindow() + window.show() + sys.exit(app.exec()) + + +if __name__ == "__main__": + main() diff --git a/src/pc/steering/auto_params.py b/src/pc/steering/auto_params.py index 9b10c4c..e8ddb24 100644 --- a/src/pc/steering/auto_params.py +++ b/src/pc/steering/auto_params.py @@ -19,11 +19,11 @@ from dataclasses import asdict from common.json_utils import PARAMS_DIR, read_json, write_json -from pc.steering.pd_control import PdParams -from pc.steering.pursuit_control import PursuitParams -from pc.steering.recovery import RecoveryParams -from pc.steering.ts_pd_control import TsPdParams -from pc.vision.line_detector import ImageParams +from common.steering.pd_control import PdParams +from common.steering.pursuit_control import PursuitParams +from common.steering.recovery import RecoveryParams +from common.steering.ts_pd_control import TsPdParams +from common.vision.line_detector import ImageParams from pc.vision.overlay import OverlayFlags # PD 制御パラメータファイル diff --git a/src/pc/steering/base.py b/src/pc/steering/base.py deleted file mode 100644 index 40eff1a..0000000 --- a/src/pc/steering/base.py +++ /dev/null @@ -1,50 +0,0 @@ -""" -base -操舵量計算の共通インターフェースを定義するモジュール -全ての操舵量計算クラスはこのインターフェースに従う -""" - -from abc import ABC, abstractmethod -from dataclasses import dataclass - -import numpy as np - - -@dataclass -class SteeringOutput: - """操舵量計算の出力を格納するデータクラス - - Attributes: - throttle: 前後方向の出力 (-1.0 ~ +1.0) - steer: 左右方向の出力 (-1.0 ~ +1.0) - """ - throttle: float - steer: float - - -class SteeringBase(ABC): - """操舵量計算の基底クラス - - 全ての操舵量計算クラスはこのクラスを継承し, - compute メソッドを実装する - """ - - @abstractmethod - def compute( - self, frame: np.ndarray, - ) -> SteeringOutput: - """カメラ画像から操舵量を計算する - - Args: - frame: BGR 形式のカメラ画像 - - Returns: - 計算された操舵量 - """ - - @abstractmethod - def reset(self) -> None: - """内部状態をリセットする - - 自動操縦の開始時に呼び出される - """ diff --git a/src/pc/steering/param_store.py b/src/pc/steering/param_store.py index 503bf7c..ea9cf04 100644 --- a/src/pc/steering/param_store.py +++ b/src/pc/steering/param_store.py @@ -7,9 +7,9 @@ from dataclasses import asdict, dataclass from common.json_utils import PARAMS_DIR, read_json, write_json -from pc.steering.pd_control import PdParams -from pc.steering.ts_pd_control import TsPdParams -from pc.vision.line_detector import ImageParams +from common.steering.pd_control import PdParams +from common.steering.ts_pd_control import TsPdParams +from common.vision.line_detector import ImageParams _PD_FILE = PARAMS_DIR / "presets_pd.json" _TS_PD_FILE = PARAMS_DIR / "presets_ts_pd.json" diff --git a/src/pc/steering/pd_control.py b/src/pc/steering/pd_control.py deleted file mode 100644 index 753c8e6..0000000 --- a/src/pc/steering/pd_control.py +++ /dev/null @@ -1,133 +0,0 @@ -""" -pd_control -PD 制御による操舵量計算モジュール -多項式フィッティングの位置・傾き・曲率から操舵量と速度を算出する -""" - -import time -from dataclasses import dataclass - -import numpy as np - -from pc.steering.base import SteeringBase, SteeringOutput -from pc.vision.line_detector import ( - ImageParams, - detect_line, - reset_valley_tracker, -) - - -@dataclass -class PdParams: - """PD 制御のパラメータ - - Attributes: - kp: 位置偏差ゲイン - kh: 傾き(ヘディング)ゲイン - kd: 微分ゲイン - max_steer_rate: 1フレームあたりの最大操舵変化量 - max_throttle: 直線での最大速度 - speed_k: 曲率ベースの減速係数 - """ - kp: float = 0.5 - kh: float = 0.3 - kd: float = 0.1 - max_steer_rate: float = 0.1 - max_throttle: float = 0.4 - speed_k: float = 0.3 - - -class PdControl(SteeringBase): - """PD 制御による操舵量計算クラス""" - - def __init__( - self, - params: PdParams | None = None, - image_params: ImageParams | None = None, - ) -> None: - self.params: PdParams = params or PdParams() - self.image_params: ImageParams = ( - image_params or ImageParams() - ) - self._prev_error: float = 0.0 - self._prev_time: float = 0.0 - self._prev_steer: float = 0.0 - self._last_result = None - - def compute( - self, frame: np.ndarray, - ) -> SteeringOutput: - """カメラ画像から PD 制御で操舵量を計算する - - Args: - frame: グレースケールのカメラ画像 - - Returns: - 計算された操舵量 - """ - p = self.params - - # 線検出 - result = detect_line(frame, self.image_params) - self._last_result = result - - # 線が検出できなかった場合は停止 - if not result.detected: - return SteeringOutput(throttle=0.0, steer=0.0) - - # 位置偏差 + 傾きによる操舵量 - error = ( - p.kp * result.position_error - + p.kh * result.heading - ) - - # 時間差分の計算 - now = time.time() - dt = ( - now - self._prev_time - if self._prev_time > 0 - else 0.033 - ) - dt = max(dt, 0.001) - - # 微分項 - derivative = (error - self._prev_error) / dt - steer = error + p.kd * derivative - - # 操舵量のクランプ - steer = max(-1.0, min(1.0, steer)) - - # レートリミッター(急な操舵変化を制限) - delta = steer - self._prev_steer - max_delta = p.max_steer_rate - delta = max(-max_delta, min(max_delta, delta)) - steer = self._prev_steer + delta - - # 速度制御(曲率連動) - throttle = ( - p.max_throttle - - p.speed_k * abs(result.curvature) - ) - throttle = max(0.0, throttle) - - # 状態の更新 - self._prev_error = error - self._prev_time = now - self._prev_steer = steer - - return SteeringOutput( - throttle=throttle, steer=steer, - ) - - def reset(self) -> None: - """内部状態をリセットする""" - self._prev_error = 0.0 - self._prev_time = 0.0 - self._prev_steer = 0.0 - self._last_result = None - reset_valley_tracker() - - @property - def last_detect_result(self): - """直近の線検出結果を取得する""" - return self._last_result diff --git a/src/pc/steering/pursuit_control.py b/src/pc/steering/pursuit_control.py deleted file mode 100644 index db0b793..0000000 --- a/src/pc/steering/pursuit_control.py +++ /dev/null @@ -1,171 +0,0 @@ -""" -pursuit_control -2点パシュートによる操舵量計算モジュール -行中心点に Theil-Sen 直線近似を適用し,外れ値に強い操舵量を算出する -""" - -from dataclasses import dataclass - -import numpy as np - -from common import config -from pc.steering.base import SteeringBase, SteeringOutput -from pc.vision.fitting import theil_sen_fit -from pc.vision.line_detector import ( - ImageParams, - detect_line, - reset_valley_tracker, -) - - -@dataclass -class PursuitParams: - """2点パシュート制御のパラメータ - - Attributes: - near_ratio: 近い目標点の位置(0.0=上端,1.0=下端) - far_ratio: 遠い目標点の位置(0.0=上端,1.0=下端) - k_near: 近い目標点の操舵ゲイン - k_far: 遠い目標点の操舵ゲイン - max_steer_rate: 1フレームあたりの最大操舵変化量 - max_throttle: 直線での最大速度 - speed_k: カーブ減速係数(2点の差に対する係数) - """ - near_ratio: float = 0.8 - far_ratio: float = 0.3 - k_near: float = 0.5 - k_far: float = 0.3 - max_steer_rate: float = 0.1 - max_throttle: float = 0.4 - speed_k: float = 2.0 - - -class PursuitControl(SteeringBase): - """2点パシュートによる操舵量計算クラス - - 行中心点から Theil-Sen 直線近似を行い, - 直線上の近い点と遠い点の偏差から操舵量を計算する - """ - - def __init__( - self, - params: PursuitParams | None = None, - image_params: ImageParams | None = None, - ) -> None: - self.params: PursuitParams = ( - params or PursuitParams() - ) - self.image_params: ImageParams = ( - image_params or ImageParams() - ) - self._prev_steer: float = 0.0 - self._last_result = None - self._last_pursuit_points: ( - tuple[tuple[float, float], tuple[float, float]] - | None - ) = None - - def compute( - self, frame: np.ndarray, - ) -> SteeringOutput: - """カメラ画像から2点パシュートで操舵量を計算する - - Args: - frame: グレースケールのカメラ画像 - - Returns: - 計算された操舵量 - """ - p = self.params - - # 線検出 - result = detect_line(frame, self.image_params) - self._last_result = result - - if not result.detected or result.row_centers is None: - self._last_pursuit_points = None - return SteeringOutput( - throttle=0.0, steer=0.0, - ) - - centers = result.row_centers - - # 有効な点(NaN でない行)を抽出 - valid = ~np.isnan(centers) - ys = np.where(valid)[0].astype(float) - xs = centers[valid] - - if len(ys) < 2: - self._last_pursuit_points = None - return SteeringOutput( - throttle=0.0, steer=0.0, - ) - - # Theil-Sen 直線近似 - slope, intercept = theil_sen_fit(ys, xs) - - center_x = config.FRAME_WIDTH / 2.0 - h = len(centers) - - # 直線上の 2 点の x 座標を取得 - near_y = h * p.near_ratio - far_y = h * p.far_ratio - near_x = slope * near_y + intercept - far_x = slope * far_y + intercept - - # 目標点を保持(デバッグ表示用) - self._last_pursuit_points = ( - (near_x, near_y), - (far_x, far_y), - ) - - # 各点の偏差(正: 線が左にある → 右に曲がる) - near_err = (center_x - near_x) / center_x - far_err = (center_x - far_x) / center_x - - # 操舵量 - steer = p.k_near * near_err + p.k_far * far_err - steer = max(-1.0, min(1.0, steer)) - - # レートリミッター - delta = steer - self._prev_steer - max_delta = p.max_steer_rate - delta = max(-max_delta, min(max_delta, delta)) - steer = self._prev_steer + delta - - # 速度制御(2点の x 差でカーブ度合いを判定) - curve = abs(near_x - far_x) / center_x - throttle = p.max_throttle - p.speed_k * curve - throttle = max(0.0, throttle) - - self._prev_steer = steer - - return SteeringOutput( - throttle=throttle, steer=steer, - ) - - def reset(self) -> None: - """内部状態をリセットする""" - self._prev_steer = 0.0 - self._last_result = None - self._last_pursuit_points = None - reset_valley_tracker() - - @property - def last_detect_result(self): - """直近の線検出結果を取得する""" - return self._last_result - - @property - def last_pursuit_points( - self, - ) -> ( - tuple[tuple[float, float], tuple[float, float]] - | None - ): - """直近の2点パシュート目標点を取得する - - Returns: - ((near_x, near_y), (far_x, far_y)) または None - """ - return self._last_pursuit_points diff --git a/src/pc/steering/recovery.py b/src/pc/steering/recovery.py deleted file mode 100644 index dfa88d9..0000000 --- a/src/pc/steering/recovery.py +++ /dev/null @@ -1,110 +0,0 @@ -""" -recovery -コースアウト復帰のパラメータと判定ロジックを定義するモジュール -黒線を一定時間検出できなかった場合に, -最後に検出した方向へ旋回しながら走行して復帰する -""" - -import time -from dataclasses import dataclass - -from pc.steering.base import SteeringOutput - - -@dataclass -class RecoveryParams: - """コースアウト復帰のパラメータ - - Attributes: - enabled: 復帰機能の有効/無効 - timeout_sec: 線を見失ってから復帰動作を開始するまでの時間 - steer_amount: 復帰時の操舵量(0.0 ~ 1.0) - throttle: 復帰時の速度(負: 後退,正: 前進) - """ - enabled: bool = True - timeout_sec: float = 0.5 - steer_amount: float = 0.5 - throttle: float = -0.3 - - -class RecoveryController: - """コースアウト復帰の判定と操舵量算出を行うクラス - - 自動操縦中にフレームごとに呼び出し, - 線検出の成否を記録する.一定時間検出できなかった場合に - 復帰用の操舵量を返す - """ - - def __init__( - self, - params: RecoveryParams | None = None, - ) -> None: - self.params: RecoveryParams = ( - params or RecoveryParams() - ) - self._last_detected_time: float = 0.0 - self._last_error_sign: float = 0.0 - self._is_recovering: bool = False - - def reset(self) -> None: - """内部状態をリセットする - - 自動操縦の開始時に呼び出す - """ - self._last_detected_time = time.time() - self._last_error_sign = 0.0 - self._is_recovering = False - - def update( - self, - detected: bool, - position_error: float = 0.0, - ) -> SteeringOutput | None: - """検出結果を記録し,復帰が必要なら操舵量を返す - - 毎フレーム呼び出す.線が検出できている間は内部状態を - 更新して None を返す.検出できない時間が timeout_sec を - 超えたら復帰用の SteeringOutput を返す - - Args: - detected: 線が検出できたか - position_error: 検出時の位置偏差(正: 線が左) - - Returns: - 復帰操舵量,または None(通常走行を継続) - """ - if not self.params.enabled: - return None - - now = time.time() - - if detected: - self._last_detected_time = now - if position_error != 0.0: - self._last_error_sign = ( - 1.0 if position_error > 0 else -1.0 - ) - self._is_recovering = False - return None - - # 線を見失ってからの経過時間を判定 - elapsed = now - self._last_detected_time - if elapsed < self.params.timeout_sec: - return None - - # 復帰モード: 最後に検出した方向へ旋回 - # position_error > 0(線が左)→ 左へ旋回(steer < 0) - self._is_recovering = True - steer = ( - -self._last_error_sign - * self.params.steer_amount - ) - return SteeringOutput( - throttle=self.params.throttle, - steer=steer, - ) - - @property - def is_recovering(self) -> bool: - """現在復帰動作中かどうかを返す""" - return self._is_recovering diff --git a/src/pc/steering/ts_pd_control.py b/src/pc/steering/ts_pd_control.py deleted file mode 100644 index 67f3bc5..0000000 --- a/src/pc/steering/ts_pd_control.py +++ /dev/null @@ -1,178 +0,0 @@ -""" -ts_pd_control -Theil-Sen 直線近似による PD 制御モジュール -行中心点に Theil-Sen 直線をフィッティングし, -位置偏差・傾き・微分項から操舵量を算出する -""" - -import time -from dataclasses import dataclass - -import numpy as np - -from common import config -from pc.steering.base import SteeringBase, SteeringOutput -from pc.vision.fitting import theil_sen_fit -from pc.vision.line_detector import ( - ImageParams, - detect_line, - reset_valley_tracker, -) - - -@dataclass -class TsPdParams: - """Theil-Sen PD 制御のパラメータ - - Attributes: - kp: 位置偏差ゲイン - kh: 傾き(Theil-Sen slope)ゲイン - kd: 微分ゲイン - max_steer_rate: 1フレームあたりの最大操舵変化量 - max_throttle: 直線での最大速度 - speed_k: 傾きベースの減速係数 - """ - kp: float = 0.5 - kh: float = 0.3 - kd: float = 0.1 - max_steer_rate: float = 0.1 - max_throttle: float = 0.4 - speed_k: float = 2.0 - - -class TsPdControl(SteeringBase): - """Theil-Sen 直線近似による PD 制御クラス - - 行中心点から Theil-Sen 直線近似を行い, - 画像下端での位置偏差と直線の傾きから PD 制御で操舵量を計算する - """ - - def __init__( - self, - params: TsPdParams | None = None, - image_params: ImageParams | None = None, - ) -> None: - self.params: TsPdParams = ( - params or TsPdParams() - ) - self.image_params: ImageParams = ( - image_params or ImageParams() - ) - self._prev_error: float = 0.0 - self._prev_time: float = 0.0 - self._prev_steer: float = 0.0 - self._last_result = None - self._last_fit_line: ( - tuple[float, float] | None - ) = None - - def compute( - self, frame: np.ndarray, - ) -> SteeringOutput: - """カメラ画像から Theil-Sen PD 制御で操舵量を計算する - - Args: - frame: グレースケールのカメラ画像 - - Returns: - 計算された操舵量 - """ - p = self.params - - # 線検出 - result = detect_line(frame, self.image_params) - self._last_result = result - - if not result.detected or result.row_centers is None: - self._last_fit_line = None - return SteeringOutput( - throttle=0.0, steer=0.0, - ) - - centers = result.row_centers - - # 有効な点(NaN でない行)を抽出 - valid = ~np.isnan(centers) - ys = np.where(valid)[0].astype(float) - xs = centers[valid] - - if len(ys) < 2: - self._last_fit_line = None - return SteeringOutput( - throttle=0.0, steer=0.0, - ) - - # Theil-Sen 直線近似: x = slope * y + intercept - slope, intercept = theil_sen_fit(ys, xs) - self._last_fit_line = (slope, intercept) - - center_x = config.FRAME_WIDTH / 2.0 - h = len(centers) - - # 画像下端での位置偏差 - bottom_x = slope * (h - 1) + intercept - position_error = (center_x - bottom_x) / center_x - - # 操舵量: P 項(位置偏差)+ Heading 項(傾き) - # 符号反転: 偏差正(線が左)→ steer 負(左へ曲がる) - error = -(p.kp * position_error + p.kh * slope) - - # 時間差分の計算 - now = time.time() - dt = ( - now - self._prev_time - if self._prev_time > 0 - else 0.033 - ) - dt = max(dt, 0.001) - - # D 項(微分項) - derivative = (error - self._prev_error) / dt - steer = error + p.kd * derivative - - # 操舵量のクランプ - steer = max(-1.0, min(1.0, steer)) - - # レートリミッター - delta = steer - self._prev_steer - max_delta = p.max_steer_rate - delta = max(-max_delta, min(max_delta, delta)) - steer = self._prev_steer + delta - - # 速度制御(傾きベース: 傾きが大きい → カーブ → 減速) - throttle = p.max_throttle - p.speed_k * abs(slope) - throttle = max(0.0, throttle) - - # 状態の更新 - self._prev_error = error - self._prev_time = now - self._prev_steer = steer - - return SteeringOutput( - throttle=throttle, steer=steer, - ) - - def reset(self) -> None: - """内部状態をリセットする""" - self._prev_error = 0.0 - self._prev_time = 0.0 - self._prev_steer = 0.0 - self._last_result = None - self._last_fit_line = None - reset_valley_tracker() - - @property - def last_detect_result(self): - """直近の線検出結果を取得する""" - return self._last_result - - @property - def last_fit_line( - self, - ) -> tuple[float, float] | None: - """直近の Theil-Sen 直線近似結果を取得する - - Returns: - (slope, intercept) または None - """ - return self._last_fit_line diff --git a/src/pc/vision/detectors/__init__.py b/src/pc/vision/detectors/__init__.py deleted file mode 100644 index e69de29..0000000 --- a/src/pc/vision/detectors/__init__.py +++ /dev/null diff --git a/src/pc/vision/detectors/blackhat.py b/src/pc/vision/detectors/blackhat.py deleted file mode 100644 index bc80920..0000000 --- a/src/pc/vision/detectors/blackhat.py +++ /dev/null @@ -1,66 +0,0 @@ -""" -blackhat -案A: Black-hat 中心型の線検出 -Black-hat 変換で背景より暗い構造を直接抽出し, -固定閾値 + 距離変換 + 行ごと中心抽出で検出する -""" - -import cv2 -import numpy as np - -from pc.vision.line_detector import ImageParams, LineDetectResult -from pc.vision.line_detector import fit_row_centers -from pc.vision.morphology import ( - apply_dist_mask, - apply_iso_closing, - apply_width_filter, -) - - -def detect_blackhat( - frame: np.ndarray, params: ImageParams, -) -> LineDetectResult: - """案A: Black-hat 中心型""" - # Black-hat 変換(暗い構造の抽出) - bh_k = params.blackhat_ksize | 1 - bh_kernel = cv2.getStructuringElement( - cv2.MORPH_ELLIPSE, (bh_k, bh_k), - ) - blackhat = cv2.morphologyEx( - frame, cv2.MORPH_BLACKHAT, bh_kernel, - ) - - # ガウシアンブラー - blur_k = params.blur_size | 1 - blurred = cv2.GaussianBlur( - blackhat, (blur_k, blur_k), 0, - ) - - # 固定閾値(Black-hat 後は線が白) - _, binary = cv2.threshold( - blurred, params.binary_thresh, 255, - cv2.THRESH_BINARY, - ) - - # 等方クロージング + 距離変換マスク + 幅フィルタ - binary = apply_iso_closing( - binary, params.iso_close_size, - ) - binary = apply_dist_mask( - binary, params.dist_thresh, - ) - if params.width_near > 0 and params.width_far > 0: - binary = apply_width_filter( - binary, - params.width_near, - params.width_far, - params.width_tolerance, - ) - - # 行ごと中心抽出 + フィッティング - return fit_row_centers( - binary, params.min_line_width, - median_ksize=params.median_ksize, - neighbor_thresh=params.neighbor_thresh, - residual_thresh=params.residual_thresh, - ) diff --git a/src/pc/vision/detectors/current.py b/src/pc/vision/detectors/current.py deleted file mode 100644 index 12dacba..0000000 --- a/src/pc/vision/detectors/current.py +++ /dev/null @@ -1,76 +0,0 @@ -""" -current -現行手法: CLAHE + 固定閾値 + 全ピクセルフィッティング -""" - -import cv2 -import numpy as np - -from pc.vision.line_detector import ( - DETECT_Y_END, - DETECT_Y_START, - MIN_FIT_PIXELS, - ImageParams, - LineDetectResult, - build_result, - no_detection, -) - - -def detect_current( - frame: np.ndarray, params: ImageParams, -) -> LineDetectResult: - """現行手法: CLAHE + 固定閾値 + 全ピクセルフィッティング""" - # CLAHE でコントラスト強調 - clahe = cv2.createCLAHE( - clipLimit=params.clahe_clip, - tileGridSize=( - params.clahe_grid, - params.clahe_grid, - ), - ) - enhanced = clahe.apply(frame) - - # ガウシアンブラー - blur_k = params.blur_size | 1 - blurred = cv2.GaussianBlur( - enhanced, (blur_k, blur_k), 0, - ) - - # 固定閾値で二値化(黒線を白に反転) - _, binary = cv2.threshold( - blurred, params.binary_thresh, 255, - cv2.THRESH_BINARY_INV, - ) - - # オープニング(孤立ノイズ除去) - if params.open_size >= 3: - open_k = params.open_size | 1 - open_kernel = cv2.getStructuringElement( - cv2.MORPH_ELLIPSE, (open_k, open_k), - ) - binary = cv2.morphologyEx( - binary, cv2.MORPH_OPEN, open_kernel, - ) - - # 横方向クロージング(途切れ補間) - if params.close_width >= 3: - close_h = max(params.close_height | 1, 1) - close_kernel = cv2.getStructuringElement( - cv2.MORPH_ELLIPSE, - (params.close_width, close_h), - ) - binary = cv2.morphologyEx( - binary, cv2.MORPH_CLOSE, close_kernel, - ) - - # 全ピクセルフィッティング - region = binary[DETECT_Y_START:DETECT_Y_END, :] - ys_local, xs = np.where(region > 0) - - if len(xs) < MIN_FIT_PIXELS: - return no_detection(binary) - - ys = ys_local + DETECT_Y_START - coeffs = np.polyfit(ys, xs, 2) - return build_result(coeffs, binary) diff --git a/src/pc/vision/detectors/dual_norm.py b/src/pc/vision/detectors/dual_norm.py deleted file mode 100644 index 46a9a6d..0000000 --- a/src/pc/vision/detectors/dual_norm.py +++ /dev/null @@ -1,86 +0,0 @@ -""" -dual_norm -案B: 二重正規化型の線検出 -背景除算で照明勾配を除去し, -適応的閾値で局所ムラにも対応する二重防壁構成 -""" - -import cv2 -import numpy as np - -from pc.vision.line_detector import ImageParams, LineDetectResult -from pc.vision.line_detector import fit_row_centers -from pc.vision.morphology import ( - apply_dist_mask, - apply_iso_closing, - apply_staged_closing, - apply_width_filter, -) - - -def detect_dual_norm( - frame: np.ndarray, params: ImageParams, -) -> LineDetectResult: - """案B: 二重正規化型""" - # 背景除算正規化 - bg_k = params.bg_blur_ksize | 1 - bg = cv2.GaussianBlur( - frame, (bg_k, bg_k), 0, - ) - normalized = ( - frame.astype(np.float32) * 255.0 - / (bg.astype(np.float32) + 1.0) - ) - normalized = np.clip( - normalized, 0, 255, - ).astype(np.uint8) - - # 適応的閾値(ガウシアン,BINARY_INV) - block = max(params.adaptive_block | 1, 3) - binary = cv2.adaptiveThreshold( - normalized, 255, - cv2.ADAPTIVE_THRESH_GAUSSIAN_C, - cv2.THRESH_BINARY_INV, - block, params.adaptive_c, - ) - - # 固定閾値との AND(有効時のみ) - if params.global_thresh > 0: - _, global_mask = cv2.threshold( - normalized, params.global_thresh, - 255, cv2.THRESH_BINARY_INV, - ) - binary = cv2.bitwise_and(binary, global_mask) - - # 段階クロージング or 等方クロージング - if params.stage_min_area > 0: - binary = apply_staged_closing( - binary, - params.stage_close_small, - params.stage_min_area, - params.stage_close_large, - ) - else: - binary = apply_iso_closing( - binary, params.iso_close_size, - ) - - # 距離変換マスク + 幅フィルタ - binary = apply_dist_mask( - binary, params.dist_thresh, - ) - if params.width_near > 0 and params.width_far > 0: - binary = apply_width_filter( - binary, - params.width_near, - params.width_far, - params.width_tolerance, - ) - - # 行ごと中心抽出 + フィッティング - return fit_row_centers( - binary, params.min_line_width, - median_ksize=params.median_ksize, - neighbor_thresh=params.neighbor_thresh, - residual_thresh=params.residual_thresh, - ) diff --git a/src/pc/vision/detectors/robust.py b/src/pc/vision/detectors/robust.py deleted file mode 100644 index ac55d35..0000000 --- a/src/pc/vision/detectors/robust.py +++ /dev/null @@ -1,66 +0,0 @@ -""" -robust -案C: 最高ロバスト型の線検出 -Black-hat + 適応的閾値の二重正規化に加え, -RANSAC で外れ値を除去する最もロバストな構成 -""" - -import cv2 -import numpy as np - -from pc.vision.line_detector import ImageParams, LineDetectResult -from pc.vision.line_detector import fit_row_centers -from pc.vision.morphology import ( - apply_dist_mask, - apply_iso_closing, - apply_width_filter, -) - - -def detect_robust( - frame: np.ndarray, params: ImageParams, -) -> LineDetectResult: - """案C: 最高ロバスト型""" - # Black-hat 変換 - bh_k = params.blackhat_ksize | 1 - bh_kernel = cv2.getStructuringElement( - cv2.MORPH_ELLIPSE, (bh_k, bh_k), - ) - blackhat = cv2.morphologyEx( - frame, cv2.MORPH_BLACKHAT, bh_kernel, - ) - - # 適応的閾値(BINARY: Black-hat 後は線が白) - block = max(params.adaptive_block | 1, 3) - binary = cv2.adaptiveThreshold( - blackhat, 255, - cv2.ADAPTIVE_THRESH_GAUSSIAN_C, - cv2.THRESH_BINARY, - block, -params.adaptive_c, - ) - - # 等方クロージング + 距離変換マスク + 幅フィルタ - binary = apply_iso_closing( - binary, params.iso_close_size, - ) - binary = apply_dist_mask( - binary, params.dist_thresh, - ) - if params.width_near > 0 and params.width_far > 0: - binary = apply_width_filter( - binary, - params.width_near, - params.width_far, - params.width_tolerance, - ) - - # 行ごと中央値抽出 + RANSAC フィッティング - return fit_row_centers( - binary, params.min_line_width, - use_median=True, - ransac_thresh=params.ransac_thresh, - ransac_iter=params.ransac_iter, - median_ksize=params.median_ksize, - neighbor_thresh=params.neighbor_thresh, - residual_thresh=params.residual_thresh, - ) diff --git a/src/pc/vision/detectors/valley.py b/src/pc/vision/detectors/valley.py deleted file mode 100644 index ce14b52..0000000 --- a/src/pc/vision/detectors/valley.py +++ /dev/null @@ -1,328 +0,0 @@ -""" -valley -案D: 谷検出+追跡型の線検出 -各行の輝度信号から谷(暗い領域)を直接検出し, -時系列追跡で安定性を確保する.二値化を使用しない -""" - -import cv2 -import numpy as np - -from common import config -from pc.vision.fitting import clean_and_fit -from pc.vision.line_detector import ( - DETECT_Y_END, - DETECT_Y_START, - MIN_FIT_ROWS, - ImageParams, - LineDetectResult, - build_result, - no_detection, -) - - -class ValleyTracker: - """谷検出の時系列追跡を管理するクラス - - 前フレームの多項式係数を保持し,予測・平滑化・ - 検出失敗時のコースティングを提供する - """ - - def __init__(self) -> None: - self._prev_coeffs: np.ndarray | None = None - self._smoothed_coeffs: np.ndarray | None = None - self._frames_lost: int = 0 - - def predict_x(self, y: float) -> float | None: - """前フレームの多項式から x 座標を予測する - - Args: - y: 画像の y 座標 - - Returns: - 予測 x 座標(履歴なしの場合は None) - """ - if self._smoothed_coeffs is None: - return None - return float(np.poly1d(self._smoothed_coeffs)(y)) - - def update( - self, - coeffs: np.ndarray, - alpha: float, - ) -> np.ndarray: - """検出成功時に状態を更新する - - EMA で多項式係数を平滑化し,更新後の係数を返す - - Args: - coeffs: 今フレームのフィッティング係数 - alpha: EMA 係数(1.0 で平滑化なし) - - Returns: - 平滑化後の多項式係数 - """ - self._frames_lost = 0 - if self._smoothed_coeffs is None: - self._smoothed_coeffs = coeffs.copy() - else: - self._smoothed_coeffs = ( - alpha * coeffs - + (1.0 - alpha) * self._smoothed_coeffs - ) - self._prev_coeffs = self._smoothed_coeffs.copy() - return self._smoothed_coeffs - - def coast( - self, max_frames: int, - ) -> LineDetectResult | None: - """検出失敗時に予測結果を返す - - Args: - max_frames: 予測を継続する最大フレーム数 - - Returns: - 予測による結果(継続不可の場合は None) - """ - if self._smoothed_coeffs is None: - return None - self._frames_lost += 1 - if self._frames_lost > max_frames: - return None - # 予測でデバッグ用二値画像は空にする - h = config.FRAME_HEIGHT - w = config.FRAME_WIDTH - blank = np.zeros((h, w), dtype=np.uint8) - return build_result(self._smoothed_coeffs, blank) - - def reset(self) -> None: - """追跡状態をリセットする""" - self._prev_coeffs = None - self._smoothed_coeffs = None - self._frames_lost = 0 - - -_valley_tracker = ValleyTracker() - - -def reset_valley_tracker() -> None: - """谷検出の追跡状態をリセットする""" - _valley_tracker.reset() - - -def _find_row_valley( - row: np.ndarray, - min_depth: int, - expected_width: float, - width_tolerance: float, - predicted_x: float | None, - max_deviation: int, -) -> tuple[float, float] | None: - """1行の輝度信号から最適な谷を検出する - - Args: - row: スムージング済みの1行輝度信号 - min_depth: 最小谷深度 - expected_width: 期待線幅(px,0 で幅フィルタ無効) - width_tolerance: 幅フィルタの上限倍率 - predicted_x: 追跡による予測 x 座標(None で無効) - max_deviation: 予測からの最大許容偏差 - - Returns: - (谷の中心x, 谷の深度) または None - """ - n = len(row) - if n < 5: - return None - - signal = row.astype(np.float32) - - # 極小値を検出(前後より小さい点) - left = signal[:-2] - center = signal[1:-1] - right = signal[2:] - minima_mask = (center <= left) & (center <= right) - minima_indices = np.where(minima_mask)[0] + 1 - - if len(minima_indices) == 0: - return None - - best: tuple[float, float] | None = None - best_score = -1.0 - - for idx in minima_indices: - val = signal[idx] - - # 左の肩を探す - left_shoulder = idx - for i in range(idx - 1, -1, -1): - if signal[i] < signal[i + 1]: - break - left_shoulder = i - # 右の肩を探す - right_shoulder = idx - for i in range(idx + 1, n): - if signal[i] < signal[i - 1]: - break - right_shoulder = i - - # 谷の深度(肩の平均 - 谷底) - shoulder_avg = ( - signal[left_shoulder] + signal[right_shoulder] - ) / 2.0 - depth = shoulder_avg - val - if depth < min_depth: - continue - - # 谷の幅 - width = right_shoulder - left_shoulder - center_x = (left_shoulder + right_shoulder) / 2.0 - - # 幅フィルタ - if expected_width > 0: - max_w = expected_width * width_tolerance - min_w = expected_width / width_tolerance - if width > max_w or width < min_w: - continue - - # 予測との偏差チェック - if predicted_x is not None: - if abs(center_x - predicted_x) > max_deviation: - continue - - # スコア: 深度優先,予測がある場合は近さも考慮 - score = float(depth) - if predicted_x is not None: - dist = abs(center_x - predicted_x) - score += max(0.0, max_deviation - dist) - - if score > best_score: - best_score = score - best = (center_x, float(depth)) - - return best - - -def _build_valley_binary( - shape: tuple[int, int], - centers_y: list[int], - centers_x: list[float], -) -> np.ndarray: - """谷検出結果からデバッグ用二値画像を生成する - - Args: - shape: 出力画像の (高さ, 幅) - centers_y: 検出行の y 座標リスト - centers_x: 検出行の中心 x 座標リスト - - Returns: - デバッグ用二値画像 - """ - binary = np.zeros(shape, dtype=np.uint8) - half_w = 3 - w = shape[1] - for y, cx in zip(centers_y, centers_x): - x0 = max(0, int(cx) - half_w) - x1 = min(w, int(cx) + half_w + 1) - binary[y, x0:x1] = 255 - return binary - - -def detect_valley( - frame: np.ndarray, params: ImageParams, -) -> LineDetectResult: - """案D: 谷検出+追跡型""" - h, w = frame.shape[:2] - - # 行ごとにガウシアン平滑化するため画像全体をブラー - gauss_k = params.valley_gauss_ksize | 1 - blurred = cv2.GaussianBlur( - frame, (gauss_k, 1), 0, - ) - - # 透視補正の期待幅を計算するための準備 - use_width = ( - params.width_near > 0 and params.width_far > 0 - ) - detect_h = DETECT_Y_END - DETECT_Y_START - denom = max(detect_h - 1, 1) - - centers_y: list[int] = [] - centers_x: list[float] = [] - depths: list[float] = [] - - for y in range(DETECT_Y_START, DETECT_Y_END): - row = blurred[y] - - # 期待幅の計算 - if use_width: - t = (DETECT_Y_END - 1 - y) / denom - expected_w = float(params.width_far) + ( - float(params.width_near) - - float(params.width_far) - ) * t - else: - expected_w = 0.0 - - # 予測 x 座標 - predicted_x = _valley_tracker.predict_x( - float(y), - ) - - result = _find_row_valley( - row, - params.valley_min_depth, - expected_w, - params.width_tolerance, - predicted_x, - params.valley_max_deviation, - ) - if result is not None: - centers_y.append(y) - centers_x.append(result[0]) - depths.append(result[1]) - - # デバッグ用二値画像 - debug_binary = _build_valley_binary( - (h, w), centers_y, centers_x, - ) - - if len(centers_y) < MIN_FIT_ROWS: - # 検出失敗 → コースティングを試みる - coasted = _valley_tracker.coast( - params.valley_coast_frames, - ) - if coasted is not None: - coasted.binary_image = debug_binary - return coasted - return no_detection(debug_binary) - - cy = np.array(centers_y, dtype=np.float64) - cx = np.array(centers_x, dtype=np.float64) - w_arr = np.array(depths, dtype=np.float64) - - # ロバストフィッティング(深度を重みに使用) - coeffs = clean_and_fit( - cy, cx, - median_ksize=params.median_ksize, - neighbor_thresh=params.neighbor_thresh, - residual_thresh=params.residual_thresh, - weights=w_arr, - ransac_thresh=params.ransac_thresh, - ransac_iter=params.ransac_iter, - ) - if coeffs is None: - coasted = _valley_tracker.coast( - params.valley_coast_frames, - ) - if coasted is not None: - coasted.binary_image = debug_binary - return coasted - return no_detection(debug_binary) - - # EMA で平滑化 - smoothed = _valley_tracker.update( - coeffs, params.valley_ema_alpha, - ) - - return build_result(smoothed, debug_binary) diff --git a/src/pc/vision/fitting.py b/src/pc/vision/fitting.py deleted file mode 100644 index 34813d5..0000000 --- a/src/pc/vision/fitting.py +++ /dev/null @@ -1,209 +0,0 @@ -""" -fitting -直線・曲線近似の共通ユーティリティモジュール -Theil-Sen 推定,RANSAC,外れ値除去付きフィッティングを提供する -""" - -import numpy as np - -# フィッティングに必要な最小行数 -MIN_FIT_ROWS: int = 10 - -# 近傍外れ値除去の設定 -NEIGHBOR_HALF_WINDOW: int = 3 -NEIGHBOR_FILTER_PASSES: int = 3 - -# 残差ベース反復除去の最大回数 -RESIDUAL_REMOVAL_ITERATIONS: int = 5 - - -def theil_sen_fit( - y: np.ndarray, - x: np.ndarray, -) -> tuple[float, float]: - """Theil-Sen 推定で直線 x = slope * y + intercept を求める - - 全ペアの傾きの中央値を使い,外れ値に強い直線近似を行う - - Args: - y: y 座標の配列(行番号) - x: x 座標の配列(各行の中心) - - Returns: - (slope, intercept) のタプル - """ - n = len(y) - slopes = [] - for i in range(n): - for j in range(i + 1, n): - dy = y[j] - y[i] - if dy != 0: - slopes.append((x[j] - x[i]) / dy) - - if len(slopes) == 0: - return 0.0, float(np.median(x)) - - slope = float(np.median(slopes)) - intercept = float(np.median(x - slope * y)) - return slope, intercept - - -def ransac_polyfit( - ys: np.ndarray, xs: np.ndarray, - degree: int, n_iter: int, thresh: float, -) -> np.ndarray | None: - """RANSAC で外れ値を除去して多項式フィッティング - - Args: - ys: y 座標配列 - xs: x 座標配列 - degree: 多項式の次数 - n_iter: 反復回数 - thresh: 外れ値判定閾値(ピクセル) - - Returns: - 多項式係数(フィッティング失敗時は None) - """ - n = len(ys) - sample_size = degree + 1 - if n < sample_size: - return None - - best_coeffs: np.ndarray | None = None - best_inliers = 0 - rng = np.random.default_rng() - - for _ in range(n_iter): - idx = rng.choice(n, sample_size, replace=False) - coeffs = np.polyfit(ys[idx], xs[idx], degree) - poly = np.poly1d(coeffs) - residuals = np.abs(xs - poly(ys)) - n_inliers = int(np.sum(residuals < thresh)) - if n_inliers > best_inliers: - best_inliers = n_inliers - best_coeffs = coeffs - - # インライアで再フィッティング - if best_coeffs is not None: - poly = np.poly1d(best_coeffs) - inlier_mask = np.abs(xs - poly(ys)) < thresh - if np.sum(inlier_mask) >= sample_size: - best_coeffs = np.polyfit( - ys[inlier_mask], - xs[inlier_mask], - degree, - ) - - return best_coeffs - - -def clean_and_fit( - cy: np.ndarray, - cx: np.ndarray, - median_ksize: int, - neighbor_thresh: float, - residual_thresh: float = 0.0, - weights: np.ndarray | None = None, - ransac_thresh: float = 0.0, - ransac_iter: int = 0, -) -> np.ndarray | None: - """外れ値除去+重み付きフィッティングを行う - - 全検出手法で共通に使えるロバストなフィッティング - (1) 移動メディアンフィルタでスパイクを平滑化 - (2) 近傍中央値からの偏差で外れ値を除去(複数パス) - (3) 重み付き最小二乗(または RANSAC)でフィッティング - (4) 残差ベースの反復除去で外れ値を最終除去 - - Args: - cy: 中心点の y 座標配列 - cx: 中心点の x 座標配列 - median_ksize: 移動メディアンのカーネルサイズ(0 で無効) - neighbor_thresh: 近傍外れ値除去の閾値 px(0 で無効) - residual_thresh: 残差除去の閾値 px(0 で無効) - weights: 各点の信頼度(None で均等) - ransac_thresh: RANSAC 閾値(0 以下で無効) - ransac_iter: RANSAC 反復回数 - - Returns: - 多項式係数(フィッティング失敗時は None) - """ - if len(cy) < MIN_FIT_ROWS: - return None - - cx_clean = cx.copy() - mask = np.ones(len(cy), dtype=bool) - - # (1) 移動メディアンフィルタ - if median_ksize >= 3: - k = median_ksize | 1 - half = k // 2 - for i in range(len(cx_clean)): - lo = max(0, i - half) - hi = min(len(cx_clean), i + half + 1) - cx_clean[i] = float(np.median(cx[lo:hi])) - - # (2) 近傍外れ値除去(複数パス) - if neighbor_thresh > 0: - half_n = NEIGHBOR_HALF_WINDOW - for _ in range(NEIGHBOR_FILTER_PASSES): - new_mask = np.ones(len(cx_clean), dtype=bool) - for i in range(len(cx_clean)): - if not mask[i]: - continue - lo = max(0, i - half_n) - hi = min(len(cx_clean), i + half_n + 1) - neighbors = cx_clean[lo:hi][mask[lo:hi]] - if len(neighbors) == 0: - new_mask[i] = False - continue - local_med = float(np.median(neighbors)) - if abs(cx_clean[i] - local_med) > neighbor_thresh: - new_mask[i] = False - if np.array_equal(mask, mask & new_mask): - break - mask = mask & new_mask - - cy = cy[mask] - cx_clean = cx_clean[mask] - if weights is not None: - weights = weights[mask] - - if len(cy) < MIN_FIT_ROWS: - return None - - # (3) フィッティング - if ransac_thresh > 0 and ransac_iter > 0: - coeffs = ransac_polyfit( - cy, cx_clean, 2, ransac_iter, ransac_thresh, - ) - elif weights is not None: - coeffs = np.polyfit(cy, cx_clean, 2, w=weights) - else: - coeffs = np.polyfit(cy, cx_clean, 2) - - if coeffs is None: - return None - - # (4) 残差ベースの反復除去 - if residual_thresh > 0: - for _ in range(RESIDUAL_REMOVAL_ITERATIONS): - poly = np.poly1d(coeffs) - residuals = np.abs(cx_clean - poly(cy)) - inlier = residuals < residual_thresh - if np.all(inlier): - break - if np.sum(inlier) < MIN_FIT_ROWS: - break - cy = cy[inlier] - cx_clean = cx_clean[inlier] - if weights is not None: - weights = weights[inlier] - if weights is not None: - coeffs = np.polyfit( - cy, cx_clean, 2, w=weights, - ) - else: - coeffs = np.polyfit(cy, cx_clean, 2) - - return coeffs diff --git a/src/pc/vision/line_detector.py b/src/pc/vision/line_detector.py deleted file mode 100644 index 7155d16..0000000 --- a/src/pc/vision/line_detector.py +++ /dev/null @@ -1,368 +0,0 @@ -""" -line_detector -カメラ画像から黒線の位置を検出するモジュール -複数の検出手法を切り替えて使用できる - -公開 API: - ImageParams, LineDetectResult, detect_line, - reset_valley_tracker, DETECT_METHODS -""" - -from dataclasses import dataclass - -import cv2 -import numpy as np - -from common import config -from pc.vision.fitting import clean_and_fit - -# 検出領域の y 範囲(画像全体) -DETECT_Y_START: int = 0 -DETECT_Y_END: int = config.FRAME_HEIGHT - -# フィッティングに必要な最小数 -MIN_FIT_PIXELS: int = 50 -MIN_FIT_ROWS: int = 10 - -# 検出手法の定義(キー: 識別子,値: 表示名) -DETECT_METHODS: dict[str, str] = { - "current": "現行(CLAHE + 固定閾値)", - "blackhat": "案A(Black-hat 中心)", - "dual_norm": "案B(二重正規化)", - "robust": "案C(最高ロバスト)", - "valley": "案D(谷検出+追跡)", -} - - -@dataclass -class ImageParams: - """二値化パラメータ - - Attributes: - method: 検出手法の識別子 - clahe_clip: CLAHE のコントラスト増幅上限 - clahe_grid: CLAHE の局所領域分割数 - blur_size: ガウシアンブラーのカーネルサイズ(奇数) - binary_thresh: 二値化の閾値 - open_size: オープニングのカーネルサイズ - close_width: クロージングの横幅 - close_height: クロージングの高さ - blackhat_ksize: Black-hat のカーネルサイズ - bg_blur_ksize: 背景除算のブラーカーネルサイズ - global_thresh: 固定閾値(0 で無効,適応的閾値との AND) - adaptive_block: 適応的閾値のブロックサイズ - adaptive_c: 適応的閾値の定数 C - iso_close_size: 等方クロージングのカーネルサイズ - dist_thresh: 距離変換の閾値 - min_line_width: 行ごと中心抽出の最小線幅 - stage_close_small: 段階クロージング第1段のサイズ - stage_min_area: 孤立除去の最小面積(0 で無効) - stage_close_large: 段階クロージング第2段のサイズ(0 で無効) - ransac_thresh: RANSAC の外れ値判定閾値 - ransac_iter: RANSAC の反復回数 - width_near: 画像下端での期待線幅(px,0 で無効) - width_far: 画像上端での期待線幅(px,0 で無効) - width_tolerance: 幅フィルタの上限倍率 - median_ksize: 中心点列の移動メディアンフィルタサイズ(0 で無効) - neighbor_thresh: 近傍外れ値除去の閾値(px,0 で無効) - residual_thresh: 残差反復除去の閾値(px,0 で無効) - valley_gauss_ksize: 谷検出の行ごとガウシアンカーネルサイズ - valley_min_depth: 谷として認識する最小深度 - valley_max_deviation: 追跡予測からの最大許容偏差(px) - valley_coast_frames: 検出失敗時の予測継続フレーム数 - valley_ema_alpha: 多項式係数の指数移動平均係数 - """ - - # 検出手法 - method: str = "current" - - # 現行手法パラメータ - clahe_clip: float = 2.0 - clahe_grid: int = 8 - blur_size: int = 5 - binary_thresh: int = 80 - open_size: int = 5 - close_width: int = 25 - close_height: int = 3 - - # 案A/C: Black-hat - blackhat_ksize: int = 45 - - # 案B: 背景除算 - bg_blur_ksize: int = 101 - global_thresh: int = 0 # 固定閾値(0 で無効) - - # 案B/C: 適応的閾値 - adaptive_block: int = 51 - adaptive_c: int = 10 - - # 案A/B/C: 後処理 - iso_close_size: int = 15 - dist_thresh: float = 3.0 - min_line_width: int = 3 - - # 案B: 段階クロージング - stage_close_small: int = 5 # 第1段: 小クロージングサイズ - stage_min_area: int = 0 # 孤立除去の最小面積(0 で無効) - stage_close_large: int = 0 # 第2段: 大クロージングサイズ(0 で無効) - - # 案C: RANSAC - ransac_thresh: float = 5.0 - ransac_iter: int = 50 - - # ロバストフィッティング(全手法共通) - median_ksize: int = 7 - neighbor_thresh: float = 10.0 - residual_thresh: float = 8.0 - - # 透視補正付き幅フィルタ(0 で無効) - width_near: int = 0 - width_far: int = 0 - width_tolerance: float = 1.8 - - # 案D: 谷検出+追跡 - valley_gauss_ksize: int = 15 - valley_min_depth: int = 15 - valley_max_deviation: int = 40 - valley_coast_frames: int = 3 - valley_ema_alpha: float = 0.7 - - -@dataclass -class LineDetectResult: - """線検出の結果を格納するデータクラス - - Attributes: - detected: 線が検出できたか - position_error: 画像下端での位置偏差(-1.0~+1.0) - heading: 線の傾き(dx/dy,画像下端での値) - curvature: 線の曲率(d²x/dy²) - poly_coeffs: 多項式の係数(描画用,未検出時は None) - row_centers: 各行の線中心 x 座標(index=行番号, - NaN=その行に線なし,未検出時は None) - binary_image: 二値化後の画像(デバッグ用) - """ - - detected: bool - position_error: float - heading: float - curvature: float - poly_coeffs: np.ndarray | None - row_centers: np.ndarray | None - binary_image: np.ndarray | None - - -# ── 公開 API ────────────────────────────────────── - - -def detect_line( - frame: np.ndarray, - params: ImageParams | None = None, -) -> LineDetectResult: - """画像から黒線の位置を検出する - - params.method に応じて検出手法を切り替える - - Args: - frame: グレースケールのカメラ画像 - params: 二値化パラメータ(None でデフォルト) - - Returns: - 線検出の結果 - """ - if params is None: - params = ImageParams() - - method = params.method - if method == "blackhat": - from pc.vision.detectors.blackhat import ( - detect_blackhat, - ) - return detect_blackhat(frame, params) - if method == "dual_norm": - from pc.vision.detectors.dual_norm import ( - detect_dual_norm, - ) - return detect_dual_norm(frame, params) - if method == "robust": - from pc.vision.detectors.robust import ( - detect_robust, - ) - return detect_robust(frame, params) - if method == "valley": - from pc.vision.detectors.valley import ( - detect_valley, - ) - return detect_valley(frame, params) - - from pc.vision.detectors.current import ( - detect_current, - ) - return detect_current(frame, params) - - -def reset_valley_tracker() -> None: - """谷検出の追跡状態をリセットする""" - from pc.vision.detectors.valley import ( - reset_valley_tracker as _reset, - ) - _reset() - - -# ── 共通結果構築(各検出器から使用) ────────────── - - -def no_detection( - binary: np.ndarray, -) -> LineDetectResult: - """未検出の結果を返す""" - return LineDetectResult( - detected=False, - position_error=0.0, - heading=0.0, - curvature=0.0, - poly_coeffs=None, - row_centers=None, - binary_image=binary, - ) - - -def _extract_row_centers( - binary: np.ndarray, -) -> np.ndarray | None: - """二値画像の最大連結領域から各行の線中心を求める - - Args: - binary: 二値画像 - - Returns: - 各行の中心 x 座標(NaN=その行に線なし), - 最大領域が見つからない場合は None - """ - h, w = binary.shape[:2] - num_labels, labels, stats, _ = ( - cv2.connectedComponentsWithStats(binary) - ) - - if num_labels <= 1: - return None - - # 背景(ラベル 0)を除いた最大領域を取得 - areas = stats[1:, cv2.CC_STAT_AREA] - largest_label = int(np.argmax(areas)) + 1 - - # 最大領域のマスク - mask = (labels == largest_label).astype(np.uint8) - - # 各行の左右端から中心を計算 - centers = np.full(h, np.nan) - for y in range(h): - row = mask[y] - cols = np.where(row > 0)[0] - if len(cols) > 0: - centers[y] = (cols[0] + cols[-1]) / 2.0 - - return centers - - -def build_result( - coeffs: np.ndarray, - binary: np.ndarray, - row_centers: np.ndarray | None = None, -) -> LineDetectResult: - """多項式係数から LineDetectResult を構築する - - row_centers が None の場合は binary から自動抽出する - """ - poly = np.poly1d(coeffs) - center_x = config.FRAME_WIDTH / 2.0 - - # 画像下端での位置偏差 - x_bottom = poly(DETECT_Y_END) - position_error = (center_x - x_bottom) / center_x - - # 傾き: dx/dy(画像下端での値) - poly_deriv = poly.deriv() - heading = float(poly_deriv(DETECT_Y_END)) - - # 曲率: d²x/dy² - poly_deriv2 = poly_deriv.deriv() - curvature = float(poly_deriv2(DETECT_Y_END)) - - # row_centers が未提供なら binary から抽出 - if row_centers is None: - row_centers = _extract_row_centers(binary) - - return LineDetectResult( - detected=True, - position_error=position_error, - heading=heading, - curvature=curvature, - poly_coeffs=coeffs, - row_centers=row_centers, - binary_image=binary, - ) - - -def fit_row_centers( - binary: np.ndarray, - min_width: int, - use_median: bool = False, - ransac_thresh: float = 0.0, - ransac_iter: int = 0, - median_ksize: int = 0, - neighbor_thresh: float = 0.0, - residual_thresh: float = 0.0, -) -> LineDetectResult: - """行ごとの中心点に多項式をフィッティングする - - 各行の白ピクセルの中心(平均または中央値)を1点抽出し, - ロバスト前処理の後にフィッティングする. - 幅の変動に強く,各行が等しく寄与する - - Args: - binary: 二値画像 - min_width: 線として認識する最小ピクセル数 - use_median: True の場合は中央値を使用 - ransac_thresh: RANSAC 閾値(0 以下で無効) - ransac_iter: RANSAC 反復回数 - median_ksize: 移動メディアンのカーネルサイズ - neighbor_thresh: 近傍外れ値除去の閾値 px - residual_thresh: 残差反復除去の閾値 px - - Returns: - 線検出の結果 - """ - region = binary[DETECT_Y_START:DETECT_Y_END, :] - centers_y: list[float] = [] - centers_x: list[float] = [] - - for y_local in range(region.shape[0]): - xs = np.where(region[y_local] > 0)[0] - if len(xs) < min_width: - continue - y = float(y_local + DETECT_Y_START) - centers_y.append(y) - if use_median: - centers_x.append(float(np.median(xs))) - else: - centers_x.append(float(np.mean(xs))) - - if len(centers_y) < MIN_FIT_ROWS: - return no_detection(binary) - - cy = np.array(centers_y) - cx = np.array(centers_x) - - coeffs = clean_and_fit( - cy, cx, - median_ksize=median_ksize, - neighbor_thresh=neighbor_thresh, - residual_thresh=residual_thresh, - ransac_thresh=ransac_thresh, - ransac_iter=ransac_iter, - ) - if coeffs is None: - return no_detection(binary) - - return build_result(coeffs, binary) diff --git a/src/pc/vision/morphology.py b/src/pc/vision/morphology.py deleted file mode 100644 index 9b7edd3..0000000 --- a/src/pc/vision/morphology.py +++ /dev/null @@ -1,134 +0,0 @@ -""" -morphology -二値画像の形態学的処理ユーティリティモジュール -""" - -import cv2 -import numpy as np - - -def apply_iso_closing( - binary: np.ndarray, size: int, -) -> np.ndarray: - """等方クロージングで穴を埋める - - Args: - binary: 二値画像 - size: カーネルサイズ - - Returns: - クロージング後の二値画像 - """ - if size < 3: - return binary - k = size | 1 - kernel = cv2.getStructuringElement( - cv2.MORPH_ELLIPSE, (k, k), - ) - return cv2.morphologyEx( - binary, cv2.MORPH_CLOSE, kernel, - ) - - -def apply_staged_closing( - binary: np.ndarray, - small_size: int, - min_area: int, - large_size: int, -) -> np.ndarray: - """段階クロージング: 小穴埋め → 孤立除去 → 大穴埋め - - Args: - binary: 二値画像 - small_size: 第1段クロージングのカーネルサイズ - min_area: 孤立領域除去の最小面積(0 で無効) - large_size: 第2段クロージングのカーネルサイズ(0 で無効) - - Returns: - 処理後の二値画像 - """ - # 第1段: 小さいクロージングで近接ピクセルをつなぐ - result = apply_iso_closing(binary, small_size) - - # 孤立領域の除去 - if min_area > 0: - contours, _ = cv2.findContours( - result, cv2.RETR_EXTERNAL, - cv2.CHAIN_APPROX_SIMPLE, - ) - mask = np.zeros_like(result) - for cnt in contours: - if cv2.contourArea(cnt) >= min_area: - cv2.drawContours( - mask, [cnt], -1, 255, -1, - ) - result = mask - - # 第2段: 大きいクロージングで中抜けを埋める - result = apply_iso_closing(result, large_size) - - return result - - -def apply_width_filter( - binary: np.ndarray, - width_near: int, - width_far: int, - tolerance: float, -) -> np.ndarray: - """透視補正付き幅フィルタで広がりすぎた行を除外する - - 各行の期待線幅を線形補間で算出し, - 実際の幅が上限(期待幅 × tolerance)を超える行をマスクする - - Args: - binary: 二値画像 - width_near: 画像下端での期待線幅(px) - width_far: 画像上端での期待線幅(px) - tolerance: 上限倍率 - - Returns: - 幅フィルタ適用後の二値画像 - """ - result = binary.copy() - h = binary.shape[0] - denom = max(h - 1, 1) - - for y_local in range(h): - xs = np.where(binary[y_local] > 0)[0] - if len(xs) == 0: - continue - # 画像下端(近距離)ほど t=1,上端(遠距離)ほど t=0 - t = (h - 1 - y_local) / denom - expected = float(width_far) + ( - float(width_near) - float(width_far) - ) * t - max_w = expected * tolerance - actual_w = int(xs[-1]) - int(xs[0]) + 1 - if actual_w > max_w: - result[y_local] = 0 - - return result - - -def apply_dist_mask( - binary: np.ndarray, thresh: float, -) -> np.ndarray: - """距離変換で中心部のみを残す - - Args: - binary: 二値画像 - thresh: 距離の閾値(ピクセル) - - Returns: - 中心部のみの二値画像 - """ - if thresh <= 0: - return binary - dist = cv2.distanceTransform( - binary, cv2.DIST_L2, 5, - ) - _, mask = cv2.threshold( - dist, thresh, 255, cv2.THRESH_BINARY, - ) - return mask.astype(np.uint8) diff --git a/src/pc/vision/overlay.py b/src/pc/vision/overlay.py index f5bb0a5..f8f2604 100644 --- a/src/pc/vision/overlay.py +++ b/src/pc/vision/overlay.py @@ -9,8 +9,8 @@ import cv2 import numpy as np -from pc.vision.fitting import theil_sen_fit -from pc.vision.line_detector import LineDetectResult +from common.vision.fitting import theil_sen_fit +from common.vision.line_detector import LineDetectResult # 描画色の定義 (BGR) COLOR_LINE: tuple = (0, 255, 0) @@ -19,6 +19,7 @@ COLOR_ROW_CENTER: tuple = (0, 165, 255) COLOR_THEIL_SEN: tuple = (255, 0, 255) COLOR_PURSUIT: tuple = (255, 255, 0) +COLOR_INTERSECTION: tuple = (0, 0, 255) # パシュート目標点の描画半径 PURSUIT_POINT_RADIUS: int = 2 @@ -57,6 +58,7 @@ tuple[tuple[float, float], tuple[float, float]] | None ) = None, + is_intersection: bool = False, ) -> np.ndarray: """カメラ映像にオーバーレイを描画する @@ -66,6 +68,7 @@ flags: 表示項目のフラグ pursuit_points: 2点パシュートの目標点 ((near_x, near_y), (far_x, far_y)) + is_intersection: 十字路と判定されているか Returns: オーバーレイ描画済みの画像 @@ -73,6 +76,13 @@ display = frame.copy() h, w = display.shape[:2] + # 十字路判定の表示 + if is_intersection: + cv2.rectangle( + display, (0, 0), (w - 1, h - 1), + COLOR_INTERSECTION, 1, + ) + if result is None: return display diff --git a/src/pi/comm/zmq_client.py b/src/pi/comm/zmq_client.py index f553511..f12ee99 100644 --- a/src/pi/comm/zmq_client.py +++ b/src/pi/comm/zmq_client.py @@ -1,10 +1,12 @@ """ zmq_client Pi 側の ZMQ 通信を担当するモジュール -画像の送信と操舵量の受信を行う +テレメトリ送信(画像+検出結果+操舵量)と +コマンド受信(モード切替・パラメータ更新・手動操作)を行う """ import json +import struct import time import cv2 @@ -17,77 +19,162 @@ class PiZmqClient: """Pi 側の ZMQ 通信クライアント - 画像送信(PUB)と操舵量受信(SUB)の2チャネルを管理する + テレメトリ送信(PUB)とコマンド受信(SUB)の2チャネルを管理する """ def __init__(self) -> None: self._context = zmq.Context() - self._image_socket: zmq.Socket | None = None - self._control_socket: zmq.Socket | None = None - self._last_receive_time: float = 0.0 + self._telemetry_socket: zmq.Socket | None = None + self._command_socket: zmq.Socket | None = None def start(self) -> None: """通信ソケットを初期化して接続する""" - # 画像送信ソケット(PUB,PC へ画像を送信) - self._image_socket = self._context.socket(zmq.PUB) - self._image_socket.setsockopt(zmq.CONFLATE, 1) - self._image_socket.connect(config.image_connect_address()) + # テレメトリ送信ソケット(PUB,PC へ画像+状態を送信) + self._telemetry_socket = self._context.socket( + zmq.PUB, + ) + self._telemetry_socket.setsockopt(zmq.CONFLATE, 1) + self._telemetry_socket.connect( + config.image_connect_address(), + ) - # 操舵量受信ソケット(SUB,PC からの操舵量を受信) - self._control_socket = self._context.socket(zmq.SUB) - self._control_socket.setsockopt(zmq.CONFLATE, 1) - self._control_socket.setsockopt_string(zmq.SUBSCRIBE, "") - self._control_socket.connect(config.control_connect_address()) + # コマンド受信ソケット(SUB,PC からのコマンドを受信) + self._command_socket = self._context.socket( + zmq.SUB, + ) + self._command_socket.setsockopt(zmq.CONFLATE, 1) + self._command_socket.setsockopt_string( + zmq.SUBSCRIBE, "", + ) + self._command_socket.connect( + config.control_connect_address(), + ) - self._last_receive_time = time.time() + def send_telemetry( + self, + frame: np.ndarray, + throttle: float, + steer: float, + detected: bool, + position_error: float, + heading: float, + is_intersection: bool, + is_recovering: bool, + intersection_available: bool, + compute_ms: float, + fps: float, + binary_image: np.ndarray | None = None, + ) -> None: + """テレメトリ(JSON ヘッダ + JPEG 画像)を送信する - def send_image(self, frame: np.ndarray) -> None: - """画像を JPEG 圧縮して送信する + メッセージ形式: + 4 バイト: JSON 長(uint32 LE) + N バイト: JSON テレメトリ + 残り: JPEG 画像(カメラ映像) + (binary_image がある場合はさらに続く) Args: frame: カメラから取得した画像の NumPy 配列 + throttle: 現在の throttle 出力 + steer: 現在の steer 出力 + detected: 線が検出できたか + position_error: 位置偏差 + heading: 線の傾き + is_intersection: 十字路と判定されたか + is_recovering: 復帰動作中か + intersection_available: 十字路分類器が利用可能か + compute_ms: 操舵計算時間(ミリ秒) + fps: Pi 側の処理 FPS + binary_image: 二値画像(None で省略) """ - if self._image_socket is None: + if self._telemetry_socket is None: return - _, encoded = cv2.imencode( + + telemetry: dict = { + "v": config.TELEMETRY_VERSION, + "ts": time.time(), + "throttle": throttle, + "steer": steer, + "detected": detected, + "pos_error": position_error, + "heading": heading, + "is_intersection": is_intersection, + "is_recovering": is_recovering, + "intersection_available": ( + intersection_available + ), + "compute_ms": compute_ms, + "fps": fps, + } + + # JSON ヘッダをエンコード + json_bytes = json.dumps(telemetry).encode("utf-8") + json_len = struct.pack(" tuple[float, float] | None: - """操舵量を非ブロッキングで受信する + # 二値画像を JPEG 圧縮(ある場合) + bin_bytes = b"" + if binary_image is not None: + _, bin_encoded = cv2.imencode( + ".jpg", + binary_image, + [ + cv2.IMWRITE_JPEG_QUALITY, + config.JPEG_QUALITY_BINARY, + ], + ) + bin_bytes = bin_encoded.tobytes() + + # メッセージ: JSON長 + JSON + CAM長 + CAM + BIN + msg = ( + json_len + json_bytes + + cam_len + cam_bytes + + bin_bytes + ) + self._telemetry_socket.send(msg, zmq.NOBLOCK) + + def receive_command(self) -> dict | None: + """PC からのコマンドを非ブロッキングで受信する + + コマンド形式(JSON): + mode: "auto" | "manual" | "stop" + throttle: float(手動モード時のみ) + steer: float(手動モード時のみ) + steering_method: "pd" | "pursuit" | "ts_pd" + image_params: dict(二値化パラメータ更新,省略可) + pd_params: dict(PD 制御パラメータ,省略可) + pursuit_params: dict(Pursuit パラメータ,省略可) + steering_params: dict(TS-PD パラメータ,省略可) + recovery_params: dict(復帰パラメータ,省略可) + intersection_enabled: bool(省略可) + intersection_throttle: float(省略可) Returns: - (throttle, steer) のタプル,受信データがない場合は None + コマンド辞書,受信データがない場合は None """ - if self._control_socket is None: + if self._command_socket is None: return None try: - data = self._control_socket.recv(zmq.NOBLOCK) - payload = json.loads(data.decode("utf-8")) - self._last_receive_time = time.time() - return (payload["throttle"], payload["steer"]) + data = self._command_socket.recv(zmq.NOBLOCK) + return json.loads(data.decode("utf-8")) except zmq.Again: return None - def is_timeout(self) -> bool: - """操舵量の受信がタイムアウトしたか判定する - - Returns: - タイムアウトしていれば True - """ - elapsed = time.time() - self._last_receive_time - return elapsed > config.CONTROL_TIMEOUT_SEC - def stop(self) -> None: """通信ソケットを閉じる""" - if self._image_socket is not None: - self._image_socket.close() - self._image_socket = None - if self._control_socket is not None: - self._control_socket.close() - self._control_socket = None + if self._telemetry_socket is not None: + self._telemetry_socket.close() + self._telemetry_socket = None + if self._command_socket is not None: + self._command_socket.close() + self._command_socket = None self._context.term() diff --git a/src/pi/main.py b/src/pi/main.py index 17b7630..3b8f27b 100644 --- a/src/pi/main.py +++ b/src/pi/main.py @@ -1,14 +1,26 @@ """ main Pi 側アプリケーションのエントリーポイント -カメラ画像の送信と操舵量の受信・モーター制御を行う +カメラ画像の取得・画像処理・操舵量計算・モーター制御を +すべて Pi 上で完結させる +PC にはテレメトリ(画像+状態)を送信し, +PC からはコマンド(モード切替・パラメータ更新)を受信する """ +import dataclasses import time +from typing import Any from pi.camera.capture import CameraCapture from pi.comm.zmq_client import PiZmqClient from pi.motor.driver import MotorDriver +from common import config +from common.steering.base import SteeringBase, SteeringOutput +from common.steering.pd_control import PdControl +from common.steering.pursuit_control import PursuitControl +from common.steering.recovery import RecoveryController +from common.steering.ts_pd_control import TsPdControl +from common.vision.intersection import IntersectionClassifier def main() -> None: @@ -17,28 +29,213 @@ zmq_client = PiZmqClient() motor = MotorDriver() + # 操舵制御(3手法) + pd_control = PdControl() + pursuit_control = PursuitControl() + ts_pd_control = TsPdControl() + controllers: dict[str, SteeringBase] = { + "pd": pd_control, + "pursuit": pursuit_control, + "ts_pd": ts_pd_control, + } + steering: SteeringBase = ts_pd_control + steering_method = "ts_pd" + + recovery = RecoveryController() + + # 十字路分類器(遅延読み込み) + intersection_clf = IntersectionClassifier() + intersection_enabled = False + intersection_throttle = 0.3 + + # モード管理 + mode = "stop" # "auto", "manual", "stop" + manual_throttle = 0.0 + manual_steer = 0.0 + + # FPS 計測用 + frame_count = 0 + fps_start = time.time() + current_fps = 0.0 + compute_ms_avg = 0.0 + compute_ms_sum = 0.0 + compute_count = 0 + compute_start = time.time() + try: camera.start() zmq_client.start() motor.start() - print("Pi: カメラ・通信・モーターを開始") + print("Pi: カメラ・通信・モーターを開始(自律モード)") while True: - # カメラ画像を取得して送信 - frame = camera.capture() - zmq_client.send_image(frame) + # ── コマンド受信 ──────────────────────── + cmd = zmq_client.receive_command() + if cmd is not None: + # 制御手法の切り替え + new_method = cmd.get("steering_method") + if ( + new_method is not None + and new_method != steering_method + ): + if new_method in controllers: + steering_method = new_method + steering = controllers[ + steering_method + ] + else: + steering_method = "ts_pd" + steering = controllers["ts_pd"] + print( + f"Pi: 制御手法変更 → " + f"{steering_method}" + ) - # 操舵量を受信してモーターに反映 - control = zmq_client.receive_control() - if control is not None: - throttle, steer = control + _apply_command( + cmd, controllers, recovery, + intersection_clf, + ) + # モード更新 + if "mode" in cmd: + new_mode = cmd["mode"] + if new_mode != mode: + if new_mode == "auto": + steering.reset() + recovery.reset() + mode = new_mode + print(f"Pi: モード変更 → {mode}") + # 手動操作値 + if mode == "manual": + manual_throttle = cmd.get( + "throttle", 0.0, + ) + manual_steer = cmd.get("steer", 0.0) + # 十字路設定 + if "intersection_enabled" in cmd: + intersection_enabled = cmd[ + "intersection_enabled" + ] + if "intersection_throttle" in cmd: + intersection_throttle = cmd[ + "intersection_throttle" + ] + + # ── カメラ画像取得 ────────────────────── + frame = camera.capture() + frame_count += 1 + + # ── 操舵量決定 ───────────────────────── + throttle = 0.0 + steer = 0.0 + detected = False + position_error = 0.0 + heading = 0.0 + is_intersection = False + is_recovering = False + binary_image = None + + if mode == "auto": + # 線検出 + 制御 + t_start = time.time() + output = steering.compute(frame) + det = steering.last_detect_result + + detected = ( + det is not None and det.detected + ) + position_error = ( + det.position_error + if detected and det is not None + else 0.0 + ) + heading = ( + det.heading + if detected and det is not None + else 0.0 + ) + binary_image = ( + det.binary_image + if det is not None else None + ) + + # 十字路判定 + if ( + intersection_enabled + and intersection_clf.available + and det is not None + and det.binary_image is not None + ): + is_intersection = ( + intersection_clf.predict( + det.binary_image, + ) + ) + if is_intersection: + output = SteeringOutput( + throttle=intersection_throttle, + steer=0.0, + ) + + # コースアウト復帰 + recovery_output = recovery.update( + detected, position_error, + ) + if recovery_output is not None: + output = recovery_output + is_recovering = recovery.is_recovering + + throttle = output.throttle + steer = output.steer + compute_ms_sum += ( + (time.time() - t_start) * 1000.0 + ) + compute_count += 1 + ce = time.time() - compute_start + if ce >= 1.0: + compute_ms_avg = ( + compute_ms_sum / compute_count + ) + compute_ms_sum = 0.0 + compute_count = 0 + compute_start = time.time() + + elif mode == "manual": + throttle = manual_throttle + steer = manual_steer + + # mode == "stop" なら throttle=0, steer=0 + + # ── モーター制御 ─────────────────────── + if mode == "stop": + motor.stop() + else: motor.set_drive(throttle, steer) - # タイムアウト時はモーター停止 - if zmq_client.is_timeout(): - motor.stop() + # ── FPS 計測 ─────────────────────────── + elapsed = time.time() - fps_start + if elapsed >= config.LOG_INTERVAL_SEC: + current_fps = frame_count / elapsed + print(f"Pi: FPS={current_fps:.1f}") + frame_count = 0 + fps_start = time.time() - time.sleep(0.01) + # ── テレメトリ送信 ───────────────────── + zmq_client.send_telemetry( + frame=frame, + throttle=throttle, + steer=steer, + detected=detected, + position_error=position_error, + heading=heading, + is_intersection=is_intersection, + is_recovering=is_recovering, + intersection_available=( + intersection_clf.available + ), + compute_ms=compute_ms_avg, + fps=current_fps, + binary_image=binary_image, + ) except KeyboardInterrupt: print("\nPi: 終了") @@ -48,5 +245,93 @@ zmq_client.stop() +def _safe_update_dataclass( + target: Any, + updates: dict[str, Any], +) -> None: + """dataclass のフィールドを型チェック付きで更新する + + Args: + target: 更新対象の dataclass インスタンス + updates: フィールド名と値の辞書 + """ + field_names = { + f.name for f in dataclasses.fields(target) + } + for key, value in updates.items(): + if key not in field_names: + continue + current = getattr(target, key) + expected = type(current) + # int フィールドに float が来た場合は変換を許容 + if expected is int and isinstance(value, float): + value = int(value) + elif expected is float and isinstance(value, int): + value = float(value) + elif not isinstance(value, expected): + print( + f"Pi: パラメータ型エラー " + f"{key}: 期待={expected.__name__}, " + f"実際={type(value).__name__}" + ) + continue + setattr(target, key, value) + + +def _apply_command( + cmd: dict, + controllers: dict[str, SteeringBase], + recovery: RecoveryController, + intersection_clf: IntersectionClassifier, +) -> None: + """コマンドからパラメータ更新を適用する + + Args: + cmd: 受信したコマンド辞書 + controllers: 制御手法名とインスタンスの辞書 + recovery: 復帰制御クラス + intersection_clf: 十字路分類器 + """ + # 画像処理パラメータの更新(全制御クラスに反映) + if "image_params" in cmd: + ip = cmd["image_params"] + for ctrl in controllers.values(): + _safe_update_dataclass( + ctrl.image_params, ip, + ) + + # 制御手法固有パラメータの更新 + _PARAM_KEYS: dict[str, str] = { + "pd": "pd_params", + "pursuit": "pursuit_params", + "ts_pd": "steering_params", + } + for name, cmd_key in _PARAM_KEYS.items(): + if cmd_key in cmd and name in controllers: + _safe_update_dataclass( + controllers[name].params, + cmd[cmd_key], + ) + + # 復帰パラメータの更新 + if "recovery_params" in cmd: + _safe_update_dataclass( + recovery.params, + cmd["recovery_params"], + ) + + # 十字路分類器の遅延読み込み + if ( + cmd.get("intersection_enabled", False) + and not intersection_clf.available + ): + print("Pi: 十字路分類器を読み込み中...") + intersection_clf.load() + if intersection_clf.available: + print("Pi: 十字路分類器の読み込み完了") + else: + print("Pi: 十字路分類器が見つかりません") + + if __name__ == "__main__": main() diff --git a/tests/test_fitting.py b/tests/test_fitting.py index 5791fbf..5230d5c 100644 --- a/tests/test_fitting.py +++ b/tests/test_fitting.py @@ -3,7 +3,7 @@ import numpy as np import pytest -from pc.vision.fitting import ( +from common.vision.fitting import ( MIN_FIT_ROWS, clean_and_fit, ransac_polyfit, diff --git a/tests/test_line_detector.py b/tests/test_line_detector.py index ce186c7..c05c2fb 100644 --- a/tests/test_line_detector.py +++ b/tests/test_line_detector.py @@ -4,7 +4,7 @@ import pytest from common import config -from pc.vision.line_detector import ( +from common.vision.line_detector import ( ImageParams, LineDetectResult, build_result, diff --git a/tests/test_morphology.py b/tests/test_morphology.py index 1cff783..01bb276 100644 --- a/tests/test_morphology.py +++ b/tests/test_morphology.py @@ -4,7 +4,7 @@ import pytest from common import config -from pc.vision.morphology import ( +from common.vision.morphology import ( apply_dist_mask, apply_iso_closing, apply_staged_closing, diff --git a/tests/test_params.py b/tests/test_params.py index d7bf6fa..8c3b1b2 100644 --- a/tests/test_params.py +++ b/tests/test_params.py @@ -5,8 +5,8 @@ import pytest -from pc.steering.pd_control import PdParams -from pc.vision.line_detector import ImageParams +from common.steering.pd_control import PdParams +from common.vision.line_detector import ImageParams class TestAutoParams: @@ -36,9 +36,10 @@ max_steer_rate=0.05, max_throttle=0.6, speed_k=0.4, ) - save_control(params, "blackhat") - loaded, method = load_control() + save_control(params, "blackhat", "pd") + loaded, method, steering = load_control() assert method == "blackhat" + assert steering == "pd" assert loaded.kp == pytest.approx(1.0) assert loaded.kh == pytest.approx(0.5) assert loaded.max_throttle == pytest.approx(0.6) @@ -46,8 +47,9 @@ def test_load_control_missing_file(self) -> None: """ファイルがない場合はデフォルト値を返す""" from pc.steering.auto_params import load_control - params, method = load_control() + params, method, steering = load_control() assert method == "current" + assert steering == "pd" assert params.kp == PdParams().kp def test_save_load_detect_params( diff --git a/tests/test_steering.py b/tests/test_steering.py new file mode 100644 index 0000000..396d834 --- /dev/null +++ b/tests/test_steering.py @@ -0,0 +1,331 @@ +"""操舵量計算モジュールのテスト""" + +import numpy as np +import pytest + +from common import config +from common.steering.base import SteeringOutput +from common.steering.pd_control import PdControl, PdParams +from common.steering.pursuit_control import ( + PursuitControl, + PursuitParams, +) +from common.steering.ts_pd_control import ( + TsPdControl, + TsPdParams, +) +from common.vision.line_detector import ImageParams + + +@pytest.fixture() +def _center_line_params() -> ImageParams: + """テスト画像用に調整した検出パラメータ""" + return ImageParams( + method="current", + clahe_grid=2, blur_size=3, + open_size=1, close_width=3, + close_height=1, + ) + + +class TestSteeringBase: + """SteeringBase の共通ロジックのテスト""" + + def test_compute_returns_steering_output( + self, + straight_line_image: np.ndarray, + _center_line_params: ImageParams, + ) -> None: + """compute は SteeringOutput を返す""" + ctrl = PdControl( + image_params=_center_line_params, + ) + output = ctrl.compute(straight_line_image) + assert isinstance(output, SteeringOutput) + + def test_steer_within_range( + self, + straight_line_image: np.ndarray, + _center_line_params: ImageParams, + ) -> None: + """steer は -1.0 ~ +1.0 の範囲内""" + ctrl = PdControl( + image_params=_center_line_params, + ) + for _ in range(10): + output = ctrl.compute(straight_line_image) + assert -1.0 <= output.steer <= 1.0 + + def test_throttle_non_negative( + self, + straight_line_image: np.ndarray, + _center_line_params: ImageParams, + ) -> None: + """throttle は 0 以上""" + ctrl = PdControl( + image_params=_center_line_params, + ) + output = ctrl.compute(straight_line_image) + assert output.throttle >= 0.0 + + def test_rate_limiter_clamps_steer_change( + self, + _center_line_params: ImageParams, + ) -> None: + """レートリミッターが操舵変化量を制限する""" + max_rate = 0.05 + params = PdParams( + kp=5.0, kh=5.0, kd=0.0, + max_steer_rate=max_rate, + ) + ctrl = PdControl( + params=params, + image_params=_center_line_params, + ) + + # 大きく左にオフセットした画像を作成 + h, w = config.FRAME_HEIGHT, config.FRAME_WIDTH + img = np.full((h, w), 200, dtype=np.uint8) + img[:, 2:5] = 30 # 左端に線 + + output = ctrl.compute(img) + # 初回は prev_steer=0 なので変化量が制限される + assert abs(output.steer) <= max_rate + 1e-9 + + def test_no_line_returns_zero( + self, + blank_image: np.ndarray, + ) -> None: + """線が検出できない場合は throttle=0, steer=0""" + ctrl = PdControl() + output = ctrl.compute(blank_image) + assert output.throttle == 0.0 + assert output.steer == 0.0 + + def test_last_detect_result_updated( + self, + straight_line_image: np.ndarray, + _center_line_params: ImageParams, + ) -> None: + """compute 後に last_detect_result が更新される""" + ctrl = PdControl( + image_params=_center_line_params, + ) + assert ctrl.last_detect_result is None + ctrl.compute(straight_line_image) + assert ctrl.last_detect_result is not None + + def test_reset_clears_state( + self, + straight_line_image: np.ndarray, + _center_line_params: ImageParams, + ) -> None: + """reset 後に内部状態がクリアされる""" + ctrl = PdControl( + image_params=_center_line_params, + ) + ctrl.compute(straight_line_image) + ctrl.reset() + assert ctrl.last_detect_result is None + assert ctrl._prev_steer == 0.0 + + +class TestPdControl: + """PD 制御のテスト""" + + def test_center_line_small_steer( + self, + straight_line_image: np.ndarray, + _center_line_params: ImageParams, + ) -> None: + """中央の線に対して操舵量が小さい""" + ctrl = PdControl( + image_params=_center_line_params, + ) + output = ctrl.compute(straight_line_image) + assert abs(output.steer) < 0.3 + + def test_left_line_steers_positive( + self, + _center_line_params: ImageParams, + ) -> None: + """左にある線に対して正方向に操舵する""" + h, w = config.FRAME_HEIGHT, config.FRAME_WIDTH + img = np.full((h, w), 200, dtype=np.uint8) + img[:, 5:8] = 30 # 左寄りの線 + + ctrl = PdControl( + params=PdParams( + kp=1.0, kh=0.0, kd=0.0, + max_steer_rate=1.0, + ), + image_params=_center_line_params, + ) + output = ctrl.compute(img) + assert output.steer > 0.0 + + def test_speed_decreases_with_curvature( + self, + _center_line_params: ImageParams, + ) -> None: + """曲率が大きいほど速度が下がる""" + params = PdParams(max_throttle=0.5, speed_k=0.3) + ctrl = PdControl( + params=params, + image_params=_center_line_params, + ) + # 曲がった線を作成 + h, w = config.FRAME_HEIGHT, config.FRAME_WIDTH + img = np.full((h, w), 200, dtype=np.uint8) + for y in range(h): + x = int(w / 2 + 5 * (y / h - 0.5) ** 2 * w) + x = max(0, min(w - 1, x)) + img[y, max(0, x - 1):min(w, x + 2)] = 30 + + output = ctrl.compute(img) + assert output.throttle <= params.max_throttle + + +class TestPursuitControl: + """2点パシュート制御のテスト""" + + def test_center_line_small_steer( + self, + straight_line_image: np.ndarray, + _center_line_params: ImageParams, + ) -> None: + """中央の線に対して操舵量が小さい""" + ctrl = PursuitControl( + image_params=_center_line_params, + ) + output = ctrl.compute(straight_line_image) + assert abs(output.steer) < 0.3 + + def test_no_line_returns_zero( + self, + blank_image: np.ndarray, + ) -> None: + """線が検出できない場合は停止""" + ctrl = PursuitControl() + output = ctrl.compute(blank_image) + assert output.throttle == 0.0 + assert output.steer == 0.0 + + +class TestTsPdControl: + """Theil-Sen PD 制御のテスト""" + + def test_center_line_small_steer( + self, + straight_line_image: np.ndarray, + _center_line_params: ImageParams, + ) -> None: + """中央の線に対して操舵量が小さい""" + ctrl = TsPdControl( + image_params=_center_line_params, + ) + output = ctrl.compute(straight_line_image) + assert abs(output.steer) < 0.3 + + def test_no_line_returns_zero( + self, + blank_image: np.ndarray, + ) -> None: + """線が検出できない場合は停止""" + ctrl = TsPdControl() + output = ctrl.compute(blank_image) + assert output.throttle == 0.0 + assert output.steer == 0.0 + + def test_reset_clears_derivative_state( + self, + straight_line_image: np.ndarray, + _center_line_params: ImageParams, + ) -> None: + """reset で微分項の状態がクリアされる""" + ctrl = TsPdControl( + image_params=_center_line_params, + ) + ctrl.compute(straight_line_image) + ctrl.reset() + assert ctrl._prev_error == 0.0 + assert ctrl._prev_time == 0.0 + + +class TestSafeUpdateDataclass: + """_safe_update_dataclass のテスト + + pi.main は picamera2 に依存するため直接 import できない. + 同等のロジックを common モジュールの dataclass で検証する + """ + + @staticmethod + def _safe_update_dataclass( + target: object, + updates: dict, + ) -> None: + """pi.main._safe_update_dataclass と同等のロジック""" + import dataclasses + field_names = { + f.name for f in dataclasses.fields(target) + } + for key, value in updates.items(): + if key not in field_names: + continue + current = getattr(target, key) + expected = type(current) + if expected is int and isinstance( + value, float, + ): + value = int(value) + elif expected is float and isinstance( + value, int, + ): + value = float(value) + elif not isinstance(value, expected): + continue + setattr(target, key, value) + + def test_updates_valid_fields(self) -> None: + """正しい型のフィールドを更新できる""" + params = PdParams() + self._safe_update_dataclass( + params, {"kp": 2.0}, + ) + assert params.kp == 2.0 + + def test_ignores_unknown_fields(self) -> None: + """存在しないフィールドは無視する""" + params = PdParams() + original_kp = params.kp + self._safe_update_dataclass( + params, {"unknown_field": 999}, + ) + assert params.kp == original_kp + + def test_rejects_wrong_type(self) -> None: + """型が一致しない場合は更新しない""" + params = PdParams() + original_kp = params.kp + self._safe_update_dataclass( + params, {"kp": "not_a_number"}, + ) + assert params.kp == original_kp + + def test_int_to_float_conversion(self) -> None: + """int を float フィールドに渡すと変換される""" + params = PdParams() + self._safe_update_dataclass( + params, {"kp": 2}, + ) + assert params.kp == 2.0 + assert isinstance(params.kp, float) + + def test_float_to_int_conversion(self) -> None: + """float を int フィールドに渡すと変換される""" + params = ImageParams() + self._safe_update_dataclass( + params, {"binary_thresh": 100.0}, + ) + assert params.binary_thresh == 100 + assert isinstance(params.binary_thresh, int)