You cannot select more than 25 topics Topics must start with a letter or number, can include dashes ('-') and can be up to 35 characters long.
RobustVideoMatting/dataset/youtubevis.py

123 lines
4.1 KiB
Python

import torch
import os
import json
import numpy as np
import random
from torch.utils.data import Dataset
from PIL import Image
from torchvision import transforms
from torchvision.transforms import functional as F
class YouTubeVISDataset(Dataset):
def __init__(self, videodir, annfile, size, seq_length, seq_sampler, transform=None):
self.videodir = videodir
self.size = size
self.seq_length = seq_length
self.seq_sampler = seq_sampler
self.transform = transform
with open(annfile) as f:
data = json.load(f)
self.masks = {}
for ann in data['annotations']:
if ann['category_id'] == 26: # person
video_id = ann['video_id']
if video_id not in self.masks:
self.masks[video_id] = [[] for _ in range(len(ann['segmentations']))]
for frame, mask in zip(self.masks[video_id], ann['segmentations']):
if mask is not None:
frame.append(mask)
self.videos = {}
for video in data['videos']:
video_id = video['id']
if video_id in self.masks:
self.videos[video_id] = video
self.index = []
for video_id in self.videos.keys():
for frame in range(len(self.videos[video_id])):
self.index.append((video_id, frame))
def __len__(self):
return len(self.index)
def __getitem__(self, idx):
video_id, frame_id = self.index[idx]
video = self.videos[video_id]
frame_count = len(self.videos[video_id]['file_names'])
H, W = video['height'], video['width']
imgs, segs = [], []
for t in self.seq_sampler(self.seq_length):
frame = (frame_id + t) % frame_count
filename = video['file_names'][frame]
masks = self.masks[video_id][frame]
with Image.open(os.path.join(self.videodir, filename)) as img:
imgs.append(self._downsample_if_needed(img.convert('RGB'), Image.BILINEAR))
seg = np.zeros((H, W), dtype=np.uint8)
for mask in masks:
seg |= self._decode_rle(mask)
segs.append(self._downsample_if_needed(Image.fromarray(seg), Image.NEAREST))
if self.transform is not None:
imgs, segs = self.transform(imgs, segs)
return imgs, segs
def _decode_rle(self, rle):
H, W = rle['size']
msk = np.zeros(H * W, dtype=np.uint8)
encoding = rle['counts']
skip = 0
for i in range(0, len(encoding) - 1, 2):
skip += encoding[i]
draw = encoding[i + 1]
msk[skip : skip + draw] = 255
skip += draw
return msk.reshape(W, H).transpose()
def _downsample_if_needed(self, img, resample):
w, h = img.size
if min(w, h) > self.size:
scale = self.size / min(w, h)
w = int(scale * w)
h = int(scale * h)
img = img.resize((w, h), resample)
return img
class YouTubeVISAugmentation:
def __init__(self, size):
self.size = size
self.jitter = transforms.ColorJitter(0.3, 0.3, 0.3, 0.15)
def __call__(self, imgs, segs):
# To tensor
imgs = torch.stack([F.to_tensor(img) for img in imgs])
segs = torch.stack([F.to_tensor(seg) for seg in segs])
# Resize
params = transforms.RandomResizedCrop.get_params(imgs, scale=(0.8, 1), ratio=(0.9, 1.1))
imgs = F.resized_crop(imgs, *params, self.size, interpolation=F.InterpolationMode.BILINEAR)
segs = F.resized_crop(segs, *params, self.size, interpolation=F.InterpolationMode.BILINEAR)
# Color jitter
imgs = self.jitter(imgs)
# Grayscale
if random.random() < 0.05:
imgs = F.rgb_to_grayscale(imgs, num_output_channels=3)
# Horizontal flip
if random.random() < 0.5:
imgs = F.hflip(imgs)
segs = F.hflip(segs)
return imgs, segs