Newer
Older
RARP / Video3D / extractFrame_DL.py
@delAguila delAguila on 20 May 3 KB Video Extraf frame
import cv2
import torch
import torchvision.transforms as T
import numpy as np
from pathlib import Path
import time
from tqdm import tqdm
import argparse
import sys

try:
    print (sys.path.index("D:\\Users\\user\\Downloads\\Research"))
except:
    sys.path.append("D:\\Users\\user\\Downloads\\Research")

from Models import RARP_NVB_ResNet50

def loadTensor(img, device, transform):
    imgTensor = torch.from_numpy(img).to(device).float()
    imgTensor = imgTensor.permute(2, 0, 1)
    imgTensor = transform(imgTensor)
    return imgTensor.repeat(1, 1, 1, 1)

def seconds_to_hms(seconds):
    hours = seconds // 3600
    minutes = (seconds % 3600) // 60
    secs = seconds % 60
    return f'{int(hours)}:{int(minutes):02}:{int(secs):02}'

if __name__ == "__main__":
    
    parser = argparse.ArgumentParser()
    
    parser.add_argument("-i", "--Input", type=str)
    parser.add_argument("-o", "--Output", type=str, default="output.mp4")
    parser.add_argument("-t", "--Target", type=str)
    
    args = parser.parse_args()
    
    torch.set_float32_matmul_precision('medium')
    torch.backends.cudnn.deterministic = True

    device = torch.device("cuda:0" if torch.cuda.is_available() else "cpu")
    
    transforms = torch.nn.Sequential(
        #T.Resize((256,256), antialias=True, interpolation=T.InterpolationMode.BICUBIC),
        #T.CenterCrop(224),
        T.Normalize([30.38144216, 42.03988769, 97.8896116], [40.63141752, 44.26910074, 50.29294373])
    ).to(device)
    
    RN50Model = RARP_NVB_ResNet50.load_from_checkpoint(Path("D:/Users/user/Downloads/Research/log_ResNet50_X10/lightning_logs/version_8/checkpoints/RARP-epoch=5.ckpt")).to(device)
    RN50Model.model.fc = torch.nn.Identity()
    RN50Model.eval()

    start_time = time.time()
    with torch.no_grad():
        img = cv2.imread(str(Path(args.Target)), cv2.IMREAD_COLOR)
        img = cv2.resize(img, (224, 224))
        img = loadTensor(img, device, transforms)
        img = RN50Model(img).squeeze()

        cap = cv2.VideoCapture(str(Path(args.Input)))
        total_frames = int(cap.get(cv2.CAP_PROP_FRAME_COUNT))
        video_fps = cap.get(cv2.CAP_PROP_FPS)
        print(f"Video FPS: {video_fps:.2f}, Total frames: {total_frames}, Video length: {seconds_to_hms(total_frames / video_fps)}")

        pbar = tqdm(
            total=total_frames,
            bar_format="{desc}",      # only render the description
            desc="aFPS: 0.00"          # initial text
        )

        prev = time.perf_counter()
        best_match = 0
        frameCouter = 0
        best_match_frame = 0
        while cap.isOpened():
            ret, frame = cap.read()
            if not ret:
                break
      
            frameCouter += 1

            try:
                
                if frameCouter % int (fps*0.10) != 0:
                    continue
                
                frame = cv2.resize(frame, (224, 224))
                frame = loadTensor(frame, device, transforms) 
                frame = RN50Model(frame).squeeze()               
                error = torch.nn.functional.cosine_similarity(img, frame, dim=0)
                
                if error > best_match:
                    best_match = error
                    best_match_frame = frameCouter
            except Exception as e:
                continue
            finally:
                now = time.perf_counter()
                fps = 1.0 / (now - prev)
                prev = now
                
                pbar.set_description(f"aFPS: {fps:.2f}, timestamp {seconds_to_hms(frameCouter / video_fps)}; best Match. {best_match_frame} / {best_match:.2f}; Aprox timestamp {seconds_to_hms(best_match_frame / video_fps)}")
                pbar.update(1)
            
        pbar.close()
        cap.release()
    
    end_time = time.time()
    print(f"{seconds_to_hms(end_time - start_time)} seconds")