import cv2
import numpy as np
from pathlib import Path
import time
from tqdm import tqdm
import argparse
import pandas as pd

P_HASH = cv2.img_hash.PHash_create()
POP_COUNT = np.array([bin(i).count("1") for i in range(256)], dtype=np.uint8)

def p_hash_bytes(Gray_img):
    h = P_HASH.compute(Gray_img)
    return h.reshape(-1).copy() 

def grid_phash(gray_img, grid=(3,3)):
    gh, gw = grid
    H, W = gray_img.shape
    th, tw = H // gh, W // gw
    parts = []
    
    for r in range(gh):
        for c in range (gh):
            tile = gray_img[r*th:(r+1)*th, c*tw:(c+1)*tw]
            parts.append(p_hash_bytes(tile))
            
    return np.concatenate(parts, axis=0)

def hamming(a, b):
    return int(POP_COUNT[np.bitwise_xor(a, b)].sum())

def mse(img1, img2):
   h, w = img1.shape
   diff = cv2.subtract(img1, img2)
   err = np.sum(diff**2)
   mse = err/(float(h*w))
   return mse

def centerCrop(img:np, size:tuple = None, pct_size:float = 0.3):
    if size is None:
        size = int(img.shape[0] * (1 - pct_size)), int(img.shape[1] * (1 - pct_size))
    x, y = (img.shape[1] - size[1]) // 2, (img.shape[0] - size[0]) // 2
            
    return img[y:y+size[0], x:x+size[1]]

def removeBlackBorder(image):        
   copyImg = cv2.cvtColor(image.copy(), cv2.COLOR_BGR2HSV)
   h = copyImg[:,:,0]
   mask = np.ones(h.shape, dtype=np.uint8) * 255
   th = (25, 175)
   mask[(h > th[0]) & (h < th[1])] = 0
   copyImg = cv2.cvtColor(copyImg, cv2.COLOR_HSV2BGR)
   resROI = cv2.bitwise_and(copyImg, copyImg, mask=mask)
      
   image_gray = cv2.cvtColor(resROI, cv2.COLOR_BGR2GRAY)
   _, thresh = cv2.threshold(image_gray, 0, 255, cv2.THRESH_BINARY)
   kernel = cv2.getStructuringElement(cv2.MORPH_RECT, (15, 15))
   morph = cv2.morphologyEx(thresh, cv2.MORPH_OPEN, kernel)
   contours = cv2.findContours(morph, cv2.RETR_EXTERNAL, cv2.CHAIN_APPROX_SIMPLE)
   contours = contours[0] if len(contours) == 2 else contours[1]
   bigCont = max(contours, key=cv2.contourArea)
   x, y, w, h = cv2.boundingRect(bigCont)
   crop = image[y : y + h, x : x + w]
   return crop, (x, y, w, h)

def normalization(img, mean = 0.458971, std = 0.225609):
    Norm_img = cv2.cvtColor(img.copy(), cv2.COLOR_BGR2GRAY).astype(np.float32) / 255.0
    Norm_img = (Norm_img - mean) / std
    Norm_img = cv2.GaussianBlur(Norm_img, (5, 5), 1)
    
    return Norm_img

def seconds_to_hms(seconds):
    hours = seconds // 3600
    minutes = (seconds % 3600) // 60
    secs = seconds % 60
    return f'{int(hours)}:{int(minutes):02}:{int(secs):02}'



def fft_cross_correlation(img1, img2):
    h, w = img1.shape
    
    H, W = cv2.getOptimalDFTSize(2*h), cv2.getOptimalDFTSize(2*w)
    pad1 = np.zeros((H, W), dtype=np.float32)
    pad2 = np.zeros((H, W), dtype=np.float32)
    pad1[:h, :w] = img1
    pad2[:h, :w] = img2
    
    dtf1 = cv2.dft(pad1, flags=cv2.DFT_COMPLEX_OUTPUT)
    dtf2 = cv2.dft(pad2, flags=cv2.DFT_COMPLEX_OUTPUT)
    
    cross_power = cv2.mulSpectrums(dtf1, dtf2, 0, conjB=True)
    
    corr = cv2.idft(cross_power, flags=cv2.DFT_SCALE | cv2.DFT_REAL_OUTPUT)
    
    corr = np.fft.fftshift(corr)
    
    max_val = corr.max()
    n_img1 = np.linalg.norm(img1)
    n_img2 = np.linalg.norm(img2)
    sim = max_val / (n_img1 * n_img2)
    
    return corr, sim

