Newer
Older
Yolov5WhiteLineDetection / detect_wl_with_calc_insertion.py
import argparse
import os
import platform
import shutil
import time
from pathlib import Path
import csv
from datetime import datetime

import cv2
import ffmpeg
import torch
import torch.backends.cudnn as cudnn
import numpy as np
from numpy import random

from models.experimental import attempt_load
from utils.datasets import LoadStreams, LoadImages
from utils.general import (
    check_img_size, non_max_suppression, apply_classifier, scale_coords,
    xyxy2xywh, plot_one_box, strip_optimizer, set_logging)
from utils.torch_utils import select_device, load_classifier, time_synchronized
from utils.calc_insertion_utils import plot_part_line, calc_mm_per_pixel, calc_match_list, calc_direction_vec, \
    sort_center_points, calc_valid_length, show_formula_mat, eval_interval_score, eval_linear_score, \
    eval_series_score

import socket
HOST = "127.0.0.1"
PORT = 50007
it = 0

client = socket.socket(socket.AF_INET, socket.SOCK_DGRAM)

def detect(save_img=False):
    global it
    out, source, weights, view_img, save_txt, imgsz = \
        opt.output, opt.source, opt.weights, opt.view_img, opt.save_txt, opt.img_size
    webcam = source == '0' or source.startswith('rtsp') or source.startswith('http') or source.endswith('.txt')

    # 追加
    cur_center_points = list()
    prev_center_points = list()
    match_list = list()
    prev_mm_per_pixel = None
    mm_per_pixel = None
    total_insertion_length = 0.0
    scope_direction_vec = None
    f = open("insertion_log.csv", "w")
    csv_writer = csv.writer(f, lineterminator='\n')
    csv_writer.writerow(["format %Y%m%d%H%M%S%f"])

    # Initialize
    set_logging()
    device = select_device(opt.device)
    if os.path.exists(out):
        shutil.rmtree(out)  # delete output folder
    os.makedirs(out)  # make new output folder
    half = device.type != 'cpu'  # half precision only supported on CUDA

    # Load model
    model = attempt_load(weights, map_location=device)  # load FP32 model
    imgsz = check_img_size(imgsz, s=model.stride.max())  # check img_size
    if half:
        model.half()  # to FP16

    # Second-stage classifier
    classify = False
    if classify:
        modelc = load_classifier(name='resnet101', n=2)  # initialize
        modelc.load_state_dict(torch.load('weights/resnet101.pt', map_location=device)['model'])  # load weights
        modelc.to(device).eval()

    # Set Dataloader
    vid_path, vid_writer = None, None
    if webcam:
        view_img = True
        cudnn.benchmark = True  # set True to speed up constant image size inference
        dataset = LoadStreams(source, img_size=imgsz)
    else:
        save_img = True
        dataset = LoadImages(source, img_size=imgsz)

    # Get names and colors
    names = model.module.names if hasattr(model, 'module') else model.names
    colors = [[random.randint(0, 255) for _ in range(3)] for _ in range(len(names))]

    # Run inference
    t0 = time.time()
    img = torch.zeros((1, 3, imgsz, imgsz), device=device)  # init img
    _ = model(img.half() if half else img) if device.type != 'cpu' else None  # run once
    for path, img, im0s, vid_cap in dataset:
        # cv2.imwrite(f"img_log/{datetime.now().strftime('%Y%m%d%H%M%S%f')}.jpg", img[0].transpose(1, 2, 0)[:, :, [2, 1, 0]])
        img = torch.from_numpy(img).to(device)
        img = img.half() if half else img.float()  # uint8 to fp16/32
        img /= 255.0  # 0 - 255 to 0.0 - 1.0
        if img.ndimension() == 3:
            img = img.unsqueeze(0)

        # Inference
        t1 = time_synchronized()
        pred = model(img, augment=opt.augment)[0]

        # Apply NMS
        pred = non_max_suppression(pred, opt.conf_thres, opt.iou_thres, classes=opt.classes, agnostic=opt.agnostic_nms)
        t2 = time_synchronized()

        # Apply Classifier
        if classify:
            pred = apply_classifier(pred, modelc, img, im0s)

        # Process detections
        for i, det in enumerate(pred):  # detections per image
            if webcam:  # batch_size >= 1
                p, s, im0 = path[i], '%g: ' % i, im0s[i].copy()
            else:
                p, s, im0 = path, '', im0s

            save_path = str(Path(out) / Path(p).name)
            txt_path = str(Path(out) / Path(p).stem) + ('_%g' % dataset.frame if dataset.mode == 'video' else '')
            s += '%gx%g ' % img.shape[2:]  # print string
            gn = torch.tensor(im0.shape)[[1, 0, 1, 0]]  # normalization gain whwh
            if det is not None and len(det):
                # 追加
                cur_center_points = list()

                # Rescale boxes from img_size to im0 size
                det[:, :4] = scale_coords(img.shape[2:], det[:, :4], im0.shape).round()

                # Print results
                for c in det[:, -1].unique():
                    n = (det[:, -1] == c).sum()  # detections per class
                    s += '%g %ss, ' % (n, names[int(c)])  # add to string

                # Write results
                for *xyxy, conf, cls in reversed(det):
                    if save_txt:  # Write to file
                        xywh = (xyxy2xywh(torch.tensor(xyxy).view(1, 4)) / gn).view(-1).tolist()  # normalized xywh
                        with open(txt_path + '.txt', 'a') as f:
                            f.write(('%g ' * 5 + '\n') % (cls, *xywh))  # label format

                    if save_img or view_img:  # Add bbox to image
                        label = '%s %.2f' % (names[int(cls)], conf)
                        plot_one_box(xyxy, im0, label=label, color=colors[int(cls)], line_thickness=3)
                    cur_center_points.append((int((xyxy[0] + xyxy[2]) / 2), int((xyxy[1] + xyxy[3]) / 2)))

                # 追加部↓
                for point in cur_center_points:
                    cv2.circle(im0, point, 2, (255, 0, 0), 2)

                # 二点以上ある場合はピクセル単位の長さ導出
                if 2 <= len(cur_center_points):
                    cur_center_points = sort_center_points(cur_center_points)
                    plot_part_line(im0, cur_center_points)

                    mm_per_pixel = calc_mm_per_pixel(cur_center_points)

                # Is_continue = eval_linear_score(cur_center_points) and \
                #     eval_interval_score(cur_center_points) and \
                #     eval_series_score(mm_per_pixel, prev_mm_per_pixel, thresh_val=0.7)

                Is_continue = eval_linear_score(cur_center_points) and eval_interval_score(cur_center_points, thresh_val=0.7)

                if not Is_continue:
                    cv2.imshow(p, im0)
                    key = cv2.waitKey(1)
                    if key == ord('q'):
                        f.close()
                        raise StopIteration
                    elif key == ord('r'):
                        print("reset series score")
                        prev_mm_per_pixel = mm_per_pixel
                    continue

                # prevがある場合はマッチする点を探す
                if len(prev_center_points) != 0 and len(cur_center_points) != 0:
                    match_list = calc_match_list(cur_center_points, prev_center_points)


                # 長さ正規化があり,マッチリストがあるなら挿入量を計算
                if mm_per_pixel is not None and len(match_list) != 0:
                    # 2つ以上なら内視鏡方向の更新
                    if 2 <= len(cur_center_points):
                        scope_direction_vec = calc_direction_vec(cur_center_points)

                    # scope_direction_vecがあるなら挿入距離を測定
                    valid_mm_length = calc_valid_length(match_list, scope_direction_vec, mm_per_pixel)

                    # TODO: 一時的に実験用に負の値を挿入距離に換算してる
                    client.sendto(str(-valid_mm_length).encode('utf-8'), (HOST, PORT))
                    csv_writer.writerow([datetime.now().strftime("%Y%m%d%H%M%S%f"), str(valid_mm_length)])

                    # print("send{}".format(valid_mm_length))
                    it += 1
                    total_insertion_length += valid_mm_length

                match_list = None
                formula = show_formula_mat(total_insertion_length)
                prev_center_points = cur_center_points

            prev_mm_per_pixel = mm_per_pixel
            # 追加部↑

            # ↓追加部
            key = cv2.waitKey(1)
            if key == ord("q"):  # q to quit
                f.close()
                raise StopIteration
            elif key == ord('r'):
                print("reset total")
                total_insertion_length = 0.0
            cv2.imshow(p, im0)
            # ↑追加部

            # Print time (inference + NMS)
            #print('%sDone. (%.3fs)' % (s, t2 - t1))

            # Stream results
            if view_img:
                cv2.imshow(p, im0)
                if cv2.waitKey(1) == ord('q'):  # q to quit
                    f.close()
                    raise StopIteration

            # Save results (image with detections)
            if save_img:
                if dataset.mode == 'images':
                    cv2.imwrite(save_path, im0)
                else:
                    if vid_path != save_path:  # new video
                        vid_path = save_path
                        if isinstance(vid_writer, cv2.VideoWriter):
                            vid_writer.release()  # release previous video writer

                        fourcc = 'mp4v'  # output video codec
                        fps = vid_cap.get(cv2.CAP_PROP_FPS)
                        w = int(vid_cap.get(cv2.CAP_PROP_FRAME_WIDTH))
                        h = int(vid_cap.get(cv2.CAP_PROP_FRAME_HEIGHT))
                        vid_writer = cv2.VideoWriter(save_path, cv2.VideoWriter_fourcc(*fourcc), fps, (w, h))
                    vid_writer.write(im0)

    if save_txt or save_img:
        print('Results saved to %s' % Path(out))
        if platform.system() == 'Darwin' and not opt.update:  # MacOS
            os.system('open ' + save_path)

    print('Done. (%.3fs)' % (time.time() - t0))
    f.close()


