added code for training with bgr

pull/254/head
Kartikaeya 1 year ago
parent 53d74c6826
commit be683689db

@ -9,14 +9,14 @@ from .augmentation import MotionAugmentation
class VideoMatteDataset(Dataset):
def __init__(self,
videomatte_dir,
background_image_dir,
# background_image_dir,
background_video_dir,
size,
seq_length,
seq_sampler,
transform=None):
self.background_image_dir = background_image_dir
self.background_image_files = os.listdir(background_image_dir)
# self.background_image_dir = background_image_dir
# self.background_image_files = os.listdir(background_image_dir)
self.background_video_dir = background_video_dir
self.background_video_clips = sorted(os.listdir(background_video_dir))
self.background_video_frames = [sorted(os.listdir(os.path.join(background_video_dir, clip)))
@ -38,10 +38,10 @@ class VideoMatteDataset(Dataset):
return len(self.videomatte_idx)
def __getitem__(self, idx):
if random.random() < 0.5:
bgrs = self._get_random_image_background()
else:
bgrs = self._get_random_video_background()
# if random.random() < 0.5:
# bgrs = self._get_random_image_background()
# else:
bgrs = self._get_random_video_background()
fgrs, phas = self._get_videomatte(idx)
@ -50,11 +50,11 @@ class VideoMatteDataset(Dataset):
return fgrs, phas, bgrs
def _get_random_image_background(self):
with Image.open(os.path.join(self.background_image_dir, random.choice(self.background_image_files))) as bgr:
bgr = self._downsample_if_needed(bgr.convert('RGB'))
bgrs = [bgr] * self.seq_length
return bgrs
# def _get_random_image_background(self):
# with Image.open(os.path.join(self.background_image_dir, random.choice(self.background_image_files))) as bgr:
# bgr = self._downsample_if_needed(bgr.convert('RGB'))
# bgrs = [bgr] * self.seq_length
# return bgrs
def _get_random_video_background(self):
clip_idx = random.choice(range(len(self.background_video_clips)))

@ -120,6 +120,10 @@ def convert_video(model,
rec = [None] * 4
for src in reader:
if src.shape[-1] %2 == 1:
src = src[:, :, :, :-1]
if src.shape[-2] %2 == 1:
src = src[:, :, :-1, :]
if downsample_ratio is None:
downsample_ratio = auto_downsample_ratio(*src.shape[2:])

@ -5,7 +5,7 @@ import numpy as np
from torch.utils.data import Dataset
from torchvision.transforms.functional import to_pil_image
from PIL import Image
import torch
class VideoReader(Dataset):
def __init__(self, path, transform=None):
@ -55,18 +55,23 @@ class VideoWriter:
class ImageSequenceReader(Dataset):
def __init__(self, path, transform=None):
self.path = path
self.files = sorted(os.listdir(path))
self.files_fgr = sorted(os.listdir(path + "fgr/"))
self.files_bgr = sorted(os.listdir(path + "bgr/"))
self.transform = transform
def __len__(self):
return len(self.files)
return len(self.files_fgr)
def __getitem__(self, idx):
with Image.open(os.path.join(self.path, self.files[idx])) as img:
img.load()
with Image.open(os.path.join(self.path + "fgr/", self.files_fgr[idx])) as fgr_img:
fgr_img.load()
with Image.open(os.path.join(self.path + "bgr/", self.files_bgr[idx])) as bgr_img:
bgr_img.load()
if self.transform is not None:
return self.transform(img)
return img
return torch.cat([self.transform(fgr_img), self.transform(bgr_img)], dim = 0)
return fgr_img
class ImageSequenceWriter:

@ -1,7 +1,7 @@
import torch
from torch import Tensor
from torch import nn
from torch.nn import functional as F
# from torch.nn import functional as F
from typing import Tuple, Optional
class RecurrentDecoder(nn.Module):
@ -9,10 +9,10 @@ class RecurrentDecoder(nn.Module):
super().__init__()
self.avgpool = AvgPool()
self.decode4 = BottleneckBlock(feature_channels[3])
self.decode3 = UpsamplingBlock(feature_channels[3], feature_channels[2], 3, decoder_channels[0])
self.decode2 = UpsamplingBlock(decoder_channels[0], feature_channels[1], 3, decoder_channels[1])
self.decode1 = UpsamplingBlock(decoder_channels[1], feature_channels[0], 3, decoder_channels[2])
self.decode0 = OutputBlock(decoder_channels[2], 3, decoder_channels[3])
self.decode3 = UpsamplingBlock(feature_channels[3], feature_channels[2], 6, decoder_channels[0])
self.decode2 = UpsamplingBlock(decoder_channels[0], feature_channels[1], 6, decoder_channels[1])
self.decode1 = UpsamplingBlock(decoder_channels[1], feature_channels[0], 6, decoder_channels[2])
self.decode0 = OutputBlock(decoder_channels[2], 6, decoder_channels[3])
def forward(self,
s0: Tensor, f1: Tensor, f2: Tensor, f3: Tensor, f4: Tensor,

@ -3,6 +3,21 @@ from torch import nn
from torchvision.models.mobilenetv3 import MobileNetV3, InvertedResidualConfig
from torchvision.transforms.functional import normalize
def load_matched_state_dict(model, state_dict, print_stats=True):
"""
Only loads weights that matched in key and shape. Ignore other weights.
"""
num_matched, num_total = 0, 0
curr_state_dict = model.state_dict()
for key in curr_state_dict.keys():
num_total += 1
if key in state_dict and curr_state_dict[key].shape == state_dict[key].shape:
curr_state_dict[key] = state_dict[key]
num_matched += 1
model.load_state_dict(curr_state_dict)
if print_stats:
print(f'Loaded state_dict: {num_matched}/{num_total} matched')
class MobileNetV3LargeEncoder(MobileNetV3):
def __init__(self, pretrained: bool = False):
super().__init__(
@ -27,14 +42,24 @@ class MobileNetV3LargeEncoder(MobileNetV3):
)
if pretrained:
self.load_state_dict(torch.hub.load_state_dict_from_url(
'https://download.pytorch.org/models/mobilenet_v3_large-8738ca79.pth'))
pretrained_state_dict = torch.hub.load_state_dict_from_url(
'https://download.pytorch.org/models/mobilenet_v3_large-8738ca79.pth')
# print("pretrained_state_dict keys \n \n ", pretrained_state_dict.keys())
# print("\n\ncurrent model state dict keys \n\n", self.state_dict().keys())
load_matched_state_dict(self, pretrained_state_dict)
# self.load_state_dict(torch.hub.load_state_dict_from_url(
# 'https://download.pytorch.org/models/mobilenet_v3_large-8738ca79.pth'))
del self.avgpool
del self.classifier
def forward_single_frame(self, x):
x = normalize(x, [0.485, 0.456, 0.406], [0.229, 0.224, 0.225])
# print(x.shape)
x = torch.cat((normalize(x[:, :3, ...], [0.485, 0.456, 0.406], [0.229, 0.224, 0.225]), normalize(x[:, 3:, ...], [0.485, 0.456, 0.406], [0.229, 0.224, 0.225])), dim = -3)
x = self.features[0](x)
x = self.features[1](x)

@ -1,6 +1,7 @@
import torch
from torch import Tensor
from torch import nn
from torchsummary import summary
from torch.nn import functional as F
from typing import Optional, List
@ -58,8 +59,8 @@ class MattingNetwork(nn.Module):
if not segmentation_pass:
fgr_residual, pha = self.project_mat(hid).split([3, 1], dim=-3)
if downsample_ratio != 1:
fgr_residual, pha = self.refiner(src, src_sm, fgr_residual, pha, hid)
fgr = fgr_residual + src
fgr_residual, pha = self.refiner(src[:, :, :3, ...], src_sm[:, :, :3, ...], fgr_residual, pha, hid)
fgr = fgr_residual + src[:, :, :3, ...]
fgr = fgr.clamp(0., 1.)
pha = pha.clamp(0., 1.)
return [fgr, pha, *rec]

@ -1,5 +1,7 @@
easing_functions==1.0.4
tensorboard==2.5.0
torch==1.9.0
torchvision==0.10.0
tqdm==4.61.1
tensorboard
torch
torchvision
tqdm==4.61.1
opencv-python==4.6.0.66
torchsummary

@ -121,7 +121,8 @@ from dataset.augmentation import (
from model import MattingNetwork
from train_config import DATA_PATHS
from train_loss import matting_loss, segmentation_loss
import kornia
from torchvision import transforms as T
class Trainer:
def __init__(self, rank, world_size):
@ -189,7 +190,7 @@ class Trainer:
if self.args.dataset == 'videomatte':
self.dataset_lr_train = VideoMatteDataset(
videomatte_dir=DATA_PATHS['videomatte']['train'],
background_image_dir=DATA_PATHS['background_images']['train'],
# background_image_dir=DATA_PATHS['background_images']['train'],
background_video_dir=DATA_PATHS['background_videos']['train'],
size=self.args.resolution_lr,
seq_length=self.args.seq_length_lr,
@ -198,7 +199,7 @@ class Trainer:
if self.args.train_hr:
self.dataset_hr_train = VideoMatteDataset(
videomatte_dir=DATA_PATHS['videomatte']['train'],
background_image_dir=DATA_PATHS['background_images']['train'],
# background_image_dir=DATA_PATHS['background_images']['train'],
background_video_dir=DATA_PATHS['background_videos']['train'],
size=self.args.resolution_hr,
seq_length=self.args.seq_length_hr,
@ -206,38 +207,38 @@ class Trainer:
transform=VideoMatteTrainAugmentation(size_hr))
self.dataset_valid = VideoMatteDataset(
videomatte_dir=DATA_PATHS['videomatte']['valid'],
background_image_dir=DATA_PATHS['background_images']['valid'],
# background_image_dir=DATA_PATHS['background_images']['valid'],
background_video_dir=DATA_PATHS['background_videos']['valid'],
size=self.args.resolution_hr if self.args.train_hr else self.args.resolution_lr,
seq_length=self.args.seq_length_hr if self.args.train_hr else self.args.seq_length_lr,
seq_sampler=ValidFrameSampler(),
transform=VideoMatteValidAugmentation(size_hr if self.args.train_hr else size_lr))
else:
self.dataset_lr_train = ImageMatteDataset(
imagematte_dir=DATA_PATHS['imagematte']['train'],
background_image_dir=DATA_PATHS['background_images']['train'],
background_video_dir=DATA_PATHS['background_videos']['train'],
size=self.args.resolution_lr,
seq_length=self.args.seq_length_lr,
seq_sampler=TrainFrameSampler(),
transform=ImageMatteAugmentation(size_lr))
if self.args.train_hr:
self.dataset_hr_train = ImageMatteDataset(
imagematte_dir=DATA_PATHS['imagematte']['train'],
background_image_dir=DATA_PATHS['background_images']['train'],
background_video_dir=DATA_PATHS['background_videos']['train'],
size=self.args.resolution_hr,
seq_length=self.args.seq_length_hr,
seq_sampler=TrainFrameSampler(),
transform=ImageMatteAugmentation(size_hr))
self.dataset_valid = ImageMatteDataset(
imagematte_dir=DATA_PATHS['imagematte']['valid'],
background_image_dir=DATA_PATHS['background_images']['valid'],
background_video_dir=DATA_PATHS['background_videos']['valid'],
size=self.args.resolution_hr if self.args.train_hr else self.args.resolution_lr,
seq_length=self.args.seq_length_hr if self.args.train_hr else self.args.seq_length_lr,
seq_sampler=ValidFrameSampler(),
transform=ImageMatteAugmentation(size_hr if self.args.train_hr else size_lr))
# else:
# self.dataset_lr_train = ImageMatteDataset(
# imagematte_dir=DATA_PATHS['imagematte']['train'],
# background_image_dir=DATA_PATHS['background_images']['train'],
# background_video_dir=DATA_PATHS['background_videos']['train'],
# size=self.args.resolution_lr,
# seq_length=self.args.seq_length_lr,
# seq_sampler=TrainFrameSampler(),
# transform=ImageMatteAugmentation(size_lr))
# if self.args.train_hr:
# self.dataset_hr_train = ImageMatteDataset(
# imagematte_dir=DATA_PATHS['imagematte']['train'],
# background_image_dir=DATA_PATHS['background_images']['train'],
# background_video_dir=DATA_PATHS['background_videos']['train'],
# size=self.args.resolution_hr,
# seq_length=self.args.seq_length_hr,
# seq_sampler=TrainFrameSampler(),
# transform=ImageMatteAugmentation(size_hr))
# self.dataset_valid = ImageMatteDataset(
# imagematte_dir=DATA_PATHS['imagematte']['valid'],
# background_image_dir=DATA_PATHS['background_images']['valid'],
# background_video_dir=DATA_PATHS['background_videos']['valid'],
# size=self.args.resolution_hr if self.args.train_hr else self.args.resolution_lr,
# seq_length=self.args.seq_length_hr if self.args.train_hr else self.args.seq_length_lr,
# seq_sampler=ValidFrameSampler(),
# transform=ImageMatteAugmentation(size_hr if self.args.train_hr else size_lr))
# Matting dataloaders:
self.datasampler_lr_train = DistributedSampler(
@ -270,49 +271,49 @@ class Trainer:
pin_memory=True)
# Segementation datasets
self.log('Initializing image segmentation datasets')
self.dataset_seg_image = ConcatDataset([
CocoPanopticDataset(
imgdir=DATA_PATHS['coco_panoptic']['imgdir'],
anndir=DATA_PATHS['coco_panoptic']['anndir'],
annfile=DATA_PATHS['coco_panoptic']['annfile'],
transform=CocoPanopticTrainAugmentation(size_lr)),
SuperviselyPersonDataset(
imgdir=DATA_PATHS['spd']['imgdir'],
segdir=DATA_PATHS['spd']['segdir'],
transform=CocoPanopticTrainAugmentation(size_lr))
])
self.datasampler_seg_image = DistributedSampler(
dataset=self.dataset_seg_image,
rank=self.rank,
num_replicas=self.world_size,
shuffle=True)
self.dataloader_seg_image = DataLoader(
dataset=self.dataset_seg_image,
batch_size=self.args.batch_size_per_gpu * self.args.seq_length_lr,
num_workers=self.args.num_workers,
sampler=self.datasampler_seg_image,
pin_memory=True)
# self.log('Initializing image segmentation datasets')
# self.dataset_seg_image = ConcatDataset([
# CocoPanopticDataset(
# imgdir=DATA_PATHS['coco_panoptic']['imgdir'],
# anndir=DATA_PATHS['coco_panoptic']['anndir'],
# annfile=DATA_PATHS['coco_panoptic']['annfile'],
# transform=CocoPanopticTrainAugmentation(size_lr)),
# SuperviselyPersonDataset(
# imgdir=DATA_PATHS['spd']['imgdir'],
# segdir=DATA_PATHS['spd']['segdir'],
# transform=CocoPanopticTrainAugmentation(size_lr))
# ])
# self.datasampler_seg_image = DistributedSampler(
# dataset=self.dataset_seg_image,
# rank=self.rank,
# num_replicas=self.world_size,
# shuffle=True)
# self.dataloader_seg_image = DataLoader(
# dataset=self.dataset_seg_image,
# batch_size=self.args.batch_size_per_gpu * self.args.seq_length_lr,
# num_workers=self.args.num_workers,
# sampler=self.datasampler_seg_image,
# pin_memory=True)
self.log('Initializing video segmentation datasets')
self.dataset_seg_video = YouTubeVISDataset(
videodir=DATA_PATHS['youtubevis']['videodir'],
annfile=DATA_PATHS['youtubevis']['annfile'],
size=self.args.resolution_lr,
seq_length=self.args.seq_length_lr,
seq_sampler=TrainFrameSampler(speed=[1]),
transform=YouTubeVISAugmentation(size_lr))
self.datasampler_seg_video = DistributedSampler(
dataset=self.dataset_seg_video,
rank=self.rank,
num_replicas=self.world_size,
shuffle=True)
self.dataloader_seg_video = DataLoader(
dataset=self.dataset_seg_video,
batch_size=self.args.batch_size_per_gpu,
num_workers=self.args.num_workers,
sampler=self.datasampler_seg_video,
pin_memory=True)
# self.log('Initializing video segmentation datasets')
# self.dataset_seg_video = YouTubeVISDataset(
# videodir=DATA_PATHS['youtubevis']['videodir'],
# annfile=DATA_PATHS['youtubevis']['annfile'],
# size=self.args.resolution_lr,
# seq_length=self.args.seq_length_lr,
# seq_sampler=TrainFrameSampler(speed=[1]),
# transform=YouTubeVISAugmentation(size_lr))
# self.datasampler_seg_video = DistributedSampler(
# dataset=self.dataset_seg_video,
# rank=self.rank,
# num_replicas=self.world_size,
# shuffle=True)
# self.dataloader_seg_video = DataLoader(
# dataset=self.dataset_seg_video,
# batch_size=self.args.batch_size_per_gpu,
# num_workers=self.args.num_workers,
# sampler=self.datasampler_seg_video,
# pin_memory=True)
def init_model(self):
self.log('Initializing model')
@ -359,12 +360,12 @@ class Trainer:
self.train_mat(true_fgr, true_pha, true_bgr, downsample_ratio=self.args.downsample_ratio, tag='hr')
# Segmentation pass
if self.step % 2 == 0:
true_img, true_seg = self.load_next_seg_video_sample()
self.train_seg(true_img, true_seg, log_label='seg_video')
else:
true_img, true_seg = self.load_next_seg_image_sample()
self.train_seg(true_img.unsqueeze(1), true_seg.unsqueeze(1), log_label='seg_image')
# if self.step % 2 == 0:
# true_img, true_seg = self.load_next_seg_video_sample()
# self.train_seg(true_img, true_seg, log_label='seg_video')
# else:
# true_img, true_seg = self.load_next_seg_image_sample()
# self.train_seg(true_img.unsqueeze(1), true_seg.unsqueeze(1), log_label='seg_image')
if self.step % self.args.checkpoint_save_interval == 0:
self.save()
@ -376,10 +377,47 @@ class Trainer:
true_pha = true_pha.to(self.rank, non_blocking=True)
true_bgr = true_bgr.to(self.rank, non_blocking=True)
true_fgr, true_pha, true_bgr = self.random_crop(true_fgr, true_pha, true_bgr)
true_src = true_fgr * true_pha + true_bgr * (1 - true_pha)
true_src = true_bgr.clone()
# Augment bgr with shadow
aug_shadow_idx = torch.rand(len(true_src)) < 0.3
if aug_shadow_idx.any():
aug_shadow = true_pha[aug_shadow_idx].mul(0.3 * random.random()).flatten(start_dim = 0, end_dim = 1)
aug_shadow = T.RandomAffine(degrees=(-5, 5), translate=(0.2, 0.2), scale=(0.5, 1.5), shear=(-5, 5))(aug_shadow)
aug_shadow = kornia.filters.box_blur(aug_shadow, (random.choice(range(20, 40)),) * 2)
expected_shape = torch.tensor(true_src[aug_shadow_idx].shape)
expected_shape[2] = -1
true_src[aug_shadow_idx] = true_src[aug_shadow_idx].sub_(aug_shadow.reshape(expected_shape.tolist())).clamp_(0, 1)
del aug_shadow
del aug_shadow_idx
# Composite foreground onto source
true_src = true_fgr * true_pha + true_src * (1 - true_pha)
# Augment with noise
aug_noise_idx = torch.rand(len(true_src)) < 0.4
if aug_noise_idx.any():
true_src[aug_noise_idx] = true_src[aug_noise_idx].add_(torch.randn_like(true_src[aug_noise_idx]).mul_(0.03 * random.random())).clamp_(0, 1)
true_bgr[aug_noise_idx] = true_bgr[aug_noise_idx].add_(torch.randn_like(true_bgr[aug_noise_idx]).mul_(0.03 * random.random())).clamp_(0, 1)
del aug_noise_idx
# Augment background with jitter
aug_jitter_idx = torch.rand(len(true_src)) < 0.8
if aug_jitter_idx.any():
true_bgr[aug_jitter_idx] = kornia.augmentation.ColorJitter(0.18, 0.18, 0.18, 0.1)(true_bgr[aug_jitter_idx].flatten(start_dim = 0, end_dim = 1)).reshape(true_bgr[aug_jitter_idx].shape)
del aug_jitter_idx
# Augment background with affine
aug_affine_idx = torch.rand(len(true_bgr)) < 0.3
if aug_affine_idx.any():
true_bgr[aug_affine_idx] = T.RandomAffine(degrees=(-1, 1), translate=(0.01, 0.01))(true_bgr[aug_affine_idx].flatten(start_dim = 0, end_dim = 1)).reshape(true_bgr[aug_affine_idx].shape)
del aug_affine_idx
fg_bg_input = torch.cat((true_src, true_bgr), dim = -3)
with autocast(enabled=not self.args.disable_mixed_precision):
pred_fgr, pred_pha = self.model_ddp(true_src, downsample_ratio=downsample_ratio)[:2]
pred_fgr, pred_pha = self.model_ddp(fg_bg_input, downsample_ratio=downsample_ratio)[:2]
loss = matting_loss(pred_fgr, pred_pha, true_fgr, true_pha)
self.scaler.scale(loss['total']).backward()
@ -397,29 +435,30 @@ class Trainer:
self.writer.add_image(f'train_{tag}_true_fgr', make_grid(true_fgr.flatten(0, 1), nrow=true_fgr.size(1)), self.step)
self.writer.add_image(f'train_{tag}_true_pha', make_grid(true_pha.flatten(0, 1), nrow=true_pha.size(1)), self.step)
self.writer.add_image(f'train_{tag}_true_src', make_grid(true_src.flatten(0, 1), nrow=true_src.size(1)), self.step)
def train_seg(self, true_img, true_seg, log_label):
true_img = true_img.to(self.rank, non_blocking=True)
true_seg = true_seg.to(self.rank, non_blocking=True)
# does not get called
# def train_seg(self, true_img, true_seg, log_label):
# true_img = true_img.to(self.rank, non_blocking=True)
# true_seg = true_seg.to(self.rank, non_blocking=True)
true_img, true_seg = self.random_crop(true_img, true_seg)
# true_img, true_seg = self.random_crop(true_img, true_seg)
with autocast(enabled=not self.args.disable_mixed_precision):
pred_seg = self.model_ddp(true_img, segmentation_pass=True)[0]
loss = segmentation_loss(pred_seg, true_seg)
# with autocast(enabled=not self.args.disable_mixed_precision):
# pred_seg = self.model_ddp(true_img, segmentation_pass=True)[0]
# loss = segmentation_loss(pred_seg, true_seg)
self.scaler.scale(loss).backward()
self.scaler.step(self.optimizer)
self.scaler.update()
self.optimizer.zero_grad()
# self.scaler.scale(loss).backward()
# self.scaler.step(self.optimizer)
# self.scaler.update()
# self.optimizer.zero_grad()
if self.rank == 0 and (self.step - self.step % 2) % self.args.log_train_loss_interval == 0:
self.writer.add_scalar(f'{log_label}_loss', loss, self.step)
# if self.rank == 0 and (self.step - self.step % 2) % self.args.log_train_loss_interval == 0:
# self.writer.add_scalar(f'{log_label}_loss', loss, self.step)
if self.rank == 0 and (self.step - self.step % 2) % self.args.log_train_images_interval == 0:
self.writer.add_image(f'{log_label}_pred_seg', make_grid(pred_seg.flatten(0, 1).float().sigmoid(), nrow=self.args.seq_length_lr), self.step)
self.writer.add_image(f'{log_label}_true_seg', make_grid(true_seg.flatten(0, 1), nrow=self.args.seq_length_lr), self.step)
self.writer.add_image(f'{log_label}_true_img', make_grid(true_img.flatten(0, 1), nrow=self.args.seq_length_lr), self.step)
# if self.rank == 0 and (self.step - self.step % 2) % self.args.log_train_images_interval == 0:
# self.writer.add_image(f'{log_label}_pred_seg', make_grid(pred_seg.flatten(0, 1).float().sigmoid(), nrow=self.args.seq_length_lr), self.step)
# self.writer.add_image(f'{log_label}_true_seg', make_grid(true_seg.flatten(0, 1), nrow=self.args.seq_length_lr), self.step)
# self.writer.add_image(f'{log_label}_true_img', make_grid(true_img.flatten(0, 1), nrow=self.args.seq_length_lr), self.step)
def load_next_mat_hr_sample(self):
try:
@ -430,23 +469,23 @@ class Trainer:
sample = next(self.dataiterator_mat_hr)
return sample
def load_next_seg_video_sample(self):
try:
sample = next(self.dataiterator_seg_video)
except:
self.datasampler_seg_video.set_epoch(self.datasampler_seg_video.epoch + 1)
self.dataiterator_seg_video = iter(self.dataloader_seg_video)
sample = next(self.dataiterator_seg_video)
return sample
# def load_next_seg_video_sample(self):
# try:
# sample = next(self.dataiterator_seg_video)
# except:
# self.datasampler_seg_video.set_epoch(self.datasampler_seg_video.epoch + 1)
# self.dataiterator_seg_video = iter(self.dataloader_seg_video)
# sample = next(self.dataiterator_seg_video)
# return sample
def load_next_seg_image_sample(self):
try:
sample = next(self.dataiterator_seg_image)
except:
self.datasampler_seg_image.set_epoch(self.datasampler_seg_image.epoch + 1)
self.dataiterator_seg_image = iter(self.dataloader_seg_image)
sample = next(self.dataiterator_seg_image)
return sample
# def load_next_seg_image_sample(self):
# try:
# sample = next(self.dataiterator_seg_image)
# except:
# self.datasampler_seg_image.set_epoch(self.datasampler_seg_image.epoch + 1)
# self.dataiterator_seg_image = iter(self.dataloader_seg_image)
# sample = next(self.dataiterator_seg_image)
# return sample
def validate(self):
if self.rank == 0:
@ -461,7 +500,9 @@ class Trainer:
true_bgr = true_bgr.to(self.rank, non_blocking=True)
true_src = true_fgr * true_pha + true_bgr * (1 - true_pha)
batch_size = true_src.size(0)
pred_fgr, pred_pha = self.model(true_src)[:2]
fg_bg_input = torch.cat((true_src, true_bgr), dim = -3)
pred_fgr, pred_pha = self.model(fg_bg_input)[:2]
total_loss += matting_loss(pred_fgr, pred_pha, true_fgr, true_pha)['total'].item() * batch_size
total_count += batch_size
avg_loss = total_loss / total_count

@ -37,32 +37,32 @@ DATA_PATHS = {
'train': '../matting-data/VideoMatte240K_JPEG_SD/train',
'valid': '../matting-data/VideoMatte240K_JPEG_SD/valid',
},
'imagematte': {
'train': '../matting-data/ImageMatte/train',
'valid': '../matting-data/ImageMatte/valid',
},
'background_images': {
'train': '../matting-data/Backgrounds/train',
'valid': '../matting-data/Backgrounds/valid',
},
# 'imagematte': {
# 'train': '../matting-data/ImageMatte/train',
# 'valid': '../matting-data/ImageMatte/valid',
# },
# 'background_images': {
# 'train': '../matting-data/Backgrounds/train',
# 'valid': '../matting-data/Backgrounds/valid',
# },
'background_videos': {
'train': '../matting-data/BackgroundVideos/train',
'valid': '../matting-data/BackgroundVideos/valid',
},
'coco_panoptic': {
'imgdir': '../matting-data/coco/train2017/',
'anndir': '../matting-data/coco/panoptic_train2017/',
'annfile': '../matting-data/coco/annotations/panoptic_train2017.json',
},
'spd': {
'imgdir': '../matting-data/SuperviselyPersonDataset/img',
'segdir': '../matting-data/SuperviselyPersonDataset/seg',
},
'youtubevis': {
'videodir': '../matting-data/YouTubeVIS/train/JPEGImages',
'annfile': '../matting-data/YouTubeVIS/train/instances.json',
}
# 'coco_panoptic': {
# 'imgdir': '../matting-data/coco/train2017/',
# 'anndir': '../matting-data/coco/panoptic_train2017/',
# 'annfile': '../matting-data/coco/annotations/panoptic_train2017.json',
# },
# 'spd': {
# 'imgdir': '../matting-data/SuperviselyPersonDataset/img',
# 'segdir': '../matting-data/SuperviselyPersonDataset/seg',
# },
# 'youtubevis': {
# 'videodir': '../matting-data/YouTubeVIS/train/JPEGImages',
# 'annfile': '../matting-data/YouTubeVIS/train/instances.json',
# }
}

Loading…
Cancel
Save