|
|
|
@ -121,7 +121,8 @@ from dataset.augmentation import (
|
|
|
|
|
from model import MattingNetwork
|
|
|
|
|
from train_config import DATA_PATHS
|
|
|
|
|
from train_loss import matting_loss, segmentation_loss
|
|
|
|
|
|
|
|
|
|
import kornia
|
|
|
|
|
from torchvision import transforms as T
|
|
|
|
|
|
|
|
|
|
class Trainer:
|
|
|
|
|
def __init__(self, rank, world_size):
|
|
|
|
@ -189,7 +190,7 @@ class Trainer:
|
|
|
|
|
if self.args.dataset == 'videomatte':
|
|
|
|
|
self.dataset_lr_train = VideoMatteDataset(
|
|
|
|
|
videomatte_dir=DATA_PATHS['videomatte']['train'],
|
|
|
|
|
background_image_dir=DATA_PATHS['background_images']['train'],
|
|
|
|
|
# background_image_dir=DATA_PATHS['background_images']['train'],
|
|
|
|
|
background_video_dir=DATA_PATHS['background_videos']['train'],
|
|
|
|
|
size=self.args.resolution_lr,
|
|
|
|
|
seq_length=self.args.seq_length_lr,
|
|
|
|
@ -198,7 +199,7 @@ class Trainer:
|
|
|
|
|
if self.args.train_hr:
|
|
|
|
|
self.dataset_hr_train = VideoMatteDataset(
|
|
|
|
|
videomatte_dir=DATA_PATHS['videomatte']['train'],
|
|
|
|
|
background_image_dir=DATA_PATHS['background_images']['train'],
|
|
|
|
|
# background_image_dir=DATA_PATHS['background_images']['train'],
|
|
|
|
|
background_video_dir=DATA_PATHS['background_videos']['train'],
|
|
|
|
|
size=self.args.resolution_hr,
|
|
|
|
|
seq_length=self.args.seq_length_hr,
|
|
|
|
@ -206,38 +207,38 @@ class Trainer:
|
|
|
|
|
transform=VideoMatteTrainAugmentation(size_hr))
|
|
|
|
|
self.dataset_valid = VideoMatteDataset(
|
|
|
|
|
videomatte_dir=DATA_PATHS['videomatte']['valid'],
|
|
|
|
|
background_image_dir=DATA_PATHS['background_images']['valid'],
|
|
|
|
|
# background_image_dir=DATA_PATHS['background_images']['valid'],
|
|
|
|
|
background_video_dir=DATA_PATHS['background_videos']['valid'],
|
|
|
|
|
size=self.args.resolution_hr if self.args.train_hr else self.args.resolution_lr,
|
|
|
|
|
seq_length=self.args.seq_length_hr if self.args.train_hr else self.args.seq_length_lr,
|
|
|
|
|
seq_sampler=ValidFrameSampler(),
|
|
|
|
|
transform=VideoMatteValidAugmentation(size_hr if self.args.train_hr else size_lr))
|
|
|
|
|
else:
|
|
|
|
|
self.dataset_lr_train = ImageMatteDataset(
|
|
|
|
|
imagematte_dir=DATA_PATHS['imagematte']['train'],
|
|
|
|
|
background_image_dir=DATA_PATHS['background_images']['train'],
|
|
|
|
|
background_video_dir=DATA_PATHS['background_videos']['train'],
|
|
|
|
|
size=self.args.resolution_lr,
|
|
|
|
|
seq_length=self.args.seq_length_lr,
|
|
|
|
|
seq_sampler=TrainFrameSampler(),
|
|
|
|
|
transform=ImageMatteAugmentation(size_lr))
|
|
|
|
|
if self.args.train_hr:
|
|
|
|
|
self.dataset_hr_train = ImageMatteDataset(
|
|
|
|
|
imagematte_dir=DATA_PATHS['imagematte']['train'],
|
|
|
|
|
background_image_dir=DATA_PATHS['background_images']['train'],
|
|
|
|
|
background_video_dir=DATA_PATHS['background_videos']['train'],
|
|
|
|
|
size=self.args.resolution_hr,
|
|
|
|
|
seq_length=self.args.seq_length_hr,
|
|
|
|
|
seq_sampler=TrainFrameSampler(),
|
|
|
|
|
transform=ImageMatteAugmentation(size_hr))
|
|
|
|
|
self.dataset_valid = ImageMatteDataset(
|
|
|
|
|
imagematte_dir=DATA_PATHS['imagematte']['valid'],
|
|
|
|
|
background_image_dir=DATA_PATHS['background_images']['valid'],
|
|
|
|
|
background_video_dir=DATA_PATHS['background_videos']['valid'],
|
|
|
|
|
size=self.args.resolution_hr if self.args.train_hr else self.args.resolution_lr,
|
|
|
|
|
seq_length=self.args.seq_length_hr if self.args.train_hr else self.args.seq_length_lr,
|
|
|
|
|
seq_sampler=ValidFrameSampler(),
|
|
|
|
|
transform=ImageMatteAugmentation(size_hr if self.args.train_hr else size_lr))
|
|
|
|
|
# else:
|
|
|
|
|
# self.dataset_lr_train = ImageMatteDataset(
|
|
|
|
|
# imagematte_dir=DATA_PATHS['imagematte']['train'],
|
|
|
|
|
# background_image_dir=DATA_PATHS['background_images']['train'],
|
|
|
|
|
# background_video_dir=DATA_PATHS['background_videos']['train'],
|
|
|
|
|
# size=self.args.resolution_lr,
|
|
|
|
|
# seq_length=self.args.seq_length_lr,
|
|
|
|
|
# seq_sampler=TrainFrameSampler(),
|
|
|
|
|
# transform=ImageMatteAugmentation(size_lr))
|
|
|
|
|
# if self.args.train_hr:
|
|
|
|
|
# self.dataset_hr_train = ImageMatteDataset(
|
|
|
|
|
# imagematte_dir=DATA_PATHS['imagematte']['train'],
|
|
|
|
|
# background_image_dir=DATA_PATHS['background_images']['train'],
|
|
|
|
|
# background_video_dir=DATA_PATHS['background_videos']['train'],
|
|
|
|
|
# size=self.args.resolution_hr,
|
|
|
|
|
# seq_length=self.args.seq_length_hr,
|
|
|
|
|
# seq_sampler=TrainFrameSampler(),
|
|
|
|
|
# transform=ImageMatteAugmentation(size_hr))
|
|
|
|
|
# self.dataset_valid = ImageMatteDataset(
|
|
|
|
|
# imagematte_dir=DATA_PATHS['imagematte']['valid'],
|
|
|
|
|
# background_image_dir=DATA_PATHS['background_images']['valid'],
|
|
|
|
|
# background_video_dir=DATA_PATHS['background_videos']['valid'],
|
|
|
|
|
# size=self.args.resolution_hr if self.args.train_hr else self.args.resolution_lr,
|
|
|
|
|
# seq_length=self.args.seq_length_hr if self.args.train_hr else self.args.seq_length_lr,
|
|
|
|
|
# seq_sampler=ValidFrameSampler(),
|
|
|
|
|
# transform=ImageMatteAugmentation(size_hr if self.args.train_hr else size_lr))
|
|
|
|
|
|
|
|
|
|
# Matting dataloaders:
|
|
|
|
|
self.datasampler_lr_train = DistributedSampler(
|
|
|
|
@ -270,49 +271,49 @@ class Trainer:
|
|
|
|
|
pin_memory=True)
|
|
|
|
|
|
|
|
|
|
# Segementation datasets
|
|
|
|
|
self.log('Initializing image segmentation datasets')
|
|
|
|
|
self.dataset_seg_image = ConcatDataset([
|
|
|
|
|
CocoPanopticDataset(
|
|
|
|
|
imgdir=DATA_PATHS['coco_panoptic']['imgdir'],
|
|
|
|
|
anndir=DATA_PATHS['coco_panoptic']['anndir'],
|
|
|
|
|
annfile=DATA_PATHS['coco_panoptic']['annfile'],
|
|
|
|
|
transform=CocoPanopticTrainAugmentation(size_lr)),
|
|
|
|
|
SuperviselyPersonDataset(
|
|
|
|
|
imgdir=DATA_PATHS['spd']['imgdir'],
|
|
|
|
|
segdir=DATA_PATHS['spd']['segdir'],
|
|
|
|
|
transform=CocoPanopticTrainAugmentation(size_lr))
|
|
|
|
|
])
|
|
|
|
|
self.datasampler_seg_image = DistributedSampler(
|
|
|
|
|
dataset=self.dataset_seg_image,
|
|
|
|
|
rank=self.rank,
|
|
|
|
|
num_replicas=self.world_size,
|
|
|
|
|
shuffle=True)
|
|
|
|
|
self.dataloader_seg_image = DataLoader(
|
|
|
|
|
dataset=self.dataset_seg_image,
|
|
|
|
|
batch_size=self.args.batch_size_per_gpu * self.args.seq_length_lr,
|
|
|
|
|
num_workers=self.args.num_workers,
|
|
|
|
|
sampler=self.datasampler_seg_image,
|
|
|
|
|
pin_memory=True)
|
|
|
|
|
|
|
|
|
|
self.log('Initializing video segmentation datasets')
|
|
|
|
|
self.dataset_seg_video = YouTubeVISDataset(
|
|
|
|
|
videodir=DATA_PATHS['youtubevis']['videodir'],
|
|
|
|
|
annfile=DATA_PATHS['youtubevis']['annfile'],
|
|
|
|
|
size=self.args.resolution_lr,
|
|
|
|
|
seq_length=self.args.seq_length_lr,
|
|
|
|
|
seq_sampler=TrainFrameSampler(speed=[1]),
|
|
|
|
|
transform=YouTubeVISAugmentation(size_lr))
|
|
|
|
|
self.datasampler_seg_video = DistributedSampler(
|
|
|
|
|
dataset=self.dataset_seg_video,
|
|
|
|
|
rank=self.rank,
|
|
|
|
|
num_replicas=self.world_size,
|
|
|
|
|
shuffle=True)
|
|
|
|
|
self.dataloader_seg_video = DataLoader(
|
|
|
|
|
dataset=self.dataset_seg_video,
|
|
|
|
|
batch_size=self.args.batch_size_per_gpu,
|
|
|
|
|
num_workers=self.args.num_workers,
|
|
|
|
|
sampler=self.datasampler_seg_video,
|
|
|
|
|
pin_memory=True)
|
|
|
|
|
# self.log('Initializing image segmentation datasets')
|
|
|
|
|
# self.dataset_seg_image = ConcatDataset([
|
|
|
|
|
# CocoPanopticDataset(
|
|
|
|
|
# imgdir=DATA_PATHS['coco_panoptic']['imgdir'],
|
|
|
|
|
# anndir=DATA_PATHS['coco_panoptic']['anndir'],
|
|
|
|
|
# annfile=DATA_PATHS['coco_panoptic']['annfile'],
|
|
|
|
|
# transform=CocoPanopticTrainAugmentation(size_lr)),
|
|
|
|
|
# SuperviselyPersonDataset(
|
|
|
|
|
# imgdir=DATA_PATHS['spd']['imgdir'],
|
|
|
|
|
# segdir=DATA_PATHS['spd']['segdir'],
|
|
|
|
|
# transform=CocoPanopticTrainAugmentation(size_lr))
|
|
|
|
|
# ])
|
|
|
|
|
# self.datasampler_seg_image = DistributedSampler(
|
|
|
|
|
# dataset=self.dataset_seg_image,
|
|
|
|
|
# rank=self.rank,
|
|
|
|
|
# num_replicas=self.world_size,
|
|
|
|
|
# shuffle=True)
|
|
|
|
|
# self.dataloader_seg_image = DataLoader(
|
|
|
|
|
# dataset=self.dataset_seg_image,
|
|
|
|
|
# batch_size=self.args.batch_size_per_gpu * self.args.seq_length_lr,
|
|
|
|
|
# num_workers=self.args.num_workers,
|
|
|
|
|
# sampler=self.datasampler_seg_image,
|
|
|
|
|
# pin_memory=True)
|
|
|
|
|
|
|
|
|
|
# self.log('Initializing video segmentation datasets')
|
|
|
|
|
# self.dataset_seg_video = YouTubeVISDataset(
|
|
|
|
|
# videodir=DATA_PATHS['youtubevis']['videodir'],
|
|
|
|
|
# annfile=DATA_PATHS['youtubevis']['annfile'],
|
|
|
|
|
# size=self.args.resolution_lr,
|
|
|
|
|
# seq_length=self.args.seq_length_lr,
|
|
|
|
|
# seq_sampler=TrainFrameSampler(speed=[1]),
|
|
|
|
|
# transform=YouTubeVISAugmentation(size_lr))
|
|
|
|
|
# self.datasampler_seg_video = DistributedSampler(
|
|
|
|
|
# dataset=self.dataset_seg_video,
|
|
|
|
|
# rank=self.rank,
|
|
|
|
|
# num_replicas=self.world_size,
|
|
|
|
|
# shuffle=True)
|
|
|
|
|
# self.dataloader_seg_video = DataLoader(
|
|
|
|
|
# dataset=self.dataset_seg_video,
|
|
|
|
|
# batch_size=self.args.batch_size_per_gpu,
|
|
|
|
|
# num_workers=self.args.num_workers,
|
|
|
|
|
# sampler=self.datasampler_seg_video,
|
|
|
|
|
# pin_memory=True)
|
|
|
|
|
|
|
|
|
|
def init_model(self):
|
|
|
|
|
self.log('Initializing model')
|
|
|
|
@ -359,12 +360,12 @@ class Trainer:
|
|
|
|
|
self.train_mat(true_fgr, true_pha, true_bgr, downsample_ratio=self.args.downsample_ratio, tag='hr')
|
|
|
|
|
|
|
|
|
|
# Segmentation pass
|
|
|
|
|
if self.step % 2 == 0:
|
|
|
|
|
true_img, true_seg = self.load_next_seg_video_sample()
|
|
|
|
|
self.train_seg(true_img, true_seg, log_label='seg_video')
|
|
|
|
|
else:
|
|
|
|
|
true_img, true_seg = self.load_next_seg_image_sample()
|
|
|
|
|
self.train_seg(true_img.unsqueeze(1), true_seg.unsqueeze(1), log_label='seg_image')
|
|
|
|
|
# if self.step % 2 == 0:
|
|
|
|
|
# true_img, true_seg = self.load_next_seg_video_sample()
|
|
|
|
|
# self.train_seg(true_img, true_seg, log_label='seg_video')
|
|
|
|
|
# else:
|
|
|
|
|
# true_img, true_seg = self.load_next_seg_image_sample()
|
|
|
|
|
# self.train_seg(true_img.unsqueeze(1), true_seg.unsqueeze(1), log_label='seg_image')
|
|
|
|
|
|
|
|
|
|
if self.step % self.args.checkpoint_save_interval == 0:
|
|
|
|
|
self.save()
|
|
|
|
@ -376,10 +377,47 @@ class Trainer:
|
|
|
|
|
true_pha = true_pha.to(self.rank, non_blocking=True)
|
|
|
|
|
true_bgr = true_bgr.to(self.rank, non_blocking=True)
|
|
|
|
|
true_fgr, true_pha, true_bgr = self.random_crop(true_fgr, true_pha, true_bgr)
|
|
|
|
|
true_src = true_fgr * true_pha + true_bgr * (1 - true_pha)
|
|
|
|
|
|
|
|
|
|
true_src = true_bgr.clone()
|
|
|
|
|
|
|
|
|
|
# Augment bgr with shadow
|
|
|
|
|
aug_shadow_idx = torch.rand(len(true_src)) < 0.3
|
|
|
|
|
if aug_shadow_idx.any():
|
|
|
|
|
aug_shadow = true_pha[aug_shadow_idx].mul(0.3 * random.random()).flatten(start_dim = 0, end_dim = 1)
|
|
|
|
|
aug_shadow = T.RandomAffine(degrees=(-5, 5), translate=(0.2, 0.2), scale=(0.5, 1.5), shear=(-5, 5))(aug_shadow)
|
|
|
|
|
aug_shadow = kornia.filters.box_blur(aug_shadow, (random.choice(range(20, 40)),) * 2)
|
|
|
|
|
expected_shape = torch.tensor(true_src[aug_shadow_idx].shape)
|
|
|
|
|
expected_shape[2] = -1
|
|
|
|
|
true_src[aug_shadow_idx] = true_src[aug_shadow_idx].sub_(aug_shadow.reshape(expected_shape.tolist())).clamp_(0, 1)
|
|
|
|
|
del aug_shadow
|
|
|
|
|
del aug_shadow_idx
|
|
|
|
|
|
|
|
|
|
# Composite foreground onto source
|
|
|
|
|
true_src = true_fgr * true_pha + true_src * (1 - true_pha)
|
|
|
|
|
|
|
|
|
|
# Augment with noise
|
|
|
|
|
aug_noise_idx = torch.rand(len(true_src)) < 0.4
|
|
|
|
|
if aug_noise_idx.any():
|
|
|
|
|
true_src[aug_noise_idx] = true_src[aug_noise_idx].add_(torch.randn_like(true_src[aug_noise_idx]).mul_(0.03 * random.random())).clamp_(0, 1)
|
|
|
|
|
true_bgr[aug_noise_idx] = true_bgr[aug_noise_idx].add_(torch.randn_like(true_bgr[aug_noise_idx]).mul_(0.03 * random.random())).clamp_(0, 1)
|
|
|
|
|
del aug_noise_idx
|
|
|
|
|
|
|
|
|
|
# Augment background with jitter
|
|
|
|
|
aug_jitter_idx = torch.rand(len(true_src)) < 0.8
|
|
|
|
|
if aug_jitter_idx.any():
|
|
|
|
|
true_bgr[aug_jitter_idx] = kornia.augmentation.ColorJitter(0.18, 0.18, 0.18, 0.1)(true_bgr[aug_jitter_idx].flatten(start_dim = 0, end_dim = 1)).reshape(true_bgr[aug_jitter_idx].shape)
|
|
|
|
|
del aug_jitter_idx
|
|
|
|
|
|
|
|
|
|
# Augment background with affine
|
|
|
|
|
aug_affine_idx = torch.rand(len(true_bgr)) < 0.3
|
|
|
|
|
if aug_affine_idx.any():
|
|
|
|
|
true_bgr[aug_affine_idx] = T.RandomAffine(degrees=(-1, 1), translate=(0.01, 0.01))(true_bgr[aug_affine_idx].flatten(start_dim = 0, end_dim = 1)).reshape(true_bgr[aug_affine_idx].shape)
|
|
|
|
|
del aug_affine_idx
|
|
|
|
|
|
|
|
|
|
fg_bg_input = torch.cat((true_src, true_bgr), dim = -3)
|
|
|
|
|
|
|
|
|
|
with autocast(enabled=not self.args.disable_mixed_precision):
|
|
|
|
|
pred_fgr, pred_pha = self.model_ddp(true_src, downsample_ratio=downsample_ratio)[:2]
|
|
|
|
|
pred_fgr, pred_pha = self.model_ddp(fg_bg_input, downsample_ratio=downsample_ratio)[:2]
|
|
|
|
|
loss = matting_loss(pred_fgr, pred_pha, true_fgr, true_pha)
|
|
|
|
|
|
|
|
|
|
self.scaler.scale(loss['total']).backward()
|
|
|
|
@ -398,28 +436,29 @@ class Trainer:
|
|
|
|
|
self.writer.add_image(f'train_{tag}_true_pha', make_grid(true_pha.flatten(0, 1), nrow=true_pha.size(1)), self.step)
|
|
|
|
|
self.writer.add_image(f'train_{tag}_true_src', make_grid(true_src.flatten(0, 1), nrow=true_src.size(1)), self.step)
|
|
|
|
|
|
|
|
|
|
def train_seg(self, true_img, true_seg, log_label):
|
|
|
|
|
true_img = true_img.to(self.rank, non_blocking=True)
|
|
|
|
|
true_seg = true_seg.to(self.rank, non_blocking=True)
|
|
|
|
|
# does not get called
|
|
|
|
|
# def train_seg(self, true_img, true_seg, log_label):
|
|
|
|
|
# true_img = true_img.to(self.rank, non_blocking=True)
|
|
|
|
|
# true_seg = true_seg.to(self.rank, non_blocking=True)
|
|
|
|
|
|
|
|
|
|
true_img, true_seg = self.random_crop(true_img, true_seg)
|
|
|
|
|
# true_img, true_seg = self.random_crop(true_img, true_seg)
|
|
|
|
|
|
|
|
|
|
with autocast(enabled=not self.args.disable_mixed_precision):
|
|
|
|
|
pred_seg = self.model_ddp(true_img, segmentation_pass=True)[0]
|
|
|
|
|
loss = segmentation_loss(pred_seg, true_seg)
|
|
|
|
|
# with autocast(enabled=not self.args.disable_mixed_precision):
|
|
|
|
|
# pred_seg = self.model_ddp(true_img, segmentation_pass=True)[0]
|
|
|
|
|
# loss = segmentation_loss(pred_seg, true_seg)
|
|
|
|
|
|
|
|
|
|
self.scaler.scale(loss).backward()
|
|
|
|
|
self.scaler.step(self.optimizer)
|
|
|
|
|
self.scaler.update()
|
|
|
|
|
self.optimizer.zero_grad()
|
|
|
|
|
# self.scaler.scale(loss).backward()
|
|
|
|
|
# self.scaler.step(self.optimizer)
|
|
|
|
|
# self.scaler.update()
|
|
|
|
|
# self.optimizer.zero_grad()
|
|
|
|
|
|
|
|
|
|
if self.rank == 0 and (self.step - self.step % 2) % self.args.log_train_loss_interval == 0:
|
|
|
|
|
self.writer.add_scalar(f'{log_label}_loss', loss, self.step)
|
|
|
|
|
# if self.rank == 0 and (self.step - self.step % 2) % self.args.log_train_loss_interval == 0:
|
|
|
|
|
# self.writer.add_scalar(f'{log_label}_loss', loss, self.step)
|
|
|
|
|
|
|
|
|
|
if self.rank == 0 and (self.step - self.step % 2) % self.args.log_train_images_interval == 0:
|
|
|
|
|
self.writer.add_image(f'{log_label}_pred_seg', make_grid(pred_seg.flatten(0, 1).float().sigmoid(), nrow=self.args.seq_length_lr), self.step)
|
|
|
|
|
self.writer.add_image(f'{log_label}_true_seg', make_grid(true_seg.flatten(0, 1), nrow=self.args.seq_length_lr), self.step)
|
|
|
|
|
self.writer.add_image(f'{log_label}_true_img', make_grid(true_img.flatten(0, 1), nrow=self.args.seq_length_lr), self.step)
|
|
|
|
|
# if self.rank == 0 and (self.step - self.step % 2) % self.args.log_train_images_interval == 0:
|
|
|
|
|
# self.writer.add_image(f'{log_label}_pred_seg', make_grid(pred_seg.flatten(0, 1).float().sigmoid(), nrow=self.args.seq_length_lr), self.step)
|
|
|
|
|
# self.writer.add_image(f'{log_label}_true_seg', make_grid(true_seg.flatten(0, 1), nrow=self.args.seq_length_lr), self.step)
|
|
|
|
|
# self.writer.add_image(f'{log_label}_true_img', make_grid(true_img.flatten(0, 1), nrow=self.args.seq_length_lr), self.step)
|
|
|
|
|
|
|
|
|
|
def load_next_mat_hr_sample(self):
|
|
|
|
|
try:
|
|
|
|
@ -430,23 +469,23 @@ class Trainer:
|
|
|
|
|
sample = next(self.dataiterator_mat_hr)
|
|
|
|
|
return sample
|
|
|
|
|
|
|
|
|
|
def load_next_seg_video_sample(self):
|
|
|
|
|
try:
|
|
|
|
|
sample = next(self.dataiterator_seg_video)
|
|
|
|
|
except:
|
|
|
|
|
self.datasampler_seg_video.set_epoch(self.datasampler_seg_video.epoch + 1)
|
|
|
|
|
self.dataiterator_seg_video = iter(self.dataloader_seg_video)
|
|
|
|
|
sample = next(self.dataiterator_seg_video)
|
|
|
|
|
return sample
|
|
|
|
|
|
|
|
|
|
def load_next_seg_image_sample(self):
|
|
|
|
|
try:
|
|
|
|
|
sample = next(self.dataiterator_seg_image)
|
|
|
|
|
except:
|
|
|
|
|
self.datasampler_seg_image.set_epoch(self.datasampler_seg_image.epoch + 1)
|
|
|
|
|
self.dataiterator_seg_image = iter(self.dataloader_seg_image)
|
|
|
|
|
sample = next(self.dataiterator_seg_image)
|
|
|
|
|
return sample
|
|
|
|
|
# def load_next_seg_video_sample(self):
|
|
|
|
|
# try:
|
|
|
|
|
# sample = next(self.dataiterator_seg_video)
|
|
|
|
|
# except:
|
|
|
|
|
# self.datasampler_seg_video.set_epoch(self.datasampler_seg_video.epoch + 1)
|
|
|
|
|
# self.dataiterator_seg_video = iter(self.dataloader_seg_video)
|
|
|
|
|
# sample = next(self.dataiterator_seg_video)
|
|
|
|
|
# return sample
|
|
|
|
|
|
|
|
|
|
# def load_next_seg_image_sample(self):
|
|
|
|
|
# try:
|
|
|
|
|
# sample = next(self.dataiterator_seg_image)
|
|
|
|
|
# except:
|
|
|
|
|
# self.datasampler_seg_image.set_epoch(self.datasampler_seg_image.epoch + 1)
|
|
|
|
|
# self.dataiterator_seg_image = iter(self.dataloader_seg_image)
|
|
|
|
|
# sample = next(self.dataiterator_seg_image)
|
|
|
|
|
# return sample
|
|
|
|
|
|
|
|
|
|
def validate(self):
|
|
|
|
|
if self.rank == 0:
|
|
|
|
@ -461,7 +500,9 @@ class Trainer:
|
|
|
|
|
true_bgr = true_bgr.to(self.rank, non_blocking=True)
|
|
|
|
|
true_src = true_fgr * true_pha + true_bgr * (1 - true_pha)
|
|
|
|
|
batch_size = true_src.size(0)
|
|
|
|
|
pred_fgr, pred_pha = self.model(true_src)[:2]
|
|
|
|
|
|
|
|
|
|
fg_bg_input = torch.cat((true_src, true_bgr), dim = -3)
|
|
|
|
|
pred_fgr, pred_pha = self.model(fg_bg_input)[:2]
|
|
|
|
|
total_loss += matting_loss(pred_fgr, pred_pha, true_fgr, true_pha)['total'].item() * batch_size
|
|
|
|
|
total_count += batch_size
|
|
|
|
|
avg_loss = total_loss / total_count
|
|
|
|
|