if __name__ == '__main__':
    parser = argparse.ArgumentParser()
    parser.add_argument('--weights', nargs='+', type=str, default='yolov5s.pt', help='model.pt path(s)')
    parser.add_argument('--source', type=str, default='inference/images', help='source')  # file/folder, 0 for webcam
    parser.add_argument('--output', type=str, default='inference/output', help='output folder')  # output folder
    parser.add_argument('--img-size', type=int, default=640, help='inference size (pixels)')
    parser.add_argument('--conf-thres', type=float, default=0.4, help='object confidence threshold')
    parser.add_argument('--iou-thres', type=float, default=0.5, help='IOU threshold for NMS')
    parser.add_argument('--device', default='', help='cuda device, i.e. 0 or 0,1,2,3 or cpu')
    parser.add_argument('--view-img', action='store_true', help='display results')
    parser.add_argument('--save-txt', action='store_true', help='save results to *.txt')
    parser.add_argument('--classes', nargs='+', type=int, help='filter by class: --class 0, or --class 0 2 3')
    parser.add_argument('--agnostic-nms', action='store_true', help='class-agnostic NMS')
    parser.add_argument('--augment', action='store_true', help='augmented inference')
    parser.add_argument('--update', action='store_true', help='update all models')
    opt = parser.parse_args()
    print(opt)

    with torch.no_grad():
        if opt.update:  # update all models (to fix SourceChangeWarning)
            for opt.weights in ['yolov5s.pt', 'yolov5m.pt', 'yolov5l.pt', 'yolov5x.pt']:
                detect()
                strip_optimizer(opt.weights)
        else:
            detect()