Newer
Older
SC-SfMLearner_for_NLab / datasets / make_datasets.py
@planck planck on 5 Dec 2020 5 KB train機能の実装
import numpy as np
import cv2
import os
import os.path as osp
import argparse
import tkinter as tk
import yaml
from glob import glob

file_dir = os.path.dirname(__file__)

parser = argparse.ArgumentParser()

parser.add_argument("--out_dir",
                    type=str,
                    help="データセットの出力先",
                    default=osp.join(file_dir, "data_for_SC_SfMLearner"))

parser.add_argument("--save_frequency",
                    type=int,
                    help="動画からどれくらいの周期でフレームを画像として保存するかの指定",
                    default=10)

parser.add_argument("--no_make_val",
                    help="評価用データを作成するか否か",
                    action="store_true")

parser.add_argument("--save_height",
                    type=int,
                    help="画像データセットに変換する際のリサイズ後の画像の高さ.32の倍数でないといけない",
                    default=512)

parser.add_argument("--save_width",
                    type=int,
                    help="画像データセットに変換する際のリサイズ後の画像の幅.32の倍数でないといけない",
                    default=288)

options = parser.parse_args()

K = []
root = tk.Tk()
root.geometry("250x250")
root.title("monodepth2 dataset GUI")

entry_boxs = {}
label1 = tk.Label(text="Please input your camera's intrinsics")
label1.place(x=30, y=20)

entry_num = 0
init_x, init_y = 55, 60
offset_x, offset_y = 50, 30
for col in range(3):
    for row in range(3):
        cur_key = "entry{}".format(entry_num)
        entry_boxs[cur_key] = tk.Entry(width=7)
        entry_boxs[cur_key].place(x=(init_x + row * offset_x), y=(init_y + col * offset_y))
        if entry_num in [1, 3, 6, 7]:
            entry_boxs[cur_key].insert(tk.END, "0")
        entry_num += 1
entry_boxs["entry0"].insert(tk.END, "f_x")
entry_boxs["entry2"].insert(tk.END, "c_x")
entry_boxs["entry4"].insert(tk.END, "f_y")
entry_boxs["entry5"].insert(tk.END, "c_y")
entry_boxs["entry8"].insert(tk.END, "1")


def end_tk_process():
    global root
    global K
    K = [float(entry_boxs[key].get()) for key in entry_boxs]
    root.destroy()


ok_button = tk.Button(text="finish", command=end_tk_process)
ok_button.place(x=100, y=180)

resized_calb_bin = tk.BooleanVar()
resized_calb_bin.set(False)
resized_calb_box = tk.Checkbutton(root, variable=resized_calb_bin, text="Is resized image's intrinsic")
resized_calb_box.place(x=45, y=150)

root.mainloop()

Is_resized_intrinsic = resized_calb_bin.get()

def make_monodepth2_dataset(mode="train"):
    assert mode in ["train", "val"], "function make_monodepth2_dataset's mode must be 'train' or 'val' "
    assert options.save_height % 32 == 0,  "'height' must be a multiple of 32"
    assert options.save_width % 32 == 0,  "'width' must be a multiple of 32"

    video_paths = glob("./{}_videos/*".format(mode))

    assert len(video_paths) != 0, "ファイル'{}_videos'に動画ファイルが入っていません"

    for path in video_paths:
        assert path[-4:] == '.mp4', "動画はmp4ファイルのみに対応しています"

    os.makedirs(options.out_dir, exist_ok=True)
    os.makedirs(osp.join(options.out_dir, "{}".format(mode)), exist_ok=True)

    image_save_num = 0
    dataset_indication_list = []
    ful_res_w, ful_res_h = cv2.VideoCapture(video_paths[0]).get(cv2.CAP_PROP_FRAME_WIDTH), cv2.VideoCapture(video_paths[0]).get(cv2.CAP_PROP_FRAME_HEIGHT)

    intrinsic = [[K[0], K[1], K[2]],
                 [K[3], K[4], K[5]],
                 [K[6], K[7], K[8]]]
    if not Is_resized_intrinsic:
        intrinsic[0][0] *= options.save_width / ful_res_w
        intrinsic[0][2] *= options.save_width / ful_res_w
        intrinsic[1][1] *= options.save_height / ful_res_h
        intrinsic[1][2] *= options.save_height / ful_res_h
    for sequence_num, path in enumerate(video_paths):
        cap = cv2.VideoCapture(path)
        if not cap.isOpened():
            raise RuntimeError("無効なmp4ファイルが発見されました.")

        dataset_indication_list.append("sequence_{}\n".format(sequence_num))
        save_dir = "sequence_{}".format(sequence_num)
        os.makedirs(osp.join(options.out_dir, "{}".format(mode), save_dir), exist_ok=True)
        while_count = 0
        while True:
            ret, frame = cap.read()

            if not ret:
                break

            if while_count % options.save_frequency == 0:
                frame = cv2.resize(frame, (options.save_width, options.save_height), interpolation=cv2.INTER_LINEAR)
                cv2.imwrite(osp.join(options.out_dir, "{}".format(mode), save_dir, "{:08}.jpg".format(image_save_num)),
                            frame)
                image_save_num += 1
            while_count += 1

    with open(osp.join(options.out_dir, "{}.txt".format(mode)), "w") as f:
        f.writelines(dataset_indication_list)

    with open(osp.join(options.out_dir, "environment.yaml"), "w") as f:
        dataset_info = {"height": options.save_height,
                        "width": options.save_width}
        camera_info = {"intrinsic": intrinsic}
        environment = {"dataset_info": dataset_info,
                       "camera_info": camera_info}
        f.write(yaml.dump(environment))


make_monodepth2_dataset(mode="train")
if not options.no_make_val:
    make_monodepth2_dataset(mode="val")