Newer
Older
DeepTIAS / Features / DeepLearning / Reference / Tang's / img_dataset.py
@ke96 ke96 on 2 Nov 2020 2 KB Refactor
#メモ(学習サイズが256×256の場合)
#画像の縦横のサイズが286pixel以上になるように変換したあと、
#256×256で切り抜きしている

import os

import numpy
from PIL import Image
import six

import numpy as np

from io import BytesIO
import os
import pickle
import json
import numpy as np
import glob

import skimage.io as io

from chainer.dataset import dataset_mixin
"""***各種設定***"""

#画像サイズの変更(学習させたい画像サイズ以上に設定)
min_size = 286

#学習させる画像サイズに設定
w_crop_width = 256
h_crop_width = 256
# download `BASE` dataset from http://cmp.felk.cvut.cz/~tylecr1/facade/
class ImgDataset(dataset_mixin.DatasetMixin):
    def __init__(self, dataSrcDir, dataDstDir, data_range=(0,0.9)):
        print("load dataset start")
        print("    from: %s, %s"%(dataSrcDir, dataDstDir))
        print("    range: [{}, {})".format(data_range[0], data_range[1]))
        self.dataSrcDir = dataSrcDir
        self.dataDstDir = dataDstDir
        self.dataset = []
        self.picfiles = list(map(os.path.basename, glob.glob(os.path.join(dataDstDir, "*.jpg"))))
        data_range_start = int(data_range[0] * len(self.picfiles))
        data_range_end   = int(data_range[1] * len(self.picfiles))
        for fn in self.picfiles[data_range_start:data_range_end]:
            img_src = Image.open(os.path.join(self.dataSrcDir, fn))
            img_dst = Image.open(os.path.join(self.dataDstDir, fn))
            w,h = img_src.size
            #この値をcrop_width以上にする
            r = min_size/min(w,h)
            #r = 286/min(w,h)
            # resize images so that min(w, h) == 286
            img_src = img_src.resize((int(r*w), int(r*h)), Image.BILINEAR)
            img_dst = img_dst.resize((int(r*w), int(r*h)), Image.BILINEAR)

            #img_src = np.asarray(img_src).astype("f")
            #img_src = img_src.transpose(2,0,1)/128.0-1.0

            img_src = np.asarray(img_src).astype("f").transpose(2,0,1)/128.0-1.0
            img_dst = np.asarray(img_dst).astype("f").transpose(2,0,1)/128.0-1.0

            #img_dst = np.asarray(img_dst).astype("f")
            #img_dst_shape = img_dst.shape
            #img_dst = img_dst.reshape(img_dst_shape[0], img_dst_shape[1], 1)
            #img_dst = img_dst.transpose(2,0,1)/128.0-1.0

            self.dataset.append((img_src, img_dst))
        print("load dataset done")

    def __len__(self):
        return len(self.dataset)

    # return (label, img)
    def get_example(self, i):
        _,h,w = self.dataset[i][0].shape
        x_l = np.random.randint(0,w-w_crop_width)
        x_r = x_l+w_crop_width
        y_l = np.random.randint(0,h-h_crop_width)
        y_r = y_l+h_crop_width
        #same image for input and output image pair
        return self.dataset[i][0][:,y_l:y_r,x_l:x_r],self.dataset[i][1][:,y_l:y_r,x_l:x_r]