mirror of https://github.com/THUDM/CodeGeeX.git
You cannot select more than 25 topics
Topics must start with a letter or number, can include dashes ('-') and can be up to 35 characters long.
73 lines
2.6 KiB
Python
73 lines
2.6 KiB
Python
# coding=utf-8
|
|
# Copyright (c) 2020, NVIDIA CORPORATION. All rights reserved.
|
|
#
|
|
# Licensed under the Apache License, Version 2.0 (the "License");
|
|
# you may not use this file except in compliance with the License.
|
|
# You may obtain a copy of the License at
|
|
#
|
|
# http://www.apache.org/licenses/LICENSE-2.0
|
|
#
|
|
# Unless required by applicable law or agreed to in writing, software
|
|
# distributed under the License is distributed on an "AS IS" BASIS,
|
|
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
|
|
# See the License for the specific language governing permissions and
|
|
# limitations under the License.
|
|
|
|
|
|
import torch
|
|
|
|
|
|
def ensure_divisibility(numerator, denominator):
|
|
"""Ensure that numerator is divisible by the denominator."""
|
|
assert numerator % denominator == 0, "{} is not divisible by {}".format(
|
|
numerator, denominator
|
|
)
|
|
|
|
|
|
def divide(numerator, denominator):
|
|
"""Ensure that numerator is divisible by the denominator and return
|
|
the division value."""
|
|
ensure_divisibility(numerator, denominator)
|
|
return numerator // denominator
|
|
|
|
|
|
def split_tensor_along_last_dim(tensor, num_partitions, contiguous_split_chunks=False):
|
|
"""Split a tensor along its last dimension.
|
|
Arguments:
|
|
tensor: input tensor.
|
|
num_partitions: number of partitions to split the tensor
|
|
contiguous_split_chunks: If True, make each chunk contiguous
|
|
in memory.
|
|
"""
|
|
# Get the size and dimension.
|
|
last_dim = tensor.dim() - 1
|
|
last_dim_size = divide(tensor.size()[last_dim], num_partitions)
|
|
# Split.
|
|
tensor_list = torch.split(tensor, last_dim_size, dim=last_dim)
|
|
# Note: torch.split does not create contiguous tensors by default.
|
|
if contiguous_split_chunks:
|
|
return tuple(chunk.contiguous() for chunk in tensor_list)
|
|
|
|
return tensor_list
|
|
|
|
|
|
class VocabUtility:
|
|
"""Split the vocabulary into `world_size` chunks amd return the
|
|
first and last index of the vocabulary belonging to the `rank`
|
|
partition: Note that indecies in [fist, last)"""
|
|
|
|
@staticmethod
|
|
def vocab_range_from_per_partition_vocab_size(
|
|
per_partition_vocab_size, rank, world_size
|
|
):
|
|
index_f = rank * per_partition_vocab_size
|
|
index_l = index_f + per_partition_vocab_size
|
|
return index_f, index_l
|
|
|
|
@staticmethod
|
|
def vocab_range_from_global_vocab_size(global_vocab_size, rank, world_size):
|
|
per_partition_vocab_size = divide(global_vocab_size, world_size)
|
|
return VocabUtility.vocab_range_from_per_partition_vocab_size(
|
|
per_partition_vocab_size, rank, world_size
|
|
)
|