import os
import sys
import glob
import random

import torch
from torch.utils.data import Dataset
from PIL import Image
from torchvision import transforms

available_datasets = ['horse2zebra', 'facades', "ord2bli_dset"]
base_url = 'https://people.eecs.berkeley.edu/~taesung_park/CycleGAN/datasets/'

class CycleGANDataset(Dataset):
    """
    Class used to read images in
    """
    def __init__(self, data_root=f"{__file__}/../data", dataset_name='facades', transform=None, unaligned=True, mode='train'):

        # Check whether the specified dataset is available
        if dataset_name not in available_datasets:
            print(f"{dataset_name}という名前のデータセットが存在しないようです")
            sys.exit(1)
        # Check whether the dataset is downloaded
        base_dir = os.path.abspath(data_root)
        dataset_dir = os.path.join(base_dir, dataset_name)
        if not os.path.exists(dataset_dir):
            if not os.path.exists(base_dir):
                os.makedirs(base_dir, True)
            print(f"{base_url}からデータセットをダウンロードし{base_dir}に解凍してください")
            sys.exit(1)

        self.transform = transform
        self.unaligned = unaligned

        self.files_A = sorted(glob.glob(os.path.join(dataset_dir, f"{mode}/A/*.*")))
        self.files_B = sorted(glob.glob(os.path.join(dataset_dir, f"{mode}/B/*.*")))

    def __getitem__(self, index):

        item_A = self.transform(Image.open(self.files_A[index % len(self.files_A)]))

        if self.unaligned:
            item_B = self.transform(Image.open(self.files_B[random.randint(0, len(self.files_B)-1)]))
        else:
            item_B = self.transform(Image.open(self.files_A[index % len(self.files_B)]))

        return {'A': item_A, 'B': item_B}

    def __len__(self):
        return max(len(self.files_A), len(self.files_B))
