Newer
Older
Yolov5WhiteLineDetection / utils / calc_insertion_utils.py
import numpy as np
import cv2
import datetime

def split_xy(xy_s):
    x, y = list(), list()
    for xy in xy_s:
        x.append(xy[0])
        y.append(xy[1])
    x = np.array(x)
    y = np.array(y)
    return x, y


def reg1dim(x, y):
    n = len(x)
    a = ((np.dot(x, y) - y.sum() * x.sum()/n) / ((x ** 2).sum() - x.sum()**2 / n))
    b = (y.sum() - a * x.sum()) / n
    return a, b


def plot_line(im0, a, b):
    imsize = im0.shape[0]
    if 0 <= int(-a/b) <= imsize:
        cv2.line(im0, (int(-b/a), 0), (int((imsize - b) / a), imsize), (112, 164, 184), 2, lineType=cv2.LINE_AA)
    else:
        cv2.line(im0, (0, int(b)), (imsize, int(a * imsize + b)), (112, 164, 184), 2, lineType=cv2.LINE_AA)


def plot_part_line(im0, center_points):
    # TODO:改善の余地あり
    for i in range(len(center_points) - 1):
        cv2.line(im0, center_points[i], center_points[i + 1], (112, 164, 184), 1, lineType=cv2.LINE_AA)


def calc_mm_per_pixel(center_xys, real_wl_interval=50.):
    mm_per_pixel_sum = 0.0
    for i in range(len(center_xys) - 1):
        mm_per_pixel_sum += (50. / np.linalg.norm(np.array(center_xys[i + 1]) - np.array(center_xys[i]), ord=2))
    mm_per_pixel = (mm_per_pixel_sum / (len(center_xys) - 1))
    return mm_per_pixel


def calc_match_list(peak_xy, prev_peak_xy, thresh_distance=30):
    match_list = []

    if len(prev_peak_xy) <= len(peak_xy):
        for prev_id, prev_xy in enumerate(prev_peak_xy):
            prev_xy = np.array(prev_xy)
            match_cur_id = -1
            min_distance = float('inf')
            for cur_id, cur_xy in enumerate(peak_xy):
                cur_xy = np.array(cur_xy)
                xy_distance = np.sum(np.abs(cur_xy - prev_xy))
                if xy_distance < min_distance:
                    min_distance = xy_distance
                    match_cur_id = cur_id
            if min_distance < thresh_distance:
                match_list.append([prev_xy, peak_xy[match_cur_id]])
            else:
                pass

    else:
        for cur_id, cur_xy in enumerate(peak_xy):
            cur_xy = np.array(cur_xy)
            match_prev_id = -1
            min_distance = float('inf')
            for prev_id, prev_xy in enumerate(prev_peak_xy):
                prev_xy = np.array(prev_xy)
                xy_distance = np.sum(np.abs(cur_xy - prev_xy))
                if xy_distance < min_distance:
                    min_distance = xy_distance
                    match_prev_id = prev_id
            if min_distance < thresh_distance:
                match_list.append([prev_peak_xy[match_prev_id], cur_xy])
            else:
                pass

    return match_list


def calc_direction_vec(center_points):
    x_diff_list = list()
    y_diff_list = list()
    for i in range(len(center_points) - 1):
        x_diff_list.append(center_points[i + 1][0] - center_points[i][0])
        y_diff_list.append(center_points[i + 1][1] - center_points[i][1])
    scope_dir_vec = np.array([sum(x_diff_list)/len(x_diff_list), sum(y_diff_list)/len(y_diff_list)])
    l2 = np.linalg.norm(scope_dir_vec, ord=2, axis=-1, keepdims=True)
    l2[l2 == 0] = 1
    return scope_dir_vec/l2


def sort_center_points(center_points):
    x_s, y_s = split_xy(center_points)
    x_diverse, y_diverse = max(x_s) - min(x_s), max(y_s) - min(y_s)
    # diverseが大きい方に準拠してsortする
    sorted_center = sorted(center_points, key=lambda xy: xy[0]) if y_diverse <= x_diverse else \
        sorted(center_points, key=lambda xy: xy[1])
    return sorted_center


def calc_valid_length(match_list, scope_dir_vec, mm_per_pixel):
    valid_mm_length = list()
    for match_xy in match_list:
        move_vec = match_xy[1] - match_xy[0]
        valid_pixel_length = np.dot(move_vec, scope_dir_vec)
        valid_mm_length.append(valid_pixel_length * mm_per_pixel)
    valid_mm_length = sum(valid_mm_length) / len(valid_mm_length)
    return valid_mm_length


def show_formula_mat(formula):
    formula_mat = np.zeros((50, 256, 3), dtype=np.uint8)
    cv2.putText(formula_mat, '{:7.1f} mm'.format(formula), (10, 30), cv2.FONT_HERSHEY_DUPLEX,
                1, (0, 200, 200), thickness=3, lineType=cv2.LINE_AA)
    cv2.imshow("formula", formula_mat)
    return formula_mat


def eval_linear_score(cur_center_points, thresh_val=0.9):
    if 3 <= len(cur_center_points):
        base_point = np.array(cur_center_points[0])
        inter_vec_list = []
        for point in cur_center_points[1:]:
            point = np.array(point)
            relative_vec = point - base_point
            inter_vec = relative_vec / np.linalg.norm(relative_vec)
            inter_vec_list.append(inter_vec)

        base_vec = inter_vec_list[0]
        linear_score = 1.0
        for vec in inter_vec_list[1:]:
            dot = np.dot(base_vec, vec)
            linear_score *= dot
            base_vec = vec

        if thresh_val < linear_score:
            return True
        else:
            print(f"{datetime.datetime.now()} linear_score is low")
            return False
    else:
        return True


def eval_interval_score(cur_center_points, thresh_val=0.9):
    if 3 <= len(cur_center_points):
        base_point = np.array(cur_center_points[0])
        relative_vec_list = []
        relative_norm_list = []
        for point in cur_center_points[1:]:
            point = np.array(point)
            relative_vec = point - base_point
            relative_vec_list.append(relative_vec)
            relative_norm_list.append(np.linalg.norm(relative_vec))
            base_point = point

        base_vec = relative_vec_list[0]
        interval_score = 1.0
        for vec in relative_vec_list[1:]:
            vec_diff = (vec - base_vec) / max(relative_norm_list)
            interval_score *= (1.0 - np.linalg.norm(vec_diff))
            base_vec = vec

        if thresh_val < interval_score:
            return True
        else:
            print(f"{datetime.datetime.now()} interval_score is low")
            return False
    else:
        return True


def eval_series_score(mpp, p_mpp, thresh_val=0.9):
    if p_mpp is not None:
        series_score = 1.0 - abs((mpp - p_mpp) / max(mpp, p_mpp))

        if thresh_val < series_score:
            return True
        else:
            print("series_score is low")
            return False
    else:
        return True