import torch
import torch.nn.functional as F
from PIL import Image
from torch.utils.data import Dataset
from torchvision.transforms import transforms
class WasteNetDataset(Dataset):
def __init__(self, datasets_table, img_width, img_height, transform=None, return_path=False):
self.transform = transform
self.img_width = img_width
self.img_height = img_height
# 以下のテーブルは 列1:1枚目画像へのpath, 列2:2枚目画像へのpath,列3:3枚目画像へのpath
# 列4:以上3枚がデータセットとして採択されるか
self.datasets_table = datasets_table
self.return_path = return_path
def __len__(self):
return len(self.datasets_table)
def __getitem__(self, idx):
target_row = self.datasets_table[idx] # [path1, path2, path3, label]
imgs = [Image.open(img_path) for img_path in target_row[:3]]
imgs = [img.resize((self.img_width, self.img_height), Image.BILINEAR).convert("RGB") for img in imgs]
if self.transform:
imgs = [self.transform(img) for img in imgs]
else:
imgs = [transforms.ToTensor()(img) for img in imgs]
imgs = torch.cat(imgs, dim=0)
label = 1 if target_row[3] else 0
if self.return_path:
return imgs, label, *target_row[:3]
else:
return imgs, label