import argparse
import time
import cv2
import matplotlib.pyplot as plt
import numpy as np
import torch
from torchvision import transforms
from models.experimental import attempt_load
from utils.datasets import letterbox
from utils.general import non_max_suppression_kpt, strip_optimizer
from utils.plots import colors, output_to_keypoint, plot_one_box_kpt
from utils.torch_utils import select_device
def run(
poseweights="yolov7-w6-pose.pt",
source="football1.mp4",
device="cpu",
view_img=True,
save_conf=False,
line_thickness=3,
hide_labels=False,
hide_conf=True,
):
frame_count = 0 # count no of frames
total_fps = 0 # count total fps
time_list = [] # list to store time
fps_list = [] # list to store fps
device = select_device(opt.device) # select device
half = device.type != "cpu"
model = attempt_load(poseweights, map_location=device) # Load model
_ = model.eval()
names = (
model.module.names if hasattr(model, "module") else model.names
) # get class names
if source.isnumeric():
cap = cv2.VideoCapture(int(source)) # pass video to videocapture object
else:
cap = cv2.VideoCapture(source) # pass video to videocapture object
if cap.isOpened() == False: # check if videocapture not opened
print("Error while trying to read video. Please check path again")
raise SystemExit()
else:
frame_width = int(cap.get(3)) # get video frame width
frame_height = int(cap.get(4)) # get video frame height
vid_write_image = letterbox(cap.read()[1], (frame_width), stride=64, auto=True)[
0
] # init videowriter
resize_height, resize_width = vid_write_image.shape[:2]
out_video_name = f"{source.split('/')[-1].split('.')[0]}"
out = cv2.VideoWriter(
f"{source}_keypoint.mp4",
cv2.VideoWriter_fourcc(*"mp4v"),
30,
(resize_width, resize_height),
)
while cap.isOpened: # loop until cap opened or video not complete
print("Frame {} Processing".format(frame_count + 1))
ret, frame = cap.read() # get frame and success from video capture
if ret: # if success is true, means frame exist
orig_image = frame # store frame
image = cv2.cvtColor(
orig_image, cv2.COLOR_BGR2RGB
) # convert frame to RGB
image = letterbox(image, (frame_width), stride=64, auto=True)[0]
image_ = image.copy()
image = transforms.ToTensor()(image)
image = torch.tensor(np.array([image.numpy()]))
image = image.to(device) # convert image data to device
image = image.float() # convert image to float precision (cpu)
start_time = time.time() # start time for fps calculation
with torch.no_grad(): # get predictions
output_data, _ = model(image)
output_data = non_max_suppression_kpt(
output_data, # Apply non max suppression
0.25, # Conf. Threshold.
0.65, # IoU Threshold.
nc=model.yaml["nc"], # Number of classes.
nkpt=model.yaml["nkpt"], # Number of keypoints.
kpt_label=True,
)
output = output_to_keypoint(output_data)
im0 = (
image[0].permute(1, 2, 0) * 255
) # Change format [b, c, h, w] to [h, w, c] for displaying the image.
im0 = im0.cpu().numpy().astype(np.uint8)
im0 = cv2.cvtColor(
im0, cv2.COLOR_RGB2BGR
) # reshape image format to (BGR)
gn = torch.tensor(im0.shape)[[1, 0, 1, 0]] # normalization gain whwh
for i, pose in enumerate(output_data): # detections per image
if len(output_data): # check if no pose
for c in pose[:, 5].unique(): # Print results
n = (pose[:, 5] == c).sum() # detections per class
print("No of Objects in Current Frame : {}".format(n))
for det_index, (*xyxy, conf, cls) in enumerate(
reversed(pose[:, :6])
): # loop over poses for drawing on frame
c = int(cls) # integer class
kpts = pose[det_index, 6:]
label = (
None
if opt.hide_labels
else (
names[c]
if opt.hide_conf
else f"{names[c]} {conf:.2f}"
)
)
plot_one_box_kpt(
xyxy,
im0,
label=label,
color=colors(c, True),
line_thickness=opt.line_thickness,
kpt_label=True,
kpts=kpts,
steps=3,
orig_shape=im0.shape[:2],
)
end_time = time.time() # Calculatio for FPS
fps = 1 / (end_time - start_time)
total_fps += fps
frame_count += 1
fps_list.append(total_fps) # append FPS in list
time_list.append(end_time - start_time) # append time in list
# Stream results
print(view_img)
if view_img:
cv2.imshow("YOLOv7 Pose Estimation Demo", im0)
cv2.waitKey(1) # 1 millisecond
out.write(im0) # writing the video frame
else:
break
cap.release()
# cv2.destroyAllWindows()
avg_fps = total_fps / frame_count
print(f"Average FPS: {avg_fps:.3f}")
# plot the comparision graph
plot_fps_time_comparision(time_list=time_list, fps_list=fps_list)
def parse_opt():
parser = argparse.ArgumentParser()
parser.add_argument(
"--poseweights",
nargs="+",
type=str,
default="yolov7-w6-pose.pt",
help="model path(s)",
)
parser.add_argument(
"--source", type=str, default="football1.mp4", help="video/0 for webcam"
) # video source
parser.add_argument(
"--device", type=str, default="cpu", help="cpu/0,1,2,3(gpu)"
) # device arugments
parser.add_argument(
"--view-img", action="store_true", help="display results"
) # display results
parser.add_argument(
"--save-conf", action="store_true", help="save confidences in --save-txt labels"
) # save confidence in txt writing
parser.add_argument(
"--line-thickness", default=3, type=int, help="bounding box thickness (pixels)"
) # box linethickness
parser.add_argument(
"--hide-labels", default=False, action="store_true", help="hide labels"
) # box hidelabel
parser.add_argument(
"--hide-conf", default=False, action="store_true", help="hide confidences"
) # boxhideconf
opt = parser.parse_args()
return opt
# function for plot fps and time comparision graph
def plot_fps_time_comparision(time_list, fps_list):
plt.figure()
plt.xlabel("Time (s)")
plt.ylabel("FPS")
plt.title("FPS and Time Comparision Graph")
plt.plot(time_list, fps_list, "b", label="FPS & Time")
plt.savefig("FPS_and_Time_Comparision_pose_estimate.png")
# main function
def main(opt):
run(**vars(opt))
if __name__ == "__main__":
opt = parse_opt()
strip_optimizer(opt.device, opt.poseweights)
main(opt)