import cv2
import torch
import torchvision.transforms as T
import numpy as np
from pathlib import Path
import time
from tqdm import tqdm
import argparse
import sys
try:
print (sys.path.index("D:\\Users\\user\\Downloads\\Research"))
except:
sys.path.append("D:\\Users\\user\\Downloads\\Research")
from Models import RARP_NVB_ResNet50
def loadTensor(img, device, transform):
imgTensor = torch.from_numpy(img).to(device).float()
imgTensor = imgTensor.permute(2, 0, 1)
imgTensor = transform(imgTensor)
return imgTensor.repeat(1, 1, 1, 1)
def seconds_to_hms(seconds):
hours = seconds // 3600
minutes = (seconds % 3600) // 60
secs = seconds % 60
return f'{int(hours)}:{int(minutes):02}:{int(secs):02}'
if __name__ == "__main__":
parser = argparse.ArgumentParser()
parser.add_argument("-i", "--Input", type=str)
parser.add_argument("-o", "--Output", type=str, default="output.mp4")
parser.add_argument("-t", "--Target", type=str)
args = parser.parse_args()
torch.set_float32_matmul_precision('medium')
torch.backends.cudnn.deterministic = True
device = torch.device("cuda:0" if torch.cuda.is_available() else "cpu")
transforms = torch.nn.Sequential(
#T.Resize((256,256), antialias=True, interpolation=T.InterpolationMode.BICUBIC),
#T.CenterCrop(224),
T.Normalize([30.38144216, 42.03988769, 97.8896116], [40.63141752, 44.26910074, 50.29294373])
).to(device)
RN50Model = RARP_NVB_ResNet50.load_from_checkpoint(Path("D:/Users/user/Downloads/Research/log_ResNet50_X10/lightning_logs/version_8/checkpoints/RARP-epoch=5.ckpt")).to(device)
RN50Model.model.fc = torch.nn.Identity()
RN50Model.eval()
start_time = time.time()
with torch.no_grad():
img = cv2.imread(str(Path(args.Target)), cv2.IMREAD_COLOR)
img = cv2.resize(img, (224, 224))
img = loadTensor(img, device, transforms)
img = RN50Model(img).squeeze()
cap = cv2.VideoCapture(str(Path(args.Input)))
total_frames = int(cap.get(cv2.CAP_PROP_FRAME_COUNT))
video_fps = cap.get(cv2.CAP_PROP_FPS)
print(f"Video FPS: {video_fps:.2f}, Total frames: {total_frames}, Video length: {seconds_to_hms(total_frames / video_fps)}")
pbar = tqdm(
total=total_frames,
bar_format="{desc}", # only render the description
desc="aFPS: 0.00" # initial text
)
prev = time.perf_counter()
best_match = 0
frameCouter = 0
best_match_frame = 0
while cap.isOpened():
ret, frame = cap.read()
if not ret:
break
frameCouter += 1
try:
if frameCouter % int (fps*0.10) != 0:
continue
frame = cv2.resize(frame, (224, 224))
frame = loadTensor(frame, device, transforms)
frame = RN50Model(frame).squeeze()
error = torch.nn.functional.cosine_similarity(img, frame, dim=0)
if error > best_match:
best_match = error
best_match_frame = frameCouter
except Exception as e:
continue
finally:
now = time.perf_counter()
fps = 1.0 / (now - prev)
prev = now
pbar.set_description(f"aFPS: {fps:.2f}, timestamp {seconds_to_hms(frameCouter / video_fps)}; best Match. {best_match_frame} / {best_match:.2f}; Aprox timestamp {seconds_to_hms(best_match_frame / video_fps)}")
pbar.update(1)
pbar.close()
cap.release()
end_time = time.time()
print(f"{seconds_to_hms(end_time - start_time)} seconds")