added code for training with bgr

pull/254/head
Kartikaeya 1 year ago
parent 53d74c6826
commit be683689db

@ -9,14 +9,14 @@ from .augmentation import MotionAugmentation
class VideoMatteDataset(Dataset): class VideoMatteDataset(Dataset):
def __init__(self, def __init__(self,
videomatte_dir, videomatte_dir,
background_image_dir, # background_image_dir,
background_video_dir, background_video_dir,
size, size,
seq_length, seq_length,
seq_sampler, seq_sampler,
transform=None): transform=None):
self.background_image_dir = background_image_dir # self.background_image_dir = background_image_dir
self.background_image_files = os.listdir(background_image_dir) # self.background_image_files = os.listdir(background_image_dir)
self.background_video_dir = background_video_dir self.background_video_dir = background_video_dir
self.background_video_clips = sorted(os.listdir(background_video_dir)) self.background_video_clips = sorted(os.listdir(background_video_dir))
self.background_video_frames = [sorted(os.listdir(os.path.join(background_video_dir, clip))) self.background_video_frames = [sorted(os.listdir(os.path.join(background_video_dir, clip)))
@ -38,9 +38,9 @@ class VideoMatteDataset(Dataset):
return len(self.videomatte_idx) return len(self.videomatte_idx)
def __getitem__(self, idx): def __getitem__(self, idx):
if random.random() < 0.5: # if random.random() < 0.5:
bgrs = self._get_random_image_background() # bgrs = self._get_random_image_background()
else: # else:
bgrs = self._get_random_video_background() bgrs = self._get_random_video_background()
fgrs, phas = self._get_videomatte(idx) fgrs, phas = self._get_videomatte(idx)
@ -50,11 +50,11 @@ class VideoMatteDataset(Dataset):
return fgrs, phas, bgrs return fgrs, phas, bgrs
def _get_random_image_background(self): # def _get_random_image_background(self):
with Image.open(os.path.join(self.background_image_dir, random.choice(self.background_image_files))) as bgr: # with Image.open(os.path.join(self.background_image_dir, random.choice(self.background_image_files))) as bgr:
bgr = self._downsample_if_needed(bgr.convert('RGB')) # bgr = self._downsample_if_needed(bgr.convert('RGB'))
bgrs = [bgr] * self.seq_length # bgrs = [bgr] * self.seq_length
return bgrs # return bgrs
def _get_random_video_background(self): def _get_random_video_background(self):
clip_idx = random.choice(range(len(self.background_video_clips))) clip_idx = random.choice(range(len(self.background_video_clips)))

@ -120,6 +120,10 @@ def convert_video(model,
rec = [None] * 4 rec = [None] * 4
for src in reader: for src in reader:
if src.shape[-1] %2 == 1:
src = src[:, :, :, :-1]
if src.shape[-2] %2 == 1:
src = src[:, :, :-1, :]
if downsample_ratio is None: if downsample_ratio is None:
downsample_ratio = auto_downsample_ratio(*src.shape[2:]) downsample_ratio = auto_downsample_ratio(*src.shape[2:])

@ -5,7 +5,7 @@ import numpy as np
from torch.utils.data import Dataset from torch.utils.data import Dataset
from torchvision.transforms.functional import to_pil_image from torchvision.transforms.functional import to_pil_image
from PIL import Image from PIL import Image
import torch
class VideoReader(Dataset): class VideoReader(Dataset):
def __init__(self, path, transform=None): def __init__(self, path, transform=None):
@ -55,18 +55,23 @@ class VideoWriter:
class ImageSequenceReader(Dataset): class ImageSequenceReader(Dataset):
def __init__(self, path, transform=None): def __init__(self, path, transform=None):
self.path = path self.path = path
self.files = sorted(os.listdir(path)) self.files_fgr = sorted(os.listdir(path + "fgr/"))
self.files_bgr = sorted(os.listdir(path + "bgr/"))
self.transform = transform self.transform = transform
def __len__(self): def __len__(self):
return len(self.files) return len(self.files_fgr)
def __getitem__(self, idx): def __getitem__(self, idx):
with Image.open(os.path.join(self.path, self.files[idx])) as img: with Image.open(os.path.join(self.path + "fgr/", self.files_fgr[idx])) as fgr_img:
img.load() fgr_img.load()
with Image.open(os.path.join(self.path + "bgr/", self.files_bgr[idx])) as bgr_img:
bgr_img.load()
if self.transform is not None: if self.transform is not None:
return self.transform(img) return torch.cat([self.transform(fgr_img), self.transform(bgr_img)], dim = 0)
return img return fgr_img
class ImageSequenceWriter: class ImageSequenceWriter:

@ -1,7 +1,7 @@
import torch import torch
from torch import Tensor from torch import Tensor
from torch import nn from torch import nn
from torch.nn import functional as F # from torch.nn import functional as F
from typing import Tuple, Optional from typing import Tuple, Optional
class RecurrentDecoder(nn.Module): class RecurrentDecoder(nn.Module):
@ -9,10 +9,10 @@ class RecurrentDecoder(nn.Module):
super().__init__() super().__init__()
self.avgpool = AvgPool() self.avgpool = AvgPool()
self.decode4 = BottleneckBlock(feature_channels[3]) self.decode4 = BottleneckBlock(feature_channels[3])
self.decode3 = UpsamplingBlock(feature_channels[3], feature_channels[2], 3, decoder_channels[0]) self.decode3 = UpsamplingBlock(feature_channels[3], feature_channels[2], 6, decoder_channels[0])
self.decode2 = UpsamplingBlock(decoder_channels[0], feature_channels[1], 3, decoder_channels[1]) self.decode2 = UpsamplingBlock(decoder_channels[0], feature_channels[1], 6, decoder_channels[1])
self.decode1 = UpsamplingBlock(decoder_channels[1], feature_channels[0], 3, decoder_channels[2]) self.decode1 = UpsamplingBlock(decoder_channels[1], feature_channels[0], 6, decoder_channels[2])
self.decode0 = OutputBlock(decoder_channels[2], 3, decoder_channels[3]) self.decode0 = OutputBlock(decoder_channels[2], 6, decoder_channels[3])
def forward(self, def forward(self,
s0: Tensor, f1: Tensor, f2: Tensor, f3: Tensor, f4: Tensor, s0: Tensor, f1: Tensor, f2: Tensor, f3: Tensor, f4: Tensor,

@ -3,6 +3,21 @@ from torch import nn
from torchvision.models.mobilenetv3 import MobileNetV3, InvertedResidualConfig from torchvision.models.mobilenetv3 import MobileNetV3, InvertedResidualConfig
from torchvision.transforms.functional import normalize from torchvision.transforms.functional import normalize
def load_matched_state_dict(model, state_dict, print_stats=True):
"""
Only loads weights that matched in key and shape. Ignore other weights.
"""
num_matched, num_total = 0, 0
curr_state_dict = model.state_dict()
for key in curr_state_dict.keys():
num_total += 1
if key in state_dict and curr_state_dict[key].shape == state_dict[key].shape:
curr_state_dict[key] = state_dict[key]
num_matched += 1
model.load_state_dict(curr_state_dict)
if print_stats:
print(f'Loaded state_dict: {num_matched}/{num_total} matched')
class MobileNetV3LargeEncoder(MobileNetV3): class MobileNetV3LargeEncoder(MobileNetV3):
def __init__(self, pretrained: bool = False): def __init__(self, pretrained: bool = False):
super().__init__( super().__init__(
@ -27,14 +42,24 @@ class MobileNetV3LargeEncoder(MobileNetV3):
) )
if pretrained: if pretrained:
self.load_state_dict(torch.hub.load_state_dict_from_url( pretrained_state_dict = torch.hub.load_state_dict_from_url(
'https://download.pytorch.org/models/mobilenet_v3_large-8738ca79.pth')) 'https://download.pytorch.org/models/mobilenet_v3_large-8738ca79.pth')
# print("pretrained_state_dict keys \n \n ", pretrained_state_dict.keys())
# print("\n\ncurrent model state dict keys \n\n", self.state_dict().keys())
load_matched_state_dict(self, pretrained_state_dict)
# self.load_state_dict(torch.hub.load_state_dict_from_url(
# 'https://download.pytorch.org/models/mobilenet_v3_large-8738ca79.pth'))
del self.avgpool del self.avgpool
del self.classifier del self.classifier
def forward_single_frame(self, x): def forward_single_frame(self, x):
x = normalize(x, [0.485, 0.456, 0.406], [0.229, 0.224, 0.225]) # print(x.shape)
x = torch.cat((normalize(x[:, :3, ...], [0.485, 0.456, 0.406], [0.229, 0.224, 0.225]), normalize(x[:, 3:, ...], [0.485, 0.456, 0.406], [0.229, 0.224, 0.225])), dim = -3)
x = self.features[0](x) x = self.features[0](x)
x = self.features[1](x) x = self.features[1](x)

@ -1,6 +1,7 @@
import torch import torch
from torch import Tensor from torch import Tensor
from torch import nn from torch import nn
from torchsummary import summary
from torch.nn import functional as F from torch.nn import functional as F
from typing import Optional, List from typing import Optional, List
@ -58,8 +59,8 @@ class MattingNetwork(nn.Module):
if not segmentation_pass: if not segmentation_pass:
fgr_residual, pha = self.project_mat(hid).split([3, 1], dim=-3) fgr_residual, pha = self.project_mat(hid).split([3, 1], dim=-3)
if downsample_ratio != 1: if downsample_ratio != 1:
fgr_residual, pha = self.refiner(src, src_sm, fgr_residual, pha, hid) fgr_residual, pha = self.refiner(src[:, :, :3, ...], src_sm[:, :, :3, ...], fgr_residual, pha, hid)
fgr = fgr_residual + src fgr = fgr_residual + src[:, :, :3, ...]
fgr = fgr.clamp(0., 1.) fgr = fgr.clamp(0., 1.)
pha = pha.clamp(0., 1.) pha = pha.clamp(0., 1.)
return [fgr, pha, *rec] return [fgr, pha, *rec]

@ -1,5 +1,7 @@
easing_functions==1.0.4 easing_functions==1.0.4
tensorboard==2.5.0 tensorboard
torch==1.9.0 torch
torchvision==0.10.0 torchvision
tqdm==4.61.1 tqdm==4.61.1
opencv-python==4.6.0.66
torchsummary

@ -121,7 +121,8 @@ from dataset.augmentation import (
from model import MattingNetwork from model import MattingNetwork
from train_config import DATA_PATHS from train_config import DATA_PATHS
from train_loss import matting_loss, segmentation_loss from train_loss import matting_loss, segmentation_loss
import kornia
from torchvision import transforms as T
class Trainer: class Trainer:
def __init__(self, rank, world_size): def __init__(self, rank, world_size):
@ -189,7 +190,7 @@ class Trainer:
if self.args.dataset == 'videomatte': if self.args.dataset == 'videomatte':
self.dataset_lr_train = VideoMatteDataset( self.dataset_lr_train = VideoMatteDataset(
videomatte_dir=DATA_PATHS['videomatte']['train'], videomatte_dir=DATA_PATHS['videomatte']['train'],
background_image_dir=DATA_PATHS['background_images']['train'], # background_image_dir=DATA_PATHS['background_images']['train'],
background_video_dir=DATA_PATHS['background_videos']['train'], background_video_dir=DATA_PATHS['background_videos']['train'],
size=self.args.resolution_lr, size=self.args.resolution_lr,
seq_length=self.args.seq_length_lr, seq_length=self.args.seq_length_lr,
@ -198,7 +199,7 @@ class Trainer:
if self.args.train_hr: if self.args.train_hr:
self.dataset_hr_train = VideoMatteDataset( self.dataset_hr_train = VideoMatteDataset(
videomatte_dir=DATA_PATHS['videomatte']['train'], videomatte_dir=DATA_PATHS['videomatte']['train'],
background_image_dir=DATA_PATHS['background_images']['train'], # background_image_dir=DATA_PATHS['background_images']['train'],
background_video_dir=DATA_PATHS['background_videos']['train'], background_video_dir=DATA_PATHS['background_videos']['train'],
size=self.args.resolution_hr, size=self.args.resolution_hr,
seq_length=self.args.seq_length_hr, seq_length=self.args.seq_length_hr,
@ -206,38 +207,38 @@ class Trainer:
transform=VideoMatteTrainAugmentation(size_hr)) transform=VideoMatteTrainAugmentation(size_hr))
self.dataset_valid = VideoMatteDataset( self.dataset_valid = VideoMatteDataset(
videomatte_dir=DATA_PATHS['videomatte']['valid'], videomatte_dir=DATA_PATHS['videomatte']['valid'],
background_image_dir=DATA_PATHS['background_images']['valid'], # background_image_dir=DATA_PATHS['background_images']['valid'],
background_video_dir=DATA_PATHS['background_videos']['valid'], background_video_dir=DATA_PATHS['background_videos']['valid'],
size=self.args.resolution_hr if self.args.train_hr else self.args.resolution_lr, size=self.args.resolution_hr if self.args.train_hr else self.args.resolution_lr,
seq_length=self.args.seq_length_hr if self.args.train_hr else self.args.seq_length_lr, seq_length=self.args.seq_length_hr if self.args.train_hr else self.args.seq_length_lr,
seq_sampler=ValidFrameSampler(), seq_sampler=ValidFrameSampler(),
transform=VideoMatteValidAugmentation(size_hr if self.args.train_hr else size_lr)) transform=VideoMatteValidAugmentation(size_hr if self.args.train_hr else size_lr))
else: # else:
self.dataset_lr_train = ImageMatteDataset( # self.dataset_lr_train = ImageMatteDataset(
imagematte_dir=DATA_PATHS['imagematte']['train'], # imagematte_dir=DATA_PATHS['imagematte']['train'],
background_image_dir=DATA_PATHS['background_images']['train'], # background_image_dir=DATA_PATHS['background_images']['train'],
background_video_dir=DATA_PATHS['background_videos']['train'], # background_video_dir=DATA_PATHS['background_videos']['train'],
size=self.args.resolution_lr, # size=self.args.resolution_lr,
seq_length=self.args.seq_length_lr, # seq_length=self.args.seq_length_lr,
seq_sampler=TrainFrameSampler(), # seq_sampler=TrainFrameSampler(),
transform=ImageMatteAugmentation(size_lr)) # transform=ImageMatteAugmentation(size_lr))
if self.args.train_hr: # if self.args.train_hr:
self.dataset_hr_train = ImageMatteDataset( # self.dataset_hr_train = ImageMatteDataset(
imagematte_dir=DATA_PATHS['imagematte']['train'], # imagematte_dir=DATA_PATHS['imagematte']['train'],
background_image_dir=DATA_PATHS['background_images']['train'], # background_image_dir=DATA_PATHS['background_images']['train'],
background_video_dir=DATA_PATHS['background_videos']['train'], # background_video_dir=DATA_PATHS['background_videos']['train'],
size=self.args.resolution_hr, # size=self.args.resolution_hr,
seq_length=self.args.seq_length_hr, # seq_length=self.args.seq_length_hr,
seq_sampler=TrainFrameSampler(), # seq_sampler=TrainFrameSampler(),
transform=ImageMatteAugmentation(size_hr)) # transform=ImageMatteAugmentation(size_hr))
self.dataset_valid = ImageMatteDataset( # self.dataset_valid = ImageMatteDataset(
imagematte_dir=DATA_PATHS['imagematte']['valid'], # imagematte_dir=DATA_PATHS['imagematte']['valid'],
background_image_dir=DATA_PATHS['background_images']['valid'], # background_image_dir=DATA_PATHS['background_images']['valid'],
background_video_dir=DATA_PATHS['background_videos']['valid'], # background_video_dir=DATA_PATHS['background_videos']['valid'],
size=self.args.resolution_hr if self.args.train_hr else self.args.resolution_lr, # size=self.args.resolution_hr if self.args.train_hr else self.args.resolution_lr,
seq_length=self.args.seq_length_hr if self.args.train_hr else self.args.seq_length_lr, # seq_length=self.args.seq_length_hr if self.args.train_hr else self.args.seq_length_lr,
seq_sampler=ValidFrameSampler(), # seq_sampler=ValidFrameSampler(),
transform=ImageMatteAugmentation(size_hr if self.args.train_hr else size_lr)) # transform=ImageMatteAugmentation(size_hr if self.args.train_hr else size_lr))
# Matting dataloaders: # Matting dataloaders:
self.datasampler_lr_train = DistributedSampler( self.datasampler_lr_train = DistributedSampler(
@ -270,49 +271,49 @@ class Trainer:
pin_memory=True) pin_memory=True)
# Segementation datasets # Segementation datasets
self.log('Initializing image segmentation datasets') # self.log('Initializing image segmentation datasets')
self.dataset_seg_image = ConcatDataset([ # self.dataset_seg_image = ConcatDataset([
CocoPanopticDataset( # CocoPanopticDataset(
imgdir=DATA_PATHS['coco_panoptic']['imgdir'], # imgdir=DATA_PATHS['coco_panoptic']['imgdir'],
anndir=DATA_PATHS['coco_panoptic']['anndir'], # anndir=DATA_PATHS['coco_panoptic']['anndir'],
annfile=DATA_PATHS['coco_panoptic']['annfile'], # annfile=DATA_PATHS['coco_panoptic']['annfile'],
transform=CocoPanopticTrainAugmentation(size_lr)), # transform=CocoPanopticTrainAugmentation(size_lr)),
SuperviselyPersonDataset( # SuperviselyPersonDataset(
imgdir=DATA_PATHS['spd']['imgdir'], # imgdir=DATA_PATHS['spd']['imgdir'],
segdir=DATA_PATHS['spd']['segdir'], # segdir=DATA_PATHS['spd']['segdir'],
transform=CocoPanopticTrainAugmentation(size_lr)) # transform=CocoPanopticTrainAugmentation(size_lr))
]) # ])
self.datasampler_seg_image = DistributedSampler( # self.datasampler_seg_image = DistributedSampler(
dataset=self.dataset_seg_image, # dataset=self.dataset_seg_image,
rank=self.rank, # rank=self.rank,
num_replicas=self.world_size, # num_replicas=self.world_size,
shuffle=True) # shuffle=True)
self.dataloader_seg_image = DataLoader( # self.dataloader_seg_image = DataLoader(
dataset=self.dataset_seg_image, # dataset=self.dataset_seg_image,
batch_size=self.args.batch_size_per_gpu * self.args.seq_length_lr, # batch_size=self.args.batch_size_per_gpu * self.args.seq_length_lr,
num_workers=self.args.num_workers, # num_workers=self.args.num_workers,
sampler=self.datasampler_seg_image, # sampler=self.datasampler_seg_image,
pin_memory=True) # pin_memory=True)
self.log('Initializing video segmentation datasets') # self.log('Initializing video segmentation datasets')
self.dataset_seg_video = YouTubeVISDataset( # self.dataset_seg_video = YouTubeVISDataset(
videodir=DATA_PATHS['youtubevis']['videodir'], # videodir=DATA_PATHS['youtubevis']['videodir'],
annfile=DATA_PATHS['youtubevis']['annfile'], # annfile=DATA_PATHS['youtubevis']['annfile'],
size=self.args.resolution_lr, # size=self.args.resolution_lr,
seq_length=self.args.seq_length_lr, # seq_length=self.args.seq_length_lr,
seq_sampler=TrainFrameSampler(speed=[1]), # seq_sampler=TrainFrameSampler(speed=[1]),
transform=YouTubeVISAugmentation(size_lr)) # transform=YouTubeVISAugmentation(size_lr))
self.datasampler_seg_video = DistributedSampler( # self.datasampler_seg_video = DistributedSampler(
dataset=self.dataset_seg_video, # dataset=self.dataset_seg_video,
rank=self.rank, # rank=self.rank,
num_replicas=self.world_size, # num_replicas=self.world_size,
shuffle=True) # shuffle=True)
self.dataloader_seg_video = DataLoader( # self.dataloader_seg_video = DataLoader(
dataset=self.dataset_seg_video, # dataset=self.dataset_seg_video,
batch_size=self.args.batch_size_per_gpu, # batch_size=self.args.batch_size_per_gpu,
num_workers=self.args.num_workers, # num_workers=self.args.num_workers,
sampler=self.datasampler_seg_video, # sampler=self.datasampler_seg_video,
pin_memory=True) # pin_memory=True)
def init_model(self): def init_model(self):
self.log('Initializing model') self.log('Initializing model')
@ -359,12 +360,12 @@ class Trainer:
self.train_mat(true_fgr, true_pha, true_bgr, downsample_ratio=self.args.downsample_ratio, tag='hr') self.train_mat(true_fgr, true_pha, true_bgr, downsample_ratio=self.args.downsample_ratio, tag='hr')
# Segmentation pass # Segmentation pass
if self.step % 2 == 0: # if self.step % 2 == 0:
true_img, true_seg = self.load_next_seg_video_sample() # true_img, true_seg = self.load_next_seg_video_sample()
self.train_seg(true_img, true_seg, log_label='seg_video') # self.train_seg(true_img, true_seg, log_label='seg_video')
else: # else:
true_img, true_seg = self.load_next_seg_image_sample() # true_img, true_seg = self.load_next_seg_image_sample()
self.train_seg(true_img.unsqueeze(1), true_seg.unsqueeze(1), log_label='seg_image') # self.train_seg(true_img.unsqueeze(1), true_seg.unsqueeze(1), log_label='seg_image')
if self.step % self.args.checkpoint_save_interval == 0: if self.step % self.args.checkpoint_save_interval == 0:
self.save() self.save()
@ -376,10 +377,47 @@ class Trainer:
true_pha = true_pha.to(self.rank, non_blocking=True) true_pha = true_pha.to(self.rank, non_blocking=True)
true_bgr = true_bgr.to(self.rank, non_blocking=True) true_bgr = true_bgr.to(self.rank, non_blocking=True)
true_fgr, true_pha, true_bgr = self.random_crop(true_fgr, true_pha, true_bgr) true_fgr, true_pha, true_bgr = self.random_crop(true_fgr, true_pha, true_bgr)
true_src = true_fgr * true_pha + true_bgr * (1 - true_pha)
true_src = true_bgr.clone()
# Augment bgr with shadow
aug_shadow_idx = torch.rand(len(true_src)) < 0.3
if aug_shadow_idx.any():
aug_shadow = true_pha[aug_shadow_idx].mul(0.3 * random.random()).flatten(start_dim = 0, end_dim = 1)
aug_shadow = T.RandomAffine(degrees=(-5, 5), translate=(0.2, 0.2), scale=(0.5, 1.5), shear=(-5, 5))(aug_shadow)
aug_shadow = kornia.filters.box_blur(aug_shadow, (random.choice(range(20, 40)),) * 2)
expected_shape = torch.tensor(true_src[aug_shadow_idx].shape)
expected_shape[2] = -1
true_src[aug_shadow_idx] = true_src[aug_shadow_idx].sub_(aug_shadow.reshape(expected_shape.tolist())).clamp_(0, 1)
del aug_shadow
del aug_shadow_idx
# Composite foreground onto source
true_src = true_fgr * true_pha + true_src * (1 - true_pha)
# Augment with noise
aug_noise_idx = torch.rand(len(true_src)) < 0.4
if aug_noise_idx.any():
true_src[aug_noise_idx] = true_src[aug_noise_idx].add_(torch.randn_like(true_src[aug_noise_idx]).mul_(0.03 * random.random())).clamp_(0, 1)
true_bgr[aug_noise_idx] = true_bgr[aug_noise_idx].add_(torch.randn_like(true_bgr[aug_noise_idx]).mul_(0.03 * random.random())).clamp_(0, 1)
del aug_noise_idx
# Augment background with jitter
aug_jitter_idx = torch.rand(len(true_src)) < 0.8
if aug_jitter_idx.any():
true_bgr[aug_jitter_idx] = kornia.augmentation.ColorJitter(0.18, 0.18, 0.18, 0.1)(true_bgr[aug_jitter_idx].flatten(start_dim = 0, end_dim = 1)).reshape(true_bgr[aug_jitter_idx].shape)
del aug_jitter_idx
# Augment background with affine
aug_affine_idx = torch.rand(len(true_bgr)) < 0.3
if aug_affine_idx.any():
true_bgr[aug_affine_idx] = T.RandomAffine(degrees=(-1, 1), translate=(0.01, 0.01))(true_bgr[aug_affine_idx].flatten(start_dim = 0, end_dim = 1)).reshape(true_bgr[aug_affine_idx].shape)
del aug_affine_idx
fg_bg_input = torch.cat((true_src, true_bgr), dim = -3)
with autocast(enabled=not self.args.disable_mixed_precision): with autocast(enabled=not self.args.disable_mixed_precision):
pred_fgr, pred_pha = self.model_ddp(true_src, downsample_ratio=downsample_ratio)[:2] pred_fgr, pred_pha = self.model_ddp(fg_bg_input, downsample_ratio=downsample_ratio)[:2]
loss = matting_loss(pred_fgr, pred_pha, true_fgr, true_pha) loss = matting_loss(pred_fgr, pred_pha, true_fgr, true_pha)
self.scaler.scale(loss['total']).backward() self.scaler.scale(loss['total']).backward()
@ -398,28 +436,29 @@ class Trainer:
self.writer.add_image(f'train_{tag}_true_pha', make_grid(true_pha.flatten(0, 1), nrow=true_pha.size(1)), self.step) self.writer.add_image(f'train_{tag}_true_pha', make_grid(true_pha.flatten(0, 1), nrow=true_pha.size(1)), self.step)
self.writer.add_image(f'train_{tag}_true_src', make_grid(true_src.flatten(0, 1), nrow=true_src.size(1)), self.step) self.writer.add_image(f'train_{tag}_true_src', make_grid(true_src.flatten(0, 1), nrow=true_src.size(1)), self.step)
def train_seg(self, true_img, true_seg, log_label): # does not get called
true_img = true_img.to(self.rank, non_blocking=True) # def train_seg(self, true_img, true_seg, log_label):
true_seg = true_seg.to(self.rank, non_blocking=True) # true_img = true_img.to(self.rank, non_blocking=True)
# true_seg = true_seg.to(self.rank, non_blocking=True)
true_img, true_seg = self.random_crop(true_img, true_seg) # true_img, true_seg = self.random_crop(true_img, true_seg)
with autocast(enabled=not self.args.disable_mixed_precision): # with autocast(enabled=not self.args.disable_mixed_precision):
pred_seg = self.model_ddp(true_img, segmentation_pass=True)[0] # pred_seg = self.model_ddp(true_img, segmentation_pass=True)[0]
loss = segmentation_loss(pred_seg, true_seg) # loss = segmentation_loss(pred_seg, true_seg)
self.scaler.scale(loss).backward() # self.scaler.scale(loss).backward()
self.scaler.step(self.optimizer) # self.scaler.step(self.optimizer)
self.scaler.update() # self.scaler.update()
self.optimizer.zero_grad() # self.optimizer.zero_grad()
if self.rank == 0 and (self.step - self.step % 2) % self.args.log_train_loss_interval == 0: # if self.rank == 0 and (self.step - self.step % 2) % self.args.log_train_loss_interval == 0:
self.writer.add_scalar(f'{log_label}_loss', loss, self.step) # self.writer.add_scalar(f'{log_label}_loss', loss, self.step)
if self.rank == 0 and (self.step - self.step % 2) % self.args.log_train_images_interval == 0: # if self.rank == 0 and (self.step - self.step % 2) % self.args.log_train_images_interval == 0:
self.writer.add_image(f'{log_label}_pred_seg', make_grid(pred_seg.flatten(0, 1).float().sigmoid(), nrow=self.args.seq_length_lr), self.step) # self.writer.add_image(f'{log_label}_pred_seg', make_grid(pred_seg.flatten(0, 1).float().sigmoid(), nrow=self.args.seq_length_lr), self.step)
self.writer.add_image(f'{log_label}_true_seg', make_grid(true_seg.flatten(0, 1), nrow=self.args.seq_length_lr), self.step) # self.writer.add_image(f'{log_label}_true_seg', make_grid(true_seg.flatten(0, 1), nrow=self.args.seq_length_lr), self.step)
self.writer.add_image(f'{log_label}_true_img', make_grid(true_img.flatten(0, 1), nrow=self.args.seq_length_lr), self.step) # self.writer.add_image(f'{log_label}_true_img', make_grid(true_img.flatten(0, 1), nrow=self.args.seq_length_lr), self.step)
def load_next_mat_hr_sample(self): def load_next_mat_hr_sample(self):
try: try:
@ -430,23 +469,23 @@ class Trainer:
sample = next(self.dataiterator_mat_hr) sample = next(self.dataiterator_mat_hr)
return sample return sample
def load_next_seg_video_sample(self): # def load_next_seg_video_sample(self):
try: # try:
sample = next(self.dataiterator_seg_video) # sample = next(self.dataiterator_seg_video)
except: # except:
self.datasampler_seg_video.set_epoch(self.datasampler_seg_video.epoch + 1) # self.datasampler_seg_video.set_epoch(self.datasampler_seg_video.epoch + 1)
self.dataiterator_seg_video = iter(self.dataloader_seg_video) # self.dataiterator_seg_video = iter(self.dataloader_seg_video)
sample = next(self.dataiterator_seg_video) # sample = next(self.dataiterator_seg_video)
return sample # return sample
def load_next_seg_image_sample(self): # def load_next_seg_image_sample(self):
try: # try:
sample = next(self.dataiterator_seg_image) # sample = next(self.dataiterator_seg_image)
except: # except:
self.datasampler_seg_image.set_epoch(self.datasampler_seg_image.epoch + 1) # self.datasampler_seg_image.set_epoch(self.datasampler_seg_image.epoch + 1)
self.dataiterator_seg_image = iter(self.dataloader_seg_image) # self.dataiterator_seg_image = iter(self.dataloader_seg_image)
sample = next(self.dataiterator_seg_image) # sample = next(self.dataiterator_seg_image)
return sample # return sample
def validate(self): def validate(self):
if self.rank == 0: if self.rank == 0:
@ -461,7 +500,9 @@ class Trainer:
true_bgr = true_bgr.to(self.rank, non_blocking=True) true_bgr = true_bgr.to(self.rank, non_blocking=True)
true_src = true_fgr * true_pha + true_bgr * (1 - true_pha) true_src = true_fgr * true_pha + true_bgr * (1 - true_pha)
batch_size = true_src.size(0) batch_size = true_src.size(0)
pred_fgr, pred_pha = self.model(true_src)[:2]
fg_bg_input = torch.cat((true_src, true_bgr), dim = -3)
pred_fgr, pred_pha = self.model(fg_bg_input)[:2]
total_loss += matting_loss(pred_fgr, pred_pha, true_fgr, true_pha)['total'].item() * batch_size total_loss += matting_loss(pred_fgr, pred_pha, true_fgr, true_pha)['total'].item() * batch_size
total_count += batch_size total_count += batch_size
avg_loss = total_loss / total_count avg_loss = total_loss / total_count

@ -37,32 +37,32 @@ DATA_PATHS = {
'train': '../matting-data/VideoMatte240K_JPEG_SD/train', 'train': '../matting-data/VideoMatte240K_JPEG_SD/train',
'valid': '../matting-data/VideoMatte240K_JPEG_SD/valid', 'valid': '../matting-data/VideoMatte240K_JPEG_SD/valid',
}, },
'imagematte': { # 'imagematte': {
'train': '../matting-data/ImageMatte/train', # 'train': '../matting-data/ImageMatte/train',
'valid': '../matting-data/ImageMatte/valid', # 'valid': '../matting-data/ImageMatte/valid',
}, # },
'background_images': { # 'background_images': {
'train': '../matting-data/Backgrounds/train', # 'train': '../matting-data/Backgrounds/train',
'valid': '../matting-data/Backgrounds/valid', # 'valid': '../matting-data/Backgrounds/valid',
}, # },
'background_videos': { 'background_videos': {
'train': '../matting-data/BackgroundVideos/train', 'train': '../matting-data/BackgroundVideos/train',
'valid': '../matting-data/BackgroundVideos/valid', 'valid': '../matting-data/BackgroundVideos/valid',
}, },
'coco_panoptic': { # 'coco_panoptic': {
'imgdir': '../matting-data/coco/train2017/', # 'imgdir': '../matting-data/coco/train2017/',
'anndir': '../matting-data/coco/panoptic_train2017/', # 'anndir': '../matting-data/coco/panoptic_train2017/',
'annfile': '../matting-data/coco/annotations/panoptic_train2017.json', # 'annfile': '../matting-data/coco/annotations/panoptic_train2017.json',
}, # },
'spd': { # 'spd': {
'imgdir': '../matting-data/SuperviselyPersonDataset/img', # 'imgdir': '../matting-data/SuperviselyPersonDataset/img',
'segdir': '../matting-data/SuperviselyPersonDataset/seg', # 'segdir': '../matting-data/SuperviselyPersonDataset/seg',
}, # },
'youtubevis': { # 'youtubevis': {
'videodir': '../matting-data/YouTubeVIS/train/JPEGImages', # 'videodir': '../matting-data/YouTubeVIS/train/JPEGImages',
'annfile': '../matting-data/YouTubeVIS/train/instances.json', # 'annfile': '../matting-data/YouTubeVIS/train/instances.json',
} # }
} }

Loading…
Cancel
Save