You cannot select more than 25 topics
Topics must start with a letter or number, can include dashes ('-') and can be up to 35 characters long.
507 lines
22 KiB
Python
507 lines
22 KiB
Python
"""
|
|
# First update `train_config.py` to set paths to your dataset locations.
|
|
|
|
# You may want to change `--num-workers` according to your machine's memory.
|
|
# The default num-workers=8 may cause dataloader to exit unexpectedly when
|
|
# machine is out of memory.
|
|
|
|
# Stage 1
|
|
python train.py \
|
|
--model-variant mobilenetv3 \
|
|
--dataset videomatte \
|
|
--resolution-lr 512 \
|
|
--seq-length-lr 15 \
|
|
--learning-rate-backbone 0.0001 \
|
|
--learning-rate-aspp 0.0002 \
|
|
--learning-rate-decoder 0.0002 \
|
|
--learning-rate-refiner 0 \
|
|
--checkpoint-dir checkpoint/stage1 \
|
|
--log-dir log/stage1 \
|
|
--epoch-start 0 \
|
|
--epoch-end 20
|
|
|
|
# Stage 2
|
|
python train.py \
|
|
--model-variant mobilenetv3 \
|
|
--dataset videomatte \
|
|
--resolution-lr 512 \
|
|
--seq-length-lr 50 \
|
|
--learning-rate-backbone 0.00005 \
|
|
--learning-rate-aspp 0.0001 \
|
|
--learning-rate-decoder 0.0001 \
|
|
--learning-rate-refiner 0 \
|
|
--checkpoint checkpoint/stage1/epoch-19.pth \
|
|
--checkpoint-dir checkpoint/stage2 \
|
|
--log-dir log/stage2 \
|
|
--epoch-start 20 \
|
|
--epoch-end 22
|
|
|
|
# Stage 3
|
|
python train.py \
|
|
--model-variant mobilenetv3 \
|
|
--dataset videomatte \
|
|
--train-hr \
|
|
--resolution-lr 512 \
|
|
--resolution-hr 2048 \
|
|
--seq-length-lr 40 \
|
|
--seq-length-hr 6 \
|
|
--learning-rate-backbone 0.00001 \
|
|
--learning-rate-aspp 0.00001 \
|
|
--learning-rate-decoder 0.00001 \
|
|
--learning-rate-refiner 0.0002 \
|
|
--checkpoint checkpoint/stage2/epoch-21.pth \
|
|
--checkpoint-dir checkpoint/stage3 \
|
|
--log-dir log/stage3 \
|
|
--epoch-start 22 \
|
|
--epoch-end 23
|
|
|
|
# Stage 4
|
|
python train.py \
|
|
--model-variant mobilenetv3 \
|
|
--dataset imagematte \
|
|
--train-hr \
|
|
--resolution-lr 512 \
|
|
--resolution-hr 2048 \
|
|
--seq-length-lr 40 \
|
|
--seq-length-hr 6 \
|
|
--learning-rate-backbone 0.00001 \
|
|
--learning-rate-aspp 0.00001 \
|
|
--learning-rate-decoder 0.00005 \
|
|
--learning-rate-refiner 0.0002 \
|
|
--checkpoint checkpoint/stage3/epoch-22.pth \
|
|
--checkpoint-dir checkpoint/stage4 \
|
|
--log-dir log/stage4 \
|
|
--epoch-start 23 \
|
|
--epoch-end 28
|
|
"""
|
|
|
|
|
|
import argparse
|
|
import torch
|
|
import random
|
|
import os
|
|
from torch import nn
|
|
from torch import distributed as dist
|
|
from torch import multiprocessing as mp
|
|
from torch.nn import functional as F
|
|
from torch.nn.parallel import DistributedDataParallel as DDP
|
|
from torch.optim import Adam
|
|
from torch.cuda.amp import autocast, GradScaler
|
|
from torch.utils.data import DataLoader, ConcatDataset
|
|
from torch.utils.data.distributed import DistributedSampler
|
|
from torch.utils.tensorboard import SummaryWriter
|
|
from torchvision.utils import make_grid
|
|
from torchvision.transforms.functional import center_crop
|
|
from tqdm import tqdm
|
|
|
|
from dataset.videomatte import (
|
|
VideoMatteDataset,
|
|
VideoMatteTrainAugmentation,
|
|
VideoMatteValidAugmentation,
|
|
)
|
|
from dataset.imagematte import (
|
|
ImageMatteDataset,
|
|
ImageMatteAugmentation
|
|
)
|
|
from dataset.coco import (
|
|
CocoPanopticDataset,
|
|
CocoPanopticTrainAugmentation,
|
|
)
|
|
from dataset.spd import (
|
|
SuperviselyPersonDataset
|
|
)
|
|
from dataset.youtubevis import (
|
|
YouTubeVISDataset,
|
|
YouTubeVISAugmentation
|
|
)
|
|
from dataset.augmentation import (
|
|
TrainFrameSampler,
|
|
ValidFrameSampler
|
|
)
|
|
from model import MattingNetwork
|
|
from train_config import DATA_PATHS
|
|
from train_loss import matting_loss, segmentation_loss
|
|
|
|
|
|
class Trainer:
|
|
def __init__(self, rank, world_size):
|
|
self.parse_args()
|
|
self.init_distributed(rank, world_size)
|
|
self.init_datasets()
|
|
self.init_model()
|
|
self.init_writer()
|
|
self.train()
|
|
self.cleanup()
|
|
|
|
def parse_args(self):
|
|
parser = argparse.ArgumentParser()
|
|
# Model
|
|
parser.add_argument('--model-variant', type=str, required=True, choices=['mobilenetv3', 'resnet50'])
|
|
# Matting dataset
|
|
parser.add_argument('--dataset', type=str, required=True, choices=['videomatte', 'imagematte'])
|
|
# Learning rate
|
|
parser.add_argument('--learning-rate-backbone', type=float, required=True)
|
|
parser.add_argument('--learning-rate-aspp', type=float, required=True)
|
|
parser.add_argument('--learning-rate-decoder', type=float, required=True)
|
|
parser.add_argument('--learning-rate-refiner', type=float, required=True)
|
|
# Training setting
|
|
parser.add_argument('--train-hr', action='store_true')
|
|
parser.add_argument('--resolution-lr', type=int, default=512)
|
|
parser.add_argument('--resolution-hr', type=int, default=2048)
|
|
parser.add_argument('--seq-length-lr', type=int, required=True)
|
|
parser.add_argument('--seq-length-hr', type=int, default=6)
|
|
parser.add_argument('--downsample-ratio', type=float, default=0.25)
|
|
parser.add_argument('--batch-size-per-gpu', type=int, default=1)
|
|
parser.add_argument('--num-workers', type=int, default=8)
|
|
parser.add_argument('--epoch-start', type=int, default=0)
|
|
parser.add_argument('--epoch-end', type=int, default=16)
|
|
# Tensorboard logging
|
|
parser.add_argument('--log-dir', type=str, required=True)
|
|
parser.add_argument('--log-train-loss-interval', type=int, default=20)
|
|
parser.add_argument('--log-train-images-interval', type=int, default=500)
|
|
# Checkpoint loading and saving
|
|
parser.add_argument('--checkpoint', type=str)
|
|
parser.add_argument('--checkpoint-dir', type=str, required=True)
|
|
parser.add_argument('--checkpoint-save-interval', type=int, default=500)
|
|
# Distributed
|
|
parser.add_argument('--distributed-addr', type=str, default='localhost')
|
|
parser.add_argument('--distributed-port', type=str, default='12355')
|
|
# Debugging
|
|
parser.add_argument('--disable-progress-bar', action='store_true')
|
|
parser.add_argument('--disable-validation', action='store_true')
|
|
parser.add_argument('--disable-mixed-precision', action='store_true')
|
|
self.args = parser.parse_args()
|
|
|
|
def init_distributed(self, rank, world_size):
|
|
self.rank = rank
|
|
self.world_size = world_size
|
|
self.log('Initializing distributed')
|
|
os.environ['MASTER_ADDR'] = self.args.distributed_addr
|
|
os.environ['MASTER_PORT'] = self.args.distributed_port
|
|
dist.init_process_group("nccl", rank=rank, world_size=world_size)
|
|
|
|
def init_datasets(self):
|
|
self.log('Initializing matting datasets')
|
|
size_hr = (self.args.resolution_hr, self.args.resolution_hr)
|
|
size_lr = (self.args.resolution_lr, self.args.resolution_lr)
|
|
|
|
# Matting datasets:
|
|
if self.args.dataset == 'videomatte':
|
|
self.dataset_lr_train = VideoMatteDataset(
|
|
videomatte_dir=DATA_PATHS['videomatte']['train'],
|
|
background_image_dir=DATA_PATHS['background_images']['train'],
|
|
background_video_dir=DATA_PATHS['background_videos']['train'],
|
|
size=self.args.resolution_lr,
|
|
seq_length=self.args.seq_length_lr,
|
|
seq_sampler=TrainFrameSampler(),
|
|
transform=VideoMatteTrainAugmentation(size_lr))
|
|
if self.args.train_hr:
|
|
self.dataset_hr_train = VideoMatteDataset(
|
|
videomatte_dir=DATA_PATHS['videomatte']['train'],
|
|
background_image_dir=DATA_PATHS['background_images']['train'],
|
|
background_video_dir=DATA_PATHS['background_videos']['train'],
|
|
size=self.args.resolution_hr,
|
|
seq_length=self.args.seq_length_hr,
|
|
seq_sampler=TrainFrameSampler(),
|
|
transform=VideoMatteTrainAugmentation(size_hr))
|
|
self.dataset_valid = VideoMatteDataset(
|
|
videomatte_dir=DATA_PATHS['videomatte']['valid'],
|
|
background_image_dir=DATA_PATHS['background_images']['valid'],
|
|
background_video_dir=DATA_PATHS['background_videos']['valid'],
|
|
size=self.args.resolution_hr if self.args.train_hr else self.args.resolution_lr,
|
|
seq_length=self.args.seq_length_hr if self.args.train_hr else self.args.seq_length_lr,
|
|
seq_sampler=ValidFrameSampler(),
|
|
transform=VideoMatteValidAugmentation(size_hr if self.args.train_hr else size_lr))
|
|
else:
|
|
self.dataset_lr_train = ImageMatteDataset(
|
|
imagematte_dir=DATA_PATHS['imagematte']['train'],
|
|
background_image_dir=DATA_PATHS['background_images']['train'],
|
|
background_video_dir=DATA_PATHS['background_videos']['train'],
|
|
size=self.args.resolution_lr,
|
|
seq_length=self.args.seq_length_lr,
|
|
seq_sampler=TrainFrameSampler(),
|
|
transform=ImageMatteAugmentation(size_lr))
|
|
if self.args.train_hr:
|
|
self.dataset_hr_train = ImageMatteDataset(
|
|
imagematte_dir=DATA_PATHS['imagematte']['train'],
|
|
background_image_dir=DATA_PATHS['background_images']['train'],
|
|
background_video_dir=DATA_PATHS['background_videos']['train'],
|
|
size=self.args.resolution_hr,
|
|
seq_length=self.args.seq_length_hr,
|
|
seq_sampler=TrainFrameSampler(),
|
|
transform=ImageMatteAugmentation(size_hr))
|
|
self.dataset_valid = ImageMatteDataset(
|
|
imagematte_dir=DATA_PATHS['imagematte']['valid'],
|
|
background_image_dir=DATA_PATHS['background_images']['valid'],
|
|
background_video_dir=DATA_PATHS['background_videos']['valid'],
|
|
size=self.args.resolution_hr if self.args.train_hr else self.args.resolution_lr,
|
|
seq_length=self.args.seq_length_hr if self.args.train_hr else self.args.seq_length_lr,
|
|
seq_sampler=ValidFrameSampler(),
|
|
transform=ImageMatteAugmentation(size_hr if self.args.train_hr else size_lr))
|
|
|
|
# Matting dataloaders:
|
|
self.datasampler_lr_train = DistributedSampler(
|
|
dataset=self.dataset_lr_train,
|
|
rank=self.rank,
|
|
num_replicas=self.world_size,
|
|
shuffle=True)
|
|
self.dataloader_lr_train = DataLoader(
|
|
dataset=self.dataset_lr_train,
|
|
batch_size=self.args.batch_size_per_gpu,
|
|
num_workers=self.args.num_workers,
|
|
sampler=self.datasampler_lr_train,
|
|
pin_memory=True)
|
|
if self.args.train_hr:
|
|
self.datasampler_hr_train = DistributedSampler(
|
|
dataset=self.dataset_hr_train,
|
|
rank=self.rank,
|
|
num_replicas=self.world_size,
|
|
shuffle=True)
|
|
self.dataloader_hr_train = DataLoader(
|
|
dataset=self.dataset_hr_train,
|
|
batch_size=self.args.batch_size_per_gpu,
|
|
num_workers=self.args.num_workers,
|
|
sampler=self.datasampler_hr_train,
|
|
pin_memory=True)
|
|
self.dataloader_valid = DataLoader(
|
|
dataset=self.dataset_valid,
|
|
batch_size=self.args.batch_size_per_gpu,
|
|
num_workers=self.args.num_workers,
|
|
pin_memory=True)
|
|
|
|
# Segementation datasets
|
|
self.log('Initializing image segmentation datasets')
|
|
self.dataset_seg_image = ConcatDataset([
|
|
CocoPanopticDataset(
|
|
imgdir=DATA_PATHS['coco_panoptic']['imgdir'],
|
|
anndir=DATA_PATHS['coco_panoptic']['anndir'],
|
|
annfile=DATA_PATHS['coco_panoptic']['annfile'],
|
|
transform=CocoPanopticTrainAugmentation(size_lr)),
|
|
SuperviselyPersonDataset(
|
|
imgdir=DATA_PATHS['spd']['imgdir'],
|
|
segdir=DATA_PATHS['spd']['segdir'],
|
|
transform=CocoPanopticTrainAugmentation(size_lr))
|
|
])
|
|
self.datasampler_seg_image = DistributedSampler(
|
|
dataset=self.dataset_seg_image,
|
|
rank=self.rank,
|
|
num_replicas=self.world_size,
|
|
shuffle=True)
|
|
self.dataloader_seg_image = DataLoader(
|
|
dataset=self.dataset_seg_image,
|
|
batch_size=self.args.batch_size_per_gpu * self.args.seq_length_lr,
|
|
num_workers=self.args.num_workers,
|
|
sampler=self.datasampler_seg_image,
|
|
pin_memory=True)
|
|
|
|
self.log('Initializing video segmentation datasets')
|
|
self.dataset_seg_video = YouTubeVISDataset(
|
|
videodir=DATA_PATHS['youtubevis']['videodir'],
|
|
annfile=DATA_PATHS['youtubevis']['annfile'],
|
|
size=self.args.resolution_lr,
|
|
seq_length=self.args.seq_length_lr,
|
|
seq_sampler=TrainFrameSampler(speed=[1]),
|
|
transform=YouTubeVISAugmentation(size_lr))
|
|
self.datasampler_seg_video = DistributedSampler(
|
|
dataset=self.dataset_seg_video,
|
|
rank=self.rank,
|
|
num_replicas=self.world_size,
|
|
shuffle=True)
|
|
self.dataloader_seg_video = DataLoader(
|
|
dataset=self.dataset_seg_video,
|
|
batch_size=self.args.batch_size_per_gpu,
|
|
num_workers=self.args.num_workers,
|
|
sampler=self.datasampler_seg_video,
|
|
pin_memory=True)
|
|
|
|
def init_model(self):
|
|
self.log('Initializing model')
|
|
self.model = MattingNetwork(self.args.model_variant, pretrained_backbone=True).to(self.rank)
|
|
|
|
if self.args.checkpoint:
|
|
self.log(f'Restoring from checkpoint: {self.args.checkpoint}')
|
|
self.log(self.model.load_state_dict(
|
|
torch.load(self.args.checkpoint, map_location=f'cuda:{self.rank}')))
|
|
|
|
self.model = nn.SyncBatchNorm.convert_sync_batchnorm(self.model)
|
|
self.model_ddp = DDP(self.model, device_ids=[self.rank], broadcast_buffers=False, find_unused_parameters=True)
|
|
self.optimizer = Adam([
|
|
{'params': self.model.backbone.parameters(), 'lr': self.args.learning_rate_backbone},
|
|
{'params': self.model.aspp.parameters(), 'lr': self.args.learning_rate_aspp},
|
|
{'params': self.model.decoder.parameters(), 'lr': self.args.learning_rate_decoder},
|
|
{'params': self.model.project_mat.parameters(), 'lr': self.args.learning_rate_decoder},
|
|
{'params': self.model.project_seg.parameters(), 'lr': self.args.learning_rate_decoder},
|
|
{'params': self.model.refiner.parameters(), 'lr': self.args.learning_rate_refiner},
|
|
])
|
|
self.scaler = GradScaler()
|
|
|
|
def init_writer(self):
|
|
if self.rank == 0:
|
|
self.log('Initializing writer')
|
|
self.writer = SummaryWriter(self.args.log_dir)
|
|
|
|
def train(self):
|
|
for epoch in range(self.args.epoch_start, self.args.epoch_end):
|
|
self.epoch = epoch
|
|
self.step = epoch * len(self.dataloader_lr_train)
|
|
|
|
if not self.args.disable_validation:
|
|
self.validate()
|
|
|
|
self.log(f'Training epoch: {epoch}')
|
|
for true_fgr, true_pha, true_bgr in tqdm(self.dataloader_lr_train, disable=self.args.disable_progress_bar, dynamic_ncols=True):
|
|
# Low resolution pass
|
|
self.train_mat(true_fgr, true_pha, true_bgr, downsample_ratio=1, tag='lr')
|
|
|
|
# High resolution pass
|
|
if self.args.train_hr:
|
|
true_fgr, true_pha, true_bgr = self.load_next_mat_hr_sample()
|
|
self.train_mat(true_fgr, true_pha, true_bgr, downsample_ratio=self.args.downsample_ratio, tag='hr')
|
|
|
|
# Segmentation pass
|
|
if self.step % 2 == 0:
|
|
true_img, true_seg = self.load_next_seg_video_sample()
|
|
self.train_seg(true_img, true_seg, log_label='seg_video')
|
|
else:
|
|
true_img, true_seg = self.load_next_seg_image_sample()
|
|
self.train_seg(true_img.unsqueeze(1), true_seg.unsqueeze(1), log_label='seg_image')
|
|
|
|
if self.step % self.args.checkpoint_save_interval == 0:
|
|
self.save()
|
|
|
|
self.step += 1
|
|
|
|
def train_mat(self, true_fgr, true_pha, true_bgr, downsample_ratio, tag):
|
|
true_fgr = true_fgr.to(self.rank, non_blocking=True)
|
|
true_pha = true_pha.to(self.rank, non_blocking=True)
|
|
true_bgr = true_bgr.to(self.rank, non_blocking=True)
|
|
true_fgr, true_pha, true_bgr = self.random_crop(true_fgr, true_pha, true_bgr)
|
|
true_src = true_fgr * true_pha + true_bgr * (1 - true_pha)
|
|
|
|
with autocast(enabled=not self.args.disable_mixed_precision):
|
|
pred_fgr, pred_pha = self.model_ddp(true_src, downsample_ratio=downsample_ratio)[:2]
|
|
loss = matting_loss(pred_fgr, pred_pha, true_fgr, true_pha)
|
|
|
|
self.scaler.scale(loss['total']).backward()
|
|
self.scaler.step(self.optimizer)
|
|
self.scaler.update()
|
|
self.optimizer.zero_grad()
|
|
|
|
if self.rank == 0 and self.step % self.args.log_train_loss_interval == 0:
|
|
for loss_name, loss_value in loss.items():
|
|
self.writer.add_scalar(f'train_{tag}_{loss_name}', loss_value, self.step)
|
|
|
|
if self.rank == 0 and self.step % self.args.log_train_images_interval == 0:
|
|
self.writer.add_image(f'train_{tag}_pred_fgr', make_grid(pred_fgr.flatten(0, 1), nrow=pred_fgr.size(1)), self.step)
|
|
self.writer.add_image(f'train_{tag}_pred_pha', make_grid(pred_pha.flatten(0, 1), nrow=pred_pha.size(1)), self.step)
|
|
self.writer.add_image(f'train_{tag}_true_fgr', make_grid(true_fgr.flatten(0, 1), nrow=true_fgr.size(1)), self.step)
|
|
self.writer.add_image(f'train_{tag}_true_pha', make_grid(true_pha.flatten(0, 1), nrow=true_pha.size(1)), self.step)
|
|
self.writer.add_image(f'train_{tag}_true_src', make_grid(true_src.flatten(0, 1), nrow=true_src.size(1)), self.step)
|
|
|
|
def train_seg(self, true_img, true_seg, log_label):
|
|
true_img = true_img.to(self.rank, non_blocking=True)
|
|
true_seg = true_seg.to(self.rank, non_blocking=True)
|
|
|
|
true_img, true_seg = self.random_crop(true_img, true_seg)
|
|
|
|
with autocast(enabled=not self.args.disable_mixed_precision):
|
|
pred_seg = self.model_ddp(true_img, segmentation_pass=True)[0]
|
|
loss = segmentation_loss(pred_seg, true_seg)
|
|
|
|
self.scaler.scale(loss).backward()
|
|
self.scaler.step(self.optimizer)
|
|
self.scaler.update()
|
|
self.optimizer.zero_grad()
|
|
|
|
if self.rank == 0 and (self.step - self.step % 2) % self.args.log_train_loss_interval == 0:
|
|
self.writer.add_scalar(f'{log_label}_loss', loss, self.step)
|
|
|
|
if self.rank == 0 and (self.step - self.step % 2) % self.args.log_train_images_interval == 0:
|
|
self.writer.add_image(f'{log_label}_pred_seg', make_grid(pred_seg.flatten(0, 1).float().sigmoid(), nrow=self.args.seq_length_lr), self.step)
|
|
self.writer.add_image(f'{log_label}_true_seg', make_grid(true_seg.flatten(0, 1), nrow=self.args.seq_length_lr), self.step)
|
|
self.writer.add_image(f'{log_label}_true_img', make_grid(true_img.flatten(0, 1), nrow=self.args.seq_length_lr), self.step)
|
|
|
|
def load_next_mat_hr_sample(self):
|
|
try:
|
|
sample = next(self.dataiterator_mat_hr)
|
|
except:
|
|
self.datasampler_hr_train.set_epoch(self.datasampler_hr_train.epoch + 1)
|
|
self.dataiterator_mat_hr = iter(self.dataloader_hr_train)
|
|
sample = next(self.dataiterator_mat_hr)
|
|
return sample
|
|
|
|
def load_next_seg_video_sample(self):
|
|
try:
|
|
sample = next(self.dataiterator_seg_video)
|
|
except:
|
|
self.datasampler_seg_video.set_epoch(self.datasampler_seg_video.epoch + 1)
|
|
self.dataiterator_seg_video = iter(self.dataloader_seg_video)
|
|
sample = next(self.dataiterator_seg_video)
|
|
return sample
|
|
|
|
def load_next_seg_image_sample(self):
|
|
try:
|
|
sample = next(self.dataiterator_seg_image)
|
|
except:
|
|
self.datasampler_seg_image.set_epoch(self.datasampler_seg_image.epoch + 1)
|
|
self.dataiterator_seg_image = iter(self.dataloader_seg_image)
|
|
sample = next(self.dataiterator_seg_image)
|
|
return sample
|
|
|
|
def validate(self):
|
|
if self.rank == 0:
|
|
self.log(f'Validating at the start of epoch: {self.epoch}')
|
|
self.model_ddp.eval()
|
|
total_loss, total_count = 0, 0
|
|
with torch.no_grad():
|
|
with autocast(enabled=not self.args.disable_mixed_precision):
|
|
for true_fgr, true_pha, true_bgr in tqdm(self.dataloader_valid, disable=self.args.disable_progress_bar, dynamic_ncols=True):
|
|
true_fgr = true_fgr.to(self.rank, non_blocking=True)
|
|
true_pha = true_pha.to(self.rank, non_blocking=True)
|
|
true_bgr = true_bgr.to(self.rank, non_blocking=True)
|
|
true_src = true_fgr * true_pha + true_bgr * (1 - true_pha)
|
|
batch_size = true_src.size(0)
|
|
pred_fgr, pred_pha = self.model(true_src)[:2]
|
|
total_loss += matting_loss(pred_fgr, pred_pha, true_fgr, true_pha)['total'].item() * batch_size
|
|
total_count += batch_size
|
|
avg_loss = total_loss / total_count
|
|
self.log(f'Validation set average loss: {avg_loss}')
|
|
self.writer.add_scalar('valid_loss', avg_loss, self.step)
|
|
self.model_ddp.train()
|
|
dist.barrier()
|
|
|
|
def random_crop(self, *imgs):
|
|
h, w = imgs[0].shape[-2:]
|
|
w = random.choice(range(w // 2, w))
|
|
h = random.choice(range(h // 2, h))
|
|
results = []
|
|
for img in imgs:
|
|
B, T = img.shape[:2]
|
|
img = img.flatten(0, 1)
|
|
img = F.interpolate(img, (max(h, w), max(h, w)), mode='bilinear', align_corners=False)
|
|
img = center_crop(img, (h, w))
|
|
img = img.reshape(B, T, *img.shape[1:])
|
|
results.append(img)
|
|
return results
|
|
|
|
def save(self):
|
|
if self.rank == 0:
|
|
os.makedirs(self.args.checkpoint_dir, exist_ok=True)
|
|
torch.save(self.model.state_dict(), os.path.join(self.args.checkpoint_dir, f'epoch-{self.epoch}.pth'))
|
|
self.log('Model saved')
|
|
dist.barrier()
|
|
|
|
def cleanup(self):
|
|
dist.destroy_process_group()
|
|
|
|
def log(self, msg):
|
|
print(f'[GPU{self.rank}] {msg}')
|
|
|
|
if __name__ == '__main__':
|
|
world_size = torch.cuda.device_count()
|
|
mp.spawn(
|
|
Trainer,
|
|
nprocs=world_size,
|
|
args=(world_size,),
|
|
join=True)
|