import decord
from pathlib import Path
import torch
import cv2
import numpy as np
from inflateDEMO import I3DResNet50
import torchvision.transforms as T
import sys
from tqdm import tqdm
import ffmpeg
import argparse
from nested_lookup import nested_lookup
import gc
import matplotlib.pyplot as plt
from scipy.signal import savgol_filter

try:
    print (sys.path.index("D:\\Users\\user\\Downloads\\Research"))
except:
    sys.path.append("D:\\Users\\user\\Downloads\\Research")
    
from Models import RARP_NVB_ResNet50

def trim_output(inputVideo:str, outputVideo:str, timeStart:str, timeEnd:str):
    (
        ffmpeg
        .input(inputVideo, ss=timeStart, to=timeEnd)
        .output(outputVideo, c="copy")
        .run(overwrite_output=True)
    )

def extract_frames_ffmpeg(video_path, start_frame, end_frame, width=None, height=None, fps=30):
    # Construir el comando ffmpeg
    stream = (
        ffmpeg
        .input(video_path, ss=start_frame / fps, t=end_frame / fps)  # Suponiendo 30 fps, ajusta según el FPS del video
        .output('pipe:', format='rawvideo', pix_fmt='rgb24')
        .run(capture_stdout=True)
    )
    
    # Convertir el stream en un array de NumPy
    video = np.frombuffer(stream[0], np.uint8)
    # Asumiendo que conocemos el ancho y el alto del video
    if width and height:
        video = video.reshape((-1, height, width, 3))  # Num_frames, altura, anchura, canales
        
    del stream 
    gc.collect()
    return video

def find_by_codex(data, target_codex):
    # returns the whole dict (or None if not found) h264
    return next((d for d in data if d['codec_name'] == target_codex), None)

def ffmpegVideoInfo (VideoPath:Path):
    viodeoInfo = ffmpeg.probe(str(VideoPath.absolute()))
    viodeoInfo = find_by_codex(viodeoInfo["streams"], "h264")
    if not viodeoInfo:
        raise Exception("No H264 Codex found") 
    fps =  eval (nested_lookup("avg_frame_rate", viodeoInfo)[0])
    w = int(nested_lookup("width", viodeoInfo)[0])
    h = int(nested_lookup("height", viodeoInfo)[0])
    total_frames = int(nested_lookup("nb_frames", viodeoInfo)[0])
    
    return (fps, (w, h), total_frames, None)

def decordVideoInfo (VideoPath:Path):
    decord.bridge.set_bridge('native')
    vr = decord.VideoReader(str(VideoPath.absolute()))
    fps = vr.get_avg_fps()
    total_frames = len(vr)
    
    return (fps, None, total_frames, vr)

def plotSimGraph(Name:str = "", total_frames=0, chunk_size=15, ListSim=None):
    chunks = range(0, total_frames, chunk_size)
    sim = torch.tensor(ListSim)

    smoothed_accuracy = savgol_filter(sim, window_length=5, polyorder=2)

    maxSim = sim.max()
    indexmax = sim.argmax()
    bestSim = chunks[indexmax]

    plt.figure(figsize=(15, 6))
    plt.plot(chunks, sim, marker='o', linestyle='-', color='b', label='Cosine similarity')
    plt.plot(chunks, smoothed_accuracy, color='g', linestyle='--', label='Smoothed Data')
    plt.title(f"frames vs. Sim. [{Name}]")
    plt.xlabel('frames')
    plt.ylabel('Cos Sim.')
    plt.grid(True)

    plt.scatter(bestSim, maxSim, zorder=5, marker="x", color='r', label=f'Higth Sim: {maxSim:.4f} at {bestSim} start frame')
    plt.legend()

    plt.show()


def seconds_to_hms(seconds):
    hours = seconds // 3600
    minutes = (seconds % 3600) // 60
    secs = seconds % 60
    return f'{int(hours)}:{int(minutes):02}:{int(secs):02}'


