You cannot select more than 25 topics Topics must start with a letter or number, can include dashes ('-') and can be up to 35 characters long.
RobustVideoMatting/train.py

507 lines
22 KiB
Python

3 years ago
"""
# First update `train_config.py` to set paths to your dataset locations.
# You may want to change `--num-workers` according to your machine's memory.
# The default num-workers=8 may cause dataloader to exit unexpectedly when
# machine is out of memory.
# Stage 1
python train.py \
--model-variant mobilenetv3 \
--dataset videomatte \
--resolution-lr 512 \
--seq-length-lr 15 \
--learning-rate-backbone 0.0001 \
--learning-rate-aspp 0.0002 \
--learning-rate-decoder 0.0002 \
--learning-rate-refiner 0 \
--checkpoint-dir checkpoint/stage1 \
--log-dir log/stage1 \
--epoch-start 0 \
--epoch-end 20
# Stage 2
python train.py \
--model-variant mobilenetv3 \
--dataset videomatte \
--resolution-lr 512 \
--seq-length-lr 50 \
--learning-rate-backbone 0.00005 \
--learning-rate-aspp 0.0001 \
--learning-rate-decoder 0.0001 \
--learning-rate-refiner 0 \
--checkpoint checkpoint/stage1/epoch-19.pth \
--checkpoint-dir checkpoint/stage2 \
--log-dir log/stage2 \
--epoch-start 20 \
--epoch-end 22
# Stage 3
python train.py \
--model-variant mobilenetv3 \
--dataset videomatte \
--train-hr \
--resolution-lr 512 \
--resolution-hr 2048 \
--seq-length-lr 40 \
--seq-length-hr 6 \
--learning-rate-backbone 0.00001 \
--learning-rate-aspp 0.00001 \
--learning-rate-decoder 0.00001 \
--learning-rate-refiner 0.0002 \
--checkpoint checkpoint/stage2/epoch-21.pth \
--checkpoint-dir checkpoint/stage3 \
--log-dir log/stage3 \
--epoch-start 22 \
--epoch-end 23
# Stage 4
python train.py \
--model-variant mobilenetv3 \
--dataset imagematte \
--train-hr \
--resolution-lr 512 \
--resolution-hr 2048 \
--seq-length-lr 40 \
--seq-length-hr 6 \
--learning-rate-backbone 0.00001 \
--learning-rate-aspp 0.00001 \
--learning-rate-decoder 0.00005 \
--learning-rate-refiner 0.0002 \
--checkpoint checkpoint/stage3/epoch-22.pth \
--checkpoint-dir checkpoint/stage4 \
--log-dir log/stage4 \
--epoch-start 23 \
--epoch-end 28
"""
import argparse
import torch
import random
import os
from torch import nn
from torch import distributed as dist
from torch import multiprocessing as mp
from torch.nn import functional as F
from torch.nn.parallel import DistributedDataParallel as DDP
from torch.optim import Adam
from torch.cuda.amp import autocast, GradScaler
from torch.utils.data import DataLoader, ConcatDataset
from torch.utils.data.distributed import DistributedSampler
from torch.utils.tensorboard import SummaryWriter
from torchvision.utils import make_grid
from torchvision.transforms.functional import center_crop
from tqdm import tqdm
from dataset.videomatte import (
VideoMatteDataset,
VideoMatteTrainAugmentation,
VideoMatteValidAugmentation,
)
from dataset.imagematte import (
ImageMatteDataset,
ImageMatteAugmentation
)
from dataset.coco import (
CocoPanopticDataset,
CocoPanopticTrainAugmentation,
)
from dataset.spd import (
SuperviselyPersonDataset
)
from dataset.youtubevis import (
YouTubeVISDataset,
YouTubeVISAugmentation
)
from dataset.augmentation import (
TrainFrameSampler,
ValidFrameSampler
)
from model import MattingNetwork
from train_config import DATA_PATHS
from train_loss import matting_loss, segmentation_loss
class Trainer:
def __init__(self, rank, world_size):
self.parse_args()
self.init_distributed(rank, world_size)
self.init_datasets()
self.init_model()
self.init_writer()
self.train()
self.cleanup()
def parse_args(self):
parser = argparse.ArgumentParser()
# Model
parser.add_argument('--model-variant', type=str, required=True, choices=['mobilenetv3', 'resnet50'])
# Matting dataset
parser.add_argument('--dataset', type=str, required=True, choices=['videomatte', 'imagematte'])
# Learning rate
parser.add_argument('--learning-rate-backbone', type=float, required=True)
parser.add_argument('--learning-rate-aspp', type=float, required=True)
parser.add_argument('--learning-rate-decoder', type=float, required=True)
parser.add_argument('--learning-rate-refiner', type=float, required=True)
# Training setting
parser.add_argument('--train-hr', action='store_true')
parser.add_argument('--resolution-lr', type=int, default=512)
parser.add_argument('--resolution-hr', type=int, default=2048)
parser.add_argument('--seq-length-lr', type=int, required=True)
parser.add_argument('--seq-length-hr', type=int, default=6)
parser.add_argument('--downsample-ratio', type=float, default=0.25)
parser.add_argument('--batch-size-per-gpu', type=int, default=1)
parser.add_argument('--num-workers', type=int, default=8)
parser.add_argument('--epoch-start', type=int, default=0)
parser.add_argument('--epoch-end', type=int, default=16)
# Tensorboard logging
parser.add_argument('--log-dir', type=str, required=True)
parser.add_argument('--log-train-loss-interval', type=int, default=20)
parser.add_argument('--log-train-images-interval', type=int, default=500)
# Checkpoint loading and saving
parser.add_argument('--checkpoint', type=str)
parser.add_argument('--checkpoint-dir', type=str, required=True)
parser.add_argument('--checkpoint-save-interval', type=int, default=500)
# Distributed
parser.add_argument('--distributed-addr', type=str, default='localhost')
parser.add_argument('--distributed-port', type=str, default='12355')
# Debugging
parser.add_argument('--disable-progress-bar', action='store_true')
parser.add_argument('--disable-validation', action='store_true')
parser.add_argument('--disable-mixed-precision', action='store_true')
self.args = parser.parse_args()
def init_distributed(self, rank, world_size):
self.rank = rank
self.world_size = world_size
self.log('Initializing distributed')
os.environ['MASTER_ADDR'] = self.args.distributed_addr
os.environ['MASTER_PORT'] = self.args.distributed_port
dist.init_process_group("nccl", rank=rank, world_size=world_size)
def init_datasets(self):
self.log('Initializing matting datasets')
size_hr = (self.args.resolution_hr, self.args.resolution_hr)
size_lr = (self.args.resolution_lr, self.args.resolution_lr)
# Matting datasets:
if self.args.dataset == 'videomatte':
self.dataset_lr_train = VideoMatteDataset(
videomatte_dir=DATA_PATHS['videomatte']['train'],
background_image_dir=DATA_PATHS['background_images']['train'],
background_video_dir=DATA_PATHS['background_videos']['train'],
size=self.args.resolution_lr,
seq_length=self.args.seq_length_lr,
seq_sampler=TrainFrameSampler(),
transform=VideoMatteTrainAugmentation(size_lr))
if self.args.train_hr:
self.dataset_hr_train = VideoMatteDataset(
videomatte_dir=DATA_PATHS['videomatte']['train'],
background_image_dir=DATA_PATHS['background_images']['train'],
background_video_dir=DATA_PATHS['background_videos']['train'],
size=self.args.resolution_hr,
seq_length=self.args.seq_length_hr,
seq_sampler=TrainFrameSampler(),
transform=VideoMatteTrainAugmentation(size_hr))
self.dataset_valid = VideoMatteDataset(
videomatte_dir=DATA_PATHS['videomatte']['valid'],
background_image_dir=DATA_PATHS['background_images']['valid'],
background_video_dir=DATA_PATHS['background_videos']['valid'],
size=self.args.resolution_hr if self.args.train_hr else self.args.resolution_lr,
seq_length=self.args.seq_length_hr if self.args.train_hr else self.args.seq_length_lr,
seq_sampler=ValidFrameSampler(),
transform=VideoMatteValidAugmentation(size_hr if self.args.train_hr else size_lr))
else:
self.dataset_lr_train = ImageMatteDataset(
imagematte_dir=DATA_PATHS['imagematte']['train'],
background_image_dir=DATA_PATHS['background_images']['train'],
background_video_dir=DATA_PATHS['background_videos']['train'],
size=self.args.resolution_lr,
seq_length=self.args.seq_length_lr,
seq_sampler=TrainFrameSampler(),
transform=ImageMatteAugmentation(size_lr))
if self.args.train_hr:
self.dataset_hr_train = ImageMatteDataset(
imagematte_dir=DATA_PATHS['imagematte']['train'],
background_image_dir=DATA_PATHS['background_images']['train'],
background_video_dir=DATA_PATHS['background_videos']['train'],
size=self.args.resolution_hr,
seq_length=self.args.seq_length_hr,
seq_sampler=TrainFrameSampler(),
transform=ImageMatteAugmentation(size_hr))
self.dataset_valid = ImageMatteDataset(
imagematte_dir=DATA_PATHS['imagematte']['valid'],
background_image_dir=DATA_PATHS['background_images']['valid'],
background_video_dir=DATA_PATHS['background_videos']['valid'],
size=self.args.resolution_hr if self.args.train_hr else self.args.resolution_lr,
seq_length=self.args.seq_length_hr if self.args.train_hr else self.args.seq_length_lr,
seq_sampler=ValidFrameSampler(),
transform=ImageMatteAugmentation(size_hr if self.args.train_hr else size_lr))
# Matting dataloaders:
self.datasampler_lr_train = DistributedSampler(
dataset=self.dataset_lr_train,
rank=self.rank,
num_replicas=self.world_size,
shuffle=True)
self.dataloader_lr_train = DataLoader(
dataset=self.dataset_lr_train,
batch_size=self.args.batch_size_per_gpu,
num_workers=self.args.num_workers,
sampler=self.datasampler_lr_train,
pin_memory=True)
if self.args.train_hr:
self.datasampler_hr_train = DistributedSampler(
dataset=self.dataset_hr_train,
rank=self.rank,
num_replicas=self.world_size,
shuffle=True)
self.dataloader_hr_train = DataLoader(
dataset=self.dataset_hr_train,
batch_size=self.args.batch_size_per_gpu,
num_workers=self.args.num_workers,
sampler=self.datasampler_hr_train,
pin_memory=True)
self.dataloader_valid = DataLoader(
dataset=self.dataset_valid,
batch_size=self.args.batch_size_per_gpu,
num_workers=self.args.num_workers,
pin_memory=True)
# Segementation datasets
self.log('Initializing image segmentation datasets')
self.dataset_seg_image = ConcatDataset([
CocoPanopticDataset(
imgdir=DATA_PATHS['coco_panoptic']['imgdir'],
anndir=DATA_PATHS['coco_panoptic']['anndir'],
annfile=DATA_PATHS['coco_panoptic']['annfile'],
transform=CocoPanopticTrainAugmentation(size_lr)),
SuperviselyPersonDataset(
imgdir=DATA_PATHS['spd']['imgdir'],
segdir=DATA_PATHS['spd']['segdir'],
transform=CocoPanopticTrainAugmentation(size_lr))
])
self.datasampler_seg_image = DistributedSampler(
dataset=self.dataset_seg_image,
rank=self.rank,
num_replicas=self.world_size,
shuffle=True)
self.dataloader_seg_image = DataLoader(
dataset=self.dataset_seg_image,
batch_size=self.args.batch_size_per_gpu * self.args.seq_length_lr,
num_workers=self.args.num_workers,
sampler=self.datasampler_seg_image,
pin_memory=True)
self.log('Initializing video segmentation datasets')
self.dataset_seg_video = YouTubeVISDataset(
videodir=DATA_PATHS['youtubevis']['videodir'],
annfile=DATA_PATHS['youtubevis']['annfile'],
size=self.args.resolution_lr,
seq_length=self.args.seq_length_lr,
seq_sampler=TrainFrameSampler(speed=[1]),
transform=YouTubeVISAugmentation(size_lr))
self.datasampler_seg_video = DistributedSampler(
dataset=self.dataset_seg_video,
rank=self.rank,
num_replicas=self.world_size,
shuffle=True)
self.dataloader_seg_video = DataLoader(
dataset=self.dataset_seg_video,
batch_size=self.args.batch_size_per_gpu,
num_workers=self.args.num_workers,
sampler=self.datasampler_seg_video,
pin_memory=True)
def init_model(self):
self.log('Initializing model')
self.model = MattingNetwork(self.args.model_variant, pretrained_backbone=True).to(self.rank)
if self.args.checkpoint:
self.log(f'Restoring from checkpoint: {self.args.checkpoint}')
self.log(self.model.load_state_dict(
torch.load(self.args.checkpoint, map_location=f'cuda:{self.rank}')))
self.model = nn.SyncBatchNorm.convert_sync_batchnorm(self.model)
self.model_ddp = DDP(self.model, device_ids=[self.rank], broadcast_buffers=False, find_unused_parameters=True)
self.optimizer = Adam([
{'params': self.model.backbone.parameters(), 'lr': self.args.learning_rate_backbone},
{'params': self.model.aspp.parameters(), 'lr': self.args.learning_rate_aspp},
{'params': self.model.decoder.parameters(), 'lr': self.args.learning_rate_decoder},
3 years ago
{'params': self.model.project_mat.parameters(), 'lr': self.args.learning_rate_decoder},
{'params': self.model.project_seg.parameters(), 'lr': self.args.learning_rate_decoder},
3 years ago
{'params': self.model.refiner.parameters(), 'lr': self.args.learning_rate_refiner},
])
self.scaler = GradScaler()
def init_writer(self):
if self.rank == 0:
self.log('Initializing writer')
self.writer = SummaryWriter(self.args.log_dir)
def train(self):
for epoch in range(self.args.epoch_start, self.args.epoch_end):
self.epoch = epoch
self.step = epoch * len(self.dataloader_lr_train)
if not self.args.disable_validation:
self.validate()
self.log(f'Training epoch: {epoch}')
for true_fgr, true_pha, true_bgr in tqdm(self.dataloader_lr_train, disable=self.args.disable_progress_bar, dynamic_ncols=True):
# Low resolution pass
self.train_mat(true_fgr, true_pha, true_bgr, downsample_ratio=1, tag='lr')
# High resolution pass
if self.args.train_hr:
true_fgr, true_pha, true_bgr = self.load_next_mat_hr_sample()
self.train_mat(true_fgr, true_pha, true_bgr, downsample_ratio=self.args.downsample_ratio, tag='hr')
# Segmentation pass
if self.step % 2 == 0:
true_img, true_seg = self.load_next_seg_video_sample()
self.train_seg(true_img, true_seg, log_label='seg_video')
else:
true_img, true_seg = self.load_next_seg_image_sample()
self.train_seg(true_img.unsqueeze(1), true_seg.unsqueeze(1), log_label='seg_image')
if self.step % self.args.checkpoint_save_interval == 0:
self.save()
self.step += 1
def train_mat(self, true_fgr, true_pha, true_bgr, downsample_ratio, tag):
true_fgr = true_fgr.to(self.rank, non_blocking=True)
true_pha = true_pha.to(self.rank, non_blocking=True)
true_bgr = true_bgr.to(self.rank, non_blocking=True)
true_fgr, true_pha, true_bgr = self.random_crop(true_fgr, true_pha, true_bgr)
true_src = true_fgr * true_pha + true_bgr * (1 - true_pha)
with autocast(enabled=not self.args.disable_mixed_precision):
pred_fgr, pred_pha = self.model_ddp(true_src, downsample_ratio=downsample_ratio)[:2]
loss = matting_loss(pred_fgr, pred_pha, true_fgr, true_pha)
self.scaler.scale(loss['total']).backward()
self.scaler.step(self.optimizer)
self.scaler.update()
self.optimizer.zero_grad()
if self.rank == 0 and self.step % self.args.log_train_loss_interval == 0:
for loss_name, loss_value in loss.items():
self.writer.add_scalar(f'train_{tag}_{loss_name}', loss_value, self.step)
if self.rank == 0 and self.step % self.args.log_train_images_interval == 0:
self.writer.add_image(f'train_{tag}_pred_fgr', make_grid(pred_fgr.flatten(0, 1), nrow=pred_fgr.size(1)), self.step)
self.writer.add_image(f'train_{tag}_pred_pha', make_grid(pred_pha.flatten(0, 1), nrow=pred_pha.size(1)), self.step)
self.writer.add_image(f'train_{tag}_true_fgr', make_grid(true_fgr.flatten(0, 1), nrow=true_fgr.size(1)), self.step)
self.writer.add_image(f'train_{tag}_true_pha', make_grid(true_pha.flatten(0, 1), nrow=true_pha.size(1)), self.step)
self.writer.add_image(f'train_{tag}_true_src', make_grid(true_src.flatten(0, 1), nrow=true_src.size(1)), self.step)
def train_seg(self, true_img, true_seg, log_label):
true_img = true_img.to(self.rank, non_blocking=True)
true_seg = true_seg.to(self.rank, non_blocking=True)
true_img, true_seg = self.random_crop(true_img, true_seg)
with autocast(enabled=not self.args.disable_mixed_precision):
pred_seg = self.model_ddp(true_img, segmentation_pass=True)[0]
loss = segmentation_loss(pred_seg, true_seg)
self.scaler.scale(loss).backward()
self.scaler.step(self.optimizer)
self.scaler.update()
self.optimizer.zero_grad()
if self.rank == 0 and (self.step - self.step % 2) % self.args.log_train_loss_interval == 0:
self.writer.add_scalar(f'{log_label}_loss', loss, self.step)
if self.rank == 0 and (self.step - self.step % 2) % self.args.log_train_images_interval == 0:
self.writer.add_image(f'{log_label}_pred_seg', make_grid(pred_seg.flatten(0, 1).float().sigmoid(), nrow=self.args.seq_length_lr), self.step)
self.writer.add_image(f'{log_label}_true_seg', make_grid(true_seg.flatten(0, 1), nrow=self.args.seq_length_lr), self.step)
self.writer.add_image(f'{log_label}_true_img', make_grid(true_img.flatten(0, 1), nrow=self.args.seq_length_lr), self.step)
def load_next_mat_hr_sample(self):
try:
sample = next(self.dataiterator_mat_hr)
except:
self.datasampler_hr_train.set_epoch(self.datasampler_hr_train.epoch + 1)
self.dataiterator_mat_hr = iter(self.dataloader_hr_train)
sample = next(self.dataiterator_mat_hr)
return sample
def load_next_seg_video_sample(self):
try:
sample = next(self.dataiterator_seg_video)
except:
self.datasampler_seg_video.set_epoch(self.datasampler_seg_video.epoch + 1)
self.dataiterator_seg_video = iter(self.dataloader_seg_video)
sample = next(self.dataiterator_seg_video)
return sample
def load_next_seg_image_sample(self):
try:
sample = next(self.dataiterator_seg_image)
except:
self.datasampler_seg_image.set_epoch(self.datasampler_seg_image.epoch + 1)
self.dataiterator_seg_image = iter(self.dataloader_seg_image)
sample = next(self.dataiterator_seg_image)
return sample
def validate(self):
if self.rank == 0:
self.log(f'Validating at the start of epoch: {self.epoch}')
self.model_ddp.eval()
total_loss, total_count = 0, 0
with torch.no_grad():
with autocast(enabled=not self.args.disable_mixed_precision):
for true_fgr, true_pha, true_bgr in tqdm(self.dataloader_valid, disable=self.args.disable_progress_bar, dynamic_ncols=True):
true_fgr = true_fgr.to(self.rank, non_blocking=True)
true_pha = true_pha.to(self.rank, non_blocking=True)
true_bgr = true_bgr.to(self.rank, non_blocking=True)
true_src = true_fgr * true_pha + true_bgr * (1 - true_pha)
batch_size = true_src.size(0)
pred_fgr, pred_pha = self.model(true_src)[:2]
total_loss += matting_loss(pred_fgr, pred_pha, true_fgr, true_pha)['total'].item() * batch_size
total_count += batch_size
avg_loss = total_loss / total_count
self.log(f'Validation set average loss: {avg_loss}')
self.writer.add_scalar('valid_loss', avg_loss, self.step)
self.model_ddp.train()
dist.barrier()
def random_crop(self, *imgs):
h, w = imgs[0].shape[-2:]
w = random.choice(range(w // 2, w))
3 years ago
h = random.choice(range(h // 2, h))
3 years ago
results = []
for img in imgs:
B, T = img.shape[:2]
img = img.flatten(0, 1)
img = F.interpolate(img, (max(h, w), max(h, w)), mode='bilinear', align_corners=False)
img = center_crop(img, (h, w))
img = img.reshape(B, T, *img.shape[1:])
results.append(img)
return results
def save(self):
if self.rank == 0:
os.makedirs(self.args.checkpoint_dir, exist_ok=True)
torch.save(self.model.state_dict(), os.path.join(self.args.checkpoint_dir, f'epoch-{self.epoch}.pth'))
self.log('Model saved')
dist.barrier()
def cleanup(self):
dist.destroy_process_group()
def log(self, msg):
print(f'[GPU{self.rank}] {msg}')
if __name__ == '__main__':
world_size = torch.cuda.device_count()
mp.spawn(
Trainer,
nprocs=world_size,
args=(world_size,),
join=True)