|
|
@ -12,6 +12,7 @@ python inference.py \
|
|
|
|
--seq-chunk 1
|
|
|
|
--seq-chunk 1
|
|
|
|
"""
|
|
|
|
"""
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
import av
|
|
|
|
import torch
|
|
|
|
import torch
|
|
|
|
import os
|
|
|
|
import os
|
|
|
|
from torch.utils.data import DataLoader
|
|
|
|
from torch.utils.data import DataLoader
|
|
|
@ -20,6 +21,8 @@ from typing import Optional, Tuple
|
|
|
|
from tqdm.auto import tqdm
|
|
|
|
from tqdm.auto import tqdm
|
|
|
|
|
|
|
|
|
|
|
|
from inference_utils import VideoReader, VideoWriter, ImageSequenceReader, ImageSequenceWriter
|
|
|
|
from inference_utils import VideoReader, VideoWriter, ImageSequenceReader, ImageSequenceWriter
|
|
|
|
|
|
|
|
from inference_utils import AudioVideoWriter
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
def convert_video(model,
|
|
|
|
def convert_video(model,
|
|
|
|
input_source: str,
|
|
|
|
input_source: str,
|
|
|
@ -33,6 +36,7 @@ def convert_video(model,
|
|
|
|
seq_chunk: int = 1,
|
|
|
|
seq_chunk: int = 1,
|
|
|
|
num_workers: int = 0,
|
|
|
|
num_workers: int = 0,
|
|
|
|
progress: bool = True,
|
|
|
|
progress: bool = True,
|
|
|
|
|
|
|
|
passthrough_audio: bool = True,
|
|
|
|
device: Optional[str] = None,
|
|
|
|
device: Optional[str] = None,
|
|
|
|
dtype: Optional[torch.dtype] = None):
|
|
|
|
dtype: Optional[torch.dtype] = None):
|
|
|
|
|
|
|
|
|
|
|
@ -51,6 +55,7 @@ def convert_video(model,
|
|
|
|
seq_chunk: Number of frames to process at once. Increase it for better parallelism.
|
|
|
|
seq_chunk: Number of frames to process at once. Increase it for better parallelism.
|
|
|
|
num_workers: PyTorch's DataLoader workers. Only use >0 for image input.
|
|
|
|
num_workers: PyTorch's DataLoader workers. Only use >0 for image input.
|
|
|
|
progress: Show progress bar.
|
|
|
|
progress: Show progress bar.
|
|
|
|
|
|
|
|
passthrough_audio: Should we passthrough any audio from the input video
|
|
|
|
device: Only need to manually provide if model is a TorchScript freezed model.
|
|
|
|
device: Only need to manually provide if model is a TorchScript freezed model.
|
|
|
|
dtype: Only need to manually provide if model is a TorchScript freezed model.
|
|
|
|
dtype: Only need to manually provide if model is a TorchScript freezed model.
|
|
|
|
"""
|
|
|
|
"""
|
|
|
@ -77,25 +82,51 @@ def convert_video(model,
|
|
|
|
source = ImageSequenceReader(input_source, transform)
|
|
|
|
source = ImageSequenceReader(input_source, transform)
|
|
|
|
reader = DataLoader(source, batch_size=seq_chunk, pin_memory=True, num_workers=num_workers)
|
|
|
|
reader = DataLoader(source, batch_size=seq_chunk, pin_memory=True, num_workers=num_workers)
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
audio_source = None
|
|
|
|
|
|
|
|
if os.path.isfile(input_source):
|
|
|
|
|
|
|
|
container = av.open(input_source)
|
|
|
|
|
|
|
|
if container.streams.get(audio=0):
|
|
|
|
|
|
|
|
audio_source = container.streams.get(audio=0)[0]
|
|
|
|
|
|
|
|
|
|
|
|
# Initialize writers
|
|
|
|
# Initialize writers
|
|
|
|
if output_type == 'video':
|
|
|
|
if output_type == 'video':
|
|
|
|
frame_rate = source.frame_rate if isinstance(source, VideoReader) else 30
|
|
|
|
frame_rate = source.frame_rate if isinstance(source, VideoReader) else 30
|
|
|
|
output_video_mbps = 1 if output_video_mbps is None else output_video_mbps
|
|
|
|
output_video_mbps = 1 if output_video_mbps is None else output_video_mbps
|
|
|
|
if output_composition is not None:
|
|
|
|
if passthrough_audio and audio_source:
|
|
|
|
writer_com = VideoWriter(
|
|
|
|
if output_composition is not None:
|
|
|
|
path=output_composition,
|
|
|
|
writer_com = AudioVideoWriter(
|
|
|
|
frame_rate=frame_rate,
|
|
|
|
path=output_composition,
|
|
|
|
bit_rate=int(output_video_mbps * 1000000))
|
|
|
|
frame_rate=frame_rate,
|
|
|
|
if output_alpha is not None:
|
|
|
|
audio_stream=audio_source,
|
|
|
|
writer_pha = VideoWriter(
|
|
|
|
bit_rate=int(output_video_mbps * 1000000))
|
|
|
|
path=output_alpha,
|
|
|
|
if output_alpha is not None:
|
|
|
|
frame_rate=frame_rate,
|
|
|
|
writer_pha = AudioVideoWriter(
|
|
|
|
bit_rate=int(output_video_mbps * 1000000))
|
|
|
|
path=output_alpha,
|
|
|
|
if output_foreground is not None:
|
|
|
|
frame_rate=frame_rate,
|
|
|
|
writer_fgr = VideoWriter(
|
|
|
|
audio_stream=audio_source,
|
|
|
|
path=output_foreground,
|
|
|
|
bit_rate=int(output_video_mbps * 1000000))
|
|
|
|
frame_rate=frame_rate,
|
|
|
|
if output_foreground is not None:
|
|
|
|
bit_rate=int(output_video_mbps * 1000000))
|
|
|
|
writer_fgr = AudioVideoWriter(
|
|
|
|
|
|
|
|
path=output_foreground,
|
|
|
|
|
|
|
|
frame_rate=frame_rate,
|
|
|
|
|
|
|
|
audio_stream=audio_source,
|
|
|
|
|
|
|
|
bit_rate=int(output_video_mbps * 1000000))
|
|
|
|
|
|
|
|
else:
|
|
|
|
|
|
|
|
if output_composition is not None:
|
|
|
|
|
|
|
|
writer_com = VideoWriter(
|
|
|
|
|
|
|
|
path=output_composition,
|
|
|
|
|
|
|
|
frame_rate=frame_rate,
|
|
|
|
|
|
|
|
bit_rate=int(output_video_mbps * 1000000))
|
|
|
|
|
|
|
|
if output_alpha is not None:
|
|
|
|
|
|
|
|
writer_pha = VideoWriter(
|
|
|
|
|
|
|
|
path=output_alpha,
|
|
|
|
|
|
|
|
frame_rate=frame_rate,
|
|
|
|
|
|
|
|
bit_rate=int(output_video_mbps * 1000000))
|
|
|
|
|
|
|
|
if output_foreground is not None:
|
|
|
|
|
|
|
|
writer_fgr = VideoWriter(
|
|
|
|
|
|
|
|
path=output_foreground,
|
|
|
|
|
|
|
|
frame_rate=frame_rate,
|
|
|
|
|
|
|
|
bit_rate=int(output_video_mbps * 1000000))
|
|
|
|
else:
|
|
|
|
else:
|
|
|
|
if output_composition is not None:
|
|
|
|
if output_composition is not None:
|
|
|
|
writer_com = ImageSequenceWriter(output_composition, 'png')
|
|
|
|
writer_com = ImageSequenceWriter(output_composition, 'png')
|
|
|
@ -168,6 +199,7 @@ class Converter:
|
|
|
|
def convert(self, *args, **kwargs):
|
|
|
|
def convert(self, *args, **kwargs):
|
|
|
|
convert_video(self.model, device=self.device, dtype=torch.float32, *args, **kwargs)
|
|
|
|
convert_video(self.model, device=self.device, dtype=torch.float32, *args, **kwargs)
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
if __name__ == '__main__':
|
|
|
|
if __name__ == '__main__':
|
|
|
|
import argparse
|
|
|
|
import argparse
|
|
|
|
from model import MattingNetwork
|
|
|
|
from model import MattingNetwork
|
|
|
@ -204,4 +236,3 @@ if __name__ == '__main__':
|
|
|
|
progress=not args.disable_progress
|
|
|
|
progress=not args.disable_progress
|
|
|
|
)
|
|
|
|
)
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|