You cannot select more than 25 topics Topics must start with a letter or number, can include dashes ('-') and can be up to 35 characters long.

511 lines
19 KiB
Python

# coding=utf-8
# Copyright (c) 2020, NVIDIA CORPORATION. All rights reserved.
#
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at
#
# http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.
import torch
from contextlib import contextmanager
from torch.nn.parallel.distributed import DistributedDataParallel as torchDDP
from codegeex.megatron import get_args
from codegeex.megatron import get_num_microbatches
from codegeex.megatron import get_timers
from codegeex.megatron import mpu
from codegeex.megatron import p2p_communication
from codegeex.megatron.utils import unwrap_model
from codegeex.megatron.model import DistributedDataParallel as LocalDDP
from codegeex.megatron.model import Float16Module
def get_forward_backward_func():
args = get_args()
if mpu.get_pipeline_model_parallel_world_size() > 1:
if args.virtual_pipeline_model_parallel_size is not None:
forward_backward_func = forward_backward_pipelining_with_interleaving
else:
forward_backward_func = forward_backward_pipelining_without_interleaving
else:
forward_backward_func = forward_backward_no_pipelining
return forward_backward_func
def forward_step(forward_step_func, data_iterator, model, input_tensor, losses_reduced):
"""Forward step for passed-in model.
If first stage, input tensor is obtained from data_iterator, otherwise
passed-in input_tensor is used.
Returns output tensor."""
timers = get_timers()
args = get_args()
timers("forward-compute").start()
unwrapped_model = unwrap_model(model, (torchDDP, LocalDDP, Float16Module))
if not args.deepspeed:
unwrapped_model.set_input_tensor(input_tensor)
else:
unwrapped_model.module.set_input_tensor(input_tensor)
output_tensor, loss_func = forward_step_func(data_iterator, model)
if mpu.is_pipeline_last_stage():
output_tensor = loss_func(output_tensor)
loss, loss_reduced = output_tensor
output_tensor = loss / get_num_microbatches()
losses_reduced.append(loss_reduced)
timers("forward-compute").stop()
return output_tensor
def backward_step(
optimizer, input_tensor, output_tensor, output_tensor_grad, model=None
):
"""Backward step through passed-in output tensor.
If last stage, output_tensor_grad is None, otherwise gradient of loss
with respect to stage's output tensor.
Returns gradient of loss with respect to input tensor (None if first
stage)."""
args = get_args()
if args.deepspeed:
assert model is not None
timers = get_timers()
timers("backward-compute").start()
# Retain the grad on the input_tensor.
if input_tensor is not None:
input_tensor.retain_grad()
if args.deepspeed:
model.backward(output_tensor)
else:
# Backward pass.
if output_tensor_grad is None:
output_tensor = optimizer.scale_loss(output_tensor)
torch.autograd.backward(output_tensor, grad_tensors=output_tensor_grad)
# Collect the grad of the input_tensor.
input_tensor_grad = None
if input_tensor is not None:
input_tensor_grad = input_tensor.grad
timers("backward-compute").stop()
return input_tensor_grad
@contextmanager
def dummy_handler():
try:
yield
finally:
pass
def forward_backward_no_pipelining(
forward_step_func, data_iterator, model, optimizer, timers, forward_only
):
"""Run forward and backward passes with no pipeline parallelism
(no inter-stage communication).
Returns dictionary with losses."""
assert len(model) == 1
model = model[0]
args = get_args()
context_handler = dummy_handler
if isinstance(model, torchDDP):
context_handler = model.no_sync
if args.deepspeed:
model.set_gradient_accumulation_boundary(False)
losses_reduced = []
input_tensor, output_tensor_grad = None, None
with context_handler():
for i in range(get_num_microbatches() - 1):
# print_rank_0("====> start of microstep {i}")
# print_rank_0("====> forward")
output_tensor = forward_step(
forward_step_func, data_iterator, model, input_tensor, losses_reduced
)
# print_rank_0("====> backward")
if not forward_only:
backward_step(
optimizer, input_tensor, output_tensor, output_tensor_grad, model
)
# print_rank_0("====> end of microstep {i}")
if args.deepspeed:
model.set_gradient_accumulation_boundary(True)
# Run computation for last microbatch out of context handler (want to
# synchronize gradients).
# print_rank_0("====> start of the last microstep")
# print_rank_0("====> forward")
output_tensor = forward_step(
forward_step_func, data_iterator, model, input_tensor, losses_reduced
)
# print_rank_0("====> backward")
if not forward_only:
backward_step(optimizer, input_tensor, output_tensor, output_tensor_grad, model)
# print_rank_0("====> end of the last microstep")
return losses_reduced
def forward_backward_pipelining_with_interleaving(
forward_step_func, data_iterator, model, optimizer, timers, forward_only
):
"""Run interleaved 1F1B schedule (model split into model chunks), with
communication between pipeline stages as needed.
Returns dictionary with losses if the last stage, empty dict otherwise."""
input_tensors = [[] for _ in range(len(model))]
output_tensors = [[] for _ in range(len(model))]
losses_reduced = []
if not forward_only:
output_tensor_grads = [[] for _ in range(len(model))]
pipeline_parallel_size = mpu.get_pipeline_model_parallel_world_size()
pipeline_parallel_rank = mpu.get_pipeline_model_parallel_rank()
# Compute number of warmup and remaining microbatches.
num_model_chunks = len(model)
num_microbatches = get_num_microbatches() * num_model_chunks
all_warmup_microbatches = False
if forward_only:
num_warmup_microbatches = num_microbatches
else:
# Run all forward passes and then all backward passes if number of
# microbatches is just the number of pipeline stages.
# Otherwise, perform (num_model_chunks-1)*pipeline_parallel_size on
# all workers, followed by more microbatches after depending on
# stage ID (more forward passes for earlier stages, later stages can
# immediately start with 1F1B).
if get_num_microbatches() == pipeline_parallel_size:
num_warmup_microbatches = num_microbatches
all_warmup_microbatches = True
else:
num_warmup_microbatches = (
pipeline_parallel_size - pipeline_parallel_rank - 1
) * 2
num_warmup_microbatches += (num_model_chunks - 1) * pipeline_parallel_size
num_warmup_microbatches = min(num_warmup_microbatches, num_microbatches)
num_microbatches_remaining = num_microbatches - num_warmup_microbatches
def get_model_chunk_id(microbatch_id, forward):
"""Helper method to get the model chunk ID given the iteration number."""
microbatch_id_in_group = microbatch_id % (
pipeline_parallel_size * num_model_chunks
)
model_chunk_id = microbatch_id_in_group // pipeline_parallel_size
if not forward:
model_chunk_id = num_model_chunks - model_chunk_id - 1
return model_chunk_id
def forward_step_helper(microbatch_id):
"""Helper method to run forward step with model split into chunks
(run set_virtual_pipeline_model_parallel_rank() before calling
forward_step())."""
model_chunk_id = get_model_chunk_id(microbatch_id, forward=True)
mpu.set_virtual_pipeline_model_parallel_rank(model_chunk_id)
if mpu.is_pipeline_first_stage():
if len(input_tensors[model_chunk_id]) == len(
output_tensors[model_chunk_id]
):
input_tensors[model_chunk_id].append(None)
input_tensor = input_tensors[model_chunk_id][-1]
output_tensor = forward_step(
forward_step_func,
data_iterator[model_chunk_id],
model[model_chunk_id],
input_tensor,
losses_reduced,
)
output_tensors[model_chunk_id].append(output_tensor)
return output_tensor
def backward_step_helper(microbatch_id):
"""Helper method to run backward step with model split into chunks
(run set_virtual_pipeline_model_parallel_rank() before calling
backward_step())."""
model_chunk_id = get_model_chunk_id(microbatch_id, forward=False)
mpu.set_virtual_pipeline_model_parallel_rank(model_chunk_id)
if mpu.is_pipeline_last_stage():
if len(output_tensor_grads[model_chunk_id]) == 0:
output_tensor_grads[model_chunk_id].append(None)
input_tensor = input_tensors[model_chunk_id].pop(0)
output_tensor = output_tensors[model_chunk_id].pop(0)
output_tensor_grad = output_tensor_grads[model_chunk_id].pop(0)
input_tensor_grad = backward_step(
optimizer, input_tensor, output_tensor, output_tensor_grad
)
return input_tensor_grad
# Run warmup forward passes.
mpu.set_virtual_pipeline_model_parallel_rank(0)
input_tensors[0].append(p2p_communication.recv_forward(timers))
for k in range(num_warmup_microbatches):
output_tensor = forward_step_helper(k)
# Determine if tensor should be received from previous stage.
next_forward_model_chunk_id = get_model_chunk_id(k + 1, forward=True)
recv_prev = True
if mpu.is_pipeline_first_stage(ignore_virtual=True):
if next_forward_model_chunk_id == 0:
recv_prev = False
if k == (num_microbatches - 1):
recv_prev = False
# Don't send tensor downstream if on last stage.
if mpu.is_pipeline_last_stage():
output_tensor = None
# Send and receive tensors as appropriate (send tensors computed
# in this iteration; receive tensors for next iteration).
if (
k == (num_warmup_microbatches - 1)
and not forward_only
and not all_warmup_microbatches
):
input_tensor_grad = None
recv_next = True
if mpu.is_pipeline_last_stage(ignore_virtual=True):
recv_next = False
(
input_tensor,
output_tensor_grad,
) = p2p_communication.send_forward_backward_recv_forward_backward(
output_tensor,
input_tensor_grad,
recv_prev=recv_prev,
recv_next=recv_next,
timers=timers,
)
output_tensor_grads[num_model_chunks - 1].append(output_tensor_grad)
else:
input_tensor = p2p_communication.send_forward_recv_forward(
output_tensor, recv_prev, timers
)
input_tensors[next_forward_model_chunk_id].append(input_tensor)
# Run 1F1B in steady state.
for k in range(num_microbatches_remaining):
# Forward pass.
forward_k = k + num_warmup_microbatches
output_tensor = forward_step_helper(forward_k)
# Backward pass.
backward_k = k
input_tensor_grad = backward_step_helper(backward_k)
# Send output_tensor and input_tensor_grad, receive input_tensor
# and output_tensor_grad.
# Determine if current stage has anything to send in either direction,
# otherwise set tensor to None.
forward_model_chunk_id = get_model_chunk_id(forward_k, forward=True)
mpu.set_virtual_pipeline_model_parallel_rank(forward_model_chunk_id)
if mpu.is_pipeline_last_stage():
output_tensor = None
backward_model_chunk_id = get_model_chunk_id(backward_k, forward=False)
mpu.set_virtual_pipeline_model_parallel_rank(backward_model_chunk_id)
if mpu.is_pipeline_first_stage():
input_tensor_grad = None
# Determine if peers are sending, and where in data structure to put
# received tensors.
recv_prev = True
if mpu.is_pipeline_first_stage(ignore_virtual=True):
# First stage is ahead of last stage by (pipeline_parallel_size - 1).
next_forward_model_chunk_id = get_model_chunk_id(
forward_k - (pipeline_parallel_size - 1), forward=True
)
if next_forward_model_chunk_id == (num_model_chunks - 1):
recv_prev = False
next_forward_model_chunk_id += 1
else:
next_forward_model_chunk_id = get_model_chunk_id(
forward_k + 1, forward=True
)
recv_next = True
if mpu.is_pipeline_last_stage(ignore_virtual=True):
# Last stage is ahead of first stage by (pipeline_parallel_size - 1).
next_backward_model_chunk_id = get_model_chunk_id(
backward_k - (pipeline_parallel_size - 1), forward=False
)
if next_backward_model_chunk_id == 0:
recv_next = False
next_backward_model_chunk_id -= 1
else:
next_backward_model_chunk_id = get_model_chunk_id(
backward_k + 1, forward=False
)
# If last iteration, don't receive; we already received one extra
# before the start of the for loop.
if k == (num_microbatches_remaining - 1):
recv_prev = False
# Communicate tensors.
(
input_tensor,
output_tensor_grad,
) = p2p_communication.send_forward_backward_recv_forward_backward(
output_tensor,
input_tensor_grad,
recv_prev=recv_prev,
recv_next=recv_next,
timers=timers,
)
# Put input_tensor and output_tensor_grad in data structures in the
# right location.
if recv_prev:
input_tensors[next_forward_model_chunk_id].append(input_tensor)
if recv_next:
output_tensor_grads[next_backward_model_chunk_id].append(output_tensor_grad)
# Run cooldown backward passes (flush out pipeline).
if not forward_only:
if all_warmup_microbatches:
output_tensor_grads[num_model_chunks - 1].append(
p2p_communication.recv_backward(timers)
)
for k in range(num_microbatches_remaining, num_microbatches):
input_tensor_grad = backward_step_helper(k)
next_backward_model_chunk_id = get_model_chunk_id(k + 1, forward=False)
recv_next = True
if mpu.is_pipeline_last_stage(ignore_virtual=True):
if next_backward_model_chunk_id == (num_model_chunks - 1):
recv_next = False
if k == (num_microbatches - 1):
recv_next = False
output_tensor_grads[next_backward_model_chunk_id].append(
p2p_communication.send_backward_recv_backward(
input_tensor_grad, recv_next, timers
)
)
return losses_reduced
def forward_backward_pipelining_without_interleaving(
forward_step_func, data_iterator, model, optimizer, timers, forward_only
):
"""Run non-interleaved 1F1B schedule, with communication between pipeline
stages.
Returns dictionary with losses if the last stage, empty dict otherwise."""
timers = get_timers()
assert len(model) == 1
model = model[0]
# Compute number of warmup microbatches.
num_microbatches = get_num_microbatches()
num_warmup_microbatches = (
mpu.get_pipeline_model_parallel_world_size()
- mpu.get_pipeline_model_parallel_rank()
- 1
)
num_warmup_microbatches = min(num_warmup_microbatches, num_microbatches)
num_microbatches_remaining = num_microbatches - num_warmup_microbatches
input_tensors = []
output_tensors = []
losses_reduced = []
# Run warmup forward passes.
for i in range(num_warmup_microbatches):
input_tensor = p2p_communication.recv_forward(timers)
output_tensor = forward_step(
forward_step_func, data_iterator, model, input_tensor, losses_reduced
)
p2p_communication.send_forward(output_tensor, timers)
input_tensors.append(input_tensor)
output_tensors.append(output_tensor)
# Before running 1F1B, need to receive first forward tensor.
# If all microbatches are run in warmup / cooldown phase, then no need to
# receive this tensor here.
if num_microbatches_remaining > 0:
input_tensor = p2p_communication.recv_forward(timers)
# Run 1F1B in steady state.
for i in range(num_microbatches_remaining):
last_iteration = i == (num_microbatches_remaining - 1)
output_tensor = forward_step(
forward_step_func, data_iterator, model, input_tensor, losses_reduced
)
if forward_only:
p2p_communication.send_forward(output_tensor, timers)
else:
output_tensor_grad = p2p_communication.send_forward_recv_backward(
output_tensor, timers
)
# Add input_tensor and output_tensor to end of list, then pop from the
# start of the list for backward pass.
input_tensors.append(input_tensor)
output_tensors.append(output_tensor)
if forward_only:
if not last_iteration:
input_tensor = p2p_communication.recv_forward(timers)
else:
input_tensor, output_tensor = input_tensors.pop(0), output_tensors.pop(0)
input_tensor_grad = backward_step(
optimizer, input_tensor, output_tensor, output_tensor_grad, model
)
if last_iteration:
input_tensor = None
p2p_communication.send_backward(input_tensor_grad, timers)
else:
input_tensor = p2p_communication.send_backward_recv_forward(
input_tensor_grad, timers
)
# Run cooldown backward passes.
if not forward_only:
for i in range(num_warmup_microbatches):
input_tensor = input_tensors.pop(0)
output_tensor = output_tensors.pop(0)
output_tensor_grad = p2p_communication.recv_backward(timers)
input_tensor_grad = backward_step(
optimizer, input_tensor, output_tensor, output_tensor_grad, model
)
p2p_communication.send_backward(input_tensor_grad, timers)
return losses_reduced