Newer
Older
CollectAnimalExperimentLabels2021 / main.py
@sato sato on 1 Mar 2022 4 KB Dockerfileの追加
import csv
import os.path as osp
import pickle
from glob import glob

import cv2
import numpy as np
import pydicom
from tqdm import tqdm

dicom_dirs = "./DICOM"
load_tmp1 = False
load_tmp2 = False

save_num = 0
while True:
    csv_path = f"./log{save_num}.csv"
    if not osp.exists(csv_path):
        break
    save_num += 1


if load_tmp2:
    with open("./tmp2.pkl", "rb") as f:
        files = pickle.load(f)
else:
    files = glob(osp.join(dicom_dirs, "*"))
    if load_tmp1:
        with open("./tmp1.pkl", "rb") as f:
            series_set = pickle.load(f)
    else:
        series_set = set([pydicom.read_file(x).SeriesDescription for x in tqdm(files)])
        with open("./tmp1.pkl", "wb") as f:
            pickle.dump(series_set, f)

    show_dict = {(i + 1): name for i, name in enumerate(series_set)}
    print("対象シリーズの選択")
    for k, v in show_dict.items():
        print(f"{k}: {v}")
    target = show_dict[int(input())]
    files = [x for x in tqdm(files) if pydicom.read_file(x).SeriesDescription == target]
    print("start_sort")
    files = sorted(files, key=lambda x: float(pydicom.read_file(x).SliceLocation))
    with open("./tmp2.pkl", "wb") as f:
        pickle.dump(files, f)
    print("end_sort")

file_size = len(files)
sample_dfile = pydicom.read_file(files[0])
dcm_wc = sample_dfile.WindowCenter[0]
dcm_ww = sample_dfile.WindowWidth[0]

print(dcm_wc, dcm_ww)


def row2uint8(CT_row, dcm_wc, dcm_ww):
    window_max = dcm_wc + dcm_ww / 2
    window_min = dcm_wc - dcm_ww / 2

    CT_img = CT_row.astype(np.float)
    CT_img[CT_img < window_min] = window_min
    CT_img[window_max < CT_img] = window_max

    CT_img -= np.mean(CT_img)
    CT_img = CT_img / (np.max(np.abs(CT_img)) + 1e-5) * 256.0
    CT_img -= np.mean(CT_img)
    CT_img = np.clip(CT_img, 0, 255).astype(np.uint8)
    CT_img = cv2.cvtColor(CT_img, cv2.COLOR_GRAY2BGR)

    return CT_img


def getImgByIndex(files, i, dcm_wc, dcm_ww):
    CT_data = pydicom.read_file(files[i])
    CT_row = CT_data.pixel_array + CT_data.RescaleIntercept
    CT_img = row2uint8(CT_row, dcm_wc, dcm_ww)
    return CT_img


mx, my = 0, 0
results = []


def mouseCallbackFunc(event, x, y, flags, param):
    global files, img_index, mx, my, csv_writer, f, results
    mx, my = x, y
    if event == cv2.EVENT_LBUTTONDOWN:
        print("クリック場所を記録しますか? (y/n)")
        if input() == "y":
            contents = [files[img_index]]
            contents.append(sample_dfile.SeriesDescription)
            print("クリック箇所の名称を入力してください")
            contents.append(input())
            contents.append(x)
            contents.append(y)
            results.append(contents)
            print("クリック箇所を記録しました")


CT_img = getImgByIndex(files, 0, dcm_wc, dcm_ww)
h, w, _ = CT_img.shape
show_img = CT_img.copy()
zoomed_img = CT_img.copy()
zoomed_img = cv2.resize(zoomed_img, (w * 2, h * 2), interpolation=cv2.INTER_CUBIC)
cv2.namedWindow("a")
cv2.setMouseCallback("a", mouseCallbackFunc)

cv2.imshow("a", show_img)
cv2.imshow("b", show_img)
cv2.waitKey(1)

img_index = 0
while True:
    k = cv2.waitKey(10)
    if k == ord("l") and img_index != (file_size - 1):
        img_index += 1
        print(f"img_index: {img_index}")
        CT_img = getImgByIndex(files, img_index, dcm_wc, dcm_ww)
    elif k == ord("j") and img_index != 0:
        img_index -= 1
        print(f"img_index: {img_index}")
        CT_img = getImgByIndex(files, img_index, dcm_wc, dcm_ww)
    elif k == ord("w"):
        print(f"wwとwcを変更します. (現在の値はwc:{dcm_wc} ww:{dcm_ww})")
        print("wc: ")
        dcm_wc = int(input())
        print("ww: ")
        dcm_ww = int(input())
        print("変更を反映しました")
        CT_img = getImgByIndex(files, img_index, dcm_wc, dcm_ww)
    elif k == ord("q"):
        break

    show_img = CT_img.copy()
    zoomed_img = CT_img.copy()
    zoomed_img = cv2.resize(zoomed_img, (w * 2, h * 2), interpolation=cv2.INTER_CUBIC)
    cv2.line(show_img, (mx, 0), (mx, show_img.shape[0]), (100, 200, 100), 1)
    cv2.line(show_img, (0, my), (show_img.shape[1], my), (100, 200, 100), 1)
    cv2.line(zoomed_img, (mx * 2, 0), (mx * 2, zoomed_img.shape[0]), (100, 200, 100), 1)
    cv2.line(zoomed_img, (0, my * 2), (zoomed_img.shape[1], my * 2), (100, 200, 100), 1)
    cv2.imshow("a", show_img)
    cv2.imshow("b", zoomed_img)

with open(csv_path, "w") as f:
    print(results)
    csv_writer = csv.writer(f)
    csv_writer.writerows(results)