Support writing the audio stream back into video

When output_type == video, we can support adding back the audio stream
from the original video input.
pull/83/head
Alex Hughes 3 years ago
parent a3424e4205
commit 36c574a1b7
No known key found for this signature in database
GPG Key ID: B76F431FA62AF5AB

@ -12,6 +12,7 @@ python inference.py \
--seq-chunk 1
"""
import av
import torch
import os
from torch.utils.data import DataLoader
@ -20,6 +21,8 @@ from typing import Optional, Tuple
from tqdm.auto import tqdm
from inference_utils import VideoReader, VideoWriter, ImageSequenceReader, ImageSequenceWriter
from inference_utils import AudioVideoWriter
def convert_video(model,
input_source: str,
@ -33,6 +36,7 @@ def convert_video(model,
seq_chunk: int = 1,
num_workers: int = 0,
progress: bool = True,
passthrough_audio: bool = True,
device: Optional[str] = None,
dtype: Optional[torch.dtype] = None):
@ -51,10 +55,11 @@ def convert_video(model,
seq_chunk: Number of frames to process at once. Increase it for better parallelism.
num_workers: PyTorch's DataLoader workers. Only use >0 for image input.
progress: Show progress bar.
passthrough_audio: Should we passthrough any audio from the input video
device: Only need to manually provide if model is a TorchScript freezed model.
dtype: Only need to manually provide if model is a TorchScript freezed model.
"""
assert downsample_ratio is None or (downsample_ratio > 0 and downsample_ratio <= 1), 'Downsample ratio must be between 0 (exclusive) and 1 (inclusive).'
assert any([output_composition, output_alpha, output_foreground]), 'Must provide at least one output.'
assert output_type in ['video', 'png_sequence'], 'Only support "video" and "png_sequence" output modes.'
@ -76,26 +81,52 @@ def convert_video(model,
else:
source = ImageSequenceReader(input_source, transform)
reader = DataLoader(source, batch_size=seq_chunk, pin_memory=True, num_workers=num_workers)
audio_source = None
if os.path.isfile(input_source):
container = av.open(input_source)
if container.streams.get(audio=0):
audio_source = container.streams.get(audio=0)[0]
# Initialize writers
if output_type == 'video':
frame_rate = source.frame_rate if isinstance(source, VideoReader) else 30
output_video_mbps = 1 if output_video_mbps is None else output_video_mbps
if output_composition is not None:
writer_com = VideoWriter(
path=output_composition,
frame_rate=frame_rate,
bit_rate=int(output_video_mbps * 1000000))
if output_alpha is not None:
writer_pha = VideoWriter(
path=output_alpha,
frame_rate=frame_rate,
bit_rate=int(output_video_mbps * 1000000))
if output_foreground is not None:
writer_fgr = VideoWriter(
path=output_foreground,
frame_rate=frame_rate,
bit_rate=int(output_video_mbps * 1000000))
if passthrough_audio and audio_source:
if output_composition is not None:
writer_com = AudioVideoWriter(
path=output_composition,
frame_rate=frame_rate,
audio_stream=audio_source,
bit_rate=int(output_video_mbps * 1000000))
if output_alpha is not None:
writer_pha = AudioVideoWriter(
path=output_alpha,
frame_rate=frame_rate,
audio_stream=audio_source,
bit_rate=int(output_video_mbps * 1000000))
if output_foreground is not None:
writer_fgr = AudioVideoWriter(
path=output_foreground,
frame_rate=frame_rate,
audio_stream=audio_source,
bit_rate=int(output_video_mbps * 1000000))
else:
if output_composition is not None:
writer_com = VideoWriter(
path=output_composition,
frame_rate=frame_rate,
bit_rate=int(output_video_mbps * 1000000))
if output_alpha is not None:
writer_pha = VideoWriter(
path=output_alpha,
frame_rate=frame_rate,
bit_rate=int(output_video_mbps * 1000000))
if output_foreground is not None:
writer_fgr = VideoWriter(
path=output_foreground,
frame_rate=frame_rate,
bit_rate=int(output_video_mbps * 1000000))
else:
if output_composition is not None:
writer_com = ImageSequenceWriter(output_composition, 'png')
@ -113,7 +144,7 @@ def convert_video(model,
if (output_composition is not None) and (output_type == 'video'):
bgr = torch.tensor([120, 255, 155], device=device, dtype=dtype).div(255).view(1, 1, 3, 1, 1)
try:
with torch.no_grad():
bar = tqdm(total=len(source), disable=not progress, dynamic_ncols=True)
@ -137,7 +168,7 @@ def convert_video(model,
fgr = fgr * pha.gt(0)
com = torch.cat([fgr, pha], dim=-3)
writer_com.write(com[0])
bar.update(src.size(1))
finally:
@ -167,11 +198,12 @@ class Converter:
def convert(self, *args, **kwargs):
convert_video(self.model, device=self.device, dtype=torch.float32, *args, **kwargs)
if __name__ == '__main__':
import argparse
from model import MattingNetwork
parser = argparse.ArgumentParser()
parser.add_argument('--variant', type=str, required=True, choices=['mobilenetv3', 'resnet50'])
parser.add_argument('--checkpoint', type=str, required=True)
@ -188,7 +220,7 @@ if __name__ == '__main__':
parser.add_argument('--num-workers', type=int, default=0)
parser.add_argument('--disable-progress', action='store_true')
args = parser.parse_args()
converter = Converter(args.variant, args.checkpoint, args.device)
converter.convert(
input_source=args.input_source,
@ -203,5 +235,4 @@ if __name__ == '__main__':
num_workers=args.num_workers,
progress=not args.disable_progress
)

@ -12,14 +12,14 @@ class VideoReader(Dataset):
self.video = pims.PyAVVideoReader(path)
self.rate = self.video.frame_rate
self.transform = transform
@property
def frame_rate(self):
return self.rate
def __len__(self):
return len(self.video)
def __getitem__(self, idx):
frame = self.video[idx]
frame = Image.fromarray(np.asarray(frame))
@ -57,10 +57,10 @@ class ImageSequenceReader(Dataset):
self.path = path
self.files = sorted(os.listdir(path))
self.transform = transform
def __len__(self):
return len(self.files)
def __getitem__(self, idx):
with Image.open(os.path.join(self.path, self.files[idx])) as img:
img.load()
@ -75,14 +75,40 @@ class ImageSequenceWriter:
self.extension = extension
self.counter = 0
os.makedirs(path, exist_ok=True)
def write(self, frames):
# frames: [T, C, H, W]
for t in range(frames.shape[0]):
to_pil_image(frames[t]).save(os.path.join(
self.path, str(self.counter).zfill(4) + '.' + self.extension))
self.counter += 1
def close(self):
pass
class AudioVideoWriter(VideoWriter):
def __init__(self, path, frame_rate, audio_stream=None, bit_rate=1000000):
super(AudioVideoWriter, self).__init__(
path=path,
frame_rate=frame_rate,
bit_rate=bit_rate
)
self.source_audio_stream = audio_stream
self.output_audio_stream = self.container.add_stream(
codec_name=self.source_audio_stream.codec_context.codec.name,
rate=self.source_audio_stream.rate,
)
def remux_audio(self):
input_audio_container = self.source_audio_stream.container
for packet in input_audio_container.demux(self.source_audio_stream):
if packet.dts is None:
continue
packet.stream = self.output_audio_stream
self.container.mux(packet)
def close(self):
self.remux_audio()
self.container.mux(self.output_audio_stream.encode())
super(AudioVideoWriter, self).close()

Loading…
Cancel
Save