You cannot select more than 25 topics
Topics must start with a letter or number, can include dashes ('-') and can be up to 35 characters long.
126 lines
4.8 KiB
Python
126 lines
4.8 KiB
Python
import os
|
|
import random
|
|
from torch.utils.data import Dataset
|
|
from PIL import Image
|
|
|
|
from .augmentation import MotionAugmentation
|
|
|
|
|
|
class VideoMatteDataset(Dataset):
|
|
def __init__(self,
|
|
videomatte_dir,
|
|
background_image_dir,
|
|
background_video_dir,
|
|
size,
|
|
seq_length,
|
|
seq_sampler,
|
|
transform=None):
|
|
self.background_image_dir = background_image_dir
|
|
self.background_image_files = os.listdir(background_image_dir)
|
|
self.background_video_dir = background_video_dir
|
|
self.background_video_clips = sorted(os.listdir(background_video_dir))
|
|
self.background_video_frames = [sorted(os.listdir(os.path.join(background_video_dir, clip)))
|
|
for clip in self.background_video_clips]
|
|
|
|
self.videomatte_dir = videomatte_dir
|
|
self.videomatte_clips = sorted(os.listdir(os.path.join(videomatte_dir, 'fgr')))
|
|
self.videomatte_frames = [sorted(os.listdir(os.path.join(videomatte_dir, 'fgr', clip)))
|
|
for clip in self.videomatte_clips]
|
|
self.videomatte_idx = [(clip_idx, frame_idx)
|
|
for clip_idx in range(len(self.videomatte_clips))
|
|
for frame_idx in range(0, len(self.videomatte_frames[clip_idx]), seq_length)]
|
|
self.size = size
|
|
self.seq_length = seq_length
|
|
self.seq_sampler = seq_sampler
|
|
self.transform = transform
|
|
|
|
def __len__(self):
|
|
return len(self.videomatte_idx)
|
|
|
|
def __getitem__(self, idx):
|
|
if random.random() < 0.5:
|
|
bgrs = self._get_random_image_background()
|
|
else:
|
|
bgrs = self._get_random_video_background()
|
|
|
|
fgrs, phas = self._get_videomatte(idx)
|
|
|
|
if self.transform is not None:
|
|
return self.transform(fgrs, phas, bgrs)
|
|
|
|
return fgrs, phas, bgrs
|
|
|
|
def _get_random_image_background(self):
|
|
with Image.open(os.path.join(self.background_image_dir, random.choice(self.background_image_files))) as bgr:
|
|
bgr = self._downsample_if_needed(bgr.convert('RGB'))
|
|
bgrs = [bgr] * self.seq_length
|
|
return bgrs
|
|
|
|
def _get_random_video_background(self):
|
|
clip_idx = random.choice(range(len(self.background_video_clips)))
|
|
frame_count = len(self.background_video_frames[clip_idx])
|
|
frame_idx = random.choice(range(max(1, frame_count - self.seq_length)))
|
|
clip = self.background_video_clips[clip_idx]
|
|
bgrs = []
|
|
for i in self.seq_sampler(self.seq_length):
|
|
frame_idx_t = frame_idx + i
|
|
frame = self.background_video_frames[clip_idx][frame_idx_t % frame_count]
|
|
with Image.open(os.path.join(self.background_video_dir, clip, frame)) as bgr:
|
|
bgr = self._downsample_if_needed(bgr.convert('RGB'))
|
|
bgrs.append(bgr)
|
|
return bgrs
|
|
|
|
def _get_videomatte(self, idx):
|
|
clip_idx, frame_idx = self.videomatte_idx[idx]
|
|
clip = self.videomatte_clips[clip_idx]
|
|
frame_count = len(self.videomatte_frames[clip_idx])
|
|
fgrs, phas = [], []
|
|
for i in self.seq_sampler(self.seq_length):
|
|
frame = self.videomatte_frames[clip_idx][(frame_idx + i) % frame_count]
|
|
with Image.open(os.path.join(self.videomatte_dir, 'fgr', clip, frame)) as fgr, \
|
|
Image.open(os.path.join(self.videomatte_dir, 'pha', clip, frame)) as pha:
|
|
fgr = self._downsample_if_needed(fgr.convert('RGB'))
|
|
pha = self._downsample_if_needed(pha.convert('L'))
|
|
fgrs.append(fgr)
|
|
phas.append(pha)
|
|
return fgrs, phas
|
|
|
|
def _downsample_if_needed(self, img):
|
|
w, h = img.size
|
|
if min(w, h) > self.size:
|
|
scale = self.size / min(w, h)
|
|
w = int(scale * w)
|
|
h = int(scale * h)
|
|
img = img.resize((w, h))
|
|
return img
|
|
|
|
class VideoMatteTrainAugmentation(MotionAugmentation):
|
|
def __init__(self, size):
|
|
super().__init__(
|
|
size=size,
|
|
prob_fgr_affine=0.3,
|
|
prob_bgr_affine=0.3,
|
|
prob_noise=0.1,
|
|
prob_color_jitter=0.3,
|
|
prob_grayscale=0.02,
|
|
prob_sharpness=0.1,
|
|
prob_blur=0.02,
|
|
prob_hflip=0.5,
|
|
prob_pause=0.03,
|
|
)
|
|
|
|
class VideoMatteValidAugmentation(MotionAugmentation):
|
|
def __init__(self, size):
|
|
super().__init__(
|
|
size=size,
|
|
prob_fgr_affine=0,
|
|
prob_bgr_affine=0,
|
|
prob_noise=0,
|
|
prob_color_jitter=0,
|
|
prob_grayscale=0,
|
|
prob_sharpness=0,
|
|
prob_blur=0,
|
|
prob_hflip=0,
|
|
prob_pause=0,
|
|
)
|