if __name__ == "__main__":
    parser = argparse.ArgumentParser()
    
    parser.add_argument("--BaseLib", default="ffmpeg", type=str)
    parser.add_argument("-i", "--Input", type=str)
    parser.add_argument("-v", "--OriginalVideo", type=str)
    parser.add_argument("-o", "--Output", type=str, default="output.mp4")
    parser.add_argument("-t", "--Target", type=str)
    parser.add_argument("-c", "--Chunk", type=int, default=15)
    parser.add_argument("-b", "--BaseModel", type=str)
    
    args = parser.parse_args()
    
    torch.set_float32_matmul_precision('medium')
    torch.backends.cudnn.deterministic = True

    device = torch.device("cuda:0" if torch.cuda.is_available() else "cpu")
    
    RN50Model = RARP_NVB_ResNet50.load_from_checkpoint(args.BaseModel)
    RN50ModelToEval = RARP_NVB_ResNet50.load_from_checkpoint(args.BaseModel)
    
    mean, std = ([30.38144216, 42.03988769, 97.8896116], [40.63141752, 44.26910074, 50.29294373])
    transforms = torch.nn.Sequential(
        T.Resize((256,256), antialias=True, interpolation=T.InterpolationMode.BICUBIC),
        T.CenterCrop(224),
        T.Normalize(mean, std)
    ).to(device)
    
    frameToFind = cv2.imread(str(Path(args.Target)), cv2.IMREAD_COLOR)
    frameToFind = torch.tensor(frameToFind, device=device, dtype=torch.float32)
    frameToFind = frameToFind.permute(2, 0, 1)

    frameToFind = transforms(frameToFind)

    frameToFind = frameToFind.repeat(1, 1, 1, 1)

    InfalteModel = I3DResNet50(RN50Model.model).to(device)
    InfalteModel.fc = torch.nn.Identity()
    InfalteModel.eval()

    RN50ModelToEval.model.fc = torch.nn.Identity()
    RN50ModelToEval.to(device)
    RN50ModelToEval.eval()
    
    chunk_size = args.Chunk
    
    
    with torch.no_grad():        
        Doutput = RN50ModelToEval(frameToFind)
        Doutput = Doutput.squeeze()
        
        maxSim = 0
        
        ListSim = []
        maxListSim = []
        del RN50ModelToEval
        del frameToFind
    
        initFrame = None
        videoPathLong = Path(args.Input)
        
        fps, size, total_frames, vr = ffmpegVideoInfo(videoPathLong) if args.BaseLib == "ffmpeg" else decordVideoInfo(videoPathLong)
        
        print (f"FPS:{fps}")
        segs = total_frames/fps
        print (f"Video Length: {segs} seg.")
        print (f"Video Length: {total_frames} frames.")
        print (f"Chuks of {chunk_size} seg: {int (segs//chunk_size)} chunks")
        
        chunk_size = int (segs//chunk_size)
        
        for start_idx in tqdm(range(0, total_frames, chunk_size)):
            end_idx = min(start_idx + chunk_size, total_frames)
    
            chunk_frames = extract_frames_ffmpeg(str(videoPathLong.absolute()), start_idx, end_idx, width=size[0], height=size[1], fps=fps) \
                if args.BaseLib == "ffmpeg" else vr.get_batch(range(start_idx, end_idx)).asnumpy()
                
            chunk_frames = chunk_frames[..., ::-1].copy()
        
            frames = torch.from_numpy(chunk_frames).to(device)
            frames = frames.permute(0, 3, 1, 2)
            frames = frames.float()

            frames = transforms(frames)

            frames = frames.repeat(1, 1, 1, 1, 1)
            frames = frames.permute(0, 2, 1, 3, 4)
            
            outPut = InfalteModel(frames)
            outPut = outPut.squeeze()
            
            cos_sim = torch.nn.functional.cosine_similarity(outPut, Doutput, dim=0)
        
            #maxSim = cos_sim if cos_sim > maxSim else maxSim
            if cos_sim > maxSim:
                print(cos_sim)
                maxSim = cos_sim
                initFrame = (start_idx, end_idx)
                maxListSim.append(cos_sim)
                
            del frames
            del chunk_frames
            gc.collect()

            ListSim.append(cos_sim)
            
        
            
    print(seconds_to_hms(initFrame[0]/fps), seconds_to_hms(initFrame[1]/fps))
    plotSimGraph(videoPathLong.name, total_frames, chunk_size, ListSim)
    trim_output(str(Path(args.OriginalVideo).absolute()), args.Output, seconds_to_hms(initFrame[0]/fps), seconds_to_hms(initFrame[1]/fps))