import argparse
import os
import re

import cv2
import numpy as np

# mmdet / mmpose
from mmdet.apis import inference_detector, init_detector
from mmpose.apis import inference_topdown
from mmpose.apis import init_model as init_pose_estimator
from mmpose.evaluation.functional import nms
from mmpose.structures import merge_data_samples
from mmpose.utils import adapt_mmdet_pipeline
from PIL import Image, ImageDraw


def extract_keypoints_rtmpose(pose_results):
    """
    RTMposeの推論結果 (pose_results) から、もっとも平均可視スコアの高い
    インスタンスを選んで keypoints([17,2]) を返す。検出できなければ None。
    """
    if not pose_results:
        return None

    max_avg_visible = 0
    best_instance = None
    for result in pose_results:
        pred_instances = result.pred_instances
        for instance in pred_instances:
            avg_visible = np.mean(instance.keypoints_visible)
            if avg_visible > max_avg_visible:
                max_avg_visible = avg_visible
                best_instance = instance

    if best_instance is None:
        return None

    # best_instance.keypoints の shape が (1, num_kpts, 2) の場合がある
    return best_instance.keypoints[0]


def pillow_draw_circle(draw, center, radius, fill=None, outline=None, width=1):
    """(center=(x,y)) を中心とする円を描画。"""
    x, y = int(center[0]), int(center[1])
    left_up = (x - radius, y - radius)
    right_down = (x + radius, y + radius)
    draw.ellipse([left_up, right_down], fill=fill, outline=outline, width=width)


def draw_glow_marker(draw, center, main_color, radius=15):
    """
    main_color で塗りつぶし、白枠を付けて光っているように描画
    """
    x, y = int(center[0]), int(center[1])
    # 白枠 (外周3ピクセル分)
    outer_radius = radius + 3
    pillow_draw_circle(
        draw, (x, y), outer_radius, fill=None, outline=(255, 255, 255), width=2
    )
    # 内部を塗りつぶす
    pillow_draw_circle(draw, (x, y), radius, fill=main_color, outline=None, width=0)


def main():
    parser = argparse.ArgumentParser(description="RTMpose: detect shoulders & hips.")
    parser.add_argument(
        "--img_dir",
        type=str,
        required=True,
        help="入力画像フォルダのパス（PNG/JPG 画像が格納されている）",
    )
    parser.add_argument(
        "--output_dir",
        type=str,
        default="pose_out",
        help="肩・腰描画後の画像を保存するフォルダ",
    )
    parser.add_argument(
        "--device",
        type=str,
        default="cuda",
        help="推論に使用するデバイス (cuda / cpu など)",
    )

    # ここは「以下のソースコードの設定のままでいいです」とあった部分をそのまま利用
    det_config = "modules/rtmpose/mmdetection_cfg/rtmdet_m_640-8xb32_coco-person.py"
    det_checkpoint = (
        "models/rtmpose/rtmdet_m_8xb32-100e_coco-obj365-person-235e8209.pth"
    )
    pose_config = (
        "modules/rtmpose/configs/body_2d_keypoint/rtmpose/body8/"
        "rtmpose-l_8xb256-420e_body8-256x192.py"
    )
    pose_checkpoint = "models/rtmpose/rtmpose-l_simcc-aic-coco_pt-aic-coco_420e-256x192-f016ffe0_20230126.pth"

    args = parser.parse_args()

    os.makedirs(args.output_dir, exist_ok=True)

    # (A) 人体検出器 (RTDet) の初期化
    print("[INFO] Initializing RTDet with config:")
    print("  det_config:", det_config)
    print("  det_checkpoint:", det_checkpoint)
    detector = init_detector(det_config, det_checkpoint, device=args.device)
    detector.cfg = adapt_mmdet_pipeline(detector.cfg)

    # (B) RTMpose (姿勢推定) の初期化
    print("[INFO] Initializing RTMpose with config:")
    print("  pose_config:", pose_config)
    print("  pose_checkpoint:", pose_checkpoint)
    pose_estimator = init_pose_estimator(
        pose_config, pose_checkpoint, device=args.device
    )

    # 画像ファイル一覧
    images = sorted(
        [
            f
            for f in os.listdir(args.img_dir)
            if f.lower().endswith((".png", ".jpg", ".jpeg"))
        ],
        key=lambda x: int(re.search(r"(\d+)", x).group())
        if re.search(r"(\d+)", x)
        else x,
    )
    if not images:
        print(f"[WARNING] 指定フォルダに画像が見つかりません: {args.img_dir}")
        return

    # 元コードで「肩・腰」に使われていた色 (pose_color_rgb = (33, 95, 154))
    pose_color_rgb = (33, 95, 154)

    for img_name in images:
        img_path = os.path.join(args.img_dir, img_name)
        frame_bgr = cv2.imread(img_path)
        if frame_bgr is None:
            print(f"[WARNING] 画像の読み込みに失敗: {img_path}")
            continue

        # (1) 物体検出
        frame_rgb = cv2.cvtColor(frame_bgr, cv2.COLOR_BGR2RGB)
        det_result = inference_detector(detector, frame_rgb)
        pred_instance = det_result.pred_instances.cpu().numpy()

        # label=0 (person), スコア>0.3 の bboxをフィルタ
        bboxes = np.concatenate(
            (pred_instance.bboxes, pred_instance.scores[:, None]), axis=1
        )
        person_bboxes = bboxes[
            np.logical_and(pred_instance.labels == 0, pred_instance.scores > 0.3)
        ]

        # NMS (バージョンによって引数形式が違う場合あり)
        keep_idx = nms(person_bboxes, 0.3)
        person_bboxes = person_bboxes[keep_idx, :4]

        if len(person_bboxes) == 0:
            # 人物が検出されなければ何も描画せず保存
            out_path = os.path.join(args.output_dir, img_name)
            cv2.imwrite(out_path, frame_bgr)
            continue

        # (2) RTMpose で姿勢推定
        pose_results = inference_topdown(pose_estimator, frame_rgb, person_bboxes)
        data_samples = merge_data_samples(pose_results)

        # (3) 肩・腰 キーポイントだけを取り出す
        keypoints = extract_keypoints_rtmpose(pose_results)
        if keypoints is None:
            out_path = os.path.join(args.output_dir, img_name)
            cv2.imwrite(out_path, frame_bgr)
            continue

        # index: left_shoulder=5, right_shoulder=6, left_hip=11, right_hip=12 (COCO)
        left_shoulder = keypoints[5]
        right_shoulder = keypoints[6]
        left_hip = keypoints[11]
        right_hip = keypoints[12]

        # (4) 肩と腰を元画像に描画 (Pillow)
        pil_img = Image.fromarray(cv2.cvtColor(frame_bgr, cv2.COLOR_BGR2RGB))
        draw = ImageDraw.Draw(pil_img)

        # 任意の可視性チェックなどは省略。ここでは座標がある4点を描画。
        for pt in [left_shoulder, right_shoulder, left_hip, right_hip]:
            draw_glow_marker(draw, (pt[0], pt[1]), pose_color_rgb, radius=15)

        # Pillow → BGR
        out_img_rgb = np.array(pil_img)
        out_img_bgr = cv2.cvtColor(out_img_rgb, cv2.COLOR_RGB2BGR)

        # (5) 保存
        out_path = os.path.join(args.output_dir, img_name)
        cv2.imwrite(out_path, out_img_bgr)

    print("All done. Results saved to:", args.output_dir)


if __name__ == "__main__":
    main()
