# preprocess_videos_to_pt.py

from pathlib import Path
import argparse
import torch
import torchvision.transforms.functional as F
from torchcodec.decoders import SimpleVideoDecoder
import torchvision
import numpy as np

def video_to_npy(video_path: Path, npy_path: Path, size=(360, 640)):
    print(f"[INFO] Decoding {video_path} ...")
    dec = SimpleVideoDecoder(str(video_path))
    clip = dec[:]  # [T, C, H, W], uint8
    
    print(f"[INFO] Original clip shape: {clip.shape}, dtype={clip.dtype}")
    # batched resize
    if size is not None:
        clip = F.resize(
            clip,
            size,
            interpolation=torchvision.transforms.InterpolationMode.BICUBIC,
            antialias=True,
        )  # still uint8
    
    print(f"[INFO] Resized clip shape: {clip.shape}, dtype={clip.dtype}")

    # move to numpy for memmap
    arr = clip.numpy()  # (T, C, H, W), uint8
    np.save(npy_path, arr)
    print("[OK] Saved:", npy_path, arr.shape, arr.dtype)


def video_to_pt(
    video_path: Path,
    pt_path: Path,
    size=(360, 640),
):
    video_path = Path(video_path)
    pt_path = Path(pt_path)
    pt_path.parent.mkdir(parents=True, exist_ok=True)

    print(f"[INFO] Decoding {video_path} ...")
    dec = SimpleVideoDecoder(str(video_path))
    clip = dec[:]   # [T, C, H, W], uint8
    # no dec.close() in torchcodec 0.1.x

    print(f"[INFO] Original clip shape: {clip.shape}, dtype={clip.dtype}")

    # Batched resize (no Python loop), stays uint8
    # F.resize supports [B, C, H, W]
    if size is not None:
        clip = F.resize(
            clip,
            size,  # (H, W)
            antialias=True,
            interpolation=torchvision.transforms.InterpolationMode.BICUBIC,
        )

    print(f"[INFO] Resized clip shape: {clip.shape}, dtype={clip.dtype}")

    # Save as uint8 tensor, compressed
    # In PyTorch ≥ 2 this uses zipfile serialization by default
    torch.save(clip, pt_path, _use_new_zipfile_serialization=True)
    print(f"[OK] Saved {pt_path}")


def main():
    parser = argparse.ArgumentParser()
    parser.add_argument("--video_dir", type=str, required=True,
                        help="Directory containing the .mp4 videos")
    parser.add_argument("--ext", type=str, default=".mp4",
                        help="Video extension (default: .mp4)")
    parser.add_argument("--out_dir", type=str, default=None,
                        help="Where to store .pt files; default = same folder as video")
    parser.add_argument("--height", type=int, default=None)
    parser.add_argument("--width", type=int, default=None)
    parser.add_argument("--pt_npy", type=str, default="pt")
    args = parser.parse_args()

    video_dir = Path(args.video_dir)
    out_dir = Path(args.out_dir) if args.out_dir is not None else None
    size = (args.height, args.width) if args.width is not None and args.height is not None else None
    
    print (f"resize: {size}")
    
    out_ext = args.pt_npy

    videos = sorted(video_dir.rglob(f"*{args.ext}"))

    print(f"[INFO] Found {len(videos)} videos in {video_dir} with ext {args.ext}")

    for v in videos:
        if out_dir is None:
            pt_path = v.with_suffix(f".{out_ext}")
        else:
            rel = v.relative_to(video_dir)
            pt_path = (out_dir / rel).with_suffix(f".{out_ext}")
        if out_ext == "pt":    
            video_to_pt(v, pt_path, size=size)
        else:
            video_to_npy(v, pt_path, size=size)


if __name__ == "__main__":
    main()
