mirror of https://github.com/menyifang/DCT-Net
You cannot select more than 25 topics
Topics must start with a letter or number, can include dashes ('-') and can be up to 35 characters long.
41 lines
1.0 KiB
Python
41 lines
1.0 KiB
Python
from io import BytesIO
|
|
|
|
import lmdb
|
|
from PIL import Image
|
|
from torch.utils.data import Dataset
|
|
|
|
|
|
class MultiResolutionDataset(Dataset):
|
|
def __init__(self, path, transform, resolution=256):
|
|
self.env = lmdb.open(
|
|
path,
|
|
max_readers=32,
|
|
readonly=True,
|
|
lock=False,
|
|
readahead=False,
|
|
meminit=False,
|
|
)
|
|
|
|
if not self.env:
|
|
raise IOError('Cannot open lmdb dataset', path)
|
|
|
|
with self.env.begin(write=False) as txn:
|
|
self.length = int(txn.get('length'.encode('utf-8')).decode('utf-8'))
|
|
|
|
self.resolution = resolution
|
|
self.transform = transform
|
|
|
|
def __len__(self):
|
|
return self.length
|
|
|
|
def __getitem__(self, index):
|
|
with self.env.begin(write=False) as txn:
|
|
key = f'{self.resolution}-{str(index).zfill(5)}'.encode('utf-8')
|
|
img_bytes = txn.get(key)
|
|
|
|
buffer = BytesIO(img_bytes)
|
|
img = Image.open(buffer)
|
|
img = self.transform(img)
|
|
|
|
return img
|