Newer
Older
DetectWhiteLinesByYOLOv5 / main.py
from timeit import default_timer as timer

import cv2
import numpy as np
from openvino.inference_engine import IECore
import torch

from utils.general import non_max_suppression, scale_coords
from utils.plots import Annotator

WEBCAM_PORT = 0
IMG_SIZE = 640
CONF_THRESH = 0.6
write_fps = True
MODEL_XML = r"yolov5s_640x640_opt.xml"
MODEL_WEIGHTS = r"yolov5s_640x640_opt.bin"


# yolov5に入力するIMG_SIZExIMG_SIZEで背景が0パディングされている形式に変更
def convert_to_yolov5format_img(frame):
    height, width = frame.shape[:2]
    if width < height:
        size, limit = height, width
    else:
        limit, size = height, width
    start = int((size - limit) / 2)
    fin = int((size + limit) / 2)

    transformed = np.full((size, size, 3), 114, np.uint8)
    if size == height:
        transformed[:, start:fin] = frame
    else:
        transformed[start:fin, :] = frame
    transformed = cv2.resize(transformed, (IMG_SIZE, IMG_SIZE), interpolation=cv2.INTER_CUBIC)
    return transformed


class DispFps:
    def __init__(self):
        # 表示関連定義
        self.__width = 160
        self.__height = 40
        self.__font_size = 1.0
        self.__font_width = 1
        self.__font_style = cv2.FONT_HERSHEY_COMPLEX
        self.__font_color = (255, 255, 255)
        self.__background_color = (0, 0, 0)

        # フレーム数カウント用変数
        self.__frame_count = 0

        # FPS計算用変数
        self.__accum_time = 0
        self.__curr_fps = 0
        self.__prev_time = timer()
        self.__str = "FPS: "

    def __calc(self):
        # フレーム数更新
        self.__frame_count += 1

        # FPS更新
        self.__curr_time = timer()
        self.__exec_time = self.__curr_time - self.__prev_time
        self.__prev_time = self.__curr_time
        self.__accum_time = self.__accum_time + self.__exec_time
        self.__curr_fps = self.__curr_fps + 1
        if self.__accum_time > 1:
            self.__accum_time = self.__accum_time - 1
            self.__str = "FPS: " + str(self.__curr_fps)
            self.__curr_fps = 0

    def __disp(self, frame, str, x1, y1, x2, y2):
        cv2.rectangle(frame, (x1, y1), (x2, y2), self.__background_color, -1)
        cv2.putText(frame, str, (x1 + 5, y2 - 5), self.__font_style, self.__font_size, self.__font_color,
                    self.__font_width)

    def disp(self, frame):
        # 表示内容計算
        self.__calc()
        # フレーム数(左上に表示する)
        self.__disp(frame, str(self.__frame_count), 0, 0, x2=self.__width, y2=self.__height)
        # FPS(右上に表示する)
        screen_width = int(frame.shape[1])
        self.__disp(frame, self.__str, screen_width - self.__width, 0, screen_width, self.__height)


def main():
    ie = IECore()
    net = ie.read_network(model=MODEL_XML, weights=MODEL_WEIGHTS)

    input_layer = next(iter(net.input_info))
    print(f"input layout: {net.input_info[input_layer].layout}")
    print(f"input precision: {net.input_info[input_layer].precision}")
    print(f"input shape: {net.input_info[input_layer].tensor_desc.dims}")

    config = {"CPU_THREADS_NUM": "8"}
    exec_net = ie.load_network(network=net, device_name="CPU", config=config)

    cap = cv2.VideoCapture(0)
    dispFps = DispFps()

    while True:
        ret, frame = cap.read()

        if not ret:
            break
        N, C, H, W = net.input_info[input_layer].tensor_desc.dims
        # OpenCV resize expects the destination size as (width, height)
        src = convert_to_yolov5format_img(frame)

        input_data = np.expand_dims(np.transpose(src, (2, 0, 1))[::-1], 0).astype(np.float32) / 255.0

        result = torch.tensor(exec_net.infer({input_layer: input_data})["output"])
        # NMS
        pred = non_max_suppression(result, CONF_THRESH, 0.45, 0, False, max_det=10)
        for i, det in enumerate(pred):
            gn = torch.tensor(src.shape)[[1, 0, 1, 0]]
            annotator = Annotator(src, line_width=3, example=str(["white_line"]))
            if len(det):
                # Rescale boxes from img_size to im0 size
                det[:, :4] = scale_coords(input_data.shape[2:], det[:, :4], src.shape).round()

                for *xyxy, conf, cls in reversed(det):
                    c = int(cls)
                    label = "white_line"
                    annotator.box_label(xyxy, label)
        if write_fps:
            dispFps.disp(src)

        cv2.imshow("detect", src)

        key = cv2.waitKey(1)

        if key == ord("q"):
            break


if __name__ == '__main__':
    main()