def find_frame_in_video(Target, Input, Patience:int = 60, Remove_Blackbar:bool = False, Forze_resize:bool = False, Sim_Check:bool = True):
    Target = Path(Target)
    Input = Path(Input)
    
    start_time = time.time()
    img = cv2.imread(str(Target), cv2.IMREAD_COLOR)
    if Remove_Blackbar:
        img, _ = removeBlackBorder(img)
    #img = normalization(img)  
    img = cv2.cvtColor(img, cv2.COLOR_BGR2GRAY)
    H, W = img.shape
    
    query_img_hash = grid_phash(img, (8, 8))

    cap = cv2.VideoCapture(str(Input))
    total_frames = int(cap.get(cv2.CAP_PROP_FRAME_COUNT))
    video_fps = cap.get(cv2.CAP_PROP_FPS)
    print(f"{Input.name} Video FPS: {video_fps:.2f}, Total frames: {total_frames}, Video length: {seconds_to_hms(total_frames / video_fps)}")

    pbar = tqdm(
        total=total_frames,
        bar_format="{desc}",      # only render the description
        desc="aFPS: 0.00"          # initial text
    )
    
    best_match_segs = []
    
    prev = time.perf_counter()
    best_match = 10000
    frameCouter = 0
    best_match_frame = 0
    ROI_rect = None
    resizeFrame = False
    Max_Sim = 0.90
    early_Stop_count = 0
    sim = 0
    sim_str = ""
    prev_sim = 0
    while cap.isOpened():
        if early_Stop_count == Patience:
            break
        
        cap.set(cv2.CAP_PROP_POS_FRAMES, frameCouter)
        ret, frame = cap.read()
        if not ret:
            break
        
        try:
            if (frameCouter == 0):
                if Remove_Blackbar:
                    frame, ROI_rect = removeBlackBorder(frame)
                img_size = img.shape[0] * img.shape[1]
                frame_size = frame.shape[0] * frame.shape[1]
                
                if (img_size > frame_size):
                    img = cv2.resize(img, (frame.shape[1], frame.shape[0]) if not Forze_resize else (512, 512), interpolation=cv2.INTER_CUBIC)
                else:
                    resizeFrame = True
                    
            if Remove_Blackbar and frameCouter > 0:
                x, y, w, h = ROI_rect
                frame = frame[y : y + h, x : x + w]
                
            if resizeFrame or Forze_resize:
                frame = cv2.resize(frame, (W, H) if not Forze_resize else (512, 512))
                
            #frame = normalization(frame)  
            frame = cv2.cvtColor(frame, cv2.COLOR_BGR2GRAY)
            frame_hash = grid_phash(frame, (8, 8))
            sim = hamming(query_img_hash, frame_hash)           
            #sim = cv2.matchTemplate(img, frame, cv2.TM_CCORR_NORMED)[0][0]
            #sim = cv2.norm(img, frame, cv2.NORM_L2)
            
            if sim < best_match:
                best_match = sim
                best_match_frame = frameCouter
                best_match_segs.append(best_match_frame / video_fps)
                
                    
            early_Stop_count += 1 if (Patience is not None) and (sim >= Max_Sim) else 0
        except Exception as e:
            continue
        finally:
            frameCouter += round(video_fps)
            now = time.perf_counter()
            fps = 1.0 / (now - prev)
            prev = now
            
            if frameCouter > total_frames:
                frameCouter = total_frames
            
            pbar.set_description(f"Progress: {(frameCouter * 100 / total_frames):.2f}% {frameCouter} / {total_frames}, Time: {seconds_to_hms(frameCouter / video_fps)} FPS: {fps:.2f} best match {best_match:.6f}/{sim:.6f} Aprox timestamp {seconds_to_hms(best_match_frame / video_fps)}")
            pbar.update(1)
        
    pbar.close()
    cap.release()
    
    end_time = time.time()
    print(f"{seconds_to_hms(end_time - start_time)} seconds")
        
    return best_match_segs[-1], 0, best_match

if __name__ == "__main__":
    
    parser = argparse.ArgumentParser()
    
    parser.add_argument("-f", "--File", type=str, required=True)
    parser.add_argument("-o", "--Output", type=str, default="output.csv")
    parser.add_argument("--Sim_Check", type=bool, default=False)
    parser.add_argument("-p", "--Patience", type=int, default=None, help="Seconds analyzed after the threshold is reached")
    parser.add_argument("-br","--Remove_Blackbar", type=bool, default=False)
    parser.add_argument("-r","--Forze_resize", type=bool, default=False)
    
    args = parser.parse_args()
    
    outputFile = Path(args.Output)
    
    file_csv = pd.read_csv(Path(args.File))
    
    if outputFile.exists():
        output_file = pd.read_csv(outputFile)
    else:
        output_file = pd.DataFrame({
            "case": [0],
            "video_path": [""],
            "image_path": [""],
            "best_frame_seg": [0],
            "error": [0],
            "sim": [0]
        })
    
    for i, row in file_csv.iterrows():
        if pd.isnull(row["image_path"]):
            print(f"case {row['case']} skipped")
            continue
        
        print(f"case {row['case']}")
        for videos in sorted(list(Path(row["video_path"]).glob("**/*.[mM][pP]4"))):
            try:
                res = find_frame_in_video(row["image_path"], videos, args.Patience, args.Remove_Blackbar, args.Forze_resize, args.Sim_Check)
                
                new_row = pd.DataFrame({
                    "case": [row["case"]],
                    "video_path": [videos],
                    "image_path": [row["image_path"]],
                    "best_frame_seg": [res[0]],
                    "error": [res[1]],
                    "sim": [res[2]]
                })
                
                output_file = pd.concat([output_file, new_row], ignore_index = True)
            except Exception as e:
                print (e)
                continue
                
        output_file.to_csv(Path(args.Output), index=False)
            
            
        

    