Newer
Older
RARP / Video3D / AutoSnippet.py
@delAguila delAguila on 22 Nov 2024 4 KB init comit
import decord
from pathlib import Path
import torch
import cv2
import numpy as np
from inflateDEMO import I3DResNet50
import torchvision.transforms as T
import sys
from tqdm import tqdm
import argparse
import ffmpeg

try:
    print (sys.path.index("d:\\Users\\user\\Documents\\postata\\RARP\\Clasification"))
except:
    sys.path.append("d:\\Users\\user\\Documents\\postata\\RARP\\Clasification")
    
print(sys.path)
from Models import RARP_NVB_ResNet50

def seconds_to_hms(seconds):
    hours = seconds // 3600
    minutes = (seconds % 3600) // 60
    secs = seconds % 60
    return f'{int(hours)}:{int(minutes):02}:{int(secs):02}'

if __name__ == "__main__":
    parser = argparse.ArgumentParser()
    
    parser.add_argument("--BaseLib", default="ffmpeg", type=str)
    parser.add_argument("-i", "--Input", type=str)
    parser.add_argument("-o", "--Output", type=str)
    parser.add_argument("-t", "--Target", type=str)
    parser.add_argument("-c", "--Chunk", type=int, default=15)
    parser.add_argument("-b", "--BaseModel", type=str, default="../log_ResNet50_X10/lightning_logs/version_8/checkpoints/RARP-epoch=5.ckpt")
    
    args = parser.parse_args() 
 

    
RN50Model = RARP_NVB_ResNet50.load_from_checkpoint(args.BaseModel)
RN50ModelToEval = RARP_NVB_ResNet50.load_from_checkpoint(args.BaseModel)

videoPathLong = Path(args.Input)

decord.bridge.set_bridge('native')
vr = decord.VideoReader(str(videoPathLong.absolute()))

fps = vr.get_avg_fps()
print (f"FPS:{vr.get_avg_fps()}")
segs = len(vr)/fps
print (f"Video Length: {segs} seg.")
print (f"Video Length: {round(segs*fps)} frames.")
print (f"Chuks of 15 seg: {segs//15} chunks")
print (f"Chuks of 30 seg: {segs//30} chunks")
print (f"Chuks of 35 seg: {segs//35} chunks")
print (f"Chuks of 40 seg: {segs//40} chunks")

mean, std = ([30.38144216, 42.03988769, 97.8896116], [40.63141752, 44.26910074, 50.29294373])
transforms = T.Compose([
    T.Resize((256,256), antialias=True, interpolation=T.InterpolationMode.BICUBIC),
    T.CenterCrop(224),
    T.Normalize(mean, std)
])

frameToFind = cv2.imread(str(Path(args.Target)), cv2.IMREAD_COLOR)
#frameToFind = _removeBlackBorder(frameToFind)
frameToFind = torch.Tensor(frameToFind)
frameToFind = frameToFind.permute(2, 0, 1).float()

frameToFind = transforms(frameToFind)

frameToFind = frameToFind.repeat(1, 1, 1, 1)

torch.set_float32_matmul_precision('medium')
torch.backends.cudnn.deterministic = True

device = torch.device("cuda:0" if torch.cuda.is_available() else "cpu")

InfalteModel = I3DResNet50(RN50Model.model).to(device)
InfalteModel.fc = torch.nn.Identity()
InfalteModel.eval()

RN50ModelToEval.model.fc = torch.nn.Identity()
RN50ModelToEval.to(device)
RN50ModelToEval.eval()

framesChunk = np.array([15, 30, 35, 40]) * round(fps)

chunk_size = framesChunk[0]
total_frames = len(vr)

with torch.no_grad():
    frameToFind = frameToFind.to(device)
    
    Doutput = RN50ModelToEval(frameToFind)
    Doutput = Doutput.squeeze()
    
    maxSim = 0
    
    ListSim = []
    
    initFrame = None
    
    for start_idx in tqdm(range(0, total_frames, chunk_size)):
        end_idx = min(start_idx + chunk_size, total_frames)
        chunk_frames = vr.get_batch(range(start_idx, end_idx)).asnumpy()
        
        chunk_frames_bgr = chunk_frames[..., ::-1].copy()
        
        frames = torch.from_numpy(chunk_frames_bgr).to(device)
        frames = frames.permute(0, 3, 1, 2)
        frames = frames.float()

        frames = transforms(frames)

        frames = frames.repeat(1, 1, 1, 1, 1)
        frames = frames.permute(0, 2, 1, 3, 4)
        
        outPut = InfalteModel(frames)
        outPut = outPut.squeeze()
        
        #print(outPut.shape, Doutput.shape)
           
        cos_sim = torch.nn.functional.cosine_similarity(outPut, Doutput, dim=0)
        
        #maxSim = cos_sim if cos_sim > maxSim else maxSim
        if cos_sim > maxSim:
            print(cos_sim)
            maxSim = cos_sim
            initFrame = (start_idx, end_idx)
            

        ListSim.append(cos_sim)
        
print(seconds_to_hms(initFrame[0]/fps), seconds_to_hms(initFrame[1]/fps))

fileVideo = ffmpeg.input(str(videoPathLong.absolute()), hwaccel='cuda')
outPutVideo = ffmpeg.output(fileVideo.trim(start_frame=initFrame[0], end_frame=initFrame[1]), args.Output)#, vcodec='h264_nvenc')
ffmpeg.run(outPutVideo)