You cannot select more than 25 topics
Topics must start with a letter or number, can include dashes ('-') and can be up to 35 characters long.
28 lines
898 B
Python
28 lines
898 B
Python
import os
|
|
from torch.utils.data import Dataset
|
|
from PIL import Image
|
|
|
|
|
|
class SuperviselyPersonDataset(Dataset):
|
|
def __init__(self, imgdir, segdir, transform=None):
|
|
self.img_dir = imgdir
|
|
self.img_files = sorted(os.listdir(imgdir))
|
|
self.seg_dir = segdir
|
|
self.seg_files = sorted(os.listdir(segdir))
|
|
assert len(self.img_files) == len(self.seg_files)
|
|
self.transform = transform
|
|
|
|
def __len__(self):
|
|
return len(self.img_files)
|
|
|
|
def __getitem__(self, idx):
|
|
with Image.open(os.path.join(self.img_dir, self.img_files[idx])) as img, \
|
|
Image.open(os.path.join(self.seg_dir, self.seg_files[idx])) as seg:
|
|
img = img.convert('RGB')
|
|
seg = seg.convert('L')
|
|
|
|
if self.transform is not None:
|
|
img, seg = self.transform(img, seg)
|
|
|
|
return img, seg
|