Newer
Older
WasteNet / src / datasets.py
@sato sato on 1 Mar 2022 1 KB 最初のコミット
import torch
import torch.nn.functional as F
from PIL import Image
from torch.utils.data import Dataset
from torchvision.transforms import transforms


class WasteNetDataset(Dataset):
    def __init__(self, datasets_table, img_width, img_height, transform=None, return_path=False):
        self.transform = transform
        self.img_width = img_width
        self.img_height = img_height
        # 以下のテーブルは 列1:1枚目画像へのpath, 列2:2枚目画像へのpath,列3:3枚目画像へのpath
        #               列4:以上3枚がデータセットとして採択されるか
        self.datasets_table = datasets_table
        self.return_path = return_path

    def __len__(self):
        return len(self.datasets_table)

    def __getitem__(self, idx):
        target_row = self.datasets_table[idx]  # [path1, path2, path3, label]
        imgs = [Image.open(img_path) for img_path in target_row[:3]]
        imgs = [img.resize((self.img_width, self.img_height), Image.BILINEAR).convert("RGB") for img in imgs]
        if self.transform:
            imgs = [self.transform(img) for img in imgs]
        else:
            imgs = [transforms.ToTensor()(img) for img in imgs]

        imgs = torch.cat(imgs, dim=0)
        label = 1 if target_row[3] else 0

        if self.return_path:
            return imgs, label, *target_row[:3]
        else:
            return imgs, label