You cannot select more than 25 topics Topics must start with a letter or number, can include dashes ('-') and can be up to 35 characters long.
RobustVideoMatting/inference_utils.py

89 lines
2.6 KiB
Python

import av
import os
import pims
import numpy as np
from torch.utils.data import Dataset
from torchvision.transforms.functional import to_pil_image
from PIL import Image
class VideoReader(Dataset):
def __init__(self, path, transform=None):
self.video = pims.PyAVVideoReader(path)
self.rate = self.video.frame_rate
self.transform = transform
@property
def frame_rate(self):
return self.rate
def __len__(self):
return len(self.video)
def __getitem__(self, idx):
frame = self.video[idx]
frame = Image.fromarray(np.asarray(frame))
if self.transform is not None:
frame = self.transform(frame)
return frame
class VideoWriter:
def __init__(self, path, frame_rate, bit_rate=1000000):
self.container = av.open(path, mode='w')
self.stream = self.container.add_stream('h264', rate=f'{frame_rate:.4f}')
self.stream.pix_fmt = 'yuv420p'
self.stream.bit_rate = bit_rate
def write(self, frames):
# frames: [T, C, H, W]
self.stream.width = frames.size(3)
self.stream.height = frames.size(2)
if frames.size(1) == 1:
frames = frames.repeat(1, 3, 1, 1) # convert grayscale to RGB
frames = frames.mul(255).byte().cpu().permute(0, 2, 3, 1).numpy()
for t in range(frames.shape[0]):
frame = frames[t]
frame = av.VideoFrame.from_ndarray(frame, format='rgb24')
self.container.mux(self.stream.encode(frame))
def close(self):
self.container.mux(self.stream.encode())
self.container.close()
class ImageSequenceReader(Dataset):
def __init__(self, path, transform=None):
self.path = path
self.files = sorted(os.listdir(path))
self.transform = transform
def __len__(self):
return len(self.files)
def __getitem__(self, idx):
with Image.open(os.path.join(self.path, self.files[idx])) as img:
img.load()
if self.transform is not None:
return self.transform(img)
return img
class ImageSequenceWriter:
def __init__(self, path, extension='jpg'):
self.path = path
self.extension = extension
self.counter = 0
os.makedirs(path, exist_ok=True)
def write(self, frames):
# frames: [T, C, H, W]
for t in range(frames.shape[0]):
to_pil_image(frames[t]).save(os.path.join(
self.path, str(self.counter).zfill(4) + '.' + self.extension))
self.counter += 1
def close(self):
pass