Release cross-platform source code and weights
@ -0,0 +1,34 @@
|
||||
# encoding:utf-8
|
||||
|
||||
import requests
|
||||
import json
|
||||
|
||||
'''
|
||||
Code Generation
|
||||
'''
|
||||
API_KEY = "" # Get from Tianqi console. 从控制台获取
|
||||
API_SECRET = "" # Get from Tianqi console. 从控制台获取
|
||||
PROMPT = "from typing import List\n\ndef has_close_elements(numbers: List[float], threshold: float) -> bool:\n \"\"\" Check if in given list of numbers, are any two numbers closer to each other than\n given threshold.\n >>> has_close_elements([1.0, 2.0, 3.0], 0.5)\n False\n >>> has_close_elements([1.0, 2.8, 3.0, 4.0, 5.0, 2.0], 0.3)\n True\n \"\"\"\n"
|
||||
NUMBER = 3
|
||||
LANG = "Python"
|
||||
request_url = "https://tianqi.aminer.cn/api/v2/"
|
||||
api = 'multilingual_code_generate'
|
||||
|
||||
# Request is in json format. 指定请求参数格式为json
|
||||
headers = {'Content-Type': 'application/json'}
|
||||
request_url = request_url + api
|
||||
data = {
|
||||
"apikey": API_KEY,
|
||||
"apisecret": API_SECRET,
|
||||
"prompt":PROMPT,
|
||||
"n":NUMBER,
|
||||
"lang":LANG
|
||||
}
|
||||
|
||||
def main():
|
||||
response = requests.post(request_url, headers=headers, data=json.dumps(data))
|
||||
if response:
|
||||
print(response.json())
|
||||
|
||||
if __name__ == '__main__':
|
||||
main()
|
@ -1,139 +0,0 @@
|
||||
# This file is for evaluating the budget distribution method
|
||||
import json
|
||||
import numpy as np
|
||||
|
||||
w = json.load(open("solve_rate_final.jsonl", 'r'))
|
||||
|
||||
|
||||
def build_chart():
|
||||
fa = np.ones((201, 201))
|
||||
|
||||
for i in range(1, 201):
|
||||
for j in range(201):
|
||||
fa[j, i] = fa[j, i - 1] * (201 - j - i) / (201 - i)
|
||||
|
||||
return fa
|
||||
|
||||
|
||||
languages = ['cpp', 'go', 'java', 'python', 'js']
|
||||
models = ['codegeex', 'codegen16b', 'codegen6b', 'incoder']
|
||||
fa = build_chart()
|
||||
|
||||
|
||||
def compute(l, dist):
|
||||
budgets = []
|
||||
alldists = []
|
||||
for i in range(2, 41):
|
||||
budgets.append(i * 5)
|
||||
alldists.append(distribute(dist, i * 5))
|
||||
# print(alldists)
|
||||
sums = np.zeros(39)
|
||||
sumdists = np.zeros(39)
|
||||
sumop = np.zeros((39, 5))
|
||||
summax = np.zeros(39)
|
||||
for i in range(164):
|
||||
currents = np.ones(39)
|
||||
currentdists = np.ones(39)
|
||||
currentops = np.ones((39, 5))
|
||||
|
||||
for w in range(5):
|
||||
num = int(l[w][i])
|
||||
for j in range(39):
|
||||
currents[j] = currents[j] * fa[j + 2, num]
|
||||
|
||||
currentdists[j] = currentdists[j] * fa[alldists[j][w], num]
|
||||
|
||||
currentops[j, w] = fa[(j + 2) * 5, num]
|
||||
|
||||
sums = sums + (1 - currents)
|
||||
sumdists = sumdists + (1 - currentdists)
|
||||
sumop = sumop + (1 - currentops)
|
||||
summax = summax + (1 - np.min(currentops, axis=1))
|
||||
|
||||
sumop = np.max(sumop, axis=1)
|
||||
return sums / 164, sumdists / 164, sumop / 164, summax / 164
|
||||
|
||||
|
||||
def distribute(distribution, budget):
|
||||
sum = np.sum(distribution)
|
||||
di = np.array(distribution) / sum * budget
|
||||
dis = []
|
||||
diff = []
|
||||
for i in range(len(di)):
|
||||
dis.append(int(di[i]))
|
||||
diff.append(dis[i] - di[i])
|
||||
# overflow assignment
|
||||
need = np.sum(dis) - budget
|
||||
while need > 0:
|
||||
g = np.argmax(diff)
|
||||
dis[g] -= 1
|
||||
diff[g] -= 1
|
||||
need -= 1
|
||||
while need < 0:
|
||||
g = np.argmin(diff)
|
||||
dis[g] += 1
|
||||
diff[g] += 1
|
||||
need += 1
|
||||
return dis
|
||||
|
||||
|
||||
names = []
|
||||
for i in range(39):
|
||||
names.append(str((i + 2) * 5) + " uniform")
|
||||
for i in range(39):
|
||||
names.append(str((i + 2) * 5) + " weighted")
|
||||
for i in range(39):
|
||||
names.append(str((i + 2) * 5) + " best")
|
||||
for i in range(39):
|
||||
names.append(str((i + 2) * 5) + " max")
|
||||
|
||||
out = open("solution_output.txt", 'w')
|
||||
for model in models:
|
||||
if 'codegeex' in model:
|
||||
dist = [33, 6, 20, 32, 9]
|
||||
if 'codegen' in model:
|
||||
dist = [38, 8, 29, 17, 8]
|
||||
if 'incoder' in model:
|
||||
dist = [12, 4, 5, 45, 34]
|
||||
avi_list = {}
|
||||
for pp in w:
|
||||
if (np.sum(w[pp]) > 1500):
|
||||
if model in pp:
|
||||
for l in languages:
|
||||
if l in pp.replace('javascript', 'js'):
|
||||
if l in avi_list:
|
||||
avi_list[l].append(pp)
|
||||
else:
|
||||
avi_list[l] = [pp]
|
||||
# print(avi_list)
|
||||
maxsums = np.zeros(len(names))
|
||||
maxsumscomb = np.zeros((len(names), 5))
|
||||
current_marker = [0, 0, 0, 0, 0]
|
||||
while current_marker[0] < len(avi_list[languages[0]]):
|
||||
aclist = []
|
||||
for i in range(5):
|
||||
aclist.append(w[avi_list[languages[i]][current_marker[i]]])
|
||||
sums, sumdists, sumop, summax = compute(aclist, dist)
|
||||
things = np.concatenate((sums, sumdists, sumop, summax))
|
||||
for i in range(len(names)):
|
||||
if (things[i] > maxsums[i]):
|
||||
# print(names[i],things[i],current_marker)
|
||||
maxsums[i] = things[i]
|
||||
maxsumscomb[i] = current_marker
|
||||
|
||||
current_marker[-1] += 1
|
||||
p = 4
|
||||
while (current_marker[p] >= len(avi_list[languages[p]]) and p > 0):
|
||||
current_marker[p] = 0
|
||||
current_marker[p - 1] += 1
|
||||
p -= 1
|
||||
|
||||
print(model)
|
||||
print(model, file=out)
|
||||
for i in range(len(names)):
|
||||
print(names[i], maxsums[i], maxsumscomb[i])
|
||||
print(names[i], maxsums[i], file=out)
|
||||
# use the best of mix100 for further purposes
|
||||
for i in range(5):
|
||||
print(languages[i], avi_list[languages[i]][int(maxsumscomb[2, i])])
|
||||
out.close()
|
@ -1,37 +0,0 @@
|
||||
# This file is for gathering the solve rates from generated files
|
||||
import json
|
||||
import os
|
||||
|
||||
import numpy as np
|
||||
|
||||
language = ['cpp', 'java', 'js', 'python', 'go']
|
||||
repo = "<directory_to_generated_jsonl_files>"
|
||||
|
||||
all_reps = os.listdir(repo)
|
||||
|
||||
# choose the ones
|
||||
all_passes = {}
|
||||
assignment = [33, 6, 20, 32, 9]
|
||||
assignment = [38, 8, 29, 17, 8]
|
||||
assignment = [12, 4, 5, 45, 34]
|
||||
for folder in all_reps:
|
||||
if not ("." in folder):
|
||||
q = os.listdir(repo + '/' + folder)
|
||||
for f in q:
|
||||
if 'result' in f and not ('example' in f):
|
||||
passed = np.zeros(164)
|
||||
all_p = 0
|
||||
fi = open(repo + '/' + folder + '/' + f, 'r')
|
||||
t = fi.readlines()
|
||||
for l in t:
|
||||
if len(l.strip()) == 0:
|
||||
continue
|
||||
qq = json.loads(l)
|
||||
if qq['passed'] == True:
|
||||
id = int(qq['task_id'].split('/')[1])
|
||||
passed[id] += 1
|
||||
all_p += 1
|
||||
all_passes[f] = list(passed)
|
||||
print(f, all_p)
|
||||
|
||||
json.dump(all_passes, open('solve_rate_final.jsonl', 'w'))
|
@ -0,0 +1,47 @@
|
||||
# coding=utf-8
|
||||
# Copyright (c) 2020, NVIDIA CORPORATION. All rights reserved.
|
||||
#
|
||||
# Licensed under the Apache License, Version 2.0 (the "License");
|
||||
# you may not use this file except in compliance with the License.
|
||||
# You may obtain a copy of the License at
|
||||
#
|
||||
# http://www.apache.org/licenses/LICENSE-2.0
|
||||
#
|
||||
# Unless required by applicable law or agreed to in writing, software
|
||||
# distributed under the License is distributed on an "AS IS" BASIS,
|
||||
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
|
||||
# See the License for the specific language governing permissions and
|
||||
# limitations under the License.
|
||||
import torch
|
||||
|
||||
from .global_vars import get_args
|
||||
from .global_vars import get_current_global_batch_size
|
||||
from .global_vars import get_num_microbatches
|
||||
from .global_vars import update_num_microbatches
|
||||
from .global_vars import get_tokenizer
|
||||
from .global_vars import get_tensorboard_writer
|
||||
from .global_vars import get_adlr_autoresume
|
||||
from .global_vars import get_timers
|
||||
from .initialize import initialize_megatron
|
||||
|
||||
|
||||
def print_rank_0(message):
|
||||
"""If distributed is initialized, print only on rank 0."""
|
||||
if torch.distributed.is_initialized():
|
||||
if torch.distributed.get_rank() == 0:
|
||||
print(message, flush=True)
|
||||
else:
|
||||
print(message, flush=True)
|
||||
|
||||
|
||||
def is_last_rank():
|
||||
return torch.distributed.get_rank() == (torch.distributed.get_world_size() - 1)
|
||||
|
||||
|
||||
def print_rank_last(message):
|
||||
"""If distributed is initialized, print only on last rank."""
|
||||
if torch.distributed.is_initialized():
|
||||
if is_last_rank():
|
||||
print(message, flush=True)
|
||||
else:
|
||||
print(message, flush=True)
|
@ -0,0 +1,528 @@
|
||||
# coding=utf-8
|
||||
# Copyright (c) 2020, NVIDIA CORPORATION. All rights reserved.
|
||||
#
|
||||
# Licensed under the Apache License, Version 2.0 (the "License");
|
||||
# you may not use this file except in compliance with the License.
|
||||
# You may obtain a copy of the License at
|
||||
#
|
||||
# http://www.apache.org/licenses/LICENSE-2.0
|
||||
#
|
||||
# Unless required by applicable law or agreed to in writing, software
|
||||
# distributed under the License is distributed on an "AS IS" BASIS,
|
||||
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
|
||||
# See the License for the specific language governing permissions and
|
||||
# limitations under the License.
|
||||
|
||||
"""Input/output checkpointing."""
|
||||
|
||||
import os
|
||||
import random
|
||||
import sys
|
||||
import numpy as np
|
||||
from glob import glob
|
||||
|
||||
import torch
|
||||
|
||||
from megatron import get_args, mpu, print_rank_0, update_num_microbatches, utils
|
||||
|
||||
_CHECKPOINT_VERSION = None
|
||||
|
||||
|
||||
def set_checkpoint_version(value):
|
||||
global _CHECKPOINT_VERSION
|
||||
if _CHECKPOINT_VERSION is not None:
|
||||
assert _CHECKPOINT_VERSION == value, "checkpoint versions do not match"
|
||||
_CHECKPOINT_VERSION = value
|
||||
|
||||
|
||||
def get_checkpoint_version():
|
||||
global _CHECKPOINT_VERSION
|
||||
return _CHECKPOINT_VERSION
|
||||
|
||||
|
||||
def check_checkpoint_args(checkpoint_args):
|
||||
"""Ensure fixed arguments for a model are the same for the input
|
||||
arguments and the one retrieved from checkpoint."""
|
||||
args = get_args()
|
||||
|
||||
def _compare(arg_name, old_arg_name=None):
|
||||
if old_arg_name is not None:
|
||||
checkpoint_value = getattr(checkpoint_args, old_arg_name)
|
||||
else:
|
||||
checkpoint_value = getattr(checkpoint_args, arg_name)
|
||||
args_value = getattr(args, arg_name)
|
||||
error_message = (
|
||||
"{} value from checkpoint ({}) is not equal to the "
|
||||
"input argument value ({}).".format(arg_name, checkpoint_value, args_value)
|
||||
)
|
||||
assert checkpoint_value == args_value, error_message
|
||||
|
||||
_compare("num_layers")
|
||||
_compare("hidden_size")
|
||||
_compare("num_attention_heads")
|
||||
_compare("max_position_embeddings")
|
||||
if args.vocab_file:
|
||||
_compare("make_vocab_size_divisible_by")
|
||||
_compare("padded_vocab_size")
|
||||
_compare("tokenizer_type")
|
||||
if get_checkpoint_version() < 3.0:
|
||||
_compare("tensor_model_parallel_size", old_arg_name="model_parallel_size")
|
||||
if get_checkpoint_version() >= 3.0:
|
||||
_compare("tensor_model_parallel_size")
|
||||
_compare("pipeline_model_parallel_size")
|
||||
|
||||
|
||||
def ensure_directory_exists(filename):
|
||||
"""Build filename's path if it does not already exists."""
|
||||
dirname = os.path.dirname(filename)
|
||||
if not os.path.exists(dirname):
|
||||
os.makedirs(dirname)
|
||||
|
||||
|
||||
def get_checkpoint_name(checkpoints_path, iteration, release=False):
|
||||
"""A unified checkpoint name."""
|
||||
if release:
|
||||
directory = "release"
|
||||
else:
|
||||
directory = "iter_{:07d}".format(iteration)
|
||||
# Use both the tensor and pipeline MP rank.
|
||||
if mpu.get_pipeline_model_parallel_world_size() == 1:
|
||||
return os.path.join(
|
||||
checkpoints_path,
|
||||
directory,
|
||||
"mp_rank_{:02d}".format(mpu.get_tensor_model_parallel_rank()),
|
||||
"model_optim_rng.pt",
|
||||
)
|
||||
return os.path.join(
|
||||
checkpoints_path,
|
||||
directory,
|
||||
"mp_rank_{:02d}_{:03d}".format(
|
||||
mpu.get_tensor_model_parallel_rank(), mpu.get_pipeline_model_parallel_rank()
|
||||
),
|
||||
"model_optim_rng.pt",
|
||||
)
|
||||
|
||||
|
||||
def get_checkpoint_tracker_filename(checkpoints_path):
|
||||
"""Tracker file rescords the latest chckpoint during
|
||||
training to restart from."""
|
||||
return os.path.join(checkpoints_path, "latest_checkpointed_iteration.txt")
|
||||
|
||||
|
||||
def save_checkpoint(iteration, model, optimizer, lr_scheduler):
|
||||
"""Save a model checkpoint."""
|
||||
args = get_args()
|
||||
|
||||
# Only rank zero of the data parallel writes to the disk.
|
||||
if not args.deepspeed:
|
||||
model = utils.unwrap_model(model)
|
||||
|
||||
print_rank_0(
|
||||
"saving checkpoint at iteration {:7d} to {}".format(iteration, args.save)
|
||||
)
|
||||
|
||||
if (
|
||||
not torch.distributed.is_initialized()
|
||||
or mpu.get_data_parallel_rank() == 0
|
||||
or args.deepspeed
|
||||
):
|
||||
|
||||
# Arguments, iteration, and model.
|
||||
state_dict = {}
|
||||
state_dict["args"] = args
|
||||
state_dict["checkpoint_version"] = 3.0
|
||||
state_dict["iteration"] = iteration
|
||||
state_dict["tokens"] = args.consumed_train_tokens
|
||||
|
||||
# DeepSpeed saves the model/optimizer/scheduler
|
||||
if not args.deepspeed:
|
||||
if len(model) == 1:
|
||||
state_dict["model"] = model[0].state_dict_for_save_checkpoint()
|
||||
else:
|
||||
for i in range(len(model)):
|
||||
mpu.set_virtual_pipeline_model_parallel_rank(i)
|
||||
state_dict["model%d" % i] = model[
|
||||
i
|
||||
].state_dict_for_save_checkpoint()
|
||||
|
||||
# Optimizer stuff.
|
||||
if not args.no_save_optim:
|
||||
if optimizer is not None:
|
||||
state_dict["optimizer"] = optimizer.state_dict()
|
||||
if lr_scheduler is not None:
|
||||
state_dict["lr_scheduler"] = lr_scheduler.state_dict()
|
||||
|
||||
# RNG states.
|
||||
if not args.no_save_rng:
|
||||
state_dict["random_rng_state"] = random.getstate()
|
||||
state_dict["np_rng_state"] = np.random.get_state()
|
||||
state_dict["torch_rng_state"] = torch.get_rng_state()
|
||||
state_dict["cuda_rng_state"] = torch.cuda.get_rng_state()
|
||||
state_dict["rng_tracker_states"] = mpu.get_cuda_rng_tracker().get_states()
|
||||
|
||||
# Save.
|
||||
checkpoint_name = get_checkpoint_name(args.save, iteration)
|
||||
if not args.deepspeed:
|
||||
ensure_directory_exists(checkpoint_name)
|
||||
torch.save(state_dict, checkpoint_name)
|
||||
|
||||
if args.deepspeed:
|
||||
# megatron model uses state_dict_for_save_checkpointing instead of the standard state_dict
|
||||
# state_dict is used by deepspeed for module saving so it needs to point to the right function
|
||||
if args.no_pipeline_parallel:
|
||||
original_state_dict = model[0].module.state_dict
|
||||
model[0].module.state_dict = model[0].module.state_dict_for_save_checkpoint
|
||||
|
||||
# Saving is a collective communication
|
||||
checkpoint_name = get_checkpoint_name(args.save, iteration)
|
||||
# Trim off the filename and mp_rank_* directory.
|
||||
for _ in range(3):
|
||||
checkpoint_name = os.path.dirname(checkpoint_name)
|
||||
model[0].save_checkpoint(checkpoint_name, client_state=state_dict)
|
||||
|
||||
if args.no_pipeline_parallel:
|
||||
model[0].module.state_dict = original_state_dict
|
||||
|
||||
# Wait so everyone is done (necessary)
|
||||
if torch.distributed.is_initialized():
|
||||
torch.distributed.barrier()
|
||||
|
||||
print_rank_0(
|
||||
" successfully saved checkpoint at iteration {:7d} to {}".format(
|
||||
iteration, args.save
|
||||
)
|
||||
)
|
||||
|
||||
# And update the latest iteration
|
||||
if not torch.distributed.is_initialized() or torch.distributed.get_rank() == 0:
|
||||
tracker_filename = get_checkpoint_tracker_filename(args.save)
|
||||
with open(tracker_filename, "w") as f:
|
||||
f.write(str(iteration))
|
||||
|
||||
# Wait so everyone is done (not necessary)
|
||||
if torch.distributed.is_initialized():
|
||||
torch.distributed.barrier()
|
||||
|
||||
|
||||
def _transpose_first_dim(t, num_splits, num_splits_first, model):
|
||||
input_shape = t.size()
|
||||
# We use a self_attention module but the values extracted aren't
|
||||
# specific to self attention so should work for cross attention as well
|
||||
while hasattr(model, "module"):
|
||||
model = model.module
|
||||
attention_module = model.language_model.encoder.layers[0].self_attention
|
||||
hidden_size_per_attention_head = attention_module.hidden_size_per_attention_head
|
||||
num_attention_heads_per_partition = (
|
||||
attention_module.num_attention_heads_per_partition
|
||||
)
|
||||
if num_splits_first:
|
||||
"""[num_splits * np * hn, h]
|
||||
-->(view) [num_splits, np, hn, h]
|
||||
-->(tranpose) [np, num_splits, hn, h]
|
||||
-->(view) [np * num_splits * hn, h]"""
|
||||
|
||||
intermediate_shape = (
|
||||
num_splits,
|
||||
num_attention_heads_per_partition,
|
||||
hidden_size_per_attention_head,
|
||||
) + input_shape[1:]
|
||||
|
||||
t = t.view(*intermediate_shape)
|
||||
t = t.transpose(0, 1).contiguous()
|
||||
else:
|
||||
"""[np * hn * num_splits, h]
|
||||
-->(view) [np, hn, num_splits, h]
|
||||
-->(tranpose) [np, num_splits, hn, h]
|
||||
-->(view) [np * num_splits * hn, h]"""
|
||||
|
||||
intermediate_shape = (
|
||||
num_attention_heads_per_partition,
|
||||
hidden_size_per_attention_head,
|
||||
num_splits,
|
||||
) + input_shape[1:]
|
||||
|
||||
t = t.view(*intermediate_shape)
|
||||
t = t.transpose(1, 2).contiguous()
|
||||
t = t.view(*input_shape)
|
||||
|
||||
return t
|
||||
|
||||
|
||||
def fix_query_key_value_ordering(model, checkpoint_version):
|
||||
"""Fix up query/key/value matrix ordering if checkpoint
|
||||
version is smaller than 2.0
|
||||
"""
|
||||
if checkpoint_version < 2.0:
|
||||
if isinstance(model, list):
|
||||
assert len(model) == 1
|
||||
model = model[0]
|
||||
for name, param in model.named_parameters():
|
||||
if name.endswith((".query_key_value.weight", ".query_key_value.bias")):
|
||||
if checkpoint_version == 0:
|
||||
fixed_param = _transpose_first_dim(param.data, 3, True, model)
|
||||
elif checkpoint_version == 1.0:
|
||||
fixed_param = _transpose_first_dim(param.data, 3, False, model)
|
||||
else:
|
||||
print_rank_0(f"Invalid checkpoint version {checkpoint_version}.")
|
||||
sys.exit()
|
||||
param.data.copy_(fixed_param)
|
||||
if name.endswith((".key_value.weight", ".key_value.bias")):
|
||||
if checkpoint_version == 0:
|
||||
fixed_param = _transpose_first_dim(param.data, 2, True, model)
|
||||
elif checkpoint_version == 1.0:
|
||||
fixed_param = _transpose_first_dim(param.data, 2, False, model)
|
||||
else:
|
||||
print_rank_0(f"Invalid checkpoint version {checkpoint_version}.")
|
||||
sys.exit()
|
||||
param.data.copy_(fixed_param)
|
||||
print_rank_0(
|
||||
" succesfully fixed query-key-values ordering for"
|
||||
" checkpoint version {}".format(checkpoint_version)
|
||||
)
|
||||
|
||||
|
||||
def load_deepspeed_state(model):
|
||||
model = utils.unwrap_model(model)
|
||||
args = get_args()
|
||||
load_dir = args.load
|
||||
if os.path.isdir(load_dir):
|
||||
model_state_paths = glob(os.path.join(load_dir, "*model_states.pt"))
|
||||
assert len(model_state_paths) == 1, (
|
||||
"only support loading deepspeed checkpoint of model parallel size 1"
|
||||
", but got {}".format(model_state_paths)
|
||||
)
|
||||
model_state_path = model_state_paths[0]
|
||||
else:
|
||||
model_state_path = load_dir
|
||||
state_dict = torch.load(model_state_path, map_location="cpu")
|
||||
state_dict = state_dict["module"]
|
||||
|
||||
model[0].load_state_dict(state_dict, strict=True)
|
||||
|
||||
|
||||
def load_checkpoint(model, optimizer, lr_scheduler, load_arg="load", strict=True):
|
||||
"""Load a model checkpoint and return the iteration.
|
||||
strict (bool): whether to strictly enforce that the keys in
|
||||
:attr:`state_dict` of the checkpoint match the names of
|
||||
parameters and buffers in model.
|
||||
"""
|
||||
args = get_args()
|
||||
load_dir = getattr(args, load_arg)
|
||||
|
||||
if args.deepspeed:
|
||||
loaded_dir, state_dict = model[0].load_checkpoint(load_dir)
|
||||
if loaded_dir is None:
|
||||
print_rank_0(
|
||||
"WARNING: could not find the metadata file {} ".format(load_dir)
|
||||
)
|
||||
print_rank_0(
|
||||
" will not load any checkpoints and will start from " "random"
|
||||
)
|
||||
return 0
|
||||
release = False
|
||||
else:
|
||||
model = utils.unwrap_model(model)
|
||||
|
||||
# Read the tracker file and set the iteration.
|
||||
tracker_filename = get_checkpoint_tracker_filename(load_dir)
|
||||
|
||||
# If no tracker file, return iretation zero.
|
||||
if not os.path.isfile(tracker_filename):
|
||||
print_rank_0(
|
||||
"WARNING: could not find the metadata file {} ".format(tracker_filename)
|
||||
)
|
||||
print_rank_0(
|
||||
" will not load any checkpoints and will start from " "random"
|
||||
)
|
||||
return 0
|
||||
|
||||
# Otherwise, read the tracker file and either set the iteration or
|
||||
# mark it as a release checkpoint.
|
||||
iteration = 0
|
||||
release = False
|
||||
with open(tracker_filename, "r") as f:
|
||||
metastring = f.read().strip()
|
||||
try:
|
||||
iteration = int(metastring)
|
||||
except ValueError:
|
||||
release = metastring == "release"
|
||||
if not release:
|
||||
print_rank_0(
|
||||
"ERROR: Invalid metadata file {}. Exiting".format(
|
||||
tracker_filename
|
||||
)
|
||||
)
|
||||
sys.exit()
|
||||
|
||||
assert iteration > 0 or release, "error parsing metadata file {}".format(
|
||||
tracker_filename
|
||||
)
|
||||
|
||||
# Checkpoint.
|
||||
checkpoint_name = get_checkpoint_name(load_dir, iteration, release)
|
||||
print_rank_0(f" loading checkpoint from {args.load} at iteration {iteration}")
|
||||
|
||||
# Load the checkpoint.
|
||||
try:
|
||||
state_dict = torch.load(checkpoint_name, map_location="cpu")
|
||||
except ModuleNotFoundError:
|
||||
from megatron.fp16_deprecated import loss_scaler
|
||||
|
||||
# For backward compatibility.
|
||||
print_rank_0(" > deserializing using the old code structure ...")
|
||||
sys.modules["fp16.loss_scaler"] = sys.modules[
|
||||
"megatron.fp16_deprecated.loss_scaler"
|
||||
]
|
||||
sys.modules["megatron.fp16.loss_scaler"] = sys.modules[
|
||||
"megatron.fp16_deprecated.loss_scaler"
|
||||
]
|
||||
state_dict = torch.load(checkpoint_name, map_location="cpu")
|
||||
sys.modules.pop("fp16.loss_scaler", None)
|
||||
sys.modules.pop("megatron.fp16.loss_scaler", None)
|
||||
except BaseException as e:
|
||||
print_rank_0("could not load the checkpoint")
|
||||
print_rank_0(e)
|
||||
sys.exit()
|
||||
|
||||
# set checkpoint version
|
||||
set_checkpoint_version(state_dict.get("checkpoint_version", 0))
|
||||
|
||||
# Set iteration.
|
||||
if args.finetune or release:
|
||||
iteration = 0
|
||||
else:
|
||||
try:
|
||||
iteration = state_dict["iteration"]
|
||||
if "tokens" in state_dict:
|
||||
args.consumed_train_tokens = state_dict["tokens"]
|
||||
except KeyError:
|
||||
try: # Backward compatible with older checkpoints
|
||||
iteration = state_dict["total_iters"]
|
||||
except KeyError:
|
||||
print_rank_0(
|
||||
"A metadata file exists but unable to load "
|
||||
"iteration from checkpoint {}, exiting".format(checkpoint_name)
|
||||
)
|
||||
sys.exit()
|
||||
|
||||
# Check arguments.
|
||||
assert args.consumed_train_samples == 0
|
||||
assert args.consumed_valid_samples == 0
|
||||
if "args" in state_dict:
|
||||
checkpoint_args = state_dict["args"]
|
||||
check_checkpoint_args(checkpoint_args)
|
||||
args.consumed_train_samples = getattr(
|
||||
checkpoint_args, "consumed_train_samples", 0
|
||||
)
|
||||
update_num_microbatches(consumed_samples=args.consumed_train_samples)
|
||||
args.consumed_valid_samples = getattr(
|
||||
checkpoint_args, "consumed_valid_samples", 0
|
||||
)
|
||||
else:
|
||||
print_rank_0("could not find arguments in the checkpoint ...")
|
||||
|
||||
# Model.
|
||||
if not args.deepspeed:
|
||||
if len(model) == 1:
|
||||
model[0].load_state_dict(state_dict["model"], strict=strict)
|
||||
else:
|
||||
for i in range(len(model)):
|
||||
mpu.set_virtual_pipeline_model_parallel_rank(i)
|
||||
model[i].load_state_dict(state_dict["model%d" % i], strict=strict)
|
||||
|
||||
# Fix up query/key/value matrix ordering if needed
|
||||
checkpoint_version = get_checkpoint_version()
|
||||
print_rank_0(f" checkpoint version {checkpoint_version}")
|
||||
fix_query_key_value_ordering(model, checkpoint_version)
|
||||
|
||||
# Optimizer.
|
||||
if not args.deepspeed:
|
||||
if not release and not args.finetune and not args.no_load_optim:
|
||||
try:
|
||||
if optimizer is not None:
|
||||
optimizer.load_state_dict(state_dict["optimizer"])
|
||||
if lr_scheduler is not None:
|
||||
lr_scheduler.load_state_dict(state_dict["lr_scheduler"])
|
||||
except KeyError:
|
||||
print_rank_0(
|
||||
"Unable to load optimizer from checkpoint {}. "
|
||||
"Specify --no-load-optim or --finetune to prevent "
|
||||
"attempting to load the optimizer state, "
|
||||
"exiting ...".format(checkpoint_name)
|
||||
)
|
||||
sys.exit()
|
||||
|
||||
# rng states.
|
||||
if not release and not args.finetune and not args.no_load_rng:
|
||||
try:
|
||||
random.setstate(state_dict["random_rng_state"])
|
||||
np.random.set_state(state_dict["np_rng_state"])
|
||||
torch.set_rng_state(state_dict["torch_rng_state"])
|
||||
torch.cuda.set_rng_state(state_dict["cuda_rng_state"])
|
||||
# Check for empty states array
|
||||
if not state_dict["rng_tracker_states"]:
|
||||
raise KeyError
|
||||
mpu.get_cuda_rng_tracker().set_states(state_dict["rng_tracker_states"])
|
||||
except KeyError:
|
||||
print_rank_0(
|
||||
"Unable to load rng state from checkpoint {}. "
|
||||
"Specify --no-load-rng or --finetune to prevent "
|
||||
"attempting to load the rng state, "
|
||||
"exiting ...".format(checkpoint_name)
|
||||
)
|
||||
sys.exit()
|
||||
|
||||
# Some utilities want to load a checkpoint without distributed being initialized
|
||||
# if torch.distributed.is_initialized():
|
||||
# torch.distributed.barrier()
|
||||
|
||||
print_rank_0(
|
||||
f" successfully loaded checkpoint from {args.load} "
|
||||
f"at iteration {iteration}"
|
||||
)
|
||||
|
||||
return iteration
|
||||
|
||||
|
||||
def load_biencoder_checkpoint(
|
||||
model, only_query_model=False, only_context_model=False, custom_load_path=None
|
||||
):
|
||||
"""
|
||||
selectively load retrieval models for indexing/retrieving
|
||||
from saved checkpoints
|
||||
"""
|
||||
|
||||
args = get_args()
|
||||
|
||||
model = utils.unwrap_model(model)
|
||||
|
||||
load_path = custom_load_path if custom_load_path is not None else args.load
|
||||
|
||||
tracker_filename = get_checkpoint_tracker_filename(load_path)
|
||||
with open(tracker_filename, "r") as f:
|
||||
iteration = int(f.read().strip())
|
||||
|
||||
checkpoint_name = get_checkpoint_name(load_path, iteration, False)
|
||||
if mpu.get_data_parallel_rank() == 0:
|
||||
print(
|
||||
"global rank {} is loading checkpoint {}".format(
|
||||
torch.distributed.get_rank(), checkpoint_name
|
||||
)
|
||||
)
|
||||
|
||||
state_dict = torch.load(checkpoint_name, map_location="cpu")
|
||||
ret_state_dict = state_dict["model"]
|
||||
|
||||
if only_query_model:
|
||||
ret_state_dict.pop("context_model")
|
||||
if only_context_model:
|
||||
ret_state_dict.pop("query_model")
|
||||
|
||||
assert len(model) == 1
|
||||
model[0].load_state_dict(ret_state_dict)
|
||||
torch.distributed.barrier()
|
||||
|
||||
if mpu.get_data_parallel_rank() == 0:
|
||||
print(" successfully loaded {}".format(checkpoint_name))
|
||||
|
||||
return model
|
@ -0,0 +1,256 @@
|
||||
# coding=utf-8
|
||||
# Copyright (c) 2020, NVIDIA CORPORATION. All rights reserved.
|
||||
#
|
||||
# Licensed under the Apache License, Version 2.0 (the "License");
|
||||
# you may not use this file except in compliance with the License.
|
||||
# You may obtain a copy of the License at
|
||||
#
|
||||
# http://www.apache.org/licenses/LICENSE-2.0
|
||||
#
|
||||
# Unless required by applicable law or agreed to in writing, software
|
||||
# distributed under the License is distributed on an "AS IS" BASIS,
|
||||
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
|
||||
# See the License for the specific language governing permissions and
|
||||
# limitations under the License.
|
||||
|
||||
"""Megatron global variables."""
|
||||
|
||||
import os
|
||||
import sys
|
||||
import time
|
||||
import torch
|
||||
|
||||
from codegeex.megatron.tokenizer import build_tokenizer
|
||||
from codegeex.megatron.arguments import parse_args
|
||||
|
||||
_GLOBAL_ARGS = None
|
||||
_GLOBAL_NUM_MICROBATCHES_CALCULATOR = None
|
||||
_GLOBAL_TOKENIZER = None
|
||||
_GLOBAL_TENSORBOARD_WRITER = None
|
||||
_GLOBAL_ADLR_AUTORESUME = None
|
||||
_GLOBAL_TIMERS = None
|
||||
|
||||
|
||||
def get_args():
|
||||
"""Return arguments."""
|
||||
_ensure_var_is_initialized(_GLOBAL_ARGS, "args")
|
||||
return _GLOBAL_ARGS
|
||||
|
||||
|
||||
def get_num_microbatches():
|
||||
return _GLOBAL_NUM_MICROBATCHES_CALCULATOR.get()
|
||||
|
||||
|
||||
def get_current_global_batch_size():
|
||||
return _GLOBAL_NUM_MICROBATCHES_CALCULATOR.get_current_global_batch_size()
|
||||
|
||||
|
||||
def update_num_microbatches(consumed_samples, consistency_check=True):
|
||||
_GLOBAL_NUM_MICROBATCHES_CALCULATOR.update(consumed_samples, consistency_check)
|
||||
|
||||
|
||||
def get_tokenizer():
|
||||
"""Return tokenizer."""
|
||||
_ensure_var_is_initialized(_GLOBAL_TOKENIZER, "tokenizer")
|
||||
return _GLOBAL_TOKENIZER
|
||||
|
||||
|
||||
def get_tensorboard_writer():
|
||||
"""Return tensorboard writer. It can be None so no need
|
||||
to check if it is initialized."""
|
||||
return _GLOBAL_TENSORBOARD_WRITER
|
||||
|
||||
|
||||
def get_adlr_autoresume():
|
||||
"""ADLR autoresume object. It can be None so no need
|
||||
to check if it is initialized."""
|
||||
return _GLOBAL_ADLR_AUTORESUME
|
||||
|
||||
|
||||
def get_timers():
|
||||
"""Return timers."""
|
||||
_ensure_var_is_initialized(_GLOBAL_TIMERS, "timers")
|
||||
return _GLOBAL_TIMERS
|
||||
|
||||
|
||||
def set_global_variables(
|
||||
extra_args_provider=None, args_defaults={}, ignore_unknown_args=False
|
||||
):
|
||||
"""Set args, tokenizer, tensorboard-writer, adlr-autoresume, and timers."""
|
||||
args = _parse_args(
|
||||
extra_args_provider=extra_args_provider,
|
||||
defaults=args_defaults,
|
||||
ignore_unknown_args=ignore_unknown_args,
|
||||
)
|
||||
if args.vocab_file or args.tokenizer_path:
|
||||
_ = _build_tokenizer(args)
|
||||
_set_tensorboard_writer(args)
|
||||
_set_adlr_autoresume(args)
|
||||
_set_timers()
|
||||
|
||||
|
||||
def _parse_args(extra_args_provider=None, defaults={}, ignore_unknown_args=False):
|
||||
"""Parse entire arguments."""
|
||||
global _GLOBAL_ARGS
|
||||
_ensure_var_is_not_initialized(_GLOBAL_ARGS, "args")
|
||||
_GLOBAL_ARGS = parse_args(
|
||||
extra_args_provider=extra_args_provider,
|
||||
defaults=defaults,
|
||||
ignore_unknown_args=ignore_unknown_args,
|
||||
)
|
||||
return _GLOBAL_ARGS
|
||||
|
||||
|
||||
def _build_tokenizer(args):
|
||||
"""Initialize tokenizer."""
|
||||
global _GLOBAL_TOKENIZER
|
||||
_ensure_var_is_not_initialized(_GLOBAL_TOKENIZER, "tokenizer")
|
||||
_GLOBAL_TOKENIZER = build_tokenizer(args)
|
||||
return _GLOBAL_TOKENIZER
|
||||
|
||||
|
||||
def rebuild_tokenizer(args):
|
||||
global _GLOBAL_TOKENIZER
|
||||
_GLOBAL_TOKENIZER = None
|
||||
return _build_tokenizer(args)
|
||||
|
||||
|
||||
def _set_tensorboard_writer(args):
|
||||
"""Set tensorboard writer."""
|
||||
global _GLOBAL_TENSORBOARD_WRITER
|
||||
_ensure_var_is_not_initialized(_GLOBAL_TENSORBOARD_WRITER, "tensorboard writer")
|
||||
|
||||
if (
|
||||
hasattr(args, "tensorboard_dir")
|
||||
and args.tensorboard_dir
|
||||
and args.rank == (args.world_size - 1)
|
||||
):
|
||||
try:
|
||||
from torch.utils.tensorboard import SummaryWriter
|
||||
|
||||
print("> setting tensorboard ...")
|
||||
_GLOBAL_TENSORBOARD_WRITER = SummaryWriter(
|
||||
log_dir=args.tensorboard_dir, max_queue=args.tensorboard_queue_size
|
||||
)
|
||||
except ModuleNotFoundError:
|
||||
print(
|
||||
"WARNING: TensorBoard writing requested but is not "
|
||||
"available (are you using PyTorch 1.1.0 or later?), "
|
||||
"no TensorBoard logs will be written.",
|
||||
flush=True,
|
||||
)
|
||||
|
||||
|
||||
def _set_adlr_autoresume(args):
|
||||
"""Initialize ADLR autoresume."""
|
||||
global _GLOBAL_ADLR_AUTORESUME
|
||||
_ensure_var_is_not_initialized(_GLOBAL_ADLR_AUTORESUME, "adlr autoresume")
|
||||
|
||||
if args.adlr_autoresume:
|
||||
if args.rank == 0:
|
||||
print("enabling autoresume ...", flush=True)
|
||||
sys.path.append(os.environ.get("SUBMIT_SCRIPTS", "."))
|
||||
try:
|
||||
from userlib.auto_resume import AutoResume
|
||||
except BaseException:
|
||||
print("ADLR autoresume is not available, exiting ...")
|
||||
sys.exit()
|
||||
|
||||
_GLOBAL_ADLR_AUTORESUME = AutoResume
|
||||
|
||||
|
||||
def _set_timers():
|
||||
"""Initialize timers."""
|
||||
global _GLOBAL_TIMERS
|
||||
_ensure_var_is_not_initialized(_GLOBAL_TIMERS, "timers")
|
||||
_GLOBAL_TIMERS = Timers()
|
||||
|
||||
|
||||
def _ensure_var_is_initialized(var, name):
|
||||
"""Make sure the input variable is not None."""
|
||||
assert var is not None, "{} is not initialized.".format(name)
|
||||
|
||||
|
||||
def _ensure_var_is_not_initialized(var, name):
|
||||
"""Make sure the input variable is not None."""
|
||||
assert var is None, "{} is already initialized.".format(name)
|
||||
|
||||
|
||||
class _Timer:
|
||||
"""Timer."""
|
||||
|
||||
def __init__(self, name):
|
||||
self.name_ = name
|
||||
self.elapsed_ = 0.0
|
||||
self.started_ = False
|
||||
self.start_time = time.time()
|
||||
|
||||
def start(self):
|
||||
"""Start the timer."""
|
||||
assert not self.started_, "timer has already been started"
|
||||
torch.cuda.synchronize()
|
||||
self.start_time = time.time()
|
||||
self.started_ = True
|
||||
|
||||
def stop(self):
|
||||
"""Stop the timer."""
|
||||
assert self.started_, "timer is not started"
|
||||
torch.cuda.synchronize()
|
||||
self.elapsed_ += time.time() - self.start_time
|
||||
self.started_ = False
|
||||
|
||||
def reset(self):
|
||||
"""Reset timer."""
|
||||
self.elapsed_ = 0.0
|
||||
self.started_ = False
|
||||
|
||||
def elapsed(self, reset=True):
|
||||
"""Calculate the elapsed time."""
|
||||
started_ = self.started_
|
||||
# If the timing in progress, end it first.
|
||||
if self.started_:
|
||||
self.stop()
|
||||
# Get the elapsed time.
|
||||
elapsed_ = self.elapsed_
|
||||
# Reset the elapsed time
|
||||
if reset:
|
||||
self.reset()
|
||||
# If timing was in progress, set it back.
|
||||
if started_:
|
||||
self.start()
|
||||
return elapsed_
|
||||
|
||||
|
||||
class Timers:
|
||||
"""Group of timers."""
|
||||
|
||||
def __init__(self):
|
||||
self.timers = {}
|
||||
|
||||
def __call__(self, name):
|
||||
if name not in self.timers:
|
||||
self.timers[name] = _Timer(name)
|
||||
return self.timers[name]
|
||||
|
||||
def write(self, names, writer, iteration, normalizer=1.0, reset=False):
|
||||
"""Write timers to a tensorboard writer"""
|
||||
# currently when using add_scalars,
|
||||
# torch.utils.add_scalars makes each timer its own run, which
|
||||
# polutes the runs list, so we just add each as a scalar
|
||||
assert normalizer > 0.0
|
||||
for name in names:
|
||||
value = self.timers[name].elapsed(reset=reset) / normalizer
|
||||
writer.add_scalar(name + "-time", value, iteration)
|
||||
|
||||
def log(self, names, normalizer=1.0, reset=True):
|
||||
"""Log a group of timers."""
|
||||
assert normalizer > 0.0
|
||||
string = "time (ms)"
|
||||
for name in names:
|
||||
elapsed_time = self.timers[name].elapsed(reset=reset) * 1000.0 / normalizer
|
||||
string += " | {}: {:.2f}".format(name, elapsed_time)
|
||||
if torch.distributed.is_initialized():
|
||||
if torch.distributed.get_rank() == (torch.distributed.get_world_size() - 1):
|
||||
print(string, flush=True)
|
||||
else:
|
||||
print(string, flush=True)
|
@ -0,0 +1,244 @@
|
||||
import copy
|
||||
import json
|
||||
import random
|
||||
import traceback
|
||||
from typing import *
|
||||
|
||||
import numpy
|
||||
import torch
|
||||
import zmq
|
||||
|
||||
from codegeex.benchmark.utils import is_code_generation_finished, cleanup_code
|
||||
from codegeex.megatron import get_args, get_tokenizer
|
||||
from codegeex.megatron import mpu
|
||||
from codegeex.megatron.code_generation_utils import get_token_stream
|
||||
from codegeex.megatron.model import CodeGeeXModel
|
||||
|
||||
|
||||
def model_provider():
|
||||
"""Build the model."""
|
||||
|
||||
model = CodeGeeXModel(num_tokentypes=0,
|
||||
parallel_output=False)
|
||||
|
||||
return model
|
||||
|
||||
|
||||
def set_random_seed(seed):
|
||||
"""Set random seed for reproducability."""
|
||||
random.seed(seed)
|
||||
numpy.random.seed(seed)
|
||||
torch.manual_seed(seed)
|
||||
mpu.model_parallel_cuda_manual_seed(seed)
|
||||
|
||||
|
||||
def run_generation_distributed(model):
|
||||
args = get_args()
|
||||
if hasattr(args, "language_tgt_type"):
|
||||
language_type = args.language_tgt_type
|
||||
else:
|
||||
language_type = args.language_type
|
||||
print(f"Connecting to tcp://{args.channel_ip}:{args.channel_port}")
|
||||
context = zmq.Context()
|
||||
socket = context.socket(zmq.REQ)
|
||||
socket.connect(f"tcp://{args.channel_ip}:{args.channel_port}")
|
||||
output_file_path = args.output_prefix + f"_finished_rank{args.gen_rank}.jsonl"
|
||||
unfinished_output_file_path = args.output_prefix + f"_unfinished_rank{args.gen_rank}.jsonl"
|
||||
problems = {}
|
||||
print("Building tokenizer...")
|
||||
tokenizer = get_tokenizer()
|
||||
|
||||
with open(output_file_path, "w") as f:
|
||||
with open(unfinished_output_file_path, "w") as unfinished_f:
|
||||
while True:
|
||||
socket.send_json({"rank": args.gen_rank, "action": "pull"})
|
||||
resp = socket.recv_json()
|
||||
try:
|
||||
if "codecontest" in args.dataset.lower():
|
||||
if resp["contest_name"] is None:
|
||||
break
|
||||
elif resp["task_id"] is None:
|
||||
break
|
||||
|
||||
if "codecontest" in args.dataset.lower():
|
||||
current_spec = problems[resp["contest_name"]]
|
||||
prompt = current_spec.prompt
|
||||
else:
|
||||
current_spec = resp["task_id"]
|
||||
prompt = current_spec["prompt"]
|
||||
|
||||
temperature = None if "temperature" not in resp else resp["temperature"]
|
||||
topp = None if "topp" not in resp else resp["topp"]
|
||||
|
||||
f.flush()
|
||||
unfinished_f.flush()
|
||||
tokens = tokenizer.tokenize(prompt)
|
||||
n_token_prompt = len(tokens)
|
||||
if n_token_prompt >= args.seq_length:
|
||||
continue
|
||||
if "micro_batch_size" in resp:
|
||||
micro_batch_size = resp["micro_batch_size"]
|
||||
else:
|
||||
micro_batch_size = args.micro_batch_size
|
||||
if args.beam_search:
|
||||
beams = get_token_stream(
|
||||
model,
|
||||
[
|
||||
copy.deepcopy(tokens)
|
||||
for _ in range(micro_batch_size)
|
||||
],
|
||||
return_scores=args.return_scores,
|
||||
prompt_length=n_token_prompt,
|
||||
micro_batch_size=micro_batch_size,
|
||||
bad_ids=args.bad_ids,
|
||||
temperature=temperature,
|
||||
topp=topp,
|
||||
beam_warmup=args.beam_warmup,
|
||||
)
|
||||
for beam in beams:
|
||||
generated_tokens_ = beam.tokens
|
||||
generated_tokens_ = (
|
||||
generated_tokens_
|
||||
if generated_tokens_[-1] != tokenizer.eod
|
||||
else generated_tokens_[:-1]
|
||||
)
|
||||
generated_code = tokenizer.detokenize(generated_tokens_[n_token_prompt:])
|
||||
generated_code = cleanup_code(generated_code,
|
||||
language_type=language_type,
|
||||
dataset=args.dataset)
|
||||
f.write(
|
||||
json.dumps(
|
||||
{
|
||||
"task_id" : current_spec['task_id'],
|
||||
"prompt" : prompt,
|
||||
"generation": generated_code,
|
||||
"scores" : beam.score,
|
||||
"finish" : 2 if generated_tokens[i].cpu().numpy()[
|
||||
-1] == tokenizer.eod else 1,
|
||||
"output" : beam.tokens,
|
||||
}
|
||||
)
|
||||
+ "\n"
|
||||
)
|
||||
socket.send_json(
|
||||
{
|
||||
"rank" : args.gen_rank,
|
||||
"action" : "success",
|
||||
"task_id": current_spec['task_id']
|
||||
}
|
||||
)
|
||||
socket.recv()
|
||||
continue
|
||||
|
||||
token_stream = get_token_stream(
|
||||
model,
|
||||
[
|
||||
copy.deepcopy(tokens)
|
||||
for _ in range(micro_batch_size)
|
||||
],
|
||||
return_scores=args.return_scores,
|
||||
prompt_length=n_token_prompt,
|
||||
micro_batch_size=micro_batch_size,
|
||||
bad_ids=args.bad_ids,
|
||||
temperature=temperature,
|
||||
topp=topp,
|
||||
beam_warmup=args.beam_warmup,
|
||||
)
|
||||
is_finished = [False for _ in range(micro_batch_size)]
|
||||
for generated in token_stream:
|
||||
generated_tokens = generated[0]
|
||||
if args.return_scores:
|
||||
scores = generated[1][1]
|
||||
else:
|
||||
scores = None
|
||||
|
||||
for i in range(micro_batch_size):
|
||||
if is_finished[i]:
|
||||
continue
|
||||
|
||||
generated_tokens_ = generated_tokens[i].cpu().numpy().tolist()
|
||||
generated_tokens_ = (
|
||||
generated_tokens_
|
||||
if generated_tokens_[-1] != tokenizer.eod
|
||||
else generated_tokens_[:-1]
|
||||
)
|
||||
generated_code = tokenizer.detokenize(generated_tokens_[n_token_prompt:])
|
||||
if generated_tokens[i].cpu().numpy()[-1] == tokenizer.eod or \
|
||||
is_code_generation_finished(
|
||||
generated_code,
|
||||
language_type=language_type,
|
||||
dataset=args.dataset,
|
||||
):
|
||||
is_finished[i] = True
|
||||
generated_code = cleanup_code(generated_code,
|
||||
language_type=language_type,
|
||||
dataset=args.dataset)
|
||||
f.write(
|
||||
json.dumps(
|
||||
{
|
||||
"task_id" : current_spec['task_id'],
|
||||
"prompt" : prompt,
|
||||
"generation": generated_code,
|
||||
"scores" : 0.0 if scores is None else scores[i].detach().cpu().item(),
|
||||
"finish" : 2 if generated_tokens[i].cpu().numpy()[
|
||||
-1] == tokenizer.eod else 1,
|
||||
"output" : generated_tokens[i].cpu().numpy().tolist(),
|
||||
}
|
||||
)
|
||||
+ "\n"
|
||||
)
|
||||
|
||||
if len(generated_tokens[i]) >= args.out_seq_length:
|
||||
break
|
||||
|
||||
if all(is_finished):
|
||||
break
|
||||
|
||||
for i in range(micro_batch_size):
|
||||
if not is_finished[i]:
|
||||
generated_tokens_ = generated_tokens[i].cpu().numpy().tolist()
|
||||
generated_code = tokenizer.detokenize(generated_tokens_[n_token_prompt:])
|
||||
unfinished_f.write(
|
||||
json.dumps(
|
||||
{
|
||||
"task_id" : current_spec['task_id'],
|
||||
"prompt" : prompt,
|
||||
"generation": generated_code,
|
||||
"scores" : 0.0 if scores is None else scores[i].detach().cpu().item(),
|
||||
"finish" : 0,
|
||||
"output" : generated_tokens_,
|
||||
}
|
||||
)
|
||||
+ "\n"
|
||||
)
|
||||
|
||||
socket.send_json(
|
||||
{
|
||||
"rank" : args.gen_rank,
|
||||
"action" : "success",
|
||||
"task_id": current_spec['task_id']
|
||||
}
|
||||
)
|
||||
socket.recv()
|
||||
|
||||
except Exception as e:
|
||||
print(f"*** (rank={args.gen_rank}) crashed.")
|
||||
print(f" error: {repr(e)}")
|
||||
traceback.print_exc()
|
||||
if args.dataset.lower() == "codecontest":
|
||||
socket.send_json({
|
||||
"rank" : args.gen_rank,
|
||||
"action" : "fail",
|
||||
"contest_name" : current_spec.name,
|
||||
"micro_batch_size": micro_batch_size
|
||||
})
|
||||
else:
|
||||
socket.send_json(
|
||||
{
|
||||
"rank" : args.gen_rank,
|
||||
"action" : "fail",
|
||||
"task_id": current_spec['task_id']
|
||||
}
|
||||
)
|
||||
socket.recv()
|
||||
continue
|
@ -0,0 +1,337 @@
|
||||
# coding=utf-8
|
||||
# Copyright (c) 2020, NVIDIA CORPORATION. All rights reserved.
|
||||
#
|
||||
# Licensed under the Apache License, Version 2.0 (the "License");
|
||||
# you may not use this file except in compliance with the License.
|
||||
# You may obtain a copy of the License at
|
||||
#
|
||||
# http://www.apache.org/licenses/LICENSE-2.0
|
||||
#
|
||||
# Unless required by applicable law or agreed to in writing, software
|
||||
# distributed under the License is distributed on an "AS IS" BASIS,
|
||||
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
|
||||
# See the License for the specific language governing permissions and
|
||||
# limitations under the License.
|
||||
|
||||
"""Megatron initialization."""
|
||||
|
||||
import random
|
||||
import os
|
||||
import time
|
||||
import datetime
|
||||
|
||||
import numpy as np
|
||||
import torch
|
||||
|
||||
from codegeex.megatron import get_adlr_autoresume
|
||||
from codegeex.megatron import get_args
|
||||
from codegeex.megatron import get_tensorboard_writer
|
||||
from codegeex.megatron import mpu
|
||||
from codegeex.megatron.global_vars import set_global_variables
|
||||
from codegeex.megatron.mpu import (
|
||||
set_tensor_model_parallel_rank,
|
||||
set_tensor_model_parallel_world_size,
|
||||
)
|
||||
|
||||
try:
|
||||
import wandb
|
||||
except ImportError:
|
||||
wandb = None
|
||||
|
||||
import deepspeed
|
||||
|
||||
|
||||
def initialize_megatron(
|
||||
extra_args_provider=None,
|
||||
args_defaults={},
|
||||
ignore_unknown_args=False,
|
||||
allow_no_cuda=False,
|
||||
):
|
||||
"""Set global variables, initialize distributed, and
|
||||
set autoresume and random seeds.
|
||||
`allow_no_cuda` should not be set unless using megatron for cpu only
|
||||
data processing. In general this arg should not be set unless you know
|
||||
what you are doing.
|
||||
Returns a function to finalize distributed env initialization
|
||||
(optionally, only when args.lazy_mpu_init == True)
|
||||
"""
|
||||
if not allow_no_cuda:
|
||||
# Make sure cuda is available.
|
||||
assert torch.cuda.is_available(), "Megatron requires CUDA."
|
||||
|
||||
# Parse args, build tokenizer, and set adlr-autoresume,
|
||||
# tensorboard-writer, and timers.
|
||||
set_global_variables(
|
||||
extra_args_provider=extra_args_provider,
|
||||
args_defaults=args_defaults,
|
||||
ignore_unknown_args=ignore_unknown_args,
|
||||
)
|
||||
|
||||
# torch.distributed initialization
|
||||
def finish_mpu_init():
|
||||
args = get_args()
|
||||
# Pytorch distributed.
|
||||
_initialize_distributed()
|
||||
|
||||
# Random seeds for reproducibility.
|
||||
if args.rank == 0:
|
||||
print("> setting random seeds to {} ...".format(args.seed))
|
||||
_set_random_seed(args.seed)
|
||||
|
||||
args = get_args()
|
||||
if args.lazy_mpu_init:
|
||||
args.use_cpu_initialization = True
|
||||
# delayed initialization of DDP-related stuff
|
||||
# We only set basic DDP globals
|
||||
set_tensor_model_parallel_world_size(args.tensor_model_parallel_size)
|
||||
# and return function for external DDP manager
|
||||
# to call when it has DDP initialized
|
||||
set_tensor_model_parallel_rank(args.rank)
|
||||
return finish_mpu_init
|
||||
else:
|
||||
# Megatron's MPU is the master. Complete initialization right away.
|
||||
finish_mpu_init()
|
||||
|
||||
# Initialize memory buffers.
|
||||
_initialize_mem_buffs()
|
||||
|
||||
# Autoresume.
|
||||
_init_autoresume()
|
||||
|
||||
# No continuation function
|
||||
return None
|
||||
|
||||
|
||||
def _compile_dependencies():
|
||||
|
||||
args = get_args()
|
||||
|
||||
# =========================
|
||||
# Compile dataset C++ code.
|
||||
# =========================
|
||||
# TODO: move this to ninja
|
||||
if torch.distributed.get_rank() == 0:
|
||||
start_time = time.time()
|
||||
print("> compiling dataset index builder ...")
|
||||
# from megatron.data.dataset_utils import compile_helper
|
||||
# compile_helper()
|
||||
print(
|
||||
">>> done with dataset index builder. Compilation time: {:.3f} "
|
||||
"seconds".format(time.time() - start_time),
|
||||
flush=True,
|
||||
)
|
||||
|
||||
# Custom kernel constraints check.
|
||||
seq_len = args.seq_length
|
||||
attn_batch_size = (
|
||||
args.num_attention_heads / args.tensor_model_parallel_size
|
||||
) * args.micro_batch_size
|
||||
# Constraints on sequence length and attn_batch_size to enable warp based
|
||||
# optimization and upper triangular optimization (for causal mask)
|
||||
custom_kernel_constraint = (
|
||||
seq_len > 16
|
||||
and seq_len <= 2048
|
||||
and seq_len % 4 == 0
|
||||
and attn_batch_size % 4 == 0
|
||||
)
|
||||
# Print a warning.
|
||||
if not (
|
||||
(args.fp16 or args.bf16)
|
||||
and custom_kernel_constraint
|
||||
and args.masked_softmax_fusion
|
||||
):
|
||||
if args.rank == 0:
|
||||
print(
|
||||
"WARNING: constraints for invoking optimized"
|
||||
" fused softmax kernel are not met. We default"
|
||||
" back to unfused kernel invocations.",
|
||||
flush=True,
|
||||
)
|
||||
|
||||
# Always build on rank zero first.
|
||||
if torch.distributed.get_rank() == 0:
|
||||
start_time = time.time()
|
||||
print("> compiling and loading fused kernels ...", flush=True)
|
||||
torch.distributed.barrier()
|
||||
else:
|
||||
torch.distributed.barrier()
|
||||
# Simple barrier to make sure all ranks have passed the
|
||||
# compilation phase successfully before moving on to the
|
||||
# rest of the program. We think this might ensure that
|
||||
# the lock is released.
|
||||
torch.distributed.barrier()
|
||||
if torch.distributed.get_rank() == 0:
|
||||
print(
|
||||
">>> done with compiling and loading fused kernels. "
|
||||
"Compilation time: {:.3f} seconds".format(time.time() - start_time),
|
||||
flush=True,
|
||||
)
|
||||
|
||||
|
||||
def setup_deepspeed_random_and_activation_checkpointing(args):
|
||||
"""Optional DeepSpeed Activation Checkpointing features.
|
||||
Gives access to partition activations, contiguous memory optimizations
|
||||
and cpu checkpointing.
|
||||
Activation checkpoint requires keep track of the random states
|
||||
and setting the random seed for each MP process. Megatron uses
|
||||
mpu.get_cuda_rng_tracker and mpu.model_parallel_cuda_manual_seed
|
||||
for keeping track of the random states and setting the random seeds.
|
||||
Since they are used in places outside of activation checkpointing,
|
||||
we overwrite them to maintain consistency.
|
||||
This must be called before all the calls to mpu.model_parallel_cuda_manual_seed
|
||||
"""
|
||||
num_layers = args.num_layers // args.checkpoint_num_layers
|
||||
num_layers = (
|
||||
num_layers
|
||||
if args.num_layers % args.checkpoint_num_layers == 0
|
||||
else num_layers + 1
|
||||
)
|
||||
if args.split_transformers:
|
||||
num_layers *= 2
|
||||
|
||||
deepspeed.checkpointing.configure(
|
||||
mpu,
|
||||
partition_activations=args.partition_activations,
|
||||
contiguous_checkpointing=args.contigious_checkpointing,
|
||||
num_checkpoints=num_layers,
|
||||
checkpoint_in_cpu=args.checkpoint_in_cpu,
|
||||
synchronize=args.synchronize_each_layer,
|
||||
profile=args.profile_backward,
|
||||
)
|
||||
|
||||
mpu.checkpoint = deepspeed.checkpointing.checkpoint
|
||||
mpu.get_cuda_rng_tracker = deepspeed.checkpointing.get_cuda_rng_tracker
|
||||
mpu.model_parallel_cuda_manual_seed = (
|
||||
deepspeed.checkpointing.model_parallel_cuda_manual_seed
|
||||
)
|
||||
|
||||
|
||||
def _initialize_distributed():
|
||||
"""Initialize torch.distributed and mpu."""
|
||||
args = get_args()
|
||||
|
||||
device_count = torch.cuda.device_count()
|
||||
if torch.distributed.is_initialized():
|
||||
|
||||
if args.rank == 0:
|
||||
print(
|
||||
"torch distributed is already initialized, "
|
||||
"skipping initialization ...",
|
||||
flush=True,
|
||||
)
|
||||
args.rank = torch.distributed.get_rank()
|
||||
args.world_size = torch.distributed.get_world_size()
|
||||
|
||||
else:
|
||||
|
||||
if args.rank == 0:
|
||||
print("> initializing torch distributed ...", flush=True)
|
||||
# Manually set the device ids.
|
||||
if device_count > 0:
|
||||
device = args.rank % device_count
|
||||
if args.local_rank is not None:
|
||||
assert (
|
||||
args.local_rank == device
|
||||
), "expected local-rank to be the same as rank % device-count."
|
||||
else:
|
||||
args.local_rank = device
|
||||
if args.force_device is not None:
|
||||
print(
|
||||
f" > forcefully set the device to {args.force_device}, originally {device}"
|
||||
)
|
||||
device = args.force_device
|
||||
torch.cuda.set_device(device)
|
||||
# Call the init process
|
||||
init_method = "tcp://"
|
||||
master_ip = os.getenv("MASTER_ADDR", "localhost")
|
||||
master_port = os.getenv("MASTER_PORT", "6000")
|
||||
init_method += master_ip + ":" + master_port
|
||||
print(
|
||||
f" > (rank={args.rank}) initializing process group: "
|
||||
f"world_size={args.world_size} "
|
||||
f"backend={args.distributed_backend} "
|
||||
f"init_method={init_method}",
|
||||
flush=True,
|
||||
)
|
||||
timeout = datetime.timedelta(minutes=args.dist_timeout)
|
||||
torch.distributed.init_process_group(
|
||||
backend=args.distributed_backend,
|
||||
world_size=args.world_size,
|
||||
rank=args.rank,
|
||||
init_method=init_method,
|
||||
timeout=timeout
|
||||
)
|
||||
print(f" > (rank={args.rank}) process group initialized")
|
||||
|
||||
# Set the tensor model-parallel, pipeline model-parallel, and
|
||||
# data-parallel communicators.
|
||||
if device_count > 0:
|
||||
if mpu.model_parallel_is_initialized():
|
||||
print("model parallel is already initialized")
|
||||
else:
|
||||
mpu.initialize_model_parallel(
|
||||
args.tensor_model_parallel_size,
|
||||
args.pipeline_model_parallel_size,
|
||||
args.virtual_pipeline_model_parallel_size,
|
||||
)
|
||||
|
||||
if args.deepspeed and args.deepspeed_activation_checkpointing:
|
||||
setup_deepspeed_random_and_activation_checkpointing(args)
|
||||
|
||||
|
||||
def _init_autoresume():
|
||||
"""Set autoresume start time."""
|
||||
autoresume = get_adlr_autoresume()
|
||||
if autoresume:
|
||||
torch.distributed.barrier()
|
||||
autoresume.init()
|
||||
torch.distributed.barrier()
|
||||
|
||||
|
||||
def _set_random_seed(seed_):
|
||||
"""Set random seed for reproducability."""
|
||||
if seed_ is not None and seed_ > 0:
|
||||
# Ensure that different pipeline MP stages get different seeds.
|
||||
seed = seed_ + (100 * mpu.get_pipeline_model_parallel_rank())
|
||||
random.seed(seed)
|
||||
np.random.seed(seed)
|
||||
torch.manual_seed(seed)
|
||||
if torch.cuda.device_count() > 0:
|
||||
mpu.model_parallel_cuda_manual_seed(seed)
|
||||
else:
|
||||
raise ValueError("Seed ({}) should be a positive integer.".format(seed))
|
||||
|
||||
|
||||
def write_args_to_tensorboard():
|
||||
"""Write arguments to tensorboard."""
|
||||
args = get_args()
|
||||
writer = get_tensorboard_writer()
|
||||
if writer:
|
||||
for arg in vars(args):
|
||||
writer.add_text(arg, str(getattr(args, arg)), global_step=args.iteration)
|
||||
|
||||
|
||||
def initialize_wandb_experiment():
|
||||
"""Initialize wandb experiment."""
|
||||
assert wandb is not None, "Fail to import wandb"
|
||||
|
||||
args = get_args()
|
||||
config = args.__dict__
|
||||
|
||||
wandb_id_path = os.path.join(args.save, "wandb_id.txt")
|
||||
if os.path.exists(wandb_id_path):
|
||||
wandb_id = open(wandb_id_path, "r").read().strip()
|
||||
else:
|
||||
wandb_id = wandb.util.generate_id()
|
||||
open(wandb_id_path, "w").write(wandb_id)
|
||||
|
||||
wandb.init(id=wandb_id, project="megatron", config=config, resume="allow")
|
||||
|
||||
|
||||
def _initialize_mem_buffs():
|
||||
"""Initialize manually allocated static memory."""
|
||||
args = get_args()
|
||||
|
||||
# Initialize memory for checkpointed activations.
|
||||
if args.distribute_checkpointed_activations:
|
||||
mpu.init_checkpointed_activations_memory_buffer()
|
@ -0,0 +1,150 @@
|
||||
# coding=utf-8
|
||||
# Copyright (c) 2020, NVIDIA CORPORATION. All rights reserved.
|
||||
#
|
||||
# Licensed under the Apache License, Version 2.0 (the "License");
|
||||
# you may not use this file except in compliance with the License.
|
||||
# You may obtain a copy of the License at
|
||||
#
|
||||
# http://www.apache.org/licenses/LICENSE-2.0
|
||||
#
|
||||
# Unless required by applicable law or agreed to in writing, software
|
||||
# distributed under the License is distributed on an "AS IS" BASIS,
|
||||
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
|
||||
# See the License for the specific language governing permissions and
|
||||
# limitations under the License.
|
||||
|
||||
|
||||
import torch
|
||||
|
||||
|
||||
# A dictionary of all the memory buffers allocated.
|
||||
_MEM_BUFFS = dict()
|
||||
|
||||
|
||||
def allocate_mem_buff(name, numel, dtype, track_usage):
|
||||
"""Allocate a memory buffer."""
|
||||
assert name not in _MEM_BUFFS, "memory buffer {} already allocated.".format(name)
|
||||
_MEM_BUFFS[name] = MemoryBuffer(name, numel, dtype, track_usage)
|
||||
return _MEM_BUFFS[name]
|
||||
|
||||
|
||||
def get_mem_buff(name):
|
||||
"""Get the memory buffer."""
|
||||
return _MEM_BUFFS[name]
|
||||
|
||||
|
||||
class MemoryBuffer:
|
||||
"""Contiguous memory buffer.
|
||||
Allocate a contiguous memory of type `dtype` and size `numel`. It is
|
||||
used to reduce memory fragmentation.
|
||||
|
||||
Usage: After the allocation, the `_start` index is set tot the first
|
||||
index of the memory. A memory chunk starting from `_start` index
|
||||
can be `allocated` for an input tensor, with the elements of the
|
||||
tensor being coppied. The buffer can be reused by resetting the
|
||||
`_start` index.
|
||||
|
||||
"""
|
||||
|
||||
def __init__(self, name, numel, dtype, track_usage):
|
||||
if torch.distributed.get_rank() == 0:
|
||||
element_size = torch.tensor([], dtype=dtype).element_size()
|
||||
print(
|
||||
"> building the {} memory buffer with {} num elements "
|
||||
"and {} dtype ({:.1f} MB)...".format(
|
||||
name, numel, dtype, numel * element_size / 1024 / 1024
|
||||
),
|
||||
flush=True,
|
||||
)
|
||||
self.name = name
|
||||
self.numel = numel
|
||||
self.dtype = dtype
|
||||
self.data = torch.empty(
|
||||
self.numel,
|
||||
dtype=self.dtype,
|
||||
device=torch.cuda.current_device(),
|
||||
requires_grad=False,
|
||||
)
|
||||
|
||||
# Index tracking the start of the free memory.
|
||||
self._start = 0
|
||||
|
||||
# Values used for tracking usage.
|
||||
self.track_usage = track_usage
|
||||
if self.track_usage:
|
||||
self.in_use_value = 0.0
|
||||
self.total_value = 0.0
|
||||
|
||||
def reset(self):
|
||||
"""Reset the buffer start index to the beginning of the buffer."""
|
||||
self._start = 0
|
||||
|
||||
def is_in_use(self):
|
||||
"""Whether the current buffer hold on to any memory."""
|
||||
return self._start > 0
|
||||
|
||||
def numel_in_use(self):
|
||||
"""Return number of elements in use."""
|
||||
return self._start
|
||||
|
||||
def add(self, tensor):
|
||||
"""Allocate a chunk of memory from the buffer to tensor and copy
|
||||
the values."""
|
||||
assert (
|
||||
tensor.dtype == self.dtype
|
||||
), "Input tensor type {} different from buffer type {}".format(
|
||||
tensor.dtype, self.dtype
|
||||
)
|
||||
# Number of elements of the input tensor.
|
||||
tensor_numel = torch.numel(tensor)
|
||||
new_start = self._start + tensor_numel
|
||||
assert (
|
||||
new_start <= self.numel
|
||||
), "Not enough memory left in the buffer ({} > {})".format(
|
||||
tensor_numel, self.numel - self._start
|
||||
)
|
||||
# New tensor is a view into the memory.
|
||||
new_tensor = self.data[self._start : new_start]
|
||||
self._start = new_start
|
||||
new_tensor = new_tensor.view(tensor.shape)
|
||||
new_tensor.copy_(tensor)
|
||||
# Return a pointer to the new tensor.
|
||||
return new_tensor
|
||||
|
||||
def get_data(self):
|
||||
"""Return the data currently in use."""
|
||||
if self.track_usage:
|
||||
self.in_use_value += float(self._start)
|
||||
self.total_value += float(self.numel)
|
||||
return self.data[: self._start]
|
||||
|
||||
def print_average_usage(self):
|
||||
"""Print memory usage average over time. We would like this value
|
||||
to be as high as possible."""
|
||||
assert self.track_usage, "You need to enable track usage."
|
||||
if torch.distributed.get_rank() == 0:
|
||||
print(
|
||||
" > usage of {} memory buffer: {:.2f} %".format(
|
||||
self.name, self.in_use_value * 100.0 / self.total_value
|
||||
),
|
||||
flush=True,
|
||||
)
|
||||
|
||||
|
||||
class RingMemBuffer:
|
||||
"""A ring of memory buffers."""
|
||||
|
||||
def __init__(self, name, num_buffers, numel, dtype, track_usage):
|
||||
self.num_buffers = num_buffers
|
||||
self.buffers = [
|
||||
allocate_mem_buff(name + " {}".format(i), numel, dtype, track_usage)
|
||||
for i in range(num_buffers)
|
||||
]
|
||||
self._index = -1
|
||||
|
||||
def get_next_buffer(self):
|
||||
self._index += 1
|
||||
self._index = self._index % self.num_buffers
|
||||
buff = self.buffers[self._index]
|
||||
assert not buff.is_in_use(), "buffer is already in use."
|
||||
return buff
|
@ -0,0 +1,319 @@
|
||||
# coding=utf-8
|
||||
# Copyright (c) 2020, NVIDIA CORPORATION. All rights reserved.
|
||||
#
|
||||
# Licensed under the Apache License, Version 2.0 (the "License");
|
||||
# you may not use this file except in compliance with the License.
|
||||
# You may obtain a copy of the License at
|
||||
#
|
||||
# http://www.apache.org/licenses/LICENSE-2.0
|
||||
#
|
||||
# Unless required by applicable law or agreed to in writing, software
|
||||
# distributed under the License is distributed on an "AS IS" BASIS,
|
||||
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
|
||||
# See the License for the specific language governing permissions and
|
||||
# limitations under the License.
|
||||
|
||||
"""Merge model parallel partitions."""
|
||||
|
||||
import os
|
||||
import random
|
||||
import sys
|
||||
|
||||
import numpy as np
|
||||
import torch
|
||||
|
||||
sys.path.append(os.path.abspath(os.path.join(os.path.dirname(__file__),
|
||||
os.path.pardir)))
|
||||
|
||||
from codegeex.megatron import get_args
|
||||
from codegeex.megatron.model import CodeGeeXModel
|
||||
from codegeex.megatron.initialize import initialize_megatron
|
||||
from codegeex.megatron.checkpointing import ensure_directory_exists
|
||||
|
||||
|
||||
def get_change_ckpt_args(parser):
|
||||
"""Provide extra arguments required for merging."""
|
||||
group = parser.add_argument_group(title='Mindspore to megatron')
|
||||
group.add_argument(
|
||||
'--npy-ckpt-path',
|
||||
type=str,
|
||||
required=True,
|
||||
help='path of npy checkpoint.',
|
||||
)
|
||||
group.add_argument(
|
||||
'--save-ckpt-path',
|
||||
type=str,
|
||||
required=True,
|
||||
help='path to save checkpoint.',
|
||||
)
|
||||
|
||||
return parser
|
||||
|
||||
|
||||
def loadModelFromNp(sd, args):
|
||||
num_layers = args.num_layers
|
||||
npCkptPath = args.npy_ckpt_path
|
||||
languageModel = sd['module']['language_model']
|
||||
loadEmbeddingFromNp(npCkptPath, languageModel)
|
||||
transformer = sd['module']['language_model']['transformer']
|
||||
for layerID in range(num_layers):
|
||||
loadAttentionLayerFromNp(npCkptPath, transformer, layerID)
|
||||
loadQueryLayerFromNp(npCkptPath, transformer)
|
||||
|
||||
transformer['final_layernorm.weight'][:] = \
|
||||
torch.tensor(
|
||||
np.load(npCkptPath + f'backbone.layernorm.gamma.npy')
|
||||
).float()
|
||||
transformer['final_layernorm.bias'][:] = \
|
||||
torch.tensor(
|
||||
np.load(npCkptPath + f'backbone.layernorm.beta.npy')
|
||||
).float()
|
||||
|
||||
|
||||
def loadEmbeddingFromNp(npCkptPath, languageModel, vocabSize=52224):
|
||||
word_embedding_np = \
|
||||
np.load(npCkptPath + 'backbone.embedding.word_embedding.embedding_table.npy')
|
||||
languageModel['embedding']['word_embeddings']['weight'][:vocabSize, :] = \
|
||||
torch.tensor(word_embedding_np).float()
|
||||
|
||||
position_embeddings_np = \
|
||||
np.load(npCkptPath + 'backbone.embedding.position_embedding.embedding_table.npy')
|
||||
languageModel['embedding']['position_embeddings']['weight'][:, :] = \
|
||||
torch.tensor(position_embeddings_np).float()
|
||||
|
||||
topQueryEmbedding_np = \
|
||||
np.load(npCkptPath + 'backbone.top_query_embedding.embedding_table.npy')
|
||||
languageModel['topQueryEmbedding']['top_query_embeddings']['weight'][:, :] = \
|
||||
torch.tensor(topQueryEmbedding_np).float()
|
||||
|
||||
|
||||
def loadAttentionLayerFromNp(npCkptPath, transformer, layerID):
|
||||
attention_dense1_weight_np = \
|
||||
np.load(npCkptPath + f'backbone.blocks.{layerID}.attention.dense1.weight.npy')
|
||||
attention_dense2_weight_np = \
|
||||
np.load(npCkptPath + f'backbone.blocks.{layerID}.attention.dense2.weight.npy')
|
||||
attention_dense3_weight_np = \
|
||||
np.load(npCkptPath + f'backbone.blocks.{layerID}.attention.dense3.weight.npy')
|
||||
|
||||
attention_dense1_bias_np = \
|
||||
np.load(npCkptPath + f'backbone.blocks.{layerID}.attention.dense1.bias.npy')
|
||||
attention_dense2_bias_np = \
|
||||
np.load(npCkptPath + f'backbone.blocks.{layerID}.attention.dense2.bias.npy')
|
||||
attention_dense3_bias_np = \
|
||||
np.load(npCkptPath + f'backbone.blocks.{layerID}.attention.dense3.bias.npy')
|
||||
|
||||
query_weight = transformer[f'layers.{layerID}.attention.query.weight']
|
||||
key_weight = transformer[f'layers.{layerID}.attention.key.weight']
|
||||
value_weight = transformer[f'layers.{layerID}.attention.value.weight']
|
||||
|
||||
query_weight[:] = torch.tensor(attention_dense1_weight_np).float()
|
||||
key_weight[:] = torch.tensor(attention_dense2_weight_np).float()
|
||||
value_weight[:] = torch.tensor(attention_dense3_weight_np).float()
|
||||
|
||||
query_bias = transformer[f'layers.{layerID}.attention.query.bias']
|
||||
key_bias = transformer[f'layers.{layerID}.attention.key.bias']
|
||||
value_bias = transformer[f'layers.{layerID}.attention.value.bias']
|
||||
|
||||
query_bias[:] = torch.tensor(attention_dense1_bias_np).float()
|
||||
key_bias[:] = torch.tensor(attention_dense2_bias_np).float()
|
||||
value_bias[:] = torch.tensor(attention_dense3_bias_np).float()
|
||||
|
||||
att_dense_weight = transformer[f'layers.{layerID}.attention.dense.weight']
|
||||
att_dense_weight[:, :] = \
|
||||
torch.tensor(
|
||||
np.load(npCkptPath + f'backbone.blocks.{layerID}.attention.projection.weight.npy').transpose()
|
||||
).float()
|
||||
att_dense_bias = transformer[f'layers.{layerID}.attention.dense.bias']
|
||||
att_dense_bias[:] = \
|
||||
torch.tensor(
|
||||
np.load(npCkptPath + f'backbone.blocks.{layerID}.attention.projection.bias.npy')
|
||||
).float()
|
||||
|
||||
mlp_dense_h_to_4h_weight = transformer[f'layers.{layerID}.mlp.dense_h_to_4h.weight']
|
||||
mlp_dense_h_to_4h_weight[:, :] = \
|
||||
torch.tensor(
|
||||
np.load(npCkptPath + f'backbone.blocks.{layerID}.output.mapping.weight.npy').transpose()
|
||||
).float()
|
||||
mlp_dense_h_to_4h_bias = transformer[f'layers.{layerID}.mlp.dense_h_to_4h.bias']
|
||||
mlp_dense_h_to_4h_bias[:] = \
|
||||
torch.tensor(
|
||||
np.load(npCkptPath + f'backbone.blocks.{layerID}.output.mapping.bias.npy')
|
||||
).float()
|
||||
|
||||
mlp_dense_4h_to_h_weight = transformer[f'layers.{layerID}.mlp.dense_4h_to_h.weight']
|
||||
mlp_dense_4h_to_h_weight[:, :] = \
|
||||
torch.tensor(
|
||||
np.load(npCkptPath + f'backbone.blocks.{layerID}.output.projection.weight.npy').transpose()
|
||||
).float()
|
||||
mlp_dense_4h_to_h_bias = transformer[f'layers.{layerID}.mlp.dense_4h_to_h.bias']
|
||||
mlp_dense_4h_to_h_bias[:] = \
|
||||
torch.tensor(
|
||||
np.load(npCkptPath + f'backbone.blocks.{layerID}.output.projection.bias.npy')
|
||||
).float()
|
||||
|
||||
input_layernorm_weight = transformer[f'layers.{layerID}.input_layernorm.weight']
|
||||
input_layernorm_weight[:] = \
|
||||
torch.tensor(
|
||||
np.load(npCkptPath + f'backbone.blocks.{layerID}.layernorm1.gamma.npy')
|
||||
).float()
|
||||
input_layernorm_bias = transformer[f'layers.{layerID}.input_layernorm.bias']
|
||||
input_layernorm_bias[:] = \
|
||||
torch.tensor(
|
||||
np.load(npCkptPath + f'backbone.blocks.{layerID}.layernorm1.beta.npy')
|
||||
).float()
|
||||
|
||||
post_attention_layernorm_weight = transformer[f'layers.{layerID}.post_attention_layernorm.weight']
|
||||
post_attention_layernorm_weight[:] = \
|
||||
torch.tensor(
|
||||
np.load(npCkptPath + f'backbone.blocks.{layerID}.layernorm2.gamma.npy')
|
||||
).float()
|
||||
post_attention_layernorm_bias = transformer[f'layers.{layerID}.post_attention_layernorm.bias']
|
||||
post_attention_layernorm_bias[:] = \
|
||||
torch.tensor(
|
||||
np.load(npCkptPath + f'backbone.blocks.{layerID}.layernorm2.beta.npy')
|
||||
).float()
|
||||
|
||||
input_layernorm_weight = transformer[f'layers.{layerID}.input_layernorm.weight']
|
||||
input_layernorm_weight[:] = \
|
||||
torch.tensor(
|
||||
np.load(npCkptPath + f'backbone.blocks.{layerID}.layernorm1.gamma.npy')
|
||||
).float()
|
||||
input_layernorm_bias = transformer[f'layers.{layerID}.input_layernorm.bias']
|
||||
input_layernorm_bias[:] = \
|
||||
torch.tensor(
|
||||
np.load(npCkptPath + f'backbone.blocks.{layerID}.layernorm1.beta.npy')
|
||||
).float()
|
||||
|
||||
post_attention_layernorm_weight = transformer[f'layers.{layerID}.post_attention_layernorm.weight']
|
||||
post_attention_layernorm_weight[:] = \
|
||||
torch.tensor(
|
||||
np.load(npCkptPath + f'backbone.blocks.{layerID}.layernorm2.gamma.npy')
|
||||
).float()
|
||||
post_attention_layernorm_bias = transformer[f'layers.{layerID}.post_attention_layernorm.bias']
|
||||
post_attention_layernorm_bias[:] = \
|
||||
torch.tensor(
|
||||
np.load(npCkptPath + f'backbone.blocks.{layerID}.layernorm2.beta.npy')
|
||||
).float()
|
||||
|
||||
|
||||
def loadQueryLayerFromNp(npCkptPath, transformer):
|
||||
attention_dense1_weight_np = \
|
||||
np.load(npCkptPath + f'backbone.top_query_layer.attention.dense1.weight.npy')
|
||||
attention_dense1_bias_np = \
|
||||
np.load(npCkptPath + f'backbone.top_query_layer.attention.dense1.bias.npy')
|
||||
attention_dense2_weight_np = \
|
||||
np.load(npCkptPath + f'backbone.top_query_layer.attention.dense2.weight.npy')
|
||||
attention_dense2_bias_np = \
|
||||
np.load(npCkptPath + f'backbone.top_query_layer.attention.dense2.bias.npy')
|
||||
attention_dense3_weight_np = \
|
||||
np.load(npCkptPath + f'backbone.top_query_layer.attention.dense3.weight.npy')
|
||||
attention_dense3_bias_np = \
|
||||
np.load(npCkptPath + f'backbone.top_query_layer.attention.dense3.bias.npy')
|
||||
|
||||
query_weight = transformer[f'topQueryLayer.attention.query.weight']
|
||||
query_weight[:, :] = \
|
||||
torch.tensor(attention_dense1_weight_np).float()
|
||||
query_bias = transformer[f'topQueryLayer.attention.query.bias']
|
||||
query_bias[:] = torch.tensor(attention_dense1_bias_np).float()
|
||||
|
||||
key_weight = transformer[f'topQueryLayer.attention.key.weight']
|
||||
key_weight[:, :] = \
|
||||
torch.tensor(attention_dense2_weight_np).float()
|
||||
key_bias = transformer[f'topQueryLayer.attention.key.bias']
|
||||
key_bias[:] = torch.tensor(attention_dense2_bias_np).float()
|
||||
|
||||
value_weight = transformer[f'topQueryLayer.attention.value.weight']
|
||||
value_weight[:, :] = \
|
||||
torch.tensor(attention_dense3_weight_np).float()
|
||||
value_bias = transformer[f'topQueryLayer.attention.value.bias']
|
||||
value_bias[:] = torch.tensor(attention_dense3_bias_np).float()
|
||||
|
||||
att_dense_weight = transformer[f'topQueryLayer.attention.dense.weight']
|
||||
att_dense_weight[:, :] = \
|
||||
torch.tensor(
|
||||
np.load(npCkptPath + f'backbone.top_query_layer.attention.projection.weight.npy')
|
||||
.transpose()
|
||||
).float()
|
||||
att_dense_bias = transformer[f'topQueryLayer.attention.dense.bias']
|
||||
att_dense_bias[:] = \
|
||||
torch.tensor(
|
||||
np.load(npCkptPath + f'backbone.top_query_layer.attention.projection.bias.npy')
|
||||
).float()
|
||||
|
||||
mlp_dense_h_to_4h_weight = transformer[f'topQueryLayer.mlp.dense_h_to_4h.weight']
|
||||
mlp_dense_h_to_4h_weight[:, :] = \
|
||||
torch.tensor(
|
||||
np.load(npCkptPath + f'backbone.top_query_layer.output.mapping.weight.npy')
|
||||
.transpose()
|
||||
).float()
|
||||
mlp_dense_h_to_4h_bias = transformer[f'topQueryLayer.mlp.dense_h_to_4h.bias']
|
||||
mlp_dense_h_to_4h_bias[:] = \
|
||||
torch.tensor(
|
||||
np.load(npCkptPath + f'backbone.top_query_layer.output.mapping.bias.npy')
|
||||
).float()
|
||||
|
||||
mlp_dense_4h_to_h_weight = transformer[f'topQueryLayer.mlp.dense_4h_to_h.weight']
|
||||
mlp_dense_4h_to_h_weight[:, :] = \
|
||||
torch.tensor(
|
||||
np.load(npCkptPath + f'backbone.top_query_layer.output.projection.weight.npy')
|
||||
.transpose()
|
||||
).float()
|
||||
mlp_dense_4h_to_h_bias = transformer[f'topQueryLayer.mlp.dense_4h_to_h.bias']
|
||||
mlp_dense_4h_to_h_bias[:] = \
|
||||
torch.tensor(
|
||||
np.load(npCkptPath + f'backbone.top_query_layer.output.projection.bias.npy')
|
||||
).float()
|
||||
|
||||
input_layernorm_weight = transformer[f'topQueryLayer.input_layernorm.weight']
|
||||
input_layernorm_weight[:] = \
|
||||
torch.tensor(
|
||||
np.load(npCkptPath + f'backbone.top_query_layer.layernorm1.gamma.npy')
|
||||
).float()
|
||||
input_layernorm_bias = transformer[f'topQueryLayer.input_layernorm.bias']
|
||||
input_layernorm_bias[:] = \
|
||||
torch.tensor(
|
||||
np.load(npCkptPath + f'backbone.top_query_layer.layernorm1.beta.npy')
|
||||
).float()
|
||||
|
||||
post_attention_layernorm_weight = transformer[f'topQueryLayer.post_attention_layernorm.weight']
|
||||
post_attention_layernorm_weight[:] = \
|
||||
torch.tensor(
|
||||
np.load(npCkptPath + f'backbone.top_query_layer.layernorm2.gamma.npy')
|
||||
).float()
|
||||
post_attention_layernorm_bias = transformer[f'topQueryLayer.post_attention_layernorm.bias']
|
||||
post_attention_layernorm_bias[:] = \
|
||||
torch.tensor(
|
||||
np.load(npCkptPath + f'backbone.top_query_layer.layernorm2.beta.npy')
|
||||
).float()
|
||||
|
||||
|
||||
def main():
|
||||
os.environ["MASTER_ADDR"] = "localhost"
|
||||
os.environ["MASTER_PORT"] = str(random.randint(10000, 20000))
|
||||
|
||||
initialize_megatron(
|
||||
extra_args_provider=get_change_ckpt_args,
|
||||
args_defaults={
|
||||
"tokenizer_type": "GPT2BPETokenizer",
|
||||
"no_load_rng" : True,
|
||||
"no_load_optim" : True,
|
||||
},
|
||||
)
|
||||
|
||||
args = get_args()
|
||||
model = CodeGeeXModel()
|
||||
# print(dir(model))
|
||||
print(model.state_dict)
|
||||
|
||||
# Save the model.
|
||||
sd = {}
|
||||
sd['module'] = model.state_dict_for_save_checkpoint()
|
||||
ensure_directory_exists(args.save_ckpt_path)
|
||||
loadModelFromNp(sd, args)
|
||||
print('> saving merged model to {}'.format(args.save_ckpt_path))
|
||||
torch.save(sd, args.save_ckpt_path)
|
||||
print(f"Converted checkpoint saved in {args.save_ckpt_path}.")
|
||||
|
||||
|
||||
if __name__ == '__main__':
|
||||
main()
|
@ -0,0 +1,19 @@
|
||||
# coding=utf-8
|
||||
# Copyright (c) 2020, NVIDIA CORPORATION. All rights reserved.
|
||||
#
|
||||
# Licensed under the Apache License, Version 2.0 (the "License");
|
||||
# you may not use this file except in compliance with the License.
|
||||
# You may obtain a copy of the License at
|
||||
#
|
||||
# http://www.apache.org/licenses/LICENSE-2.0
|
||||
#
|
||||
# Unless required by applicable law or agreed to in writing, software
|
||||
# distributed under the License is distributed on an "AS IS" BASIS,
|
||||
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
|
||||
# See the License for the specific language governing permissions and
|
||||
# limitations under the License.
|
||||
|
||||
from .distributed import DistributedDataParallel
|
||||
from .codegeex_model import CodeGeeXModel
|
||||
from .language_model import get_language_model
|
||||
from .module import Float16Module
|
@ -0,0 +1,109 @@
|
||||
# coding=utf-8
|
||||
# Copyright (c) 2020, NVIDIA CORPORATION. All rights reserved.
|
||||
#
|
||||
# Licensed under the Apache License, Version 2.0 (the "License");
|
||||
# you may not use this file except in compliance with the License.
|
||||
# You may obtain a copy of the License at
|
||||
#
|
||||
# http://www.apache.org/licenses/LICENSE-2.0
|
||||
#
|
||||
# Unless required by applicable law or agreed to in writing, software
|
||||
# distributed under the License is distributed on an "AS IS" BASIS,
|
||||
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
|
||||
# See the License for the specific language governing permissions and
|
||||
# limitations under the License.
|
||||
|
||||
import torch
|
||||
from codegeex.megatron import get_args
|
||||
from codegeex.megatron import mpu
|
||||
from .module import MegatronModule
|
||||
|
||||
from .language_model import parallel_lm_logits
|
||||
from .language_model import get_language_model
|
||||
from .utils import init_method_normal
|
||||
from .utils import scaled_init_method_normal
|
||||
|
||||
|
||||
class CodeGeeXModel(MegatronModule):
|
||||
"""Code Generative Model for Multilingual Program Synthesis."""
|
||||
|
||||
def __init__(self, num_tokentypes=0, parallel_output=False):
|
||||
super(CodeGeeXModel, self).__init__()
|
||||
args = get_args()
|
||||
|
||||
self.parallel_output = parallel_output
|
||||
self.fp16_lm_cross_entropy = args.fp16_lm_cross_entropy
|
||||
|
||||
self.language_model, self._language_model_key = get_language_model(
|
||||
num_tokentypes=num_tokentypes,
|
||||
add_pooler=False,
|
||||
init_method=init_method_normal(args.init_method_std),
|
||||
scaled_init_method=scaled_init_method_normal(args.init_method_std,
|
||||
args.num_layers))
|
||||
|
||||
def forward(
|
||||
self,
|
||||
input_ids,
|
||||
position_ids,
|
||||
attention_mask,
|
||||
labels=None,
|
||||
tokentype_ids=None,
|
||||
layer_past=None,
|
||||
get_key_value=False,
|
||||
forward_method_parallel_output=None,
|
||||
prompt_length=None,
|
||||
context_length=None,
|
||||
):
|
||||
|
||||
# Language model.
|
||||
lm_output = self.language_model(input_ids,
|
||||
position_ids,
|
||||
attention_mask,
|
||||
tokentype_ids=tokentype_ids,
|
||||
layer_past=layer_past,
|
||||
get_key_value=get_key_value,
|
||||
prompt_length=prompt_length,
|
||||
context_length=context_length)
|
||||
|
||||
if get_key_value:
|
||||
lm_output, presents = lm_output
|
||||
|
||||
lm_output = torch.add(lm_output, 0)
|
||||
# Output.
|
||||
parallel_output = self.parallel_output
|
||||
if forward_method_parallel_output is not None:
|
||||
parallel_output = forward_method_parallel_output
|
||||
output = parallel_lm_logits(
|
||||
lm_output,
|
||||
self.language_model.embedding.word_embeddings.weight,
|
||||
parallel_output)
|
||||
|
||||
if get_key_value:
|
||||
output = [output, presents]
|
||||
|
||||
if labels is None:
|
||||
return output
|
||||
else:
|
||||
if self.fp16_lm_cross_entropy:
|
||||
assert output.dtype == torch.half
|
||||
loss = mpu.vocab_parallel_cross_entropy(output, labels)
|
||||
else:
|
||||
loss = mpu.vocab_parallel_cross_entropy(output.float(), labels)
|
||||
|
||||
return loss
|
||||
|
||||
def state_dict_for_save_checkpoint(self, destination=None, prefix='',
|
||||
keep_vars=False):
|
||||
|
||||
state_dict_ = {}
|
||||
state_dict_[self._language_model_key] \
|
||||
= self.language_model.state_dict_for_save_checkpoint(
|
||||
destination, prefix, keep_vars)
|
||||
return state_dict_
|
||||
|
||||
def load_state_dict(self, state_dict, strict=True):
|
||||
"""Customized load."""
|
||||
|
||||
if self._language_model_key in state_dict:
|
||||
state_dict = state_dict[self._language_model_key]
|
||||
self.language_model.load_state_dict(state_dict, strict=strict)
|
@ -0,0 +1,215 @@
|
||||
# coding=utf-8
|
||||
# Copyright (c) 2020, NVIDIA CORPORATION. All rights reserved.
|
||||
#
|
||||
# Licensed under the Apache License, Version 2.0 (the "License");
|
||||
# you may not use this file except in compliance with the License.
|
||||
# You may obtain a copy of the License at
|
||||
#
|
||||
# http://www.apache.org/licenses/LICENSE-2.0
|
||||
#
|
||||
# Unless required by applicable law or agreed to in writing, software
|
||||
# distributed under the License is distributed on an "AS IS" BASIS,
|
||||
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
|
||||
# See the License for the specific language governing permissions and
|
||||
# limitations under the License.
|
||||
|
||||
from abc import ABC
|
||||
from abc import abstractmethod
|
||||
|
||||
import torch
|
||||
from torch._utils import _flatten_dense_tensors, _unflatten_dense_tensors
|
||||
|
||||
from codegeex.megatron import mpu
|
||||
from .module import MegatronModule
|
||||
|
||||
|
||||
class MemoryBuffer:
|
||||
def __init__(self, numel, dtype):
|
||||
self.numel = numel
|
||||
self.dtype = dtype
|
||||
self.data = torch.zeros(
|
||||
self.numel,
|
||||
dtype=self.dtype,
|
||||
device=torch.cuda.current_device(),
|
||||
requires_grad=False,
|
||||
)
|
||||
|
||||
def zero(self):
|
||||
"""Reset the buffer to zero."""
|
||||
self.data.zero_()
|
||||
|
||||
def get(self, shape, start_index):
|
||||
"""Return a tensor with the input `shape` as a view into the
|
||||
1-D data starting at `start_index`."""
|
||||
end_index = start_index + shape.numel()
|
||||
assert end_index <= self.numel, "requested tensor is out of the buffer range."
|
||||
buffer_tensor = self.data[start_index:end_index]
|
||||
buffer_tensor = buffer_tensor.view(shape)
|
||||
return buffer_tensor
|
||||
|
||||
|
||||
class DistributedDataParallelBase(MegatronModule, ABC):
|
||||
"""Abstract class for DDP."""
|
||||
|
||||
def __init__(self, module):
|
||||
super(DistributedDataParallelBase, self).__init__()
|
||||
# Keep a pointer to the model.
|
||||
self.module = module
|
||||
|
||||
@abstractmethod
|
||||
def allreduce_gradients(self):
|
||||
pass
|
||||
|
||||
def forward(self, *inputs, **kwargs):
|
||||
return self.module(*inputs, **kwargs)
|
||||
|
||||
def state_dict(self, destination=None, prefix="", keep_vars=False):
|
||||
return self.module.state_dict(destination, prefix, keep_vars)
|
||||
|
||||
def state_dict_for_save_checkpoint(
|
||||
self, destination=None, prefix="", keep_vars=False
|
||||
):
|
||||
return self.module.state_dict_for_save_checkpoint(
|
||||
destination, prefix, keep_vars
|
||||
)
|
||||
|
||||
def load_state_dict(self, state_dict, strict=True):
|
||||
self.module.load_state_dict(state_dict, strict=strict)
|
||||
|
||||
|
||||
class DistributedDataParallel(DistributedDataParallelBase):
|
||||
"""DDP with contiguous buffers options to storre and accumulate gradients.
|
||||
This class:
|
||||
- has the potential to reduce memory fragmentation.
|
||||
- provides the option to do the gradient accumulation
|
||||
in a type other than the params type (for example fp32)
|
||||
|
||||
Arguments:
|
||||
module: input model.
|
||||
accumulate_allreduce_grads_in_fp32: if true do the gradient accumulation
|
||||
and the gradient all-reduce all in in float32. If this option is
|
||||
true, we require `use_contiguous_buffers` to be true too.
|
||||
use_contiguous_buffers: if true, use a contiguous buffer to store the
|
||||
gradients.
|
||||
"""
|
||||
|
||||
def __init__(
|
||||
self, module, accumulate_allreduce_grads_in_fp32, use_contiguous_buffers
|
||||
):
|
||||
|
||||
super(DistributedDataParallel, self).__init__(module)
|
||||
|
||||
self.accumulate_allreduce_grads_in_fp32 = accumulate_allreduce_grads_in_fp32
|
||||
self.use_contiguous_buffers = use_contiguous_buffers
|
||||
# If we are using fp32-accumulate-allreduce explicitly
|
||||
# this means we need main grads in a continous buffer.
|
||||
if self.accumulate_allreduce_grads_in_fp32:
|
||||
assert self.use_contiguous_buffers
|
||||
|
||||
# ===================================
|
||||
# Rest of this part applies only to
|
||||
# the case we use continuous buffers.
|
||||
# ===================================
|
||||
self._grad_buffers = None
|
||||
if self.use_contiguous_buffers:
|
||||
self._grad_buffers = {}
|
||||
|
||||
# Simple function to define buffer type.
|
||||
def _get_buffer_type(param):
|
||||
return (
|
||||
torch.float
|
||||
if self.accumulate_allreduce_grads_in_fp32
|
||||
else param.dtype
|
||||
)
|
||||
|
||||
# First calculate total number of elements per type.
|
||||
type_num_elements = {}
|
||||
for param in self.module.parameters():
|
||||
if param.requires_grad:
|
||||
dtype = _get_buffer_type(param)
|
||||
type_num_elements[dtype] = (
|
||||
type_num_elements.get(dtype, 0) + param.data.nelement()
|
||||
)
|
||||
|
||||
# Allocate the buffer.
|
||||
for dtype, num_elements in type_num_elements.items():
|
||||
self._grad_buffers[dtype] = MemoryBuffer(num_elements, dtype)
|
||||
|
||||
# Assume the back prop order is reverse the params order,
|
||||
# store the start index for the gradients.
|
||||
for param in self.module.parameters():
|
||||
if param.requires_grad:
|
||||
dtype = _get_buffer_type(param)
|
||||
type_num_elements[dtype] -= param.data.nelement()
|
||||
param.main_grad = self._grad_buffers[dtype].get(
|
||||
param.data.shape, type_num_elements[dtype]
|
||||
)
|
||||
|
||||
# Backward hook.
|
||||
# Accumalation function for the gradients. We need
|
||||
# to store them so they don't go out of scope.
|
||||
self.grad_accs = []
|
||||
# Loop over all the parameters in the model.
|
||||
for param in self.module.parameters():
|
||||
if param.requires_grad:
|
||||
# Expand so we get access to grad_fn.
|
||||
param_tmp = param.expand_as(param)
|
||||
# Get the gradient accumulator functtion.
|
||||
grad_acc = param_tmp.grad_fn.next_functions[0][0]
|
||||
grad_acc.register_hook(self._make_param_hook(param))
|
||||
self.grad_accs.append(grad_acc)
|
||||
|
||||
def _make_param_hook(self, param):
|
||||
"""Create the all-reduce hook for backprop."""
|
||||
|
||||
# Hook used for back-prop.
|
||||
def param_hook(*unused):
|
||||
# Add the gradient to the buffer.
|
||||
if param.grad.data is not None:
|
||||
param.main_grad.add_(param.grad.data)
|
||||
# Now we can deallocate grad memory.
|
||||
param.grad = None
|
||||
|
||||
return param_hook
|
||||
|
||||
def zero_grad_buffer(self):
|
||||
"""Set the grad buffer data to zero. Needs to be called at the
|
||||
begining of each iteration."""
|
||||
assert self._grad_buffers is not None, "buffers are not initialized."
|
||||
for _, buffer_ in self._grad_buffers.items():
|
||||
buffer_.zero()
|
||||
|
||||
def allreduce_gradients(self):
|
||||
"""Reduce gradients across data parallel ranks."""
|
||||
# If we have buffers, simply reduce the data in the buffer.
|
||||
if self._grad_buffers is not None:
|
||||
for _, buffer_ in self._grad_buffers.items():
|
||||
buffer_.data /= mpu.get_data_parallel_world_size()
|
||||
torch.distributed.all_reduce(
|
||||
buffer_.data, group=mpu.get_data_parallel_group()
|
||||
)
|
||||
else:
|
||||
# Otherwise, bucketize and all-reduce
|
||||
buckets = {}
|
||||
# Pack the buckets.
|
||||
for param in self.module.parameters():
|
||||
if param.requires_grad and param.grad is not None:
|
||||
tp = param.data.type()
|
||||
if tp not in buckets:
|
||||
buckets[tp] = []
|
||||
buckets[tp].append(param)
|
||||
param.main_grad = param.grad
|
||||
|
||||
# For each bucket, all-reduce and copy all-reduced grads.
|
||||
for tp in buckets:
|
||||
bucket = buckets[tp]
|
||||
grads = [param.grad.data for param in bucket]
|
||||
coalesced = _flatten_dense_tensors(grads)
|
||||
coalesced /= mpu.get_data_parallel_world_size()
|
||||
torch.distributed.all_reduce(
|
||||
coalesced, group=mpu.get_data_parallel_group()
|
||||
)
|
||||
for buf, synced in zip(
|
||||
grads, _unflatten_dense_tensors(coalesced, grads)
|
||||
):
|
||||
buf.copy_(synced)
|
@ -0,0 +1,503 @@
|
||||
# coding=utf-8
|
||||
# Copyright (c) 2020, NVIDIA CORPORATION. All rights reserved.
|
||||
#
|
||||
# Licensed under the Apache License, Version 2.0 (the "License");
|
||||
# you may not use this file except in compliance with the License.
|
||||
# You may obtain a copy of the License at
|
||||
#
|
||||
# http://www.apache.org/licenses/LICENSE-2.0
|
||||
#
|
||||
# Unless required by applicable law or agreed to in writing, software
|
||||
# distributed under the License is distributed on an "AS IS" BASIS,
|
||||
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
|
||||
# See the License for the specific language governing permissions and
|
||||
# limitations under the License.
|
||||
|
||||
"""Transformer based language model."""
|
||||
|
||||
import torch
|
||||
import torch.nn.functional as F
|
||||
|
||||
from codegeex.megatron import get_args
|
||||
from codegeex.megatron import mpu
|
||||
from codegeex.megatron.model.module import MegatronModule
|
||||
from codegeex.megatron.model.transformer import ParallelTransformer
|
||||
from codegeex.megatron.model.utils import init_method_normal, scaled_init_method_normal
|
||||
|
||||
|
||||
def parallel_lm_logits(input_, word_embeddings_weight, parallel_output, bias=None):
|
||||
"""LM logits using word embedding weights."""
|
||||
# Parallel logits.
|
||||
input_parallel = mpu.copy_to_tensor_model_parallel_region(input_)
|
||||
# Matrix multiply.
|
||||
if bias is None:
|
||||
logits_parallel = F.linear(input_parallel, word_embeddings_weight.half())
|
||||
else:
|
||||
logits_parallel = F.linear(input_parallel, word_embeddings_weight.half(), bias)
|
||||
# Gather if needed.
|
||||
if parallel_output:
|
||||
return logits_parallel
|
||||
|
||||
return mpu.gather_from_tensor_model_parallel_region(logits_parallel)
|
||||
|
||||
|
||||
def get_language_model(
|
||||
num_tokentypes,
|
||||
add_pooler,
|
||||
init_method=None,
|
||||
scaled_init_method=None,
|
||||
):
|
||||
"""Build language model and return along with the key to save."""
|
||||
args = get_args()
|
||||
|
||||
if init_method is None:
|
||||
init_method = init_method_normal(args.init_method_std)
|
||||
|
||||
if scaled_init_method is None:
|
||||
scaled_init_method = scaled_init_method_normal(args.init_method_std, args.num_layers)
|
||||
|
||||
# Language model.
|
||||
language_model = TransformerLanguageModel(
|
||||
init_method=init_method,
|
||||
output_layer_init_method=scaled_init_method,
|
||||
num_tokentypes=num_tokentypes,
|
||||
add_pooler=add_pooler)
|
||||
# key used for checkpoints.
|
||||
language_model_key = 'language_model'
|
||||
|
||||
return language_model, language_model_key
|
||||
|
||||
|
||||
class Embedding(MegatronModule):
|
||||
"""Language model embeddings.
|
||||
|
||||
Arguments:
|
||||
hidden_size: hidden size
|
||||
vocab_size: vocabulary size
|
||||
max_sequence_length: maximum size of sequence. This
|
||||
is used for positional embedding
|
||||
embedding_dropout_prob: dropout probability for embeddings
|
||||
init_method: weight initialization method
|
||||
num_tokentypes: size of the token-type embeddings. 0 value
|
||||
will ignore this embedding
|
||||
"""
|
||||
|
||||
def __init__(self,
|
||||
hidden_size,
|
||||
vocab_size,
|
||||
max_sequence_length,
|
||||
embedding_dropout_prob,
|
||||
init_method,
|
||||
num_tokentypes=0):
|
||||
super(Embedding, self).__init__()
|
||||
|
||||
self.hidden_size = hidden_size
|
||||
self.init_method = init_method
|
||||
self.num_tokentypes = num_tokentypes
|
||||
|
||||
# Word embeddings (parallel).
|
||||
self.word_embeddings = mpu.VocabParallelEmbedding(
|
||||
vocab_size, self.hidden_size, init_method=self.init_method)
|
||||
self._word_embeddings_key = 'word_embeddings'
|
||||
self.vocab_size = vocab_size
|
||||
|
||||
# Position embedding (serial).
|
||||
self.position_embeddings = torch.nn.Embedding(
|
||||
max_sequence_length, self.hidden_size)
|
||||
self.position_embeddings = self.position_embeddings.half()
|
||||
self._position_embeddings_key = 'position_embeddings'
|
||||
# Initialize the position embeddings.
|
||||
self.init_method(self.position_embeddings.weight)
|
||||
|
||||
# Token type embedding.
|
||||
# Add this as an optional field that can be added through
|
||||
# method call so we can load a pretrain model without
|
||||
# token types and add them as needed.
|
||||
self._tokentype_embeddings_key = 'tokentype_embeddings'
|
||||
if self.num_tokentypes > 0:
|
||||
self.tokentype_embeddings = torch.nn.Embedding(self.num_tokentypes,
|
||||
self.hidden_size)
|
||||
# Initialize the token-type embeddings.
|
||||
self.init_method(self.tokentype_embeddings.weight)
|
||||
else:
|
||||
self.tokentype_embeddings = None
|
||||
|
||||
# Embeddings dropout
|
||||
self.embedding_dropout = torch.nn.Dropout(embedding_dropout_prob)
|
||||
|
||||
def add_tokentype_embeddings(self, num_tokentypes):
|
||||
"""Add token-type embedding. This function is provided so we can add
|
||||
token-type embeddings in case the pretrained model does not have it.
|
||||
This allows us to load the model normally and then add this embedding.
|
||||
"""
|
||||
if self.tokentype_embeddings is not None:
|
||||
raise Exception('tokentype embeddings is already initialized')
|
||||
if torch.distributed.get_rank() == 0:
|
||||
print('adding embedding for {} tokentypes'.format(num_tokentypes),
|
||||
flush=True)
|
||||
self.num_tokentypes = num_tokentypes
|
||||
self.tokentype_embeddings = torch.nn.Embedding(num_tokentypes,
|
||||
self.hidden_size)
|
||||
# Initialize the token-type embeddings.
|
||||
self.init_method(self.tokentype_embeddings.weight)
|
||||
|
||||
def forward(self, input_ids, position_ids, tokentype_ids=None):
|
||||
# Embeddings.
|
||||
words_embeddings = self.word_embeddings(input_ids)
|
||||
position_embeddings = self.position_embeddings(position_ids)
|
||||
embeddings = words_embeddings + position_embeddings
|
||||
if tokentype_ids is not None:
|
||||
assert self.tokentype_embeddings is not None
|
||||
embeddings = embeddings + self.tokentype_embeddings(tokentype_ids)
|
||||
else:
|
||||
assert self.tokentype_embeddings is None
|
||||
|
||||
# Dropout.
|
||||
embeddings = self.embedding_dropout(embeddings)
|
||||
|
||||
return embeddings
|
||||
|
||||
def state_dict_for_save_checkpoint(self, destination=None, prefix='',
|
||||
keep_vars=False):
|
||||
"""For easy load."""
|
||||
|
||||
state_dict_ = {}
|
||||
state_dict_[self._word_embeddings_key] \
|
||||
= self.word_embeddings.state_dict(destination, prefix, keep_vars)
|
||||
state_dict_[self._position_embeddings_key] \
|
||||
= self.position_embeddings.state_dict(
|
||||
destination, prefix, keep_vars)
|
||||
if self.num_tokentypes > 0:
|
||||
state_dict_[self._tokentype_embeddings_key] \
|
||||
= self.tokentype_embeddings.state_dict(
|
||||
destination, prefix, keep_vars)
|
||||
|
||||
return state_dict_
|
||||
|
||||
def load_state_dict(self, state_dict, strict=True):
|
||||
"""Customized load."""
|
||||
|
||||
# Word embedding.
|
||||
if self._word_embeddings_key in state_dict:
|
||||
state_dict_ = state_dict[self._word_embeddings_key]
|
||||
else:
|
||||
# for backward compatibility.
|
||||
state_dict_ = {}
|
||||
for key in state_dict.keys():
|
||||
if 'word_embeddings' in key:
|
||||
state_dict_[key.split('word_embeddings.')[1]] \
|
||||
= state_dict[key]
|
||||
state_dict_["weight"] = state_dict_["weight"][:self.vocab_size]
|
||||
self.word_embeddings.load_state_dict(state_dict_, strict=strict)
|
||||
|
||||
# Position embedding.
|
||||
if self._position_embeddings_key in state_dict:
|
||||
state_dict_ = state_dict[self._position_embeddings_key]
|
||||
else:
|
||||
# for backward compatibility.
|
||||
state_dict_ = {}
|
||||
for key in state_dict.keys():
|
||||
if 'position_embeddings' in key:
|
||||
state_dict_[key.split('position_embeddings.')[1]] \
|
||||
= state_dict[key]
|
||||
self.position_embeddings.load_state_dict(state_dict_, strict=strict)
|
||||
|
||||
# Tokentype embedding.
|
||||
if self.num_tokentypes > 0:
|
||||
state_dict_ = {}
|
||||
if self._tokentype_embeddings_key in state_dict:
|
||||
state_dict_ = state_dict[self._tokentype_embeddings_key]
|
||||
else:
|
||||
# for backward compatibility.
|
||||
for key in state_dict.keys():
|
||||
if 'tokentype_embeddings' in key:
|
||||
state_dict_[key.split('tokentype_embeddings.')[1]] \
|
||||
= state_dict[key]
|
||||
if len(state_dict_.keys()) > 0:
|
||||
self.tokentype_embeddings.load_state_dict(state_dict_,
|
||||
strict=strict)
|
||||
else:
|
||||
print('***WARNING*** expected tokentype embeddings in the '
|
||||
'checkpoint but could not find it', flush=True)
|
||||
|
||||
|
||||
class QueryEmbedding(MegatronModule):
|
||||
"""Language model embeddings.
|
||||
|
||||
Arguments:
|
||||
hidden_size: hidden size
|
||||
vocab_size: vocabulary size
|
||||
max_sequence_length: maximum size of sequence. This
|
||||
is used for positional embedding
|
||||
embedding_dropout_prob: dropout probability for embeddings
|
||||
init_method: weight initialization method
|
||||
num_tokentypes: size of the token-type embeddings. 0 value
|
||||
will ignore this embedding
|
||||
"""
|
||||
|
||||
def __init__(self,
|
||||
hidden_size,
|
||||
vocab_size,
|
||||
max_sequence_length,
|
||||
embedding_dropout_prob,
|
||||
init_method,
|
||||
num_tokentypes=0):
|
||||
super(QueryEmbedding, self).__init__()
|
||||
|
||||
self.hidden_size = hidden_size
|
||||
self.init_method = init_method
|
||||
self.num_tokentypes = num_tokentypes
|
||||
|
||||
# Top query position embedding (serial).
|
||||
self.top_query_embeddings = torch.nn.Embedding(
|
||||
max_sequence_length, self.hidden_size)
|
||||
self.top_query_embeddings = self.top_query_embeddings.half()
|
||||
self._top_query_embeddings_key = 'top_query_embeddings'
|
||||
# Initialize the top query position embeddings.
|
||||
self.init_method(self.top_query_embeddings.weight)
|
||||
|
||||
# Token type embedding.
|
||||
# Add this as an optional field that can be added through
|
||||
# method call so we can load a pretrain model without
|
||||
# token types and add them as needed.
|
||||
self._tokentype_embeddings_key = 'tokentype_embeddings'
|
||||
if self.num_tokentypes > 0:
|
||||
self.tokentype_embeddings = torch.nn.Embedding(self.num_tokentypes,
|
||||
self.hidden_size)
|
||||
# Initialize the token-type embeddings.
|
||||
self.init_method(self.tokentype_embeddings.weight)
|
||||
else:
|
||||
self.tokentype_embeddings = None
|
||||
|
||||
# Embeddings dropout
|
||||
self.embedding_dropout = torch.nn.Dropout(embedding_dropout_prob)
|
||||
|
||||
def add_tokentype_embeddings(self, num_tokentypes):
|
||||
"""Add token-type embedding. This function is provided so we can add
|
||||
token-type embeddings in case the pretrained model does not have it.
|
||||
This allows us to load the model normally and then add this embedding.
|
||||
"""
|
||||
if self.tokentype_embeddings is not None:
|
||||
raise Exception('tokentype embeddings is already initialized')
|
||||
if torch.distributed.get_rank() == 0:
|
||||
print('adding embedding for {} tokentypes'.format(num_tokentypes),
|
||||
flush=True)
|
||||
self.num_tokentypes = num_tokentypes
|
||||
self.tokentype_embeddings = torch.nn.Embedding(num_tokentypes,
|
||||
self.hidden_size)
|
||||
# Initialize the token-type embeddings.
|
||||
self.init_method(self.tokentype_embeddings.weight)
|
||||
|
||||
def forward(self, position_ids, tokentype_ids=None):
|
||||
# Embeddings.
|
||||
|
||||
embeddings = self.top_query_embeddings(position_ids)
|
||||
if tokentype_ids is not None:
|
||||
assert self.tokentype_embeddings is not None
|
||||
embeddings = embeddings + self.tokentype_embeddings(tokentype_ids)
|
||||
else:
|
||||
assert self.tokentype_embeddings is None
|
||||
|
||||
# Dropout.
|
||||
embeddings = self.embedding_dropout(embeddings)
|
||||
|
||||
return embeddings
|
||||
|
||||
def state_dict_for_save_checkpoint(self, destination=None, prefix='',
|
||||
keep_vars=False):
|
||||
"""For easy load."""
|
||||
|
||||
state_dict_ = {}
|
||||
state_dict_[self._top_query_embeddings_key] \
|
||||
= self.top_query_embeddings.state_dict(
|
||||
destination, prefix, keep_vars)
|
||||
if self.num_tokentypes > 0:
|
||||
state_dict_[self._tokentype_embeddings_key] \
|
||||
= self.tokentype_embeddings.state_dict(
|
||||
destination, prefix, keep_vars)
|
||||
|
||||
return state_dict_
|
||||
|
||||
def load_state_dict(self, state_dict, strict=True):
|
||||
"""Customized load."""
|
||||
|
||||
# Position embedding.
|
||||
if self._top_query_embeddings_key in state_dict:
|
||||
state_dict_ = state_dict[self._top_query_embeddings_key]
|
||||
else:
|
||||
# for backward compatibility.
|
||||
state_dict_ = {}
|
||||
for key in state_dict.keys():
|
||||
if 'top_query_embeddings' in key:
|
||||
state_dict_[key.split('top_query_embeddings.')[1]] \
|
||||
= state_dict[key]
|
||||
self.top_query_embeddings.load_state_dict(state_dict_, strict=strict)
|
||||
|
||||
# Tokentype embedding.
|
||||
if self.num_tokentypes > 0:
|
||||
state_dict_ = {}
|
||||
if self._tokentype_embeddings_key in state_dict:
|
||||
state_dict_ = state_dict[self._tokentype_embeddings_key]
|
||||
else:
|
||||
# for backward compatibility.
|
||||
for key in state_dict.keys():
|
||||
if 'tokentype_embeddings' in key:
|
||||
state_dict_[key.split('tokentype_embeddings.')[1]] \
|
||||
= state_dict[key]
|
||||
if len(state_dict_.keys()) > 0:
|
||||
self.tokentype_embeddings.load_state_dict(state_dict_,
|
||||
strict=strict)
|
||||
else:
|
||||
print('***WARNING*** expected tokentype embeddings in the '
|
||||
'checkpoint but could not find it', flush=True)
|
||||
|
||||
|
||||
class TransformerLanguageModel(MegatronModule):
|
||||
"""Transformer language model.
|
||||
|
||||
Arguments:
|
||||
transformer_hparams: transformer hyperparameters
|
||||
attention_mask_func: a function that takes `unmaksed-attention-scores`
|
||||
with size [b, np, s, s] and an `attention-mask` and will apply
|
||||
the masking. The function should return a masked score of the
|
||||
same size [b, np, s, s].
|
||||
masked-attention-scores = attention_mask_func(
|
||||
unmaksed-attention-scores, attention-mask)
|
||||
vocab_size: vocabulary size
|
||||
max_sequence_length: maximum size of sequence. This
|
||||
is used for positional embedding
|
||||
embedding_dropout_prob: dropout probability for embeddings
|
||||
num_tokentypes: size of the token-type embeddings. 0 value
|
||||
will ignore this embedding
|
||||
"""
|
||||
|
||||
def __init__(self,
|
||||
init_method,
|
||||
output_layer_init_method,
|
||||
num_tokentypes=0,
|
||||
add_pooler=False):
|
||||
super(TransformerLanguageModel, self).__init__()
|
||||
args = get_args()
|
||||
|
||||
self.hidden_size = args.hidden_size
|
||||
self.num_tokentypes = num_tokentypes
|
||||
self.init_method = init_method
|
||||
self.add_pooler = add_pooler
|
||||
|
||||
# Embeddings
|
||||
self.embedding = Embedding(self.hidden_size,
|
||||
args.padded_vocab_size,
|
||||
args.max_position_embeddings,
|
||||
args.hidden_dropout,
|
||||
self.init_method,
|
||||
self.num_tokentypes)
|
||||
self._embedding_key = 'embedding'
|
||||
|
||||
# Query embeddings
|
||||
self.topQueryEmbedding = QueryEmbedding(self.hidden_size,
|
||||
args.padded_vocab_size,
|
||||
args.max_position_embeddings,
|
||||
args.hidden_dropout,
|
||||
self.init_method,
|
||||
self.num_tokentypes)
|
||||
self._topQueryEmbedding_key = 'topQueryEmbedding'
|
||||
|
||||
# Transformer
|
||||
self.transformer = ParallelTransformer(
|
||||
self.init_method,
|
||||
output_layer_init_method)
|
||||
self._transformer_key = 'transformer'
|
||||
|
||||
def forward(
|
||||
self,
|
||||
input_ids,
|
||||
position_ids,
|
||||
attention_mask,
|
||||
tokentype_ids=None,
|
||||
layer_past=None,
|
||||
get_key_value=False,
|
||||
pooling_sequence_index=0,
|
||||
prompt_length=None,
|
||||
context_length=None,
|
||||
):
|
||||
|
||||
# Embeddings.
|
||||
embedding_output = self.embedding(input_ids, position_ids,
|
||||
tokentype_ids=tokentype_ids)
|
||||
query_position_ids = position_ids
|
||||
queryEmbedding_out = self.topQueryEmbedding(query_position_ids,
|
||||
tokentype_ids=tokentype_ids)
|
||||
|
||||
# Transformer.
|
||||
transformer_output = self.transformer(embedding_output,
|
||||
queryEmbedding_out,
|
||||
attention_mask,
|
||||
layer_past=layer_past,
|
||||
get_key_value=get_key_value,
|
||||
prompt_length=prompt_length,
|
||||
context_length=context_length, )
|
||||
|
||||
return transformer_output
|
||||
|
||||
def state_dict_for_save_checkpoint(self, destination=None, prefix='',
|
||||
keep_vars=False):
|
||||
"""For easy load."""
|
||||
|
||||
state_dict_ = {}
|
||||
state_dict_[self._embedding_key] \
|
||||
= self.embedding.state_dict_for_save_checkpoint(
|
||||
destination, prefix, keep_vars)
|
||||
state_dict_[self._topQueryEmbedding_key] \
|
||||
= self.topQueryEmbedding.state_dict_for_save_checkpoint(
|
||||
destination, prefix, keep_vars)
|
||||
state_dict_[self._transformer_key] \
|
||||
= self.transformer.state_dict_for_save_checkpoint(
|
||||
destination, prefix, keep_vars)
|
||||
if self.add_pooler:
|
||||
state_dict_[self._pooler_key] \
|
||||
= self.pooler.state_dict_for_save_checkpoint(
|
||||
destination, prefix, keep_vars)
|
||||
|
||||
return state_dict_
|
||||
|
||||
def load_state_dict(self, state_dict, strict=True):
|
||||
"""Customized load."""
|
||||
|
||||
# Embedding.
|
||||
if self._embedding_key in state_dict:
|
||||
state_dict_ = state_dict[self._embedding_key]
|
||||
else:
|
||||
# for backward compatibility.
|
||||
state_dict_ = {}
|
||||
for key in state_dict.keys():
|
||||
if '_embeddings' in key:
|
||||
state_dict_[key] = state_dict[key]
|
||||
self.embedding.load_state_dict(state_dict_, strict=strict)
|
||||
|
||||
if self._topQueryEmbedding_key in state_dict:
|
||||
state_dict_ = state_dict[self._topQueryEmbedding_key]
|
||||
else:
|
||||
# for backward compatibility.
|
||||
state_dict_ = {}
|
||||
for key in state_dict.keys():
|
||||
if '_embeddings' in key:
|
||||
state_dict_[key] = state_dict[key]
|
||||
self.topQueryEmbedding.load_state_dict(state_dict_, strict=strict)
|
||||
|
||||
# Transformer.
|
||||
if self._transformer_key in state_dict:
|
||||
state_dict_ = state_dict[self._transformer_key]
|
||||
else:
|
||||
# for backward compatibility.
|
||||
state_dict_ = {}
|
||||
for key in state_dict.keys():
|
||||
if 'transformer.' in key:
|
||||
state_dict_[key.split('transformer.')[1]] = state_dict[key]
|
||||
self.transformer.load_state_dict(state_dict_, strict=strict)
|
||||
|
||||
# Pooler.
|
||||
if self.add_pooler:
|
||||
assert 'pooler' in state_dict, \
|
||||
'could not find data for pooler in the checkpoint'
|
||||
self.pooler.load_state_dict(state_dict[self._pooler_key],
|
||||
strict=strict)
|
@ -0,0 +1,199 @@
|
||||
# coding=utf-8
|
||||
# Copyright (c) 2020, NVIDIA CORPORATION. All rights reserved.
|
||||
#
|
||||
# Licensed under the Apache License, Version 2.0 (the "License");
|
||||
# you may not use this file except in compliance with the License.
|
||||
# You may obtain a copy of the License at
|
||||
#
|
||||
# http://www.apache.org/licenses/LICENSE-2.0
|
||||
#
|
||||
# Unless required by applicable law or agreed to in writing, software
|
||||
# distributed under the License is distributed on an "AS IS" BASIS,
|
||||
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
|
||||
# See the License for the specific language governing permissions and
|
||||
# limitations under the License.
|
||||
|
||||
"""Megatron Module"""
|
||||
|
||||
import torch
|
||||
from torch.autograd import Variable
|
||||
from torch.nn.parameter import Parameter
|
||||
|
||||
from codegeex.megatron import get_args
|
||||
from codegeex.megatron import mpu
|
||||
|
||||
|
||||
_FLOAT_TYPES = (torch.FloatTensor, torch.cuda.FloatTensor)
|
||||
_HALF_TYPES = (torch.HalfTensor, torch.cuda.HalfTensor)
|
||||
_BF16_TYPES = (torch.BFloat16Tensor, torch.cuda.BFloat16Tensor)
|
||||
|
||||
|
||||
def param_is_not_shared(param):
|
||||
return not hasattr(param, "shared") or not param.shared
|
||||
|
||||
|
||||
class MegatronModule(torch.nn.Module):
|
||||
"""Megatron specific extensions of torch Module with support
|
||||
for pipelining."""
|
||||
|
||||
def __init__(self, share_word_embeddings=True):
|
||||
super(MegatronModule, self).__init__()
|
||||
self.share_word_embeddings = share_word_embeddings
|
||||
|
||||
def state_dict_for_save_checkpoint(
|
||||
self, destination=None, prefix="", keep_vars=False
|
||||
):
|
||||
"""Use this function to override the state dict for
|
||||
saving checkpoints."""
|
||||
return self.state_dict(destination, prefix, keep_vars)
|
||||
|
||||
def word_embeddings_weight(self):
|
||||
if mpu.is_pipeline_first_stage(ignore_virtual=True):
|
||||
return self.language_model.embedding.word_embeddings.weight
|
||||
if mpu.is_pipeline_last_stage(ignore_virtual=True):
|
||||
if not self.share_word_embeddings:
|
||||
raise Exception(
|
||||
"word_embeddings_weight() called for last "
|
||||
"stage, but share_word_embeddings is false"
|
||||
)
|
||||
return self.word_embeddings.weight
|
||||
raise Exception(
|
||||
"word_embeddings_weight() should be " "called for first and last stage only"
|
||||
)
|
||||
|
||||
def initialize_word_embeddings(self, init_method_normal):
|
||||
args = get_args()
|
||||
if not self.share_word_embeddings:
|
||||
raise Exception(
|
||||
"initialize_word_embeddings() was called but "
|
||||
"share_word_embeddings is false"
|
||||
)
|
||||
|
||||
# This function just initializes the word embeddings in the final stage
|
||||
# when we are using pipeline parallelism. If we aren't using pipeline
|
||||
# parallelism there is nothing to do.
|
||||
if args.pipeline_model_parallel_size == 1:
|
||||
return
|
||||
|
||||
# Parameters are shared between the word embeddings layer, and the
|
||||
# heads at the end of the model. In a pipelined setup with more than
|
||||
# one stage, the initial embedding layer and the head are on different
|
||||
# workers, so we do the following:
|
||||
# 1. Create a second copy of word_embeddings on the last stage, with
|
||||
# initial parameters of 0.0.
|
||||
# 2. Do an all-reduce between the first and last stage to ensure that
|
||||
# the two copies of word_embeddings start off with the same
|
||||
# parameter values.
|
||||
# 3. In the training loop, before an all-reduce between the grads of
|
||||
# the two word_embeddings layers to ensure that every applied weight
|
||||
# update is the same on both stages.
|
||||
if mpu.is_pipeline_last_stage():
|
||||
assert not mpu.is_pipeline_first_stage()
|
||||
self._word_embeddings_for_head_key = "word_embeddings_for_head"
|
||||
# set word_embeddings weights to 0 here, then copy first
|
||||
# stage's weights using all_reduce below.
|
||||
self.word_embeddings = mpu.VocabParallelEmbedding(
|
||||
args.padded_vocab_size,
|
||||
args.hidden_size,
|
||||
init_method=init_method_normal(args.init_method_std),
|
||||
)
|
||||
self.word_embeddings.weight.data.fill_(0)
|
||||
self.word_embeddings.weight.shared = True
|
||||
|
||||
# Ensure that first and last stages have the same initial parameter
|
||||
# values.
|
||||
if torch.distributed.is_initialized():
|
||||
if mpu.is_pipeline_first_stage() or mpu.is_pipeline_last_stage():
|
||||
torch.distributed.all_reduce(
|
||||
self.word_embeddings_weight().data, group=mpu.get_embedding_group()
|
||||
)
|
||||
else:
|
||||
print(
|
||||
"WARNING! Distributed processes aren't initialized, so "
|
||||
"word embeddings in the last layer are not initialized. "
|
||||
"If you are just manipulating a model this is fine, but "
|
||||
"this needs to be handled manually. If you are training "
|
||||
"something is definitely wrong."
|
||||
)
|
||||
|
||||
|
||||
def conversion_helper(val, conversion):
|
||||
"""Apply conversion to val. Recursively apply conversion if `val`
|
||||
#is a nested tuple/list structure."""
|
||||
if not isinstance(val, (tuple, list)):
|
||||
return conversion(val)
|
||||
rtn = [conversion_helper(v, conversion) for v in val]
|
||||
if isinstance(val, tuple):
|
||||
rtn = tuple(rtn)
|
||||
return rtn
|
||||
|
||||
|
||||
def fp32_to_float16(val, float16_convertor):
|
||||
"""Convert fp32 `val` to fp16/bf16"""
|
||||
|
||||
def half_conversion(val):
|
||||
val_typecheck = val
|
||||
if isinstance(val_typecheck, (Parameter, Variable)):
|
||||
val_typecheck = val.data
|
||||
if isinstance(val_typecheck, _FLOAT_TYPES):
|
||||
val = float16_convertor(val)
|
||||
return val
|
||||
|
||||
return conversion_helper(val, half_conversion)
|
||||
|
||||
|
||||
def float16_to_fp32(val):
|
||||
"""Convert fp16/bf16 `val` to fp32"""
|
||||
|
||||
def float_conversion(val):
|
||||
val_typecheck = val
|
||||
if isinstance(val_typecheck, (Parameter, Variable)):
|
||||
val_typecheck = val.data
|
||||
if isinstance(val_typecheck, (_BF16_TYPES, _HALF_TYPES)):
|
||||
val = val.float()
|
||||
return val
|
||||
|
||||
return conversion_helper(val, float_conversion)
|
||||
|
||||
|
||||
class Float16Module(MegatronModule):
|
||||
def __init__(self, module, args):
|
||||
super(Float16Module, self).__init__()
|
||||
|
||||
if args.fp16:
|
||||
self.add_module("module", module.half())
|
||||
|
||||
def float16_convertor(val):
|
||||
return val.half()
|
||||
|
||||
elif args.bf16:
|
||||
self.add_module("module", module.bfloat16())
|
||||
|
||||
def float16_convertor(val):
|
||||
return val.bfloat16()
|
||||
|
||||
else:
|
||||
raise Exception("should not be here")
|
||||
|
||||
self.float16_convertor = float16_convertor
|
||||
|
||||
def forward(self, *inputs, **kwargs):
|
||||
if mpu.is_pipeline_first_stage():
|
||||
inputs = fp32_to_float16(inputs, self.float16_convertor)
|
||||
outputs = self.module(*inputs, **kwargs)
|
||||
if mpu.is_pipeline_last_stage():
|
||||
outputs = float16_to_fp32(outputs)
|
||||
return outputs
|
||||
|
||||
def state_dict(self, destination=None, prefix="", keep_vars=False):
|
||||
return self.module.state_dict(destination, prefix, keep_vars)
|
||||
|
||||
def state_dict_for_save_checkpoint(
|
||||
self, destination=None, prefix="", keep_vars=False
|
||||
):
|
||||
return self.module.state_dict_for_save_checkpoint(
|
||||
destination, prefix, keep_vars
|
||||
)
|
||||
|
||||
def load_state_dict(self, state_dict, strict=True):
|
||||
self.module.load_state_dict(state_dict, strict=strict)
|
@ -0,0 +1,970 @@
|
||||
# coding=utf-8
|
||||
# Copyright (c) 2020, NVIDIA CORPORATION. All rights reserved.
|
||||
#
|
||||
# Licensed under the Apache License, Version 2.0 (the "License");
|
||||
# you may not use this file except in compliance with the License.
|
||||
# You may obtain a copy of the License at
|
||||
#
|
||||
# http://www.apache.org/licenses/LICENSE-2.0
|
||||
#
|
||||
# Unless required by applicable law or agreed to in writing, software
|
||||
# distributed under the License is distributed on an "AS IS" BASIS,
|
||||
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
|
||||
# See the License for the specific language governing permissions and
|
||||
# limitations under the License.
|
||||
|
||||
"""Transformer."""
|
||||
|
||||
import math
|
||||
import torch
|
||||
from torch.nn import LayerNorm
|
||||
|
||||
from codegeex.megatron import get_args
|
||||
from codegeex.megatron import mpu
|
||||
from codegeex.megatron.model.module import MegatronModule
|
||||
from codegeex.megatron.model.utils import fast_gelu
|
||||
|
||||
# flags required to enable jit fusion kernels
|
||||
torch._C._jit_set_profiling_mode(False)
|
||||
torch._C._jit_set_profiling_executor(False)
|
||||
torch._C._jit_override_can_fuse_on_cpu(True)
|
||||
torch._C._jit_override_can_fuse_on_gpu(True)
|
||||
|
||||
""" We use the following notation throughout this file:
|
||||
h: hidden size
|
||||
n: number of attention heads
|
||||
p: number of model parallel partitions
|
||||
np: n/p
|
||||
hp: h/p
|
||||
hn: h/n
|
||||
b: batch size
|
||||
s: sequence length
|
||||
l: number of layers
|
||||
Transformer takes input of size [s, b, h] and returns a
|
||||
tensor of the same size. We use the following arguments:
|
||||
hyperparameters: transformer hyperparameters
|
||||
attention_mask_func: a function that takes `unmaksed-attention-scores`
|
||||
with size [b, np, s, s] and an `attention-mask` and will apply
|
||||
the masking. The function should return a masked score of the
|
||||
same size [b, np, s, s].
|
||||
masked-attention-scores = attention_mask_func(
|
||||
unmaksed-attention-scores, attention-mask)
|
||||
"""
|
||||
|
||||
|
||||
class ParallelMLP(MegatronModule):
|
||||
"""MLP.
|
||||
|
||||
MLP will take the input with h hidden state, project it to 4*h
|
||||
hidden dimension, perform nonlinear transformation, and project the
|
||||
state back into h hidden dimension. At the end, dropout is also
|
||||
applied.
|
||||
"""
|
||||
|
||||
def __init__(self, init_method, output_layer_init_method):
|
||||
super(ParallelMLP, self).__init__()
|
||||
args = get_args()
|
||||
|
||||
# Project to 4h.
|
||||
self.dense_h_to_4h = mpu.ColumnParallelLinear(
|
||||
args.hidden_size,
|
||||
4 * args.hidden_size,
|
||||
gather_output=False,
|
||||
init_method=init_method,
|
||||
# skip_bias_add=True,
|
||||
)
|
||||
|
||||
self.activation_func = fast_gelu
|
||||
|
||||
# Project back to h.
|
||||
self.dense_4h_to_h = mpu.RowParallelLinear(
|
||||
4 * args.hidden_size,
|
||||
args.hidden_size,
|
||||
input_is_parallel=False,
|
||||
init_method=output_layer_init_method,
|
||||
# skip_bias_add=True,
|
||||
)
|
||||
|
||||
def forward(self, hidden_states):
|
||||
# [s, b, 4hp]
|
||||
intermediate_parallel, _ = self.dense_h_to_4h(hidden_states)
|
||||
intermediate_parallel = self.activation_func(intermediate_parallel)
|
||||
# [s, b, h]
|
||||
output, output_bias = self.dense_4h_to_h(intermediate_parallel)
|
||||
|
||||
return output, output_bias
|
||||
|
||||
|
||||
class ParallelSelfAttention(MegatronModule):
|
||||
"""Parallel self-attention layer abstract class.
|
||||
|
||||
Self-attention layer takes input with size [b, s, h]
|
||||
and returns output of the same size.
|
||||
"""
|
||||
|
||||
def __init__(self, init_method,
|
||||
output_layer_init_method, layer_number):
|
||||
super(ParallelSelfAttention, self).__init__()
|
||||
args = get_args()
|
||||
self.fp16 = args.fp16
|
||||
self.attention_softmax_in_fp32 = args.attention_softmax_in_fp32
|
||||
self.layer_number = max(1, layer_number)
|
||||
|
||||
# Per attention head and per partition values.
|
||||
world_size = mpu.get_model_parallel_world_size()
|
||||
self.hidden_size_per_partition = mpu.divide(args.hidden_size,
|
||||
world_size)
|
||||
self.hidden_size_per_attention_head = mpu.divide(
|
||||
args.hidden_size, args.num_attention_heads)
|
||||
self.num_attention_heads_per_partition = mpu.divide(
|
||||
args.num_attention_heads, world_size)
|
||||
if hasattr(args, 'attention_upweight'):
|
||||
self.attention_upweight = args.attention_upweight
|
||||
else:
|
||||
self.attention_upweight = None
|
||||
# Strided linear layer.
|
||||
self.query = mpu.ColumnParallelLinear(
|
||||
args.hidden_size,
|
||||
args.hidden_size,
|
||||
gather_output=False,
|
||||
init_method=init_method)
|
||||
self.key = mpu.ColumnParallelLinear(
|
||||
args.hidden_size,
|
||||
args.hidden_size,
|
||||
gather_output=False,
|
||||
init_method=init_method)
|
||||
self.value = mpu.ColumnParallelLinear(
|
||||
args.hidden_size,
|
||||
args.hidden_size,
|
||||
gather_output=False,
|
||||
init_method=init_method)
|
||||
|
||||
self.norm_factor = math.sqrt(self.hidden_size_per_attention_head)
|
||||
self.softmax = torch.nn.Softmax(dim=-1)
|
||||
|
||||
# Dropout. Note that for a single iteration, this layer will generate
|
||||
# different outputs on different number of parallel partitions but
|
||||
# on average it should not be partition dependent.
|
||||
self.attention_dropout = torch.nn.Dropout(args.attention_dropout)
|
||||
|
||||
# Output.
|
||||
self.dense = mpu.RowParallelLinear(
|
||||
args.hidden_size,
|
||||
args.hidden_size,
|
||||
input_is_parallel=False,
|
||||
init_method=output_layer_init_method,
|
||||
skip_bias_add=True)
|
||||
|
||||
def forward(
|
||||
self,
|
||||
hidden_states,
|
||||
attention_mask,
|
||||
layer_past=None,
|
||||
get_key_value=False,
|
||||
prompt_length=None,
|
||||
context_length=None,
|
||||
):
|
||||
# hidden_states: [sq, b, h]
|
||||
|
||||
# =====================
|
||||
# Query, Key, and Value
|
||||
# =====================
|
||||
|
||||
query_layer, _ = self.query(hidden_states)
|
||||
key_layer, _ = self.key(hidden_states)
|
||||
value_layer, _ = self.value(hidden_states)
|
||||
|
||||
new_query_layer_shape = query_layer.size()[:-1] + \
|
||||
(self.num_attention_heads_per_partition,
|
||||
self.hidden_size_per_attention_head)
|
||||
query_layer = query_layer.view(*new_query_layer_shape)
|
||||
|
||||
new_query_layer_shape = key_layer.size()[:-1] + \
|
||||
(self.num_attention_heads_per_partition,
|
||||
self.hidden_size_per_attention_head)
|
||||
key_layer = key_layer.view(*new_query_layer_shape)
|
||||
|
||||
new_query_layer_shape = value_layer.size()[:-1] + \
|
||||
(self.num_attention_heads_per_partition,
|
||||
self.hidden_size_per_attention_head)
|
||||
value_layer = value_layer.view(*new_query_layer_shape)
|
||||
|
||||
# ==================================
|
||||
# Adjust key and value for inference
|
||||
# ==================================
|
||||
|
||||
if layer_past is not None:
|
||||
past_key, past_value = layer_past
|
||||
key_layer = torch.cat((past_key.type_as(key_layer),
|
||||
key_layer), dim=0)
|
||||
value_layer = torch.cat((past_value.type_as(value_layer),
|
||||
value_layer), dim=0)
|
||||
if get_key_value:
|
||||
present = (key_layer, value_layer)
|
||||
|
||||
# ===================================
|
||||
# Raw attention scores. [b, np, sq, sk]
|
||||
# ===================================
|
||||
|
||||
# [b, np, sq, sk]
|
||||
output_size = (query_layer.size(1),
|
||||
query_layer.size(2),
|
||||
query_layer.size(0),
|
||||
key_layer.size(0))
|
||||
|
||||
# [sq, b, np, hn] -> [sq, b * np, hn]
|
||||
query_layer = query_layer.contiguous().view(output_size[2], output_size[0] * output_size[1], -1)
|
||||
key_layer = key_layer.contiguous().view(output_size[3], output_size[0] * output_size[1], -1)
|
||||
|
||||
# Raw attention scores. [b * np, sq, sk]
|
||||
matmul_result = torch.matmul(query_layer.transpose(0, 1),
|
||||
key_layer.transpose(0, 1).transpose(1, 2)) / self.norm_factor
|
||||
|
||||
# change view to [b, np, sq, sk]
|
||||
attention_scores = matmul_result.view(*output_size)
|
||||
|
||||
if self.attention_upweight is not None and layer_past is None:
|
||||
log_attention_weights = torch.zeros(attention_scores.size(3), attention_scores.size(3),
|
||||
device=torch.cuda.current_device(),
|
||||
dtype=torch.half if self.fp16 else torch.float32)
|
||||
if prompt_length is None:
|
||||
log_attention_weights = self.attention_upweight
|
||||
else:
|
||||
log_attention_weights[:prompt_length, :prompt_length] = self.attention_upweight
|
||||
attention_scores += log_attention_weights
|
||||
|
||||
# ==================================================
|
||||
# Update attention mask for inference. [b, np, sq, sk]
|
||||
# ==================================================
|
||||
|
||||
if get_key_value:
|
||||
with torch.no_grad():
|
||||
if layer_past is not None:
|
||||
attention_mask = attention_mask[
|
||||
...,
|
||||
attention_scores.size(3) - 1,
|
||||
:attention_scores.size(3)].unsqueeze(2)
|
||||
else:
|
||||
attention_mask = attention_mask[
|
||||
...,
|
||||
:attention_scores.size(3),
|
||||
:attention_scores.size(3)]
|
||||
|
||||
# ===========================
|
||||
# Attention probs and dropout
|
||||
# ===========================
|
||||
|
||||
if context_length is not None:
|
||||
attention_mask = torch.clone(attention_mask)
|
||||
attention_mask[:, :, context_length:, :] = True
|
||||
|
||||
# attention scores and attention mask [b, np, sq, sk]
|
||||
# attention_scores = attention_mask_func(attention_scores, attention_mask)
|
||||
attention_scores = attention_scores - attention_mask * 10000.0
|
||||
if self.attention_softmax_in_fp32:
|
||||
attention_probs = self.softmax(attention_scores.float()).half()
|
||||
else:
|
||||
attention_probs = self.softmax(attention_scores)
|
||||
|
||||
# This is actually dropping out entire tokens to attend to, which might
|
||||
# seem a bit unusual, but is taken from the original Transformer paper.
|
||||
with mpu.get_cuda_rng_tracker().fork():
|
||||
attention_probs = self.attention_dropout(attention_probs)
|
||||
|
||||
# =========================
|
||||
# Context layer. [sq, b, hp]
|
||||
# =========================
|
||||
|
||||
# value_layer -> context layer.
|
||||
# [sq, b, np, hn] --> [b, np, sq, hn]
|
||||
|
||||
# context layer shape: [b, np, sq, hn]
|
||||
output_size = (value_layer.size(1),
|
||||
value_layer.size(2),
|
||||
query_layer.size(0),
|
||||
value_layer.size(3))
|
||||
|
||||
# change view [sq, b * np, hn]
|
||||
value_layer = value_layer.view(value_layer.size(0), output_size[0] * output_size[1], -1)
|
||||
|
||||
# change view [b * np, sq, sk]
|
||||
attention_probs = attention_probs.view(output_size[0] * output_size[1],
|
||||
output_size[2], -1)
|
||||
|
||||
context_layer = torch.bmm(attention_probs, value_layer.unsqueeze(0).transpose(1, 2).squeeze(0))
|
||||
|
||||
# change view [b, np, sq, hn]
|
||||
context_layer = context_layer.view(*output_size)
|
||||
|
||||
# # [b, np, sq, hn] --> [sq, b, np, hn]
|
||||
context_layer = context_layer.permute(2, 0, 1, 3).contiguous()
|
||||
|
||||
# # [sq, b, np, hn] --> [sq, b, hp]
|
||||
new_context_layer_shape = context_layer.size()[:-2] + \
|
||||
(self.hidden_size_per_partition,)
|
||||
context_layer = context_layer.view(*new_context_layer_shape)
|
||||
|
||||
# =================
|
||||
# Output. [sq, b, h]
|
||||
# =================
|
||||
|
||||
output, bias = self.dense(context_layer)
|
||||
|
||||
if get_key_value:
|
||||
output = [output, present]
|
||||
|
||||
return output, bias
|
||||
|
||||
|
||||
class ParallelTopQuerySelfAttention(MegatronModule):
|
||||
"""Parallel top query self-attention layer abstract class.
|
||||
|
||||
Self-attention layer takes input with size [b, s, h]
|
||||
and returns output of the same size.
|
||||
"""
|
||||
|
||||
def __init__(self, init_method,
|
||||
output_layer_init_method, layer_number):
|
||||
super(ParallelTopQuerySelfAttention, self).__init__()
|
||||
args = get_args()
|
||||
self.fp16 = args.fp16
|
||||
self.attention_softmax_in_fp32 = args.attention_softmax_in_fp32
|
||||
self.layer_number = max(1, layer_number)
|
||||
|
||||
if hasattr(args, 'attention_upweight_top'):
|
||||
self.attention_upweight = args.attention_upweight_top
|
||||
else:
|
||||
self.attention_upweight = None
|
||||
# Per attention head and per partition values.
|
||||
world_size = mpu.get_model_parallel_world_size()
|
||||
self.hidden_size_per_partition = mpu.divide(args.hidden_size,
|
||||
world_size)
|
||||
self.hidden_size_per_attention_head = mpu.divide(
|
||||
args.hidden_size, args.num_attention_heads)
|
||||
self.num_attention_heads_per_partition = mpu.divide(
|
||||
args.num_attention_heads, world_size)
|
||||
|
||||
self.query = mpu.ColumnParallelLinear(
|
||||
args.hidden_size,
|
||||
args.hidden_size,
|
||||
gather_output=False,
|
||||
init_method=init_method)
|
||||
|
||||
self.key = mpu.ColumnParallelLinear(
|
||||
args.hidden_size,
|
||||
args.hidden_size,
|
||||
gather_output=False,
|
||||
init_method=init_method)
|
||||
|
||||
self.value = mpu.ColumnParallelLinear(
|
||||
args.hidden_size,
|
||||
args.hidden_size,
|
||||
gather_output=False,
|
||||
init_method=init_method)
|
||||
|
||||
self.norm_factor = math.sqrt(self.hidden_size_per_attention_head)
|
||||
self.softmax = torch.nn.Softmax(dim=-1)
|
||||
|
||||
# Dropout. Note that for a single iteration, this layer will generate
|
||||
# different outputs on different number of parallel partitions but
|
||||
# on average it should not be partition dependent.
|
||||
self.attention_dropout = torch.nn.Dropout(args.attention_dropout)
|
||||
|
||||
# Output.
|
||||
self.dense = mpu.RowParallelLinear(
|
||||
args.hidden_size,
|
||||
args.hidden_size,
|
||||
input_is_parallel=False,
|
||||
init_method=output_layer_init_method,
|
||||
skip_bias_add=True)
|
||||
|
||||
def forward(
|
||||
self,
|
||||
hidden_states,
|
||||
query_hidden_state,
|
||||
attention_mask,
|
||||
layer_past=None,
|
||||
get_key_value=False,
|
||||
prompt_length=None,
|
||||
context_length=None,
|
||||
):
|
||||
|
||||
# hidden_states: [sq, b, h]
|
||||
|
||||
query_layer, _ = self.query(query_hidden_state)
|
||||
key_layer, _ = self.key(hidden_states)
|
||||
value_layer, _ = self.value(hidden_states)
|
||||
|
||||
new_query_layer_shape = query_layer.size()[:-1] + \
|
||||
(self.num_attention_heads_per_partition,
|
||||
self.hidden_size_per_attention_head)
|
||||
query_layer = query_layer.view(*new_query_layer_shape)
|
||||
|
||||
new_query_layer_shape = key_layer.size()[:-1] + \
|
||||
(self.num_attention_heads_per_partition,
|
||||
self.hidden_size_per_attention_head)
|
||||
key_layer = key_layer.view(*new_query_layer_shape)
|
||||
|
||||
new_query_layer_shape = value_layer.size()[:-1] + \
|
||||
(self.num_attention_heads_per_partition,
|
||||
self.hidden_size_per_attention_head)
|
||||
value_layer = value_layer.view(*new_query_layer_shape)
|
||||
|
||||
# ==================================
|
||||
# Adjust key and value for inference
|
||||
# ==================================
|
||||
|
||||
if layer_past is not None:
|
||||
past_key, past_value = layer_past
|
||||
key_layer = torch.cat((past_key.type_as(key_layer),
|
||||
key_layer), dim=0)
|
||||
value_layer = torch.cat((past_value.type_as(value_layer),
|
||||
value_layer), dim=0)
|
||||
if get_key_value:
|
||||
present = (key_layer, value_layer)
|
||||
|
||||
# ===================================
|
||||
# Raw attention scores. [b, np, sq, sk]
|
||||
# ===================================
|
||||
|
||||
# [b, np, sq, sk]
|
||||
output_size = (query_layer.size(1),
|
||||
query_layer.size(2),
|
||||
query_layer.size(0),
|
||||
key_layer.size(0))
|
||||
|
||||
# [s, b, np, hn] -> [s, b * np, hn]
|
||||
query_layer = query_layer.contiguous().view(output_size[2], output_size[0] * output_size[1], -1)
|
||||
key_layer = key_layer.contiguous().view(output_size[3], output_size[0] * output_size[1], -1)
|
||||
|
||||
# Raw attention scores. [b * np, sq, sk]
|
||||
matmul_result = torch.matmul(query_layer.transpose(0, 1),
|
||||
key_layer.transpose(0, 1).transpose(1, 2)) / self.norm_factor
|
||||
|
||||
# change view to [b, np, s, s]
|
||||
attention_scores = matmul_result.view(*output_size)
|
||||
|
||||
if self.attention_upweight is not None and layer_past is None:
|
||||
log_attention_weights = torch.zeros(attention_scores.size(3), attention_scores.size(3),
|
||||
device=torch.cuda.current_device(),
|
||||
dtype=torch.half if self.fp16 else torch.float32)
|
||||
if prompt_length is None:
|
||||
log_attention_weights = self.attention_upweight
|
||||
else:
|
||||
log_attention_weights[:prompt_length, :prompt_length] = self.attention_upweight
|
||||
attention_scores += log_attention_weights
|
||||
|
||||
# ==================================================
|
||||
# Update attention mask for inference. [b, np, sq, sk]
|
||||
# ==================================================
|
||||
|
||||
if get_key_value:
|
||||
with torch.no_grad():
|
||||
if layer_past is not None:
|
||||
attention_mask = attention_mask[
|
||||
...,
|
||||
attention_scores.size(3) - 1,
|
||||
:attention_scores.size(3)].unsqueeze(2)
|
||||
else:
|
||||
attention_mask = attention_mask[
|
||||
...,
|
||||
:attention_scores.size(3),
|
||||
:attention_scores.size(3)]
|
||||
|
||||
# ===========================
|
||||
# Attention probs and dropout
|
||||
# ===========================
|
||||
|
||||
if context_length is not None:
|
||||
attention_mask = torch.clone(attention_mask)
|
||||
attention_mask[:, :, context_length:, :] = True
|
||||
|
||||
# attention scores and attention mask [b, np, sq, sk]
|
||||
# attention_scores = attention_mask_func(attention_scores, attention_mask)
|
||||
attention_scores = attention_scores - attention_mask * 10000.0
|
||||
if self.attention_softmax_in_fp32:
|
||||
attention_probs = self.softmax(attention_scores.float()).half()
|
||||
else:
|
||||
attention_probs = self.softmax(attention_scores)
|
||||
|
||||
# This is actually dropping out entire tokens to attend to, which might
|
||||
# seem a bit unusual, but is taken from the original Transformer paper.
|
||||
with mpu.get_cuda_rng_tracker().fork():
|
||||
attention_probs = self.attention_dropout(attention_probs)
|
||||
|
||||
# =========================
|
||||
# Context layer. [sq, b, hp]
|
||||
# =========================
|
||||
|
||||
# value_layer -> context layer.
|
||||
# [sq, b, np, hn] --> [b, np, sq, hn]
|
||||
|
||||
# context layer shape: [b, np, sq, hn]
|
||||
output_size = (value_layer.size(1),
|
||||
value_layer.size(2),
|
||||
query_layer.size(0),
|
||||
value_layer.size(3))
|
||||
|
||||
# change view [sq, b * np, hn]
|
||||
value_layer = value_layer.view(value_layer.size(0), output_size[0] * output_size[1], -1)
|
||||
|
||||
# change view [b * np, sq, sk]
|
||||
attention_probs = attention_probs.view(output_size[0] * output_size[1],
|
||||
output_size[2], -1)
|
||||
|
||||
# matmul: [b * np, sq, hn]
|
||||
context_layer = torch.bmm(attention_probs, value_layer.unsqueeze(0).transpose(1, 2).squeeze(0))
|
||||
|
||||
# change view [b, np, sq, hn]
|
||||
context_layer = context_layer.view(*output_size)
|
||||
|
||||
# [b, np, sq, hn] --> [sq, b, np, hn]
|
||||
context_layer = context_layer.permute(2, 0, 1, 3).contiguous()
|
||||
|
||||
# [sq, b, np, hn] --> [sq, b, hp]
|
||||
new_context_layer_shape = context_layer.size()[:-2] + \
|
||||
(self.hidden_size_per_partition,)
|
||||
context_layer = context_layer.view(*new_context_layer_shape)
|
||||
|
||||
# =================
|
||||
# Output. [sq, b, h]
|
||||
# =================
|
||||
|
||||
output, bias = self.dense(context_layer)
|
||||
|
||||
if get_key_value:
|
||||
output = [output, present]
|
||||
|
||||
return output, bias
|
||||
|
||||
|
||||
def bias_dropout_add(x, bias, residual, prob, training):
|
||||
# type: (Tensor, Tensor, Tensor, float, bool) -> Tensor
|
||||
out = torch.nn.functional.dropout(x + bias, p=prob, training=training)
|
||||
out = residual + out
|
||||
return out
|
||||
|
||||
|
||||
def get_bias_dropout_add(training):
|
||||
def _bias_dropout_add(x, bias, residual, prob):
|
||||
return bias_dropout_add(x, bias, residual, prob, training)
|
||||
|
||||
return _bias_dropout_add
|
||||
|
||||
|
||||
@torch.jit.script
|
||||
def bias_dropout_add_fused_train(x, bias, residual, prob):
|
||||
# type: (Tensor, Tensor, Tensor, float) -> Tensor
|
||||
return bias_dropout_add(x, bias, residual, prob, True)
|
||||
|
||||
|
||||
@torch.jit.script
|
||||
def bias_dropout_add_fused_inference(x, bias, residual, prob):
|
||||
# type: (Tensor, Tensor, Tensor, float) -> Tensor
|
||||
return bias_dropout_add(x, bias, residual, prob, False)
|
||||
|
||||
|
||||
class ParallelTransformerLayer(MegatronModule):
|
||||
"""A single transformer layer.
|
||||
|
||||
Transformore layer takes input with size [b, s, h] and returns an
|
||||
output of the same size.
|
||||
"""
|
||||
|
||||
def __init__(self, init_method,
|
||||
output_layer_init_method, layer_number):
|
||||
args = get_args()
|
||||
|
||||
super(ParallelTransformerLayer, self).__init__()
|
||||
self.layer_number = layer_number
|
||||
|
||||
self.apply_residual_connection_post_layernorm \
|
||||
= args.apply_residual_connection_post_layernorm
|
||||
|
||||
# Layernorm on the input data.
|
||||
self.input_layernorm = LayerNorm(
|
||||
args.hidden_size,
|
||||
eps=args.layernorm_epsilon)
|
||||
|
||||
# Self attention.
|
||||
self.attention = ParallelSelfAttention(init_method,
|
||||
output_layer_init_method,
|
||||
layer_number)
|
||||
self.hidden_dropout = args.hidden_dropout
|
||||
self.bias_dropout_fusion = args.bias_dropout_fusion
|
||||
|
||||
# Layernorm on the input data.
|
||||
self.post_attention_layernorm = LayerNorm(
|
||||
args.hidden_size,
|
||||
eps=args.layernorm_epsilon)
|
||||
if hasattr(args, 'attention_upweight'):
|
||||
self.attention_upweight = args.attention_upweight
|
||||
else:
|
||||
self.attention_upweight = None
|
||||
if hasattr(args, 'ln_fp16'):
|
||||
self.ln_fp16 = args.ln_fp16
|
||||
else:
|
||||
self.ln_fp16 = False
|
||||
# MLP
|
||||
self.mlp = ParallelMLP(init_method,
|
||||
output_layer_init_method)
|
||||
|
||||
def forward(
|
||||
self,
|
||||
hidden_states,
|
||||
attention_mask,
|
||||
layer_past=None,
|
||||
get_key_value=False,
|
||||
prompt_length=None,
|
||||
context_length=None,
|
||||
):
|
||||
# hidden_states: [b, s, h]
|
||||
if self.ln_fp16:
|
||||
layernorm_output = self.input_layernorm(hidden_states)
|
||||
else:
|
||||
layernorm_output = self.input_layernorm(hidden_states.float()).half()
|
||||
|
||||
# Self attention.
|
||||
attention_output, attention_bias = \
|
||||
self.attention(layernorm_output,
|
||||
attention_mask,
|
||||
layer_past=layer_past,
|
||||
get_key_value=get_key_value,
|
||||
prompt_length=prompt_length,
|
||||
context_length=context_length)
|
||||
|
||||
if get_key_value:
|
||||
attention_output, presents = attention_output
|
||||
|
||||
# Residual connection.
|
||||
if self.apply_residual_connection_post_layernorm:
|
||||
residual = layernorm_output
|
||||
else:
|
||||
residual = hidden_states
|
||||
|
||||
# jit scripting for a nn.module (with dropout) is not
|
||||
# trigerring the fusion kernel. For now, we use two
|
||||
# different nn.functional routines to account for varying
|
||||
# dropout semantics during training and inference phases.
|
||||
if self.bias_dropout_fusion:
|
||||
if self.training:
|
||||
bias_dropout_add_func = bias_dropout_add_fused_train
|
||||
else:
|
||||
bias_dropout_add_func = bias_dropout_add_fused_inference
|
||||
else:
|
||||
bias_dropout_add_func = get_bias_dropout_add(self.training)
|
||||
|
||||
# re-enable torch grad to enable fused optimization.
|
||||
with torch.enable_grad():
|
||||
layernorm_input = bias_dropout_add_func(
|
||||
attention_output,
|
||||
attention_bias.expand_as(residual),
|
||||
residual,
|
||||
self.hidden_dropout)
|
||||
|
||||
# Layer norm post the self attention.
|
||||
if self.ln_fp16:
|
||||
layernorm_output = self.post_attention_layernorm(layernorm_input)
|
||||
else:
|
||||
layernorm_output = self.post_attention_layernorm(layernorm_input.float()).half()
|
||||
|
||||
mlp_output, _ = self.mlp(layernorm_output)
|
||||
|
||||
# MLP.
|
||||
# Second residual connection.
|
||||
if self.apply_residual_connection_post_layernorm:
|
||||
residual = layernorm_output
|
||||
else:
|
||||
residual = layernorm_input
|
||||
|
||||
output = mlp_output + residual
|
||||
|
||||
if get_key_value:
|
||||
output = [output, presents]
|
||||
|
||||
return output
|
||||
|
||||
|
||||
class ParallelTopQueryLayer(MegatronModule):
|
||||
"""A single top query layer.
|
||||
|
||||
Top query layer takes input with size [b, s, h] and returns an
|
||||
output of the same size.
|
||||
"""
|
||||
|
||||
def __init__(self, init_method,
|
||||
output_layer_init_method, layer_number):
|
||||
args = get_args()
|
||||
|
||||
super(ParallelTopQueryLayer, self).__init__()
|
||||
self.layer_number = layer_number
|
||||
|
||||
self.apply_residual_connection_post_layernorm \
|
||||
= args.apply_residual_connection_post_layernorm
|
||||
|
||||
# Layernorm on the input data.
|
||||
self.input_layernorm = LayerNorm(
|
||||
args.hidden_size,
|
||||
eps=args.layernorm_epsilon)
|
||||
|
||||
# Self attention.
|
||||
self.attention = ParallelTopQuerySelfAttention(init_method,
|
||||
output_layer_init_method,
|
||||
layer_number)
|
||||
|
||||
self.hidden_dropout = args.hidden_dropout
|
||||
self.bias_dropout_fusion = args.bias_dropout_fusion
|
||||
|
||||
# Layernorm on the input data.
|
||||
self.post_attention_layernorm = LayerNorm(
|
||||
args.hidden_size,
|
||||
eps=args.layernorm_epsilon)
|
||||
|
||||
if hasattr(args, 'ln_fp16'):
|
||||
self.ln_fp16 = args.ln_fp16
|
||||
else:
|
||||
self.ln_fp16 = False
|
||||
|
||||
# MLP
|
||||
self.mlp = ParallelMLP(init_method,
|
||||
output_layer_init_method)
|
||||
|
||||
def forward(
|
||||
self,
|
||||
hidden_states,
|
||||
query_hidden_state,
|
||||
attention_mask,
|
||||
layer_past=None,
|
||||
get_key_value=False,
|
||||
prompt_length=None,
|
||||
context_length=None,
|
||||
):
|
||||
# hidden_states: [b, s, h]
|
||||
assert query_hidden_state != None
|
||||
|
||||
# Layer norm at the begining of the transformer layer.
|
||||
if self.ln_fp16:
|
||||
layernorm_output = self.input_layernorm(hidden_states)
|
||||
else:
|
||||
layernorm_output = self.input_layernorm(hidden_states.float()).half()
|
||||
|
||||
# Self attention.
|
||||
attention_output, attention_bias = \
|
||||
self.attention(layernorm_output,
|
||||
query_hidden_state,
|
||||
attention_mask,
|
||||
layer_past=layer_past,
|
||||
get_key_value=get_key_value,
|
||||
prompt_length=prompt_length,
|
||||
context_length=context_length)
|
||||
|
||||
if get_key_value:
|
||||
attention_output, presents = attention_output
|
||||
|
||||
# Residual connection.
|
||||
if self.apply_residual_connection_post_layernorm:
|
||||
residual = layernorm_output
|
||||
else:
|
||||
residual = hidden_states
|
||||
|
||||
# jit scripting for a nn.module (with dropout) is not
|
||||
# trigerring the fusion kernel. For now, we use two
|
||||
# different nn.functional routines to account for varying
|
||||
# dropout semantics during training and inference phases.
|
||||
if self.bias_dropout_fusion:
|
||||
if self.training:
|
||||
bias_dropout_add_func = bias_dropout_add_fused_train
|
||||
else:
|
||||
bias_dropout_add_func = bias_dropout_add_fused_inference
|
||||
else:
|
||||
bias_dropout_add_func = get_bias_dropout_add(self.training)
|
||||
|
||||
# re-enable torch grad to enable fused optimization.
|
||||
with torch.enable_grad():
|
||||
layernorm_input = bias_dropout_add_func(
|
||||
attention_output,
|
||||
attention_bias.expand_as(residual),
|
||||
residual,
|
||||
self.hidden_dropout)
|
||||
|
||||
# Layer norm post the self attention.
|
||||
if self.ln_fp16:
|
||||
layernorm_output = self.post_attention_layernorm(layernorm_input)
|
||||
else:
|
||||
layernorm_output = self.post_attention_layernorm(layernorm_input.float()).half()
|
||||
|
||||
# MLP.
|
||||
mlp_output, _ = self.mlp(layernorm_output)
|
||||
|
||||
# Second residual connection.
|
||||
if self.apply_residual_connection_post_layernorm:
|
||||
residual = layernorm_output
|
||||
else:
|
||||
residual = layernorm_input
|
||||
|
||||
output = mlp_output + residual
|
||||
|
||||
if get_key_value:
|
||||
output = [output, presents]
|
||||
|
||||
return output
|
||||
|
||||
|
||||
class ParallelTransformer(MegatronModule):
|
||||
"""Transformer class."""
|
||||
|
||||
def __init__(self, init_method, output_layer_init_method):
|
||||
super(ParallelTransformer, self).__init__()
|
||||
args = get_args()
|
||||
|
||||
# Store activation checkpoiting flag.
|
||||
self.checkpoint_activations = args.checkpoint_activations
|
||||
self.checkpoint_num_layers = args.checkpoint_num_layers
|
||||
|
||||
# Number of layers:
|
||||
self.num_layers = args.num_layers
|
||||
self.num_unique_layers = None
|
||||
|
||||
#################
|
||||
assert self.num_unique_layers is None
|
||||
#################
|
||||
|
||||
if self.num_unique_layers is None:
|
||||
self.num_unique_layers = self.num_layers
|
||||
assert self.num_layers % self.num_unique_layers == 0, \
|
||||
'number of layers should be divisible by number of unique layers'
|
||||
self.param_sharing_style = 'grouped'
|
||||
|
||||
# Transformer layers.
|
||||
def build_layer(layer_number):
|
||||
return ParallelTransformerLayer(
|
||||
init_method,
|
||||
output_layer_init_method, layer_number)
|
||||
|
||||
self.layers = torch.nn.ModuleList(
|
||||
[build_layer(i + 1) for i in range(self.num_unique_layers)])
|
||||
|
||||
self.topQueryLayer = ParallelTopQueryLayer(
|
||||
init_method,
|
||||
output_layer_init_method, self.num_unique_layers)
|
||||
|
||||
# Final layer norm before output.
|
||||
if hasattr(args, 'ln_fp16'):
|
||||
self.ln_fp16 = args.ln_fp16
|
||||
else:
|
||||
self.ln_fp16 = False
|
||||
|
||||
self.final_layernorm = LayerNorm(
|
||||
args.hidden_size,
|
||||
eps=args.layernorm_epsilon)
|
||||
|
||||
def _get_layer_index(self, layer_number):
|
||||
if self.param_sharing_style == 'grouped':
|
||||
return layer_number % self.num_unique_layers
|
||||
if self.param_sharing_style == 'spaced':
|
||||
return layer_number // (self.num_layers // self.num_unique_layers)
|
||||
assert False, 'should not be here'
|
||||
|
||||
def _get_layer(self, layer_number):
|
||||
return self.layers[self._get_layer_index(layer_number)]
|
||||
|
||||
def _checkpointed_forward(self, hidden_states, attention_mask):
|
||||
"""Forward method with activation checkpointing."""
|
||||
|
||||
def custom(start, end):
|
||||
def custom_forward(*inputs):
|
||||
x_ = inputs[0]
|
||||
for index in range(start, end):
|
||||
layer = self._get_layer(index)
|
||||
x_ = layer(x_, inputs[1])
|
||||
return x_
|
||||
|
||||
return custom_forward
|
||||
|
||||
# Make sure memory is freed.
|
||||
mpu.reset_checkpointed_activations_memory_buffer()
|
||||
l = 0
|
||||
while l < self.num_layers:
|
||||
hidden_states = mpu.checkpoint(
|
||||
custom(l, l + self.checkpoint_num_layers),
|
||||
hidden_states, attention_mask)
|
||||
l += self.checkpoint_num_layers
|
||||
|
||||
return hidden_states
|
||||
|
||||
def forward(
|
||||
self,
|
||||
hidden_states,
|
||||
query_hidden_state,
|
||||
attention_mask,
|
||||
layer_past=None,
|
||||
get_key_value=False,
|
||||
prompt_length=None,
|
||||
context_length=None,
|
||||
):
|
||||
|
||||
# Checks
|
||||
if layer_past is not None:
|
||||
assert get_key_value, \
|
||||
'for not None values in layer_past, ' \
|
||||
'expected get_key_value to be set'
|
||||
if get_key_value:
|
||||
assert not self.checkpoint_activations, \
|
||||
'get_key_value does not work with ' \
|
||||
'activation checkpointing'
|
||||
|
||||
# data format change to avoid explicit tranposes : [b s h] --> [s b h]
|
||||
hidden_states = hidden_states.transpose(0, 1).contiguous()
|
||||
query_hidden_state = query_hidden_state.transpose(0, 1).contiguous()
|
||||
|
||||
if self.checkpoint_activations:
|
||||
hidden_states = self._checkpointed_forward(hidden_states,
|
||||
attention_mask)
|
||||
else:
|
||||
if get_key_value:
|
||||
presents = []
|
||||
for index in range(self.num_layers):
|
||||
layer = self._get_layer(index)
|
||||
past = None
|
||||
if layer_past is not None:
|
||||
past = layer_past[index]
|
||||
hidden_states = layer(hidden_states,
|
||||
attention_mask,
|
||||
layer_past=past,
|
||||
get_key_value=get_key_value,
|
||||
prompt_length=prompt_length,
|
||||
context_length=context_length)
|
||||
if get_key_value:
|
||||
hidden_states, present = hidden_states
|
||||
presents.append(present)
|
||||
|
||||
if self.ln_fp16:
|
||||
hidden_states_ = self.final_layernorm(hidden_states)
|
||||
else:
|
||||
hidden_states_ = self.final_layernorm(hidden_states.float()).half()
|
||||
|
||||
#################################
|
||||
# top query layer
|
||||
#################################
|
||||
past = None
|
||||
if layer_past is not None:
|
||||
past = layer_past[self.num_layers]
|
||||
hidden_states = self.topQueryLayer(hidden_states_,
|
||||
query_hidden_state,
|
||||
attention_mask,
|
||||
layer_past=past,
|
||||
get_key_value=get_key_value,
|
||||
prompt_length=prompt_length,
|
||||
context_length=context_length)
|
||||
|
||||
if get_key_value:
|
||||
hidden_states, present = hidden_states
|
||||
presents.append(present)
|
||||
|
||||
# reverting data format change [s b h] --> [b s h]
|
||||
output = hidden_states.transpose(0, 1).contiguous()
|
||||
|
||||
if get_key_value:
|
||||
output = [output, presents]
|
||||
|
||||
return output
|
@ -0,0 +1,83 @@
|
||||
# coding=utf-8
|
||||
# Copyright (c) 2020, NVIDIA CORPORATION. All rights reserved.
|
||||
#
|
||||
# Licensed under the Apache License, Version 2.0 (the "License");
|
||||
# you may not use this file except in compliance with the License.
|
||||
# You may obtain a copy of the License at
|
||||
#
|
||||
# http://www.apache.org/licenses/LICENSE-2.0
|
||||
#
|
||||
# Unless required by applicable law or agreed to in writing, software
|
||||
# distributed under the License is distributed on an "AS IS" BASIS,
|
||||
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
|
||||
# See the License for the specific language governing permissions and
|
||||
# limitations under the License.
|
||||
|
||||
"""Utilities for models."""
|
||||
|
||||
import math
|
||||
import torch
|
||||
|
||||
|
||||
def init_method_normal(sigma):
|
||||
"""Init method based on N(0, sigma)."""
|
||||
|
||||
def init_(tensor):
|
||||
return torch.nn.init.normal_(tensor, mean=0.0, std=sigma)
|
||||
|
||||
return init_
|
||||
|
||||
|
||||
def scaled_init_method_normal(sigma, num_layers):
|
||||
"""Init method based on N(0, sigma/sqrt(2*num_layers)."""
|
||||
std = sigma / math.sqrt(2.0 * num_layers)
|
||||
|
||||
def init_(tensor):
|
||||
return torch.nn.init.normal_(tensor, mean=0.0, std=std)
|
||||
|
||||
return init_
|
||||
|
||||
|
||||
def attention_mask_func(attention_scores, attention_mask):
|
||||
attention_scores.masked_fill_(attention_mask, -10000.0)
|
||||
|
||||
return attention_scores
|
||||
|
||||
|
||||
def get_linear_layer(rows, columns, init_method):
|
||||
"""Simple linear layer with weight initialization."""
|
||||
layer = torch.nn.Linear(rows, columns)
|
||||
init_method(layer.weight)
|
||||
with torch.no_grad():
|
||||
layer.bias.zero_()
|
||||
return layer
|
||||
|
||||
|
||||
def fast_gelu(x):
|
||||
"""Mindspore's fast gelu implementation."""
|
||||
return x / (1 + torch.exp(-1.702 * torch.abs(x))) * torch.exp(0.851 * (x - torch.abs(x)))
|
||||
|
||||
|
||||
@torch.jit.script
|
||||
def gelu_impl(x):
|
||||
"""OpenAI's gelu implementation."""
|
||||
return (
|
||||
0.5 * x * (1.0 + torch.tanh(0.7978845608028654 * x * (1.0 + 0.044715 * x * x)))
|
||||
)
|
||||
|
||||
|
||||
def openai_gelu(x):
|
||||
return gelu_impl(x)
|
||||
|
||||
|
||||
# This is actually Python equivalent of torch.nn.functional.gelu(), also with type hints for ONNX exporter
|
||||
@torch.jit.script
|
||||
def erf_gelu(x):
|
||||
return (
|
||||
x
|
||||
* 0.5
|
||||
* (
|
||||
torch.erf(x / 1.41421).to(dtype=x.dtype)
|
||||
+ torch.ones_like(x).to(dtype=x.dtype)
|
||||
)
|
||||
)
|
@ -0,0 +1,81 @@
|
||||
# coding=utf-8
|
||||
# Copyright (c) 2020, NVIDIA CORPORATION. All rights reserved.
|
||||
#
|
||||
# Licensed under the Apache License, Version 2.0 (the "License");
|
||||
# you may not use this file except in compliance with the License.
|
||||
# You may obtain a copy of the License at
|
||||
#
|
||||
# http://www.apache.org/licenses/LICENSE-2.0
|
||||
#
|
||||
# Unless required by applicable law or agreed to in writing, software
|
||||
# distributed under the License is distributed on an "AS IS" BASIS,
|
||||
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
|
||||
# See the License for the specific language governing permissions and
|
||||
# limitations under the License.
|
||||
|
||||
"""Model parallel utility interface."""
|
||||
|
||||
from .cross_entropy import vocab_parallel_cross_entropy
|
||||
|
||||
from .data import broadcast_data
|
||||
|
||||
from .initialize import is_unitialized
|
||||
from .initialize import destroy_model_parallel
|
||||
from .initialize import get_data_parallel_group
|
||||
from .initialize import get_data_parallel_rank
|
||||
from .initialize import get_data_parallel_world_size
|
||||
from .initialize import get_embedding_group
|
||||
from .initialize import get_model_parallel_group
|
||||
from .initialize import get_tensor_model_parallel_group
|
||||
from .initialize import get_pipeline_model_parallel_group
|
||||
from .initialize import get_tensor_model_parallel_rank, set_tensor_model_parallel_rank
|
||||
from .initialize import (
|
||||
get_pipeline_model_parallel_rank,
|
||||
set_pipeline_model_parallel_rank,
|
||||
)
|
||||
from .initialize import is_pipeline_first_stage, is_pipeline_last_stage
|
||||
from .initialize import get_tensor_model_parallel_src_rank
|
||||
from .initialize import get_pipeline_model_parallel_first_rank
|
||||
from .initialize import get_pipeline_model_parallel_last_rank
|
||||
from .initialize import get_pipeline_model_parallel_next_rank
|
||||
from .initialize import get_pipeline_model_parallel_prev_rank
|
||||
from .initialize import (
|
||||
get_tensor_model_parallel_world_size,
|
||||
set_tensor_model_parallel_world_size,
|
||||
)
|
||||
from .initialize import (
|
||||
get_pipeline_model_parallel_world_size,
|
||||
set_pipeline_model_parallel_world_size,
|
||||
)
|
||||
from .initialize import (
|
||||
get_virtual_pipeline_model_parallel_rank,
|
||||
set_virtual_pipeline_model_parallel_rank,
|
||||
)
|
||||
from .initialize import initialize_model_parallel
|
||||
from .initialize import model_parallel_is_initialized
|
||||
from .initialize import get_model_parallel_world_size, get_model_parallel_rank
|
||||
|
||||
from .layers import ColumnParallelLinear
|
||||
from .layers import RowParallelLinear
|
||||
from .layers import VocabParallelEmbedding
|
||||
from .layers import (
|
||||
set_tensor_model_parallel_attributes,
|
||||
set_defaults_if_not_set_tensor_model_parallel_attributes,
|
||||
copy_tensor_model_parallel_attributes,
|
||||
)
|
||||
|
||||
from .mappings import copy_to_tensor_model_parallel_region
|
||||
from .mappings import gather_from_tensor_model_parallel_region
|
||||
from .mappings import reduce_from_tensor_model_parallel_region
|
||||
from .mappings import scatter_to_tensor_model_parallel_region
|
||||
|
||||
from .random import checkpoint
|
||||
from .random import get_cuda_rng_tracker
|
||||
from .random import init_checkpointed_activations_memory_buffer
|
||||
from .random import model_parallel_cuda_manual_seed
|
||||
from .random import reset_checkpointed_activations_memory_buffer
|
||||
from .random import gather_split_1d_tensor
|
||||
from .random import split_tensor_into_1d_equal_chunks
|
||||
|
||||
from .utils import divide
|
||||
from .utils import split_tensor_along_last_dim
|
@ -0,0 +1,115 @@
|
||||
# coding=utf-8
|
||||
# Copyright (c) 2020, NVIDIA CORPORATION. All rights reserved.
|
||||
#
|
||||
# Licensed under the Apache License, Version 2.0 (the "License");
|
||||
# you may not use this file except in compliance with the License.
|
||||
# You may obtain a copy of the License at
|
||||
#
|
||||
# http://www.apache.org/licenses/LICENSE-2.0
|
||||
#
|
||||
# Unless required by applicable law or agreed to in writing, software
|
||||
# distributed under the License is distributed on an "AS IS" BASIS,
|
||||
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
|
||||
# See the License for the specific language governing permissions and
|
||||
# limitations under the License.
|
||||
|
||||
|
||||
import torch
|
||||
|
||||
from .initialize import get_tensor_model_parallel_group
|
||||
from .initialize import get_tensor_model_parallel_rank
|
||||
from .initialize import get_tensor_model_parallel_world_size
|
||||
from .utils import VocabUtility
|
||||
|
||||
|
||||
class _VocabParallelCrossEntropy(torch.autograd.Function):
|
||||
@staticmethod
|
||||
def forward(ctx, vocab_parallel_logits, target):
|
||||
|
||||
# Maximum value along vocab dimension across all GPUs.
|
||||
logits_max = torch.max(vocab_parallel_logits, dim=-1)[0]
|
||||
torch.distributed.all_reduce(
|
||||
logits_max,
|
||||
op=torch.distributed.ReduceOp.MAX,
|
||||
group=get_tensor_model_parallel_group(),
|
||||
)
|
||||
# Subtract the maximum value.
|
||||
vocab_parallel_logits.sub_(logits_max.unsqueeze(dim=-1))
|
||||
|
||||
# Get the partition's vocab indecies
|
||||
get_vocab_range = VocabUtility.vocab_range_from_per_partition_vocab_size
|
||||
partition_vocab_size = vocab_parallel_logits.size()[-1]
|
||||
rank = get_tensor_model_parallel_rank()
|
||||
world_size = get_tensor_model_parallel_world_size()
|
||||
vocab_start_index, vocab_end_index = get_vocab_range(
|
||||
partition_vocab_size, rank, world_size
|
||||
)
|
||||
|
||||
# Create a mask of valid vocab ids (1 means it needs to be masked).
|
||||
target_mask = (target < vocab_start_index) | (target >= vocab_end_index)
|
||||
masked_target = target.clone() - vocab_start_index
|
||||
masked_target[target_mask] = 0
|
||||
|
||||
# Get predicted-logits = logits[target].
|
||||
# For Simplicity, we convert logits to a 2-D tensor with size
|
||||
# [*, partition-vocab-size] and target to a 1-D tensor of size [*].
|
||||
logits_2d = vocab_parallel_logits.view(-1, partition_vocab_size)
|
||||
masked_target_1d = masked_target.view(-1)
|
||||
arange_1d = torch.arange(
|
||||
start=0, end=logits_2d.size()[0], device=logits_2d.device
|
||||
)
|
||||
predicted_logits_1d = logits_2d[arange_1d, masked_target_1d]
|
||||
predicted_logits_1d = predicted_logits_1d.clone().contiguous()
|
||||
predicted_logits = predicted_logits_1d.view_as(target)
|
||||
predicted_logits[target_mask] = 0.0
|
||||
# All reduce is needed to get the chunks from other GPUs.
|
||||
torch.distributed.all_reduce(
|
||||
predicted_logits,
|
||||
op=torch.distributed.ReduceOp.SUM,
|
||||
group=get_tensor_model_parallel_group(),
|
||||
)
|
||||
|
||||
# Sum of exponential of logits along vocab dimension across all GPUs.
|
||||
exp_logits = vocab_parallel_logits
|
||||
torch.exp(vocab_parallel_logits, out=exp_logits)
|
||||
sum_exp_logits = exp_logits.sum(dim=-1)
|
||||
torch.distributed.all_reduce(
|
||||
sum_exp_logits,
|
||||
op=torch.distributed.ReduceOp.SUM,
|
||||
group=get_tensor_model_parallel_group(),
|
||||
)
|
||||
|
||||
# Loss = log(sum(exp(logits))) - predicted-logit.
|
||||
loss = torch.log(sum_exp_logits) - predicted_logits
|
||||
|
||||
# Store softmax, target-mask and masked-target for backward pass.
|
||||
exp_logits.div_(sum_exp_logits.unsqueeze(dim=-1))
|
||||
ctx.save_for_backward(exp_logits, target_mask, masked_target_1d)
|
||||
|
||||
return loss
|
||||
|
||||
@staticmethod
|
||||
def backward(ctx, grad_output):
|
||||
|
||||
# Retreive tensors from the forward path.
|
||||
softmax, target_mask, masked_target_1d = ctx.saved_tensors
|
||||
|
||||
# All the inputs have softmax as thier gradient.
|
||||
grad_input = softmax
|
||||
# For simplicity, work with the 2D gradient.
|
||||
partition_vocab_size = softmax.size()[-1]
|
||||
grad_2d = grad_input.view(-1, partition_vocab_size)
|
||||
|
||||
# Add the gradient from matching classes.
|
||||
arange_1d = torch.arange(start=0, end=grad_2d.size()[0], device=grad_2d.device)
|
||||
grad_2d[arange_1d, masked_target_1d] -= 1.0 - target_mask.view(-1).float()
|
||||
|
||||
# Finally elementwise multiplication with the output gradients.
|
||||
grad_input.mul_(grad_output.unsqueeze(dim=-1))
|
||||
|
||||
return grad_input, None
|
||||
|
||||
|
||||
def vocab_parallel_cross_entropy(vocab_parallel_logits, target):
|
||||
"""Helper function for the cross entropy."""
|
||||
return _VocabParallelCrossEntropy.apply(vocab_parallel_logits, target)
|
@ -0,0 +1,125 @@
|
||||
# coding=utf-8
|
||||
# Copyright (c) 2020, NVIDIA CORPORATION. All rights reserved.
|
||||
#
|
||||
# Licensed under the Apache License, Version 2.0 (the "License");
|
||||
# you may not use this file except in compliance with the License.
|
||||
# You may obtain a copy of the License at
|
||||
#
|
||||
# http://www.apache.org/licenses/LICENSE-2.0
|
||||
#
|
||||
# Unless required by applicable law or agreed to in writing, software
|
||||
# distributed under the License is distributed on an "AS IS" BASIS,
|
||||
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
|
||||
# See the License for the specific language governing permissions and
|
||||
# limitations under the License.
|
||||
|
||||
import torch
|
||||
|
||||
from .initialize import get_tensor_model_parallel_group
|
||||
from .initialize import get_tensor_model_parallel_rank
|
||||
from .initialize import get_tensor_model_parallel_src_rank
|
||||
|
||||
|
||||
_MAX_DATA_DIM = 5
|
||||
|
||||
|
||||
def _check_data_types(keys, data, target_dtype):
|
||||
"""Check that all the keys have the same target data type."""
|
||||
for key in keys:
|
||||
assert (
|
||||
data[key].dtype == target_dtype
|
||||
), "{} has data type {} which " "is different than {}".format(
|
||||
key, data[key].dtype, target_dtype
|
||||
)
|
||||
|
||||
|
||||
def _build_key_size_numel_dictionaries(keys, data):
|
||||
"""Build the size on rank 0 and broadcast."""
|
||||
max_dim = _MAX_DATA_DIM
|
||||
sizes = [0 for _ in range(max_dim) for _ in keys]
|
||||
|
||||
# Pack the sizes on rank zero.
|
||||
if get_tensor_model_parallel_rank() == 0:
|
||||
offset = 0
|
||||
for key in keys:
|
||||
assert data[key].dim() < max_dim, "you should increase MAX_DATA_DIM"
|
||||
size = data[key].size()
|
||||
for i, s in enumerate(size):
|
||||
sizes[i + offset] = s
|
||||
offset += max_dim
|
||||
|
||||
# Move to GPU and broadcast.
|
||||
sizes_cuda = torch.cuda.LongTensor(sizes)
|
||||
torch.distributed.broadcast(
|
||||
sizes_cuda,
|
||||
get_tensor_model_parallel_src_rank(),
|
||||
group=get_tensor_model_parallel_group(),
|
||||
)
|
||||
|
||||
# Move back to cpu and unpack.
|
||||
sizes_cpu = sizes_cuda.cpu()
|
||||
key_size = {}
|
||||
key_numel = {}
|
||||
total_numel = 0
|
||||
offset = 0
|
||||
for key in keys:
|
||||
i = 0
|
||||
size = []
|
||||
numel = 1
|
||||
while sizes_cpu[offset + i] > 0:
|
||||
this_size = sizes_cpu[offset + i]
|
||||
size.append(this_size)
|
||||
numel *= this_size
|
||||
i += 1
|
||||
key_size[key] = size
|
||||
key_numel[key] = numel
|
||||
total_numel += numel
|
||||
offset += max_dim
|
||||
|
||||
return key_size, key_numel, total_numel
|
||||
|
||||
|
||||
def broadcast_data(keys, data, datatype):
|
||||
"""Broadcast data from rank zero of each model parallel group to the
|
||||
members of the same model parallel group.
|
||||
|
||||
Arguments:
|
||||
keys: list of keys in the data disctionary to be broadcasted
|
||||
data: data dictionary of string keys and cpu tensor values.
|
||||
datatype: torch data type of all tensors in data associated
|
||||
with keys.
|
||||
"""
|
||||
# Build (key, size) and (key, number of elements) dictionaries along
|
||||
# with the total number of elements on all ranks.
|
||||
key_size, key_numel, total_numel = _build_key_size_numel_dictionaries(keys, data)
|
||||
|
||||
# Pack on rank zero.
|
||||
if get_tensor_model_parallel_rank() == 0:
|
||||
# Check that all keys have the same data type.
|
||||
_check_data_types(keys, data, datatype)
|
||||
# Flatten the data associated with the keys
|
||||
flatten_data = torch.cat(
|
||||
[data[key].contiguous().view(-1) for key in keys], dim=0
|
||||
).cuda()
|
||||
else:
|
||||
flatten_data = torch.empty(
|
||||
total_numel, device=torch.cuda.current_device(), dtype=datatype
|
||||
)
|
||||
|
||||
# Broadcast
|
||||
torch.distributed.broadcast(
|
||||
flatten_data,
|
||||
get_tensor_model_parallel_src_rank(),
|
||||
group=get_tensor_model_parallel_group(),
|
||||
)
|
||||
|
||||
# Unpack
|
||||
output = {}
|
||||
offset = 0
|
||||
for key in keys:
|
||||
size = key_size[key]
|
||||
numel = key_numel[key]
|
||||
output[key] = flatten_data.narrow(0, offset, numel).view(size)
|
||||
offset += numel
|
||||
|
||||
return output
|
@ -0,0 +1,403 @@
|
||||
# coding=utf-8
|
||||
# Copyright (c) 2020, NVIDIA CORPORATION. All rights reserved.
|
||||
#
|
||||
# Licensed under the Apache License, Version 2.0 (the "License");
|
||||
# you may not use this file except in compliance with the License.
|
||||
# You may obtain a copy of the License at
|
||||
#
|
||||
# http://www.apache.org/licenses/LICENSE-2.0
|
||||
#
|
||||
# Unless required by applicable law or agreed to in writing, software
|
||||
# distributed under the License is distributed on an "AS IS" BASIS,
|
||||
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
|
||||
# See the License for the specific language governing permissions and
|
||||
# limitations under the License.
|
||||
|
||||
|
||||
"""Model and data parallel groups."""
|
||||
|
||||
import torch
|
||||
|
||||
from .utils import ensure_divisibility
|
||||
|
||||
|
||||
# Intra-layer model parallel group that the current rank belongs to.
|
||||
_TENSOR_MODEL_PARALLEL_GROUP = None
|
||||
# Inter-layer model parallel group that the current rank belongs to.
|
||||
_PIPELINE_MODEL_PARALLEL_GROUP = None
|
||||
# Model parallel group (both intra- and pipeline) that the current rank belongs to.
|
||||
_MODEL_PARALLEL_GROUP = None
|
||||
# Embedding group.
|
||||
_EMBEDDING_GROUP = None
|
||||
# Data parallel group that the current rank belongs to.
|
||||
_DATA_PARALLEL_GROUP = None
|
||||
|
||||
_VIRTUAL_PIPELINE_MODEL_PARALLEL_RANK = None
|
||||
_VIRTUAL_PIPELINE_MODEL_PARALLEL_WORLD_SIZE = None
|
||||
|
||||
# These values enable us to change the mpu sizes on the fly.
|
||||
_MPU_TENSOR_MODEL_PARALLEL_WORLD_SIZE = None
|
||||
_MPU_PIPELINE_MODEL_PARALLEL_WORLD_SIZE = None
|
||||
_MPU_TENSOR_MODEL_PARALLEL_RANK = None
|
||||
_MPU_PIPELINE_MODEL_PARALLEL_RANK = None
|
||||
|
||||
# A list of global ranks for each pipeline group to ease calculation of the source
|
||||
# rank when broadcasting from the first or last pipeline stage
|
||||
_PIPELINE_GLOBAL_RANKS = None
|
||||
|
||||
|
||||
def is_unitialized():
|
||||
"""Useful for code segments that may be accessed with or without mpu initialization"""
|
||||
return _DATA_PARALLEL_GROUP is None
|
||||
|
||||
|
||||
def initialize_model_parallel(
|
||||
tensor_model_parallel_size_=1,
|
||||
pipeline_model_parallel_size_=1,
|
||||
virtual_pipeline_model_parallel_size_=None,
|
||||
):
|
||||
"""
|
||||
Initialize model data parallel groups.
|
||||
|
||||
Arguments:
|
||||
tensor_model_parallel_size: number of GPUs used to parallelize model tensor.
|
||||
pipeline_model_parallel_size: number of GPUs used to parallelize model pipeline.
|
||||
|
||||
Let's say we have a total of 16 GPUs denoted by g0 ... g15 and we
|
||||
use 2 GPUs to parallelize the model tensor, and 4 GPUs to parallelize
|
||||
the model pipeline. The present function will
|
||||
create 8 tensor model-parallel groups, 4 pipeline model-parallel groups
|
||||
and 8 data-parallel groups as:
|
||||
8 data_parallel groups:
|
||||
[g0, g2], [g1, g3], [g4, g6], [g5, g7], [g8, g10], [g9, g11], [g12, g14], [g13, g15]
|
||||
8 tensor model-parallel groups:
|
||||
[g0, g1], [g2, g3], [g4, g5], [g6, g7], [g8, g9], [g10, g11], [g12, g13], [g14, g15]
|
||||
4 pipeline model-parallel groups:
|
||||
[g0, g4, g8, g12], [g1, g5, g9, g13], [g2, g6, g10, g14], [g3, g7, g11, g15]
|
||||
Note that for efficiency, the caller should make sure adjacent ranks
|
||||
are on the same DGX box. For example if we are using 2 DGX-1 boxes
|
||||
with a total of 16 GPUs, rank 0 to 7 belong to the first box and
|
||||
ranks 8 to 15 belong to the second box.
|
||||
"""
|
||||
if torch.distributed.get_rank() == 0:
|
||||
print(
|
||||
"> initializing tensor model parallel with size {}".format(
|
||||
tensor_model_parallel_size_
|
||||
)
|
||||
)
|
||||
print(
|
||||
"> initializing pipeline model parallel with size {}".format(
|
||||
pipeline_model_parallel_size_
|
||||
)
|
||||
)
|
||||
# Get world size and rank. Ensure some consistencies.
|
||||
assert torch.distributed.is_initialized()
|
||||
world_size = torch.distributed.get_world_size()
|
||||
tensor_model_parallel_size = min(tensor_model_parallel_size_, world_size)
|
||||
pipeline_model_parallel_size = min(pipeline_model_parallel_size_, world_size)
|
||||
ensure_divisibility(
|
||||
world_size, tensor_model_parallel_size * pipeline_model_parallel_size
|
||||
)
|
||||
data_parallel_size = world_size // (
|
||||
tensor_model_parallel_size * pipeline_model_parallel_size
|
||||
)
|
||||
|
||||
num_tensor_model_parallel_groups = world_size // tensor_model_parallel_size
|
||||
num_pipeline_model_parallel_groups = world_size // pipeline_model_parallel_size
|
||||
num_data_parallel_groups = world_size // data_parallel_size
|
||||
|
||||
if virtual_pipeline_model_parallel_size_ is not None:
|
||||
global _VIRTUAL_PIPELINE_MODEL_PARALLEL_RANK
|
||||
global _VIRTUAL_PIPELINE_MODEL_PARALLEL_WORLD_SIZE
|
||||
_VIRTUAL_PIPELINE_MODEL_PARALLEL_RANK = 0
|
||||
_VIRTUAL_PIPELINE_MODEL_PARALLEL_WORLD_SIZE = (
|
||||
virtual_pipeline_model_parallel_size_
|
||||
)
|
||||
|
||||
rank = torch.distributed.get_rank()
|
||||
|
||||
# Build the data-parallel groups.
|
||||
global _DATA_PARALLEL_GROUP
|
||||
assert _DATA_PARALLEL_GROUP is None, "data parallel group is already initialized"
|
||||
all_data_parallel_group_ranks = []
|
||||
for i in range(pipeline_model_parallel_size):
|
||||
start_rank = i * num_pipeline_model_parallel_groups
|
||||
end_rank = (i + 1) * num_pipeline_model_parallel_groups
|
||||
for j in range(tensor_model_parallel_size):
|
||||
ranks = range(start_rank + j, end_rank, tensor_model_parallel_size)
|
||||
all_data_parallel_group_ranks.append(list(ranks))
|
||||
group = torch.distributed.new_group(ranks)
|
||||
if rank in ranks:
|
||||
_DATA_PARALLEL_GROUP = group
|
||||
|
||||
# Build the model-parallel groups.
|
||||
global _MODEL_PARALLEL_GROUP
|
||||
assert _MODEL_PARALLEL_GROUP is None, "model parallel group is already initialized"
|
||||
for i in range(data_parallel_size):
|
||||
ranks = [
|
||||
data_parallel_group_ranks[i]
|
||||
for data_parallel_group_ranks in all_data_parallel_group_ranks
|
||||
]
|
||||
group = torch.distributed.new_group(ranks)
|
||||
if rank in ranks:
|
||||
_MODEL_PARALLEL_GROUP = group
|
||||
|
||||
# Build the tensor model-parallel groups.
|
||||
global _TENSOR_MODEL_PARALLEL_GROUP
|
||||
assert (
|
||||
_TENSOR_MODEL_PARALLEL_GROUP is None
|
||||
), "tensor model parallel group is already initialized"
|
||||
for i in range(num_tensor_model_parallel_groups):
|
||||
ranks = range(
|
||||
i * tensor_model_parallel_size, (i + 1) * tensor_model_parallel_size
|
||||
)
|
||||
group = torch.distributed.new_group(ranks)
|
||||
if rank in ranks:
|
||||
_TENSOR_MODEL_PARALLEL_GROUP = group
|
||||
|
||||
# Build the pipeline model-parallel groups and embedding groups
|
||||
# (first and last rank in each pipeline model-parallel group).
|
||||
global _PIPELINE_MODEL_PARALLEL_GROUP
|
||||
global _PIPELINE_GLOBAL_RANKS
|
||||
assert (
|
||||
_PIPELINE_MODEL_PARALLEL_GROUP is None
|
||||
), "pipeline model parallel group is already initialized"
|
||||
global _EMBEDDING_GROUP
|
||||
assert _EMBEDDING_GROUP is None, "embedding group is already initialized"
|
||||
for i in range(num_pipeline_model_parallel_groups):
|
||||
ranks = range(i, world_size, num_pipeline_model_parallel_groups)
|
||||
group = torch.distributed.new_group(ranks)
|
||||
if rank in ranks:
|
||||
_PIPELINE_MODEL_PARALLEL_GROUP = group
|
||||
_PIPELINE_GLOBAL_RANKS = ranks
|
||||
# Setup embedding group (to exchange gradients between
|
||||
# first and last stages).
|
||||
if len(ranks) > 1:
|
||||
embedding_ranks = [ranks[0], ranks[-1]]
|
||||
else:
|
||||
embedding_ranks = ranks
|
||||
group = torch.distributed.new_group(embedding_ranks)
|
||||
if rank in embedding_ranks:
|
||||
_EMBEDDING_GROUP = group
|
||||
|
||||
|
||||
def model_parallel_is_initialized():
|
||||
"""Check if model and data parallel groups are initialized."""
|
||||
if (
|
||||
_TENSOR_MODEL_PARALLEL_GROUP is None
|
||||
or _PIPELINE_MODEL_PARALLEL_GROUP is None
|
||||
or _DATA_PARALLEL_GROUP is None
|
||||
):
|
||||
return False
|
||||
return True
|
||||
|
||||
|
||||
def get_model_parallel_group():
|
||||
"""Get the model parallel group the caller rank belongs to."""
|
||||
assert _MODEL_PARALLEL_GROUP is not None, "model parallel group is not initialized"
|
||||
return _MODEL_PARALLEL_GROUP
|
||||
|
||||
|
||||
def get_tensor_model_parallel_group():
|
||||
"""Get the tensor model parallel group the caller rank belongs to."""
|
||||
assert (
|
||||
_TENSOR_MODEL_PARALLEL_GROUP is not None
|
||||
), "intra_layer_model parallel group is not initialized"
|
||||
return _TENSOR_MODEL_PARALLEL_GROUP
|
||||
|
||||
|
||||
def get_pipeline_model_parallel_group():
|
||||
"""Get the pipeline model parallel group the caller rank belongs to."""
|
||||
assert (
|
||||
_PIPELINE_MODEL_PARALLEL_GROUP is not None
|
||||
), "pipeline_model parallel group is not initialized"
|
||||
return _PIPELINE_MODEL_PARALLEL_GROUP
|
||||
|
||||
|
||||
def get_data_parallel_group():
|
||||
"""Get the data parallel group the caller rank belongs to."""
|
||||
assert _DATA_PARALLEL_GROUP is not None, "data parallel group is not initialized"
|
||||
return _DATA_PARALLEL_GROUP
|
||||
|
||||
|
||||
def get_embedding_group():
|
||||
"""Get the embedding group the caller rank belongs to."""
|
||||
assert _EMBEDDING_GROUP is not None, "embedding group is not initialized"
|
||||
return _EMBEDDING_GROUP
|
||||
|
||||
|
||||
def set_tensor_model_parallel_world_size(world_size):
|
||||
"""Set the tensor model parallel size"""
|
||||
global _MPU_TENSOR_MODEL_PARALLEL_WORLD_SIZE
|
||||
_MPU_TENSOR_MODEL_PARALLEL_WORLD_SIZE = world_size
|
||||
|
||||
|
||||
def set_pipeline_model_parallel_world_size(world_size):
|
||||
"""Set the pipeline model parallel size"""
|
||||
global _MPU_PIPELINE_MODEL_PARALLEL_WORLD_SIZE
|
||||
_MPU_PIPELINE_MODEL_PARALLEL_WORLD_SIZE = world_size
|
||||
|
||||
|
||||
def get_tensor_model_parallel_world_size():
|
||||
"""Return world size for the tensor model parallel group."""
|
||||
global _MPU_TENSOR_MODEL_PARALLEL_WORLD_SIZE
|
||||
if _MPU_TENSOR_MODEL_PARALLEL_WORLD_SIZE is not None:
|
||||
return _MPU_TENSOR_MODEL_PARALLEL_WORLD_SIZE
|
||||
return torch.distributed.get_world_size(group=get_tensor_model_parallel_group())
|
||||
|
||||
|
||||
def get_model_parallel_world_size():
|
||||
assert (
|
||||
get_pipeline_model_parallel_world_size() == 1
|
||||
), "legacy get_model_parallel_world_size is only supported if PP is disabled"
|
||||
return get_tensor_model_parallel_world_size()
|
||||
|
||||
|
||||
def get_pipeline_model_parallel_world_size():
|
||||
"""Return world size for the pipeline model parallel group."""
|
||||
global _MPU_PIPELINE_MODEL_PARALLEL_WORLD_SIZE
|
||||
if _MPU_PIPELINE_MODEL_PARALLEL_WORLD_SIZE is not None:
|
||||
return _MPU_PIPELINE_MODEL_PARALLEL_WORLD_SIZE
|
||||
return torch.distributed.get_world_size(group=get_pipeline_model_parallel_group())
|
||||
|
||||
|
||||
def set_tensor_model_parallel_rank(rank):
|
||||
"""Set tensor model parallel rank."""
|
||||
global _MPU_TENSOR_MODEL_PARALLEL_RANK
|
||||
_MPU_TENSOR_MODEL_PARALLEL_RANK = rank
|
||||
|
||||
|
||||
def set_pipeline_model_parallel_rank(rank):
|
||||
"""Set pipeline model parallel rank."""
|
||||
global _MPU_PIPELINE_MODEL_PARALLEL_RANK
|
||||
_MPU_PIPELINE_MODEL_PARALLEL_RANK = rank
|
||||
|
||||
|
||||
def get_tensor_model_parallel_rank():
|
||||
"""Return my rank for the tensor model parallel group."""
|
||||
global _MPU_TENSOR_MODEL_PARALLEL_RANK
|
||||
if _MPU_TENSOR_MODEL_PARALLEL_RANK is not None:
|
||||
return _MPU_TENSOR_MODEL_PARALLEL_RANK
|
||||
return torch.distributed.get_rank(group=get_tensor_model_parallel_group())
|
||||
|
||||
|
||||
def get_model_parallel_rank():
|
||||
assert (
|
||||
get_pipeline_model_parallel_world_size() == 1
|
||||
), "legacy get_model_parallel_rank is only supported if PP is disabled"
|
||||
return get_tensor_model_parallel_rank()
|
||||
|
||||
|
||||
def get_pipeline_model_parallel_rank():
|
||||
"""Return my rank for the pipeline model parallel group."""
|
||||
global _MPU_PIPELINE_MODEL_PARALLEL_RANK
|
||||
if _MPU_PIPELINE_MODEL_PARALLEL_RANK is not None:
|
||||
return _MPU_PIPELINE_MODEL_PARALLEL_RANK
|
||||
return torch.distributed.get_rank(group=get_pipeline_model_parallel_group())
|
||||
|
||||
|
||||
def is_pipeline_first_stage(ignore_virtual=False):
|
||||
"""Return True if in the first pipeline model-parallel stage, False otherwise."""
|
||||
if not ignore_virtual:
|
||||
if (
|
||||
get_virtual_pipeline_model_parallel_world_size() is not None
|
||||
and get_virtual_pipeline_model_parallel_rank() != 0
|
||||
):
|
||||
return False
|
||||
return get_pipeline_model_parallel_rank() == 0
|
||||
|
||||
|
||||
def is_pipeline_last_stage(ignore_virtual=False):
|
||||
"""Return True if in the last pipeline model-parallel stage, False otherwise."""
|
||||
if not ignore_virtual:
|
||||
virtual_pipeline_model_parallel_world_size = (
|
||||
get_virtual_pipeline_model_parallel_world_size()
|
||||
)
|
||||
if (
|
||||
virtual_pipeline_model_parallel_world_size is not None
|
||||
and get_virtual_pipeline_model_parallel_rank()
|
||||
!= (virtual_pipeline_model_parallel_world_size - 1)
|
||||
):
|
||||
return False
|
||||
return get_pipeline_model_parallel_rank() == (
|
||||
get_pipeline_model_parallel_world_size() - 1
|
||||
)
|
||||
|
||||
|
||||
def get_virtual_pipeline_model_parallel_rank():
|
||||
"""Return the virtual pipeline-parallel rank."""
|
||||
global _VIRTUAL_PIPELINE_MODEL_PARALLEL_RANK
|
||||
return _VIRTUAL_PIPELINE_MODEL_PARALLEL_RANK
|
||||
|
||||
|
||||
def set_virtual_pipeline_model_parallel_rank(rank):
|
||||
"""Set the virtual pipeline-parallel rank."""
|
||||
global _VIRTUAL_PIPELINE_MODEL_PARALLEL_RANK
|
||||
_VIRTUAL_PIPELINE_MODEL_PARALLEL_RANK = rank
|
||||
|
||||
|
||||
def get_virtual_pipeline_model_parallel_world_size():
|
||||
"""Return the virtual pipeline-parallel world size."""
|
||||
global _VIRTUAL_PIPELINE_MODEL_PARALLEL_WORLD_SIZE
|
||||
return _VIRTUAL_PIPELINE_MODEL_PARALLEL_WORLD_SIZE
|
||||
|
||||
|
||||
def get_tensor_model_parallel_src_rank():
|
||||
"""Calculate the global rank corresponding to the first local rank
|
||||
in the tensor model parallel group."""
|
||||
global_rank = torch.distributed.get_rank()
|
||||
local_world_size = get_tensor_model_parallel_world_size()
|
||||
return (global_rank // local_world_size) * local_world_size
|
||||
|
||||
|
||||
def get_pipeline_model_parallel_first_rank():
|
||||
assert (
|
||||
_PIPELINE_GLOBAL_RANKS is not None
|
||||
), "Pipeline parallel group is not initialized"
|
||||
return _PIPELINE_GLOBAL_RANKS[0]
|
||||
|
||||
|
||||
def get_pipeline_model_parallel_last_rank():
|
||||
assert (
|
||||
_PIPELINE_GLOBAL_RANKS is not None
|
||||
), "Pipeline parallel group is not initialized"
|
||||
last_rank_local = get_pipeline_model_parallel_world_size() - 1
|
||||
return _PIPELINE_GLOBAL_RANKS[last_rank_local]
|
||||
|
||||
|
||||
def get_pipeline_model_parallel_next_rank():
|
||||
assert (
|
||||
_PIPELINE_GLOBAL_RANKS is not None
|
||||
), "Pipeline parallel group is not initialized"
|
||||
rank_in_pipeline = get_pipeline_model_parallel_rank()
|
||||
world_size = get_pipeline_model_parallel_world_size()
|
||||
return _PIPELINE_GLOBAL_RANKS[(rank_in_pipeline + 1) % world_size]
|
||||
|
||||
|
||||
def get_pipeline_model_parallel_prev_rank():
|
||||
assert (
|
||||
_PIPELINE_GLOBAL_RANKS is not None
|
||||
), "Pipeline parallel group is not initialized"
|
||||
rank_in_pipeline = get_pipeline_model_parallel_rank()
|
||||
world_size = get_pipeline_model_parallel_world_size()
|
||||
return _PIPELINE_GLOBAL_RANKS[(rank_in_pipeline - 1) % world_size]
|
||||
|
||||
|
||||
def get_data_parallel_world_size():
|
||||
"""Return world size for the data parallel group."""
|
||||
return torch.distributed.get_world_size(group=get_data_parallel_group())
|
||||
|
||||
|
||||
def get_data_parallel_rank():
|
||||
"""Return my rank for the data parallel group."""
|
||||
return torch.distributed.get_rank(group=get_data_parallel_group())
|
||||
|
||||
|
||||
def destroy_model_parallel():
|
||||
"""Set the groups to none."""
|
||||
global _TENSOR_MODEL_PARALLEL_GROUP
|
||||
_TENSOR_MODEL_PARALLEL_GROUP = None
|
||||
global _PIPELINE_MODEL_PARALLEL_GROUP
|
||||
_PIPELINE_MODEL_PARALLEL_GROUP = None
|
||||
global _DATA_PARALLEL_GROUP
|
||||
_DATA_PARALLEL_GROUP = None
|
@ -0,0 +1,480 @@
|
||||
# coding=utf-8
|
||||
# Copyright (c) 2020, NVIDIA CORPORATION. All rights reserved.
|
||||
#
|
||||
# Licensed under the Apache License, Version 2.0 (the "License");
|
||||
# you may not use this file except in compliance with the License.
|
||||
# You may obtain a copy of the License at
|
||||
#
|
||||
# http://www.apache.org/licenses/LICENSE-2.0
|
||||
#
|
||||
# Unless required by applicable law or agreed to in writing, software
|
||||
# distributed under the License is distributed on an "AS IS" BASIS,
|
||||
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
|
||||
# See the License for the specific language governing permissions and
|
||||
# limitations under the License.
|
||||
|
||||
|
||||
# Parts of the code here are adapted from PyTorch
|
||||
# repo: https://github.com/pytorch/pytorch
|
||||
|
||||
|
||||
import math
|
||||
|
||||
import torch
|
||||
import torch.nn.functional as F
|
||||
import torch.nn.init as init
|
||||
from torch.nn.parameter import Parameter
|
||||
|
||||
from .initialize import get_tensor_model_parallel_rank
|
||||
from .initialize import get_tensor_model_parallel_world_size
|
||||
from .mappings import copy_to_tensor_model_parallel_region
|
||||
from .mappings import gather_from_tensor_model_parallel_region
|
||||
from .mappings import reduce_from_tensor_model_parallel_region
|
||||
from .mappings import scatter_to_tensor_model_parallel_region
|
||||
from .random import get_cuda_rng_tracker
|
||||
from .utils import divide
|
||||
from .utils import split_tensor_along_last_dim
|
||||
from .utils import VocabUtility
|
||||
from codegeex.megatron import get_args
|
||||
import deepspeed.runtime.activation_checkpointing.checkpointing as ds_checkpointing
|
||||
|
||||
|
||||
_MODEL_PARALLEL_ATTRIBUTE_DEFAULTS = {
|
||||
"tensor_model_parallel": False,
|
||||
"partition_dim": -1,
|
||||
"partition_stride": 1,
|
||||
}
|
||||
|
||||
|
||||
def param_is_not_tensor_parallel_duplicate(param):
|
||||
return (
|
||||
hasattr(param, "tensor_model_parallel") and param.tensor_model_parallel
|
||||
) or (get_tensor_model_parallel_rank() == 0)
|
||||
|
||||
|
||||
def set_tensor_model_parallel_attributes(tensor, is_parallel, dim, stride):
|
||||
# Make sure the attributes are not set.
|
||||
for attribute in _MODEL_PARALLEL_ATTRIBUTE_DEFAULTS:
|
||||
assert not hasattr(tensor, attribute)
|
||||
# Set the attributes.
|
||||
setattr(tensor, "tensor_model_parallel", is_parallel)
|
||||
setattr(tensor, "partition_dim", dim)
|
||||
setattr(tensor, "partition_stride", stride)
|
||||
|
||||
|
||||
def set_defaults_if_not_set_tensor_model_parallel_attributes(tensor):
|
||||
def maybe_set(attribute, value):
|
||||
if not hasattr(tensor, attribute):
|
||||
setattr(tensor, attribute, value)
|
||||
|
||||
for attribute in _MODEL_PARALLEL_ATTRIBUTE_DEFAULTS:
|
||||
maybe_set(attribute, _MODEL_PARALLEL_ATTRIBUTE_DEFAULTS[attribute])
|
||||
|
||||
|
||||
def copy_tensor_model_parallel_attributes(destination_tensor, source_tensor):
|
||||
def maybe_copy(attribute):
|
||||
if hasattr(source_tensor, attribute):
|
||||
setattr(destination_tensor, attribute, getattr(source_tensor, attribute))
|
||||
|
||||
for attribute in _MODEL_PARALLEL_ATTRIBUTE_DEFAULTS:
|
||||
maybe_copy(attribute)
|
||||
|
||||
|
||||
def _initialize_affine_weight_gpu(weight, init_method, partition_dim, stride=1):
|
||||
"""Initialize affine weight for model parallel on GPU."""
|
||||
|
||||
set_tensor_model_parallel_attributes(
|
||||
tensor=weight, is_parallel=True, dim=partition_dim, stride=stride
|
||||
)
|
||||
|
||||
if ds_checkpointing.is_configured():
|
||||
global get_cuda_rng_tracker
|
||||
get_cuda_rng_tracker = ds_checkpointing.get_cuda_rng_tracker
|
||||
|
||||
with get_cuda_rng_tracker().fork():
|
||||
init_method(weight)
|
||||
|
||||
|
||||
def _initialize_affine_weight_cpu(
|
||||
weight,
|
||||
output_size,
|
||||
input_size,
|
||||
per_partition_size,
|
||||
partition_dim,
|
||||
init_method,
|
||||
stride=1,
|
||||
return_master_weight=False,
|
||||
):
|
||||
"""Initialize affine weight for model parallel.
|
||||
|
||||
Build the master weight on all processes and scatter
|
||||
the relevant chunk."""
|
||||
|
||||
set_tensor_model_parallel_attributes(
|
||||
tensor=weight, is_parallel=True, dim=partition_dim, stride=stride
|
||||
)
|
||||
|
||||
# Initialize master weight
|
||||
master_weight = torch.empty(
|
||||
output_size, input_size, dtype=torch.float, requires_grad=False
|
||||
)
|
||||
init_method(master_weight)
|
||||
args = get_args()
|
||||
master_weight = master_weight.to(dtype=args.params_dtype)
|
||||
|
||||
# Split and copy
|
||||
per_partition_per_stride_size = divide(per_partition_size, stride)
|
||||
weight_list = torch.split(
|
||||
master_weight, per_partition_per_stride_size, dim=partition_dim
|
||||
)
|
||||
rank = get_tensor_model_parallel_rank()
|
||||
world_size = get_tensor_model_parallel_world_size()
|
||||
my_weight_list = weight_list[rank::world_size]
|
||||
|
||||
with torch.no_grad():
|
||||
torch.cat(my_weight_list, dim=partition_dim, out=weight)
|
||||
if return_master_weight:
|
||||
return master_weight
|
||||
return None
|
||||
|
||||
|
||||
class VocabParallelEmbedding(torch.nn.Module):
|
||||
"""Embedding parallelized in the vocabulary dimension.
|
||||
|
||||
This is mainly adapted from torch.nn.Embedding and all the default
|
||||
values are kept.
|
||||
Arguments:
|
||||
num_embeddings: vocabulary size.
|
||||
embedding_dim: size of hidden state.
|
||||
init_method: method to initialize weights.
|
||||
"""
|
||||
|
||||
def __init__(self, num_embeddings, embedding_dim, init_method=init.xavier_normal_):
|
||||
super(VocabParallelEmbedding, self).__init__()
|
||||
# Keep the input dimensions.
|
||||
self.num_embeddings = num_embeddings
|
||||
self.embedding_dim = embedding_dim
|
||||
# Set the detauls for compatibility.
|
||||
self.padding_idx = None
|
||||
self.max_norm = None
|
||||
self.norm_type = 2.0
|
||||
self.scale_grad_by_freq = False
|
||||
self.sparse = False
|
||||
self._weight = None
|
||||
self.tensor_model_parallel_size = get_tensor_model_parallel_world_size()
|
||||
# Divide the weight matrix along the vocaburaly dimension.
|
||||
(
|
||||
self.vocab_start_index,
|
||||
self.vocab_end_index,
|
||||
) = VocabUtility.vocab_range_from_global_vocab_size(
|
||||
self.num_embeddings,
|
||||
get_tensor_model_parallel_rank(),
|
||||
self.tensor_model_parallel_size,
|
||||
)
|
||||
self.num_embeddings_per_partition = (
|
||||
self.vocab_end_index - self.vocab_start_index
|
||||
)
|
||||
|
||||
# Allocate weights and initialize.
|
||||
args = get_args()
|
||||
if args.use_cpu_initialization:
|
||||
self.weight = Parameter(
|
||||
torch.empty(
|
||||
self.num_embeddings_per_partition,
|
||||
self.embedding_dim,
|
||||
dtype=args.params_dtype,
|
||||
# dtype=torch.float32,
|
||||
)
|
||||
)
|
||||
_initialize_affine_weight_cpu(
|
||||
self.weight,
|
||||
self.num_embeddings,
|
||||
self.embedding_dim,
|
||||
self.num_embeddings_per_partition,
|
||||
0,
|
||||
init_method,
|
||||
)
|
||||
else:
|
||||
self.weight = Parameter(
|
||||
torch.empty(
|
||||
self.num_embeddings_per_partition,
|
||||
self.embedding_dim,
|
||||
device=torch.cuda.current_device(),
|
||||
dtype=args.params_dtype,
|
||||
# dtype=torch.float32,
|
||||
)
|
||||
)
|
||||
_initialize_affine_weight_gpu(
|
||||
self.weight, init_method, partition_dim=0, stride=1
|
||||
)
|
||||
|
||||
def forward(self, input_):
|
||||
if self.tensor_model_parallel_size > 1:
|
||||
# Build the mask.
|
||||
input_mask = (input_ < self.vocab_start_index) | (
|
||||
input_ >= self.vocab_end_index
|
||||
)
|
||||
# Mask the input.
|
||||
masked_input = input_.clone() - self.vocab_start_index
|
||||
masked_input[input_mask] = 0
|
||||
else:
|
||||
masked_input = input_
|
||||
# Get the embeddings.
|
||||
output_parallel = F.embedding(
|
||||
masked_input,
|
||||
self.weight,
|
||||
self.padding_idx,
|
||||
self.max_norm,
|
||||
self.norm_type,
|
||||
self.scale_grad_by_freq,
|
||||
self.sparse,
|
||||
)
|
||||
# Mask the output embedding.
|
||||
if self.tensor_model_parallel_size > 1:
|
||||
output_parallel[input_mask, :] = 0.0
|
||||
# Reduce across all the model parallel GPUs.
|
||||
output = reduce_from_tensor_model_parallel_region(output_parallel)
|
||||
return output
|
||||
|
||||
|
||||
class ColumnParallelLinear(torch.nn.Module):
|
||||
"""Linear layer with column parallelism.
|
||||
|
||||
The linear layer is defined as Y = XA + b. A is parallelized along
|
||||
its second dimension as A = [A_1, ..., A_p].
|
||||
|
||||
Arguments:
|
||||
input_size: first dimension of matrix A.
|
||||
output_size: second dimension of matrix A.
|
||||
bias: If true, add bias
|
||||
gather_output: If true, call all-gether on output and make Y avaiable
|
||||
to all GPUs, otherwise, every GPU will have its output
|
||||
which is Y_i = XA_i
|
||||
init_method: method to initialize weights. Note that bias is always set
|
||||
to zero.
|
||||
stride: For the strided linear layers.
|
||||
keep_master_weight_for_test: This was added for testing and should be
|
||||
set to False. It returns the master weights
|
||||
used for initialization.
|
||||
skip_bias_add: This was added to enable performance optimations where bias
|
||||
can be fused with other elementwise operations. we skip
|
||||
adding bias but instead return it.
|
||||
"""
|
||||
|
||||
def __init__(
|
||||
self,
|
||||
input_size,
|
||||
output_size,
|
||||
bias=True,
|
||||
gather_output=True,
|
||||
init_method=init.xavier_normal_,
|
||||
stride=1,
|
||||
keep_master_weight_for_test=False,
|
||||
skip_bias_add=False,
|
||||
):
|
||||
super(ColumnParallelLinear, self).__init__()
|
||||
|
||||
# Keep input parameters
|
||||
self.input_size = input_size
|
||||
self.output_size = output_size
|
||||
self.gather_output = gather_output
|
||||
# Divide the weight matrix along the last dimension.
|
||||
world_size = get_tensor_model_parallel_world_size()
|
||||
self.output_size_per_partition = divide(output_size, world_size)
|
||||
self.skip_bias_add = skip_bias_add
|
||||
|
||||
# Parameters.
|
||||
# Note: torch.nn.functional.linear performs XA^T + b and as a result
|
||||
# we allocate the transpose.
|
||||
# Initialize weight.
|
||||
args = get_args()
|
||||
if args.use_cpu_initialization:
|
||||
self.weight = Parameter(
|
||||
torch.empty(
|
||||
self.output_size_per_partition,
|
||||
self.input_size,
|
||||
dtype=args.params_dtype,
|
||||
)
|
||||
)
|
||||
self.master_weight = _initialize_affine_weight_cpu(
|
||||
self.weight,
|
||||
self.output_size,
|
||||
self.input_size,
|
||||
self.output_size_per_partition,
|
||||
0,
|
||||
init_method,
|
||||
stride=stride,
|
||||
return_master_weight=keep_master_weight_for_test,
|
||||
)
|
||||
else:
|
||||
self.weight = Parameter(
|
||||
torch.empty(
|
||||
self.output_size_per_partition,
|
||||
self.input_size,
|
||||
device=torch.cuda.current_device(),
|
||||
dtype=args.params_dtype,
|
||||
)
|
||||
)
|
||||
_initialize_affine_weight_gpu(
|
||||
self.weight, init_method, partition_dim=0, stride=stride
|
||||
)
|
||||
|
||||
if bias:
|
||||
if args.use_cpu_initialization:
|
||||
self.bias = Parameter(
|
||||
torch.empty(self.output_size_per_partition, dtype=args.params_dtype)
|
||||
)
|
||||
else:
|
||||
self.bias = Parameter(
|
||||
torch.empty(
|
||||
self.output_size_per_partition,
|
||||
device=torch.cuda.current_device(),
|
||||
dtype=args.params_dtype,
|
||||
)
|
||||
)
|
||||
set_tensor_model_parallel_attributes(self.bias, True, 0, stride)
|
||||
# Always initialize bias to zero.
|
||||
with torch.no_grad():
|
||||
self.bias.zero_()
|
||||
else:
|
||||
self.register_parameter("bias", None)
|
||||
|
||||
def forward(self, input_):
|
||||
# Set up backprop all-reduce.
|
||||
input_parallel = copy_to_tensor_model_parallel_region(input_)
|
||||
# Matrix multiply.
|
||||
|
||||
bias = self.bias if not self.skip_bias_add else None
|
||||
output_parallel = F.linear(input_parallel, self.weight, bias)
|
||||
if self.gather_output:
|
||||
# All-gather across the partitions.
|
||||
output = gather_from_tensor_model_parallel_region(output_parallel)
|
||||
else:
|
||||
output = output_parallel
|
||||
output_bias = self.bias if self.skip_bias_add else None
|
||||
return output, output_bias
|
||||
|
||||
|
||||
class RowParallelLinear(torch.nn.Module):
|
||||
"""Linear layer with row parallelism.
|
||||
|
||||
The linear layer is defined as Y = XA + b. A is parallelized along
|
||||
its first dimension and X along its second dimension as:
|
||||
- -
|
||||
| A_1 |
|
||||
| . |
|
||||
A = | . | X = [X_1, ..., X_p]
|
||||
| . |
|
||||
| A_p |
|
||||
- -
|
||||
Arguments:
|
||||
input_size: first dimension of matrix A.
|
||||
output_size: second dimension of matrix A.
|
||||
bias: If true, add bias. Note that bias is not parallelized.
|
||||
input_is_parallel: If true, we assume that the input is already
|
||||
split across the GPUs and we do not split
|
||||
again.
|
||||
init_method: method to initialize weights. Note that bias is always set
|
||||
to zero.
|
||||
stride: For the strided linear layers.
|
||||
keep_master_weight_for_test: This was added for testing and should be
|
||||
set to False. It returns the master weights
|
||||
used for initialization.
|
||||
skip_bias_add: This was added to enable performance optimations where bias
|
||||
can be fused with other elementwise operations. we skip
|
||||
adding bias but instead return it.
|
||||
"""
|
||||
|
||||
def __init__(
|
||||
self,
|
||||
input_size,
|
||||
output_size,
|
||||
bias=True,
|
||||
input_is_parallel=False,
|
||||
init_method=init.xavier_normal_,
|
||||
stride=1,
|
||||
keep_master_weight_for_test=False,
|
||||
skip_bias_add=False,
|
||||
):
|
||||
super(RowParallelLinear, self).__init__()
|
||||
|
||||
# Keep input parameters
|
||||
self.input_size = input_size
|
||||
self.output_size = output_size
|
||||
self.input_is_parallel = input_is_parallel
|
||||
# Divide the weight matrix along the last dimension.
|
||||
world_size = get_tensor_model_parallel_world_size()
|
||||
self.input_size_per_partition = divide(input_size, world_size)
|
||||
self.skip_bias_add = skip_bias_add
|
||||
|
||||
# Parameters.
|
||||
# Note: torch.nn.functional.linear performs XA^T + b and as a result
|
||||
# we allocate the transpose.
|
||||
# Initialize weight.
|
||||
args = get_args()
|
||||
if args.use_cpu_initialization:
|
||||
self.weight = Parameter(
|
||||
torch.empty(
|
||||
self.output_size,
|
||||
self.input_size_per_partition,
|
||||
dtype=args.params_dtype,
|
||||
)
|
||||
)
|
||||
self.master_weight = _initialize_affine_weight_cpu(
|
||||
self.weight,
|
||||
self.output_size,
|
||||
self.input_size,
|
||||
self.input_size_per_partition,
|
||||
1,
|
||||
init_method,
|
||||
stride=stride,
|
||||
return_master_weight=keep_master_weight_for_test,
|
||||
)
|
||||
else:
|
||||
self.weight = Parameter(
|
||||
torch.empty(
|
||||
self.output_size,
|
||||
self.input_size_per_partition,
|
||||
device=torch.cuda.current_device(),
|
||||
dtype=args.params_dtype,
|
||||
)
|
||||
)
|
||||
_initialize_affine_weight_gpu(
|
||||
self.weight, init_method, partition_dim=1, stride=stride
|
||||
)
|
||||
if bias:
|
||||
if args.use_cpu_initialization:
|
||||
self.bias = Parameter(
|
||||
torch.empty(self.output_size, dtype=args.params_dtype)
|
||||
)
|
||||
else:
|
||||
self.bias = Parameter(
|
||||
torch.empty(
|
||||
self.output_size,
|
||||
device=torch.cuda.current_device(),
|
||||
dtype=args.params_dtype,
|
||||
)
|
||||
)
|
||||
# Always initialize bias to zero.
|
||||
with torch.no_grad():
|
||||
self.bias.zero_()
|
||||
else:
|
||||
self.register_parameter("bias", None)
|
||||
|
||||
def forward(self, input_):
|
||||
# Set up backprop all-reduce.
|
||||
if self.input_is_parallel:
|
||||
input_parallel = input_
|
||||
else:
|
||||
input_parallel = scatter_to_tensor_model_parallel_region(input_)
|
||||
# Matrix multiply.
|
||||
output_parallel = F.linear(input_parallel, self.weight)
|
||||
# All-reduce across all the partitions.
|
||||
output_ = reduce_from_tensor_model_parallel_region(output_parallel)
|
||||
if not self.skip_bias_add:
|
||||
output = output_ + self.bias if self.bias is not None else output_
|
||||
output_bias = None
|
||||
else:
|
||||
output = output_
|
||||
output_bias = self.bias
|
||||
return output, output_bias
|
@ -0,0 +1,164 @@
|
||||
# coding=utf-8
|
||||
# Copyright (c) 2020, NVIDIA CORPORATION. All rights reserved.
|
||||
#
|
||||
# Licensed under the Apache License, Version 2.0 (the "License");
|
||||
# you may not use this file except in compliance with the License.
|
||||
# You may obtain a copy of the License at
|
||||
#
|
||||
# http://www.apache.org/licenses/LICENSE-2.0
|
||||
#
|
||||
# Unless required by applicable law or agreed to in writing, software
|
||||
# distributed under the License is distributed on an "AS IS" BASIS,
|
||||
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
|
||||
# See the License for the specific language governing permissions and
|
||||
# limitations under the License.
|
||||
|
||||
import torch
|
||||
|
||||
from .initialize import (
|
||||
get_tensor_model_parallel_group,
|
||||
get_tensor_model_parallel_world_size,
|
||||
get_tensor_model_parallel_rank,
|
||||
)
|
||||
from .utils import split_tensor_along_last_dim
|
||||
|
||||
|
||||
def _reduce(input_):
|
||||
"""All-reduce the the input tensor across model parallel group."""
|
||||
|
||||
# Bypass the function if we are using only 1 GPU.
|
||||
if get_tensor_model_parallel_world_size() == 1:
|
||||
return input_
|
||||
|
||||
# All-reduce.
|
||||
torch.distributed.all_reduce(input_, group=get_tensor_model_parallel_group())
|
||||
|
||||
return input_
|
||||
|
||||
|
||||
def _split(input_):
|
||||
"""Split the tensor along its last dimension and keep the
|
||||
corresponding slice."""
|
||||
|
||||
world_size = get_tensor_model_parallel_world_size()
|
||||
# Bypass the function if we are using only 1 GPU.
|
||||
if world_size == 1:
|
||||
return input_
|
||||
|
||||
# Split along last dimension.
|
||||
input_list = split_tensor_along_last_dim(input_, world_size)
|
||||
|
||||
# Note: torch.split does not create contiguous tensors by default.
|
||||
rank = get_tensor_model_parallel_rank()
|
||||
output = input_list[rank].contiguous()
|
||||
|
||||
return output
|
||||
|
||||
|
||||
def _gather(input_):
|
||||
"""Gather tensors and concatinate along the last dimension."""
|
||||
|
||||
world_size = get_tensor_model_parallel_world_size()
|
||||
# Bypass the function if we are using only 1 GPU.
|
||||
if world_size == 1:
|
||||
return input_
|
||||
|
||||
# Size and dimension.
|
||||
last_dim = input_.dim() - 1
|
||||
rank = get_tensor_model_parallel_rank()
|
||||
|
||||
tensor_list = [torch.empty_like(input_) for _ in range(world_size)]
|
||||
tensor_list[rank] = input_
|
||||
torch.distributed.all_gather(
|
||||
tensor_list, input_, group=get_tensor_model_parallel_group()
|
||||
)
|
||||
|
||||
# Note: torch.cat already creates a contiguous tensor.
|
||||
output = torch.cat(tensor_list, dim=last_dim).contiguous()
|
||||
|
||||
return output
|
||||
|
||||
|
||||
class _CopyToModelParallelRegion(torch.autograd.Function):
|
||||
"""Pass the input to the model parallel region."""
|
||||
|
||||
@staticmethod
|
||||
def symbolic(graph, input_):
|
||||
return input_
|
||||
|
||||
@staticmethod
|
||||
def forward(ctx, input_):
|
||||
return input_
|
||||
|
||||
@staticmethod
|
||||
def backward(ctx, grad_output):
|
||||
return _reduce(grad_output)
|
||||
|
||||
|
||||
class _ReduceFromModelParallelRegion(torch.autograd.Function):
|
||||
"""All-reduce the input from the model parallel region."""
|
||||
|
||||
@staticmethod
|
||||
def symbolic(graph, input_):
|
||||
return _reduce(input_)
|
||||
|
||||
@staticmethod
|
||||
def forward(ctx, input_):
|
||||
return _reduce(input_)
|
||||
|
||||
@staticmethod
|
||||
def backward(ctx, grad_output):
|
||||
return grad_output
|
||||
|
||||
|
||||
class _ScatterToModelParallelRegion(torch.autograd.Function):
|
||||
"""Split the input and keep only the corresponding chuck to the rank."""
|
||||
|
||||
@staticmethod
|
||||
def symbolic(graph, input_):
|
||||
return _split(input_)
|
||||
|
||||
@staticmethod
|
||||
def forward(ctx, input_):
|
||||
return _split(input_)
|
||||
|
||||
@staticmethod
|
||||
def backward(ctx, grad_output):
|
||||
return _gather(grad_output)
|
||||
|
||||
|
||||
class _GatherFromModelParallelRegion(torch.autograd.Function):
|
||||
"""Gather the input from model parallel region and concatinate."""
|
||||
|
||||
@staticmethod
|
||||
def symbolic(graph, input_):
|
||||
return _gather(input_)
|
||||
|
||||
@staticmethod
|
||||
def forward(ctx, input_):
|
||||
return _gather(input_)
|
||||
|
||||
@staticmethod
|
||||
def backward(ctx, grad_output):
|
||||
return _split(grad_output)
|
||||
|
||||
|
||||
# -----------------
|
||||
# Helper functions.
|
||||
# -----------------
|
||||
|
||||
|
||||
def copy_to_tensor_model_parallel_region(input_):
|
||||
return _CopyToModelParallelRegion.apply(input_)
|
||||
|
||||
|
||||
def reduce_from_tensor_model_parallel_region(input_):
|
||||
return _ReduceFromModelParallelRegion.apply(input_)
|
||||
|
||||
|
||||
def scatter_to_tensor_model_parallel_region(input_):
|
||||
return _ScatterToModelParallelRegion.apply(input_)
|
||||
|
||||
|
||||
def gather_from_tensor_model_parallel_region(input_):
|
||||
return _GatherFromModelParallelRegion.apply(input_)
|
@ -0,0 +1,342 @@
|
||||
# coding=utf-8
|
||||
# Copyright (c) 2020, NVIDIA CORPORATION. All rights reserved.
|
||||
#
|
||||
# Licensed under the Apache License, Version 2.0 (the "License");
|
||||
# you may not use this file except in compliance with the License.
|
||||
# You may obtain a copy of the License at
|
||||
#
|
||||
# http://www.apache.org/licenses/LICENSE-2.0
|
||||
#
|
||||
# Unless required by applicable law or agreed to in writing, software
|
||||
# distributed under the License is distributed on an "AS IS" BASIS,
|
||||
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
|
||||
# See the License for the specific language governing permissions and
|
||||
# limitations under the License.
|
||||
|
||||
|
||||
# Parts of the code here are adapted from PyTorch
|
||||
# repo: https://github.com/pytorch/pytorch
|
||||
|
||||
import contextlib
|
||||
|
||||
import torch
|
||||
from torch import _C
|
||||
from torch.cuda import _lazy_call, device as device_ctx_manager
|
||||
from torch.utils.checkpoint import detach_variable
|
||||
|
||||
from codegeex.megatron import get_args
|
||||
from codegeex.megatron.memory import allocate_mem_buff
|
||||
|
||||
from .initialize import get_data_parallel_rank
|
||||
from .initialize import get_tensor_model_parallel_group
|
||||
from .initialize import get_tensor_model_parallel_rank
|
||||
from .initialize import get_tensor_model_parallel_world_size
|
||||
|
||||
|
||||
# Default name for the model parallel rng tracker.
|
||||
_MODEL_PARALLEL_RNG_TRACKER_NAME = "model-parallel-rng"
|
||||
|
||||
|
||||
# Whether apply model parallelsim to checkpointed hidden states.
|
||||
_CHECKPOINTED_ACTIVATIONS_MEMORY_BUFFER = None
|
||||
|
||||
|
||||
def init_checkpointed_activations_memory_buffer():
|
||||
"""Initializ the memory buffer for the checkpointed activations."""
|
||||
args = get_args()
|
||||
|
||||
per_layer = (
|
||||
args.micro_batch_size
|
||||
* args.max_position_embeddings
|
||||
* args.hidden_size
|
||||
// args.tensor_model_parallel_size
|
||||
)
|
||||
assert (
|
||||
args.num_layers % args.checkpoint_num_layers == 0
|
||||
), "number of layers is not divisible by checkpoint-num-layers"
|
||||
num_checkpointer_layers = args.num_layers // args.checkpoint_num_layers
|
||||
numel = per_layer * num_checkpointer_layers
|
||||
dtype = torch.half
|
||||
if not args.fp16:
|
||||
dtype = torch.float
|
||||
|
||||
global _CHECKPOINTED_ACTIVATIONS_MEMORY_BUFFER
|
||||
assert (
|
||||
_CHECKPOINTED_ACTIVATIONS_MEMORY_BUFFER is None
|
||||
), "checkpointed activations memory buffer is already allocated."
|
||||
_CHECKPOINTED_ACTIVATIONS_MEMORY_BUFFER = allocate_mem_buff(
|
||||
"checkpointed activations", numel, dtype, track_usage=False
|
||||
)
|
||||
|
||||
|
||||
def reset_checkpointed_activations_memory_buffer():
|
||||
"""Reset the memory used for checkpointing."""
|
||||
if _CHECKPOINTED_ACTIVATIONS_MEMORY_BUFFER is not None:
|
||||
_CHECKPOINTED_ACTIVATIONS_MEMORY_BUFFER.reset()
|
||||
|
||||
|
||||
def _set_cuda_rng_state(new_state, device=-1):
|
||||
"""Sets the random number generator state of the current GPU.
|
||||
|
||||
Argumentss:
|
||||
new_state (torch.ByteTensor): The desired state
|
||||
This function is adapted from PyTorch repo (torch.cuda.set_rng_state)
|
||||
with a single change: the input state is not cloned. Cloning caused
|
||||
major performance issues for +4 GPU cases.
|
||||
"""
|
||||
if hasattr(_C, "_cuda_setRNGState") and callable(_C._cuda_setRNGState):
|
||||
# older PyTorch
|
||||
def cb():
|
||||
with device_ctx_manager(device):
|
||||
_C._cuda_setRNGState(new_state)
|
||||
|
||||
else:
|
||||
# newer PyTorch
|
||||
if device == -1:
|
||||
device = torch.device("cuda")
|
||||
elif isinstance(device, str):
|
||||
device = torch.device(device)
|
||||
elif isinstance(device, int):
|
||||
device = torch.device("cuda", device)
|
||||
|
||||
def cb():
|
||||
idx = device.index
|
||||
if idx is None:
|
||||
idx = torch.cuda.current_device()
|
||||
default_generator = torch.cuda.default_generators[idx]
|
||||
default_generator.set_state(new_state)
|
||||
|
||||
_lazy_call(cb)
|
||||
|
||||
|
||||
def split_tensor_into_1d_equal_chunks(tensor):
|
||||
"""Break a tensor into equal 1D chunks."""
|
||||
data = tensor.view(-1)
|
||||
partition_size = torch.numel(data) // get_tensor_model_parallel_world_size()
|
||||
start_index = partition_size * get_tensor_model_parallel_rank()
|
||||
end_index = start_index + partition_size
|
||||
return data[start_index:end_index]
|
||||
|
||||
|
||||
def gather_split_1d_tensor(tensor):
|
||||
"""Opposite of above function, gather values from model parallel ranks."""
|
||||
world_size = get_tensor_model_parallel_world_size()
|
||||
numel = torch.numel(tensor)
|
||||
numel_gathered = world_size * numel
|
||||
gathered = torch.empty(
|
||||
numel_gathered,
|
||||
dtype=tensor.dtype,
|
||||
device=torch.cuda.current_device(),
|
||||
requires_grad=False,
|
||||
)
|
||||
chunks = [gathered[i * numel : (i + 1) * numel] for i in range(world_size)]
|
||||
torch.distributed.all_gather(
|
||||
chunks, tensor, group=get_tensor_model_parallel_group()
|
||||
)
|
||||
return gathered
|
||||
|
||||
|
||||
class CudaRNGStatesTracker:
|
||||
"""Tracker for the cuda RNG states.
|
||||
|
||||
Using the `add` method, a cuda rng state is initialized based on
|
||||
the input `seed` and is assigned to `name`. Later, by forking the
|
||||
rng state, we can perform operations and return to our starting
|
||||
cuda state.
|
||||
"""
|
||||
|
||||
def __init__(self):
|
||||
# Map from a string name to the cuda rng state.
|
||||
self.states_ = {}
|
||||
# Seeds are just for book keeping and ensure no seed is set twice.
|
||||
self.seeds_ = set()
|
||||
|
||||
def reset(self):
|
||||
"""Set to the initial state (no tracker)."""
|
||||
self.states_ = {}
|
||||
self.seeds_ = set()
|
||||
|
||||
def get_states(self):
|
||||
"""Get rng states. Copy the dictionary so we have direct
|
||||
pointers to the states, not just a pointer to the dictionary."""
|
||||
states = {}
|
||||
for name in self.states_:
|
||||
states[name] = self.states_[name]
|
||||
return states
|
||||
|
||||
def set_states(self, states):
|
||||
"""Set the rng states. For efficiency purposes, we do not check
|
||||
the size of seed for compatibility."""
|
||||
self.states_ = states
|
||||
|
||||
def add(self, name, seed):
|
||||
"""Track the rng state."""
|
||||
# Check seed is not already used.
|
||||
if seed in self.seeds_:
|
||||
raise Exception("seed {} already exists".format(seed))
|
||||
self.seeds_.add(seed)
|
||||
# Check that state is not already defined.
|
||||
if name in self.states_:
|
||||
raise Exception("cuda rng state {} already exists".format(name))
|
||||
# Get the current rng state.
|
||||
orig_rng_state = torch.cuda.get_rng_state()
|
||||
# Set the new state and store it.
|
||||
torch.cuda.manual_seed(seed)
|
||||
self.states_[name] = torch.cuda.get_rng_state()
|
||||
# Reset rng state to what it was.
|
||||
_set_cuda_rng_state(orig_rng_state)
|
||||
|
||||
@contextlib.contextmanager
|
||||
def fork(self, name=_MODEL_PARALLEL_RNG_TRACKER_NAME):
|
||||
"""Fork the cuda rng state, perform operations, and exit with
|
||||
the original state."""
|
||||
# Check if we have added the state
|
||||
if name not in self.states_:
|
||||
print(name, self.states_)
|
||||
raise Exception("cuda rng state {} is not added".format(name))
|
||||
# Store current rng state.
|
||||
orig_cuda_rng_state = torch.cuda.get_rng_state()
|
||||
# Set rng state to the desired one
|
||||
_set_cuda_rng_state(self.states_[name])
|
||||
# Do the stuff we wanted to do.
|
||||
try:
|
||||
yield
|
||||
finally:
|
||||
# Update the current rng state for later use.
|
||||
self.states_[name] = torch.cuda.get_rng_state()
|
||||
# And set the state to the original state we started with.
|
||||
_set_cuda_rng_state(orig_cuda_rng_state)
|
||||
|
||||
|
||||
# RNG tracker object.
|
||||
_CUDA_RNG_STATE_TRACKER = CudaRNGStatesTracker()
|
||||
|
||||
|
||||
def get_cuda_rng_tracker():
|
||||
"""Get cuda rng tracker."""
|
||||
return _CUDA_RNG_STATE_TRACKER
|
||||
|
||||
|
||||
def model_parallel_cuda_manual_seed(seed):
|
||||
"""Initialize model parallel cuda seed.
|
||||
|
||||
This function should be called after the model parallel is
|
||||
initialized. Also, no torch.cuda.manual_seed should be called
|
||||
after this function. Basically, this is replacement for that
|
||||
function.
|
||||
Two set of RNG states are tracked:
|
||||
default state: This is for data parallelism and is the same among a
|
||||
set of model parallel GPUs but different across
|
||||
different model paralle groups. This is used for
|
||||
example for dropout in the non-tensor-model-parallel regions.
|
||||
tensor-model-parallel state: This state is different among a set of model
|
||||
parallel GPUs, but the same across data parallel
|
||||
groups. This is used for example for dropout in
|
||||
model parallel regions.
|
||||
"""
|
||||
# 2718 is just for fun and any POSITIVE value will work.
|
||||
offset = seed + 2718
|
||||
tensor_model_parallel_seed = offset + get_tensor_model_parallel_rank()
|
||||
# Data parallel gets the original seed.
|
||||
data_parallel_seed = seed
|
||||
|
||||
if torch.distributed.get_rank() == 0:
|
||||
print(
|
||||
"> initializing model parallel cuda seeds on global rank {}, "
|
||||
"model parallel rank {}, and data parallel rank {} with "
|
||||
"model parallel seed: {} and data parallel seed: {}".format(
|
||||
torch.distributed.get_rank(),
|
||||
get_tensor_model_parallel_rank(),
|
||||
get_data_parallel_rank(),
|
||||
tensor_model_parallel_seed,
|
||||
data_parallel_seed,
|
||||
),
|
||||
flush=True,
|
||||
)
|
||||
_CUDA_RNG_STATE_TRACKER.reset()
|
||||
# Set the default state.
|
||||
torch.cuda.manual_seed(data_parallel_seed)
|
||||
# and model parallel state.
|
||||
_CUDA_RNG_STATE_TRACKER.add(
|
||||
_MODEL_PARALLEL_RNG_TRACKER_NAME, tensor_model_parallel_seed
|
||||
)
|
||||
|
||||
|
||||
class CheckpointFunction(torch.autograd.Function):
|
||||
"""This function is adapted from torch.utils.checkpoint with
|
||||
two main changes:
|
||||
1) torch.cuda.set_rng_state is replaced with `_set_cuda_rng_state`
|
||||
2) the states in the model parallel tracker are also properly
|
||||
tracked/set/reset.
|
||||
"""
|
||||
|
||||
@staticmethod
|
||||
def forward(ctx, run_function, *args):
|
||||
ctx.run_function = run_function
|
||||
|
||||
# Copy the rng states.
|
||||
ctx.fwd_cpu_rng_state = torch.get_rng_state()
|
||||
ctx.fwd_cuda_rng_state = torch.cuda.get_rng_state()
|
||||
ctx.fwd_cuda_rng_state_tracker = get_cuda_rng_tracker().get_states()
|
||||
|
||||
with torch.no_grad():
|
||||
outputs = run_function(*args)
|
||||
|
||||
# Divide hidden states across model parallel group and only keep
|
||||
# the chunk corresponding to the current rank.
|
||||
if _CHECKPOINTED_ACTIVATIONS_MEMORY_BUFFER is not None:
|
||||
ctx.input_0_shape = args[0].data.shape
|
||||
args[0].data = split_tensor_into_1d_equal_chunks(args[0].data)
|
||||
args[0].data = _CHECKPOINTED_ACTIVATIONS_MEMORY_BUFFER.add(args[0].data)
|
||||
|
||||
# Store everything.
|
||||
ctx.save_for_backward(*args)
|
||||
|
||||
return outputs
|
||||
|
||||
@staticmethod
|
||||
def backward(ctx, *args):
|
||||
if not torch.autograd._is_checkpoint_valid():
|
||||
raise RuntimeError(
|
||||
"Checkpointing is not compatible with .grad(), "
|
||||
"please use .backward() if possible"
|
||||
)
|
||||
inputs = ctx.saved_tensors
|
||||
if _CHECKPOINTED_ACTIVATIONS_MEMORY_BUFFER is not None:
|
||||
inputs[0].data = gather_split_1d_tensor(inputs[0].data)
|
||||
inputs[0].data = inputs[0].data.view(ctx.input_0_shape)
|
||||
|
||||
# Store the current states.
|
||||
bwd_cpu_rng_state = torch.get_rng_state()
|
||||
bwd_cuda_rng_state = torch.cuda.get_rng_state()
|
||||
bwd_cuda_rng_state_tracker = get_cuda_rng_tracker().get_states()
|
||||
|
||||
# Set the states to what it used to be before the forward pass.
|
||||
torch.set_rng_state(ctx.fwd_cpu_rng_state)
|
||||
_set_cuda_rng_state(ctx.fwd_cuda_rng_state)
|
||||
get_cuda_rng_tracker().set_states(ctx.fwd_cuda_rng_state_tracker)
|
||||
|
||||
# Compute the forward pass.
|
||||
detached_inputs = detach_variable(inputs)
|
||||
with torch.enable_grad():
|
||||
outputs = ctx.run_function(*detached_inputs)
|
||||
|
||||
# Set the states back to what it was at the start of this function.
|
||||
torch.set_rng_state(bwd_cpu_rng_state)
|
||||
_set_cuda_rng_state(bwd_cuda_rng_state)
|
||||
get_cuda_rng_tracker().set_states(bwd_cuda_rng_state_tracker)
|
||||
|
||||
if isinstance(outputs, torch.Tensor):
|
||||
outputs = (outputs,)
|
||||
torch.autograd.backward(outputs, args)
|
||||
grads = tuple(
|
||||
inp.grad if isinstance(inp, torch.Tensor) else inp
|
||||
for inp in detached_inputs
|
||||
)
|
||||
return (None,) + grads
|
||||
|
||||
|
||||
def checkpoint(function, *args):
|
||||
"""Checkpoint a model or part of the model.
|
||||
This has been directly copied from torch.utils.checkpoint."""
|
||||
return CheckpointFunction.apply(function, *args)
|
@ -0,0 +1,72 @@
|
||||
# coding=utf-8
|
||||
# Copyright (c) 2020, NVIDIA CORPORATION. All rights reserved.
|
||||
#
|
||||
# Licensed under the Apache License, Version 2.0 (the "License");
|
||||
# you may not use this file except in compliance with the License.
|
||||
# You may obtain a copy of the License at
|
||||
#
|
||||
# http://www.apache.org/licenses/LICENSE-2.0
|
||||
#
|
||||
# Unless required by applicable law or agreed to in writing, software
|
||||
# distributed under the License is distributed on an "AS IS" BASIS,
|
||||
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
|
||||
# See the License for the specific language governing permissions and
|
||||
# limitations under the License.
|
||||
|
||||
|
||||
import torch
|
||||
|
||||
|
||||
def ensure_divisibility(numerator, denominator):
|
||||
"""Ensure that numerator is divisible by the denominator."""
|
||||
assert numerator % denominator == 0, "{} is not divisible by {}".format(
|
||||
numerator, denominator
|
||||
)
|
||||
|
||||
|
||||
def divide(numerator, denominator):
|
||||
"""Ensure that numerator is divisible by the denominator and return
|
||||
the division value."""
|
||||
ensure_divisibility(numerator, denominator)
|
||||
return numerator // denominator
|
||||
|
||||
|
||||
def split_tensor_along_last_dim(tensor, num_partitions, contiguous_split_chunks=False):
|
||||
"""Split a tensor along its last dimension.
|
||||
Arguments:
|
||||
tensor: input tensor.
|
||||
num_partitions: number of partitions to split the tensor
|
||||
contiguous_split_chunks: If True, make each chunk contiguous
|
||||
in memory.
|
||||
"""
|
||||
# Get the size and dimension.
|
||||
last_dim = tensor.dim() - 1
|
||||
last_dim_size = divide(tensor.size()[last_dim], num_partitions)
|
||||
# Split.
|
||||
tensor_list = torch.split(tensor, last_dim_size, dim=last_dim)
|
||||
# Note: torch.split does not create contiguous tensors by default.
|
||||
if contiguous_split_chunks:
|
||||
return tuple(chunk.contiguous() for chunk in tensor_list)
|
||||
|
||||
return tensor_list
|
||||
|
||||
|
||||
class VocabUtility:
|
||||
"""Split the vocabulary into `world_size` chunks amd return the
|
||||
first and last index of the vocabulary belonging to the `rank`
|
||||
partition: Note that indecies in [fist, last)"""
|
||||
|
||||
@staticmethod
|
||||
def vocab_range_from_per_partition_vocab_size(
|
||||
per_partition_vocab_size, rank, world_size
|
||||
):
|
||||
index_f = rank * per_partition_vocab_size
|
||||
index_l = index_f + per_partition_vocab_size
|
||||
return index_f, index_l
|
||||
|
||||
@staticmethod
|
||||
def vocab_range_from_global_vocab_size(global_vocab_size, rank, world_size):
|
||||
per_partition_vocab_size = divide(global_vocab_size, world_size)
|
||||
return VocabUtility.vocab_range_from_per_partition_vocab_size(
|
||||
per_partition_vocab_size, rank, world_size
|
||||
)
|
@ -0,0 +1,296 @@
|
||||
# coding=utf-8
|
||||
# Copyright (c) 2020, NVIDIA CORPORATION. All rights reserved.
|
||||
#
|
||||
# Licensed under the Apache License, Version 2.0 (the "License");
|
||||
# you may not use this file except in compliance with the License.
|
||||
# You may obtain a copy of the License at
|
||||
#
|
||||
# http://www.apache.org/licenses/LICENSE-2.0
|
||||
#
|
||||
# Unless required by applicable law or agreed to in writing, software
|
||||
# distributed under the License is distributed on an "AS IS" BASIS,
|
||||
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
|
||||
# See the License for the specific language governing permissions and
|
||||
# limitations under the License.
|
||||
|
||||
from functools import reduce
|
||||
import operator
|
||||
import torch
|
||||
|
||||
from codegeex.megatron import get_args
|
||||
from codegeex.megatron import mpu
|
||||
|
||||
|
||||
def _communicate(
|
||||
tensor_send_next, tensor_send_prev, recv_prev, recv_next, use_ring_exchange=False
|
||||
):
|
||||
"""Communicate tensors between stages. Used as helper method in other
|
||||
communication methods that are used in megatron/schedules.py.
|
||||
|
||||
Takes the following arguments:
|
||||
tensor_send_next: tensor to send to next rank (no tensor sent if
|
||||
set to None).
|
||||
tensor_send_prev: tensor to send to prev rank (no tensor sent if
|
||||
set to None).
|
||||
recv_prev: boolean for whether tensor should be received from
|
||||
previous rank.
|
||||
recv_next: boolean for whether tensor should be received from
|
||||
next rank.
|
||||
use_ring_exchange: boolean for whether torch.distributed.ring_exchange()
|
||||
API should be used.
|
||||
|
||||
Returns:
|
||||
(tensor_recv_prev, tensor_recv_next)
|
||||
"""
|
||||
args = get_args()
|
||||
|
||||
# Create placeholder tensors for receive in forward and backward directions
|
||||
# if needed.
|
||||
tensor_recv_prev = None
|
||||
tensor_recv_next = None
|
||||
tensor_shape = (args.seq_length, args.micro_batch_size, args.hidden_size)
|
||||
if args.scatter_gather_tensors_in_pipeline:
|
||||
tensor_chunk_shape = (
|
||||
reduce(operator.mul, tensor_shape, 1)
|
||||
// mpu.get_tensor_model_parallel_world_size()
|
||||
)
|
||||
else:
|
||||
tensor_chunk_shape = tensor_shape
|
||||
dtype = args.params_dtype
|
||||
if args.fp32_residual_connection:
|
||||
dtype = torch.float
|
||||
if recv_prev:
|
||||
tensor_recv_prev = torch.empty(
|
||||
tensor_chunk_shape,
|
||||
requires_grad=True,
|
||||
device=torch.cuda.current_device(),
|
||||
dtype=dtype,
|
||||
)
|
||||
if recv_next:
|
||||
tensor_recv_next = torch.empty(
|
||||
tensor_chunk_shape,
|
||||
requires_grad=True,
|
||||
device=torch.cuda.current_device(),
|
||||
dtype=dtype,
|
||||
)
|
||||
|
||||
# Split tensor into smaller chunks if using scatter-gather optimization.
|
||||
if args.scatter_gather_tensors_in_pipeline:
|
||||
if tensor_send_next is not None:
|
||||
tensor_send_next = mpu.split_tensor_into_1d_equal_chunks(tensor_send_next)
|
||||
|
||||
if tensor_send_prev is not None:
|
||||
tensor_send_prev = mpu.split_tensor_into_1d_equal_chunks(tensor_send_prev)
|
||||
|
||||
# Send tensors in both the forward and backward directions as appropriate.
|
||||
if use_ring_exchange:
|
||||
torch.distributed.ring_exchange(
|
||||
tensor_send_prev=tensor_send_prev,
|
||||
tensor_recv_prev=tensor_recv_prev,
|
||||
tensor_send_next=tensor_send_next,
|
||||
tensor_recv_next=tensor_recv_next,
|
||||
group=mpu.get_pipeline_model_parallel_group(),
|
||||
)
|
||||
else:
|
||||
ops = []
|
||||
if tensor_send_prev is not None:
|
||||
send_prev_op = torch.distributed.P2POp(
|
||||
torch.distributed.isend,
|
||||
tensor_send_prev,
|
||||
mpu.get_pipeline_model_parallel_prev_rank(),
|
||||
)
|
||||
ops.append(send_prev_op)
|
||||
if tensor_recv_prev is not None:
|
||||
recv_prev_op = torch.distributed.P2POp(
|
||||
torch.distributed.irecv,
|
||||
tensor_recv_prev,
|
||||
mpu.get_pipeline_model_parallel_prev_rank(),
|
||||
)
|
||||
ops.append(recv_prev_op)
|
||||
if tensor_send_next is not None:
|
||||
send_next_op = torch.distributed.P2POp(
|
||||
torch.distributed.isend,
|
||||
tensor_send_next,
|
||||
mpu.get_pipeline_model_parallel_next_rank(),
|
||||
)
|
||||
ops.append(send_next_op)
|
||||
if tensor_recv_next is not None:
|
||||
recv_next_op = torch.distributed.P2POp(
|
||||
torch.distributed.irecv,
|
||||
tensor_recv_next,
|
||||
mpu.get_pipeline_model_parallel_next_rank(),
|
||||
)
|
||||
ops.append(recv_next_op)
|
||||
if len(ops) > 0:
|
||||
reqs = torch.distributed.batch_isend_irecv(ops)
|
||||
for req in reqs:
|
||||
req.wait()
|
||||
# To protect against race condition when using batch_isend_irecv().
|
||||
torch.cuda.synchronize()
|
||||
|
||||
# If using scatter-gather optimization, gather smaller chunks.
|
||||
if args.scatter_gather_tensors_in_pipeline:
|
||||
if recv_prev:
|
||||
tensor_recv_prev = (
|
||||
mpu.gather_split_1d_tensor(tensor_recv_prev)
|
||||
.view(tensor_shape)
|
||||
.requires_grad_()
|
||||
)
|
||||
|
||||
if recv_next:
|
||||
tensor_recv_next = (
|
||||
mpu.gather_split_1d_tensor(tensor_recv_next)
|
||||
.view(tensor_shape)
|
||||
.requires_grad_()
|
||||
)
|
||||
|
||||
return tensor_recv_prev, tensor_recv_next
|
||||
|
||||
|
||||
def recv_forward(timers=None):
|
||||
"""Receive tensor from previous rank in pipeline (forward receive)."""
|
||||
if mpu.is_pipeline_first_stage():
|
||||
input_tensor = None
|
||||
else:
|
||||
if timers is not None:
|
||||
timers("forward-recv").start()
|
||||
input_tensor, _ = _communicate(
|
||||
tensor_send_next=None,
|
||||
tensor_send_prev=None,
|
||||
recv_prev=True,
|
||||
recv_next=False,
|
||||
)
|
||||
if timers is not None:
|
||||
timers("forward-recv").stop()
|
||||
return input_tensor
|
||||
|
||||
|
||||
def recv_backward(timers=None):
|
||||
"""Receive tensor from next rank in pipeline (backward receive)."""
|
||||
if mpu.is_pipeline_last_stage():
|
||||
output_tensor_grad = None
|
||||
else:
|
||||
if timers is not None:
|
||||
timers("backward-recv").start()
|
||||
_, output_tensor_grad = _communicate(
|
||||
tensor_send_next=None,
|
||||
tensor_send_prev=None,
|
||||
recv_prev=False,
|
||||
recv_next=True,
|
||||
)
|
||||
if timers is not None:
|
||||
timers("backward-recv").stop()
|
||||
return output_tensor_grad
|
||||
|
||||
|
||||
def send_forward(output_tensor, timers=None):
|
||||
"""Send tensor to next rank in pipeline (forward send)."""
|
||||
if not mpu.is_pipeline_last_stage():
|
||||
if timers is not None:
|
||||
timers("forward-send").start()
|
||||
_communicate(
|
||||
tensor_send_next=output_tensor,
|
||||
tensor_send_prev=None,
|
||||
recv_prev=False,
|
||||
recv_next=False,
|
||||
)
|
||||
if timers is not None:
|
||||
timers("forward-send").stop()
|
||||
|
||||
|
||||
def send_backward(input_tensor_grad, timers=None):
|
||||
"""Send tensor to previous rank in pipeline (backward send)."""
|
||||
if not mpu.is_pipeline_first_stage():
|
||||
if timers is not None:
|
||||
timers("backward-send").start()
|
||||
_communicate(
|
||||
tensor_send_next=None,
|
||||
tensor_send_prev=input_tensor_grad,
|
||||
recv_prev=False,
|
||||
recv_next=False,
|
||||
)
|
||||
if timers is not None:
|
||||
timers("backward-send").stop()
|
||||
|
||||
|
||||
def send_forward_recv_backward(output_tensor, timers=None):
|
||||
"""Batched send and recv with next rank in pipeline."""
|
||||
if mpu.is_pipeline_last_stage():
|
||||
output_tensor_grad = None
|
||||
else:
|
||||
if timers is not None:
|
||||
timers("forward-send-backward-recv").start()
|
||||
_, output_tensor_grad = _communicate(
|
||||
tensor_send_next=output_tensor,
|
||||
tensor_send_prev=None,
|
||||
recv_prev=False,
|
||||
recv_next=True,
|
||||
)
|
||||
if timers is not None:
|
||||
timers("forward-send-backward-recv").stop()
|
||||
return output_tensor_grad
|
||||
|
||||
|
||||
def send_backward_recv_forward(input_tensor_grad, timers=None):
|
||||
"""Batched send and recv with previous rank in pipeline."""
|
||||
if mpu.is_pipeline_first_stage():
|
||||
input_tensor = None
|
||||
else:
|
||||
if timers is not None:
|
||||
timers("backward-send-forward-recv").start()
|
||||
input_tensor, _ = _communicate(
|
||||
tensor_send_next=None,
|
||||
tensor_send_prev=input_tensor_grad,
|
||||
recv_prev=True,
|
||||
recv_next=False,
|
||||
)
|
||||
if timers is not None:
|
||||
timers("backward-send-forward-recv").stop()
|
||||
return input_tensor
|
||||
|
||||
|
||||
def send_forward_recv_forward(output_tensor, recv_prev, timers=None):
|
||||
"""Batched recv from previous rank and send to next rank in pipeline."""
|
||||
if timers is not None:
|
||||
timers("forward-send-forward-recv").start()
|
||||
input_tensor, _ = _communicate(
|
||||
tensor_send_next=output_tensor,
|
||||
tensor_send_prev=None,
|
||||
recv_prev=recv_prev,
|
||||
recv_next=False,
|
||||
)
|
||||
if timers is not None:
|
||||
timers("forward-send-forward-recv").stop()
|
||||
return input_tensor
|
||||
|
||||
|
||||
def send_backward_recv_backward(input_tensor_grad, recv_next, timers=None):
|
||||
"""Batched recv from next rank and send to previous rank in pipeline."""
|
||||
if timers is not None:
|
||||
timers("backward-send-backward-recv").start()
|
||||
_, output_tensor_grad = _communicate(
|
||||
tensor_send_next=None,
|
||||
tensor_send_prev=input_tensor_grad,
|
||||
recv_prev=False,
|
||||
recv_next=recv_next,
|
||||
)
|
||||
if timers is not None:
|
||||
timers("backward-send-backward-recv").stop()
|
||||
return output_tensor_grad
|
||||
|
||||
|
||||
def send_forward_backward_recv_forward_backward(
|
||||
output_tensor, input_tensor_grad, recv_prev, recv_next, timers=None
|
||||
):
|
||||
"""Batched send and recv with previous and next ranks in pipeline."""
|
||||
if timers is not None:
|
||||
timers("forward-backward-send-forward-backward-recv").start()
|
||||
input_tensor, output_tensor_grad = _communicate(
|
||||
tensor_send_next=output_tensor,
|
||||
tensor_send_prev=input_tensor_grad,
|
||||
recv_prev=recv_prev,
|
||||
recv_next=recv_next,
|
||||
)
|
||||
if timers is not None:
|
||||
timers("forward-backward-send-forward-backward-recv").stop()
|
||||
return input_tensor, output_tensor_grad
|
@ -0,0 +1,17 @@
|
||||
# coding=utf-8
|
||||
# Copyright (c) 2020, NVIDIA CORPORATION. All rights reserved.
|
||||
#
|
||||
# Licensed under the Apache License, Version 2.0 (the "License");
|
||||
# you may not use this file except in compliance with the License.
|
||||
# You may obtain a copy of the License at
|
||||
#
|
||||
# http://www.apache.org/licenses/LICENSE-2.0
|
||||
#
|
||||
# Unless required by applicable law or agreed to in writing, software
|
||||
# distributed under the License is distributed on an "AS IS" BASIS,
|
||||
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
|
||||
# See the License for the specific language governing permissions and
|
||||
# limitations under the License.
|
||||
|
||||
|
||||
from .tokenizer import build_tokenizer
|
@ -0,0 +1,375 @@
|
||||
# coding=utf-8
|
||||
# Copyright 2018 The Open AI Team Authors and The HuggingFace Inc. team.
|
||||
#
|
||||
# Licensed under the Apache License, Version 2.0 (the "License");
|
||||
# you may not use this file except in compliance with the License.
|
||||
# You may obtain a copy of the License at
|
||||
#
|
||||
# http://www.apache.org/licenses/LICENSE-2.0
|
||||
#
|
||||
# Unless required by applicable law or agreed to in writing, software
|
||||
# distributed under the License is distributed on an "AS IS" BASIS,
|
||||
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
|
||||
# See the License for the specific language governing permissions and
|
||||
# limitations under the License.
|
||||
|
||||
"""Tokenization classes for OpenAI GPT."""
|
||||
|
||||
from __future__ import absolute_import, division, print_function, unicode_literals
|
||||
|
||||
import sys
|
||||
import json
|
||||
import logging
|
||||
import os
|
||||
import regex as re
|
||||
from io import open
|
||||
|
||||
try:
|
||||
from functools import lru_cache
|
||||
except ImportError:
|
||||
# Just a dummy decorator to get the checks to run on python2
|
||||
# because honestly I don't want to support a byte-level unicode BPE
|
||||
# tokenizer on python 2 right now.
|
||||
def lru_cache():
|
||||
return lambda func: func
|
||||
|
||||
|
||||
logger = logging.getLogger(__name__)
|
||||
|
||||
PRETRAINED_VOCAB_ARCHIVE_MAP = {
|
||||
"gpt2": "https://s3.amazonaws.com/models.huggingface.co/bert/gpt2-vocab.json",
|
||||
}
|
||||
PRETRAINED_MERGES_ARCHIVE_MAP = {
|
||||
"gpt2": "https://s3.amazonaws.com/models.huggingface.co/bert/gpt2-merges.txt",
|
||||
}
|
||||
PRETRAINED_VOCAB_POSITIONAL_EMBEDDINGS_SIZE_MAP = {
|
||||
"gpt2": 1024,
|
||||
}
|
||||
VOCAB_NAME = "vocab.json"
|
||||
MERGES_NAME = "merges.txt"
|
||||
SPECIAL_TOKENS_NAME = "special_tokens.txt"
|
||||
|
||||
|
||||
@lru_cache()
|
||||
def bytes_to_unicode():
|
||||
"""
|
||||
Returns list of utf-8 byte and a corresponding list of unicode strings.
|
||||
The reversible bpe codes work on unicode strings.
|
||||
This means you need a large # of unicode characters in your vocab if you want to avoid UNKs.
|
||||
When you're at something like a 10B token dataset you end up needing around 5K for decent coverage.
|
||||
This is a signficant percentage of your normal, say, 32K bpe vocab.
|
||||
To avoid that, we want lookup tables between utf-8 bytes and unicode strings.
|
||||
And avoids mapping to whitespace/control characters the bpe code barfs on.
|
||||
"""
|
||||
_chr = unichr if sys.version_info[0] == 2 else chr
|
||||
bs = (
|
||||
list(range(ord("!"), ord("~") + 1))
|
||||
+ list(range(ord("¡"), ord("¬") + 1))
|
||||
+ list(range(ord("®"), ord("ÿ") + 1))
|
||||
)
|
||||
cs = bs[:]
|
||||
n = 0
|
||||
for b in range(2 ** 8):
|
||||
if b not in bs:
|
||||
bs.append(b)
|
||||
cs.append(2 ** 8 + n)
|
||||
n += 1
|
||||
cs = [_chr(n) for n in cs]
|
||||
return dict(zip(bs, cs))
|
||||
|
||||
|
||||
def get_pairs(word):
|
||||
"""Return set of symbol pairs in a word.
|
||||
|
||||
Word is represented as tuple of symbols (symbols being variable-length strings).
|
||||
"""
|
||||
pairs = set()
|
||||
prev_char = word[0]
|
||||
for char in word[1:]:
|
||||
pairs.add((prev_char, char))
|
||||
prev_char = char
|
||||
return pairs
|
||||
|
||||
|
||||
class GPT2Tokenizer(object):
|
||||
"""
|
||||
GPT-2 BPE tokenizer. Peculiarities:
|
||||
- Byte-level BPE
|
||||
"""
|
||||
|
||||
@classmethod
|
||||
def from_pretrained(
|
||||
cls, pretrained_model_name_or_path, cache_dir=None, *inputs, **kwargs
|
||||
):
|
||||
"""
|
||||
Instantiate a PreTrainedBertModel from a pre-trained model file.
|
||||
Download and cache the pre-trained model file if needed.
|
||||
"""
|
||||
if pretrained_model_name_or_path in PRETRAINED_VOCAB_ARCHIVE_MAP:
|
||||
vocab_file = PRETRAINED_VOCAB_ARCHIVE_MAP[pretrained_model_name_or_path]
|
||||
merges_file = PRETRAINED_MERGES_ARCHIVE_MAP[pretrained_model_name_or_path]
|
||||
special_tokens_file = None
|
||||
else:
|
||||
vocab_file = os.path.join(pretrained_model_name_or_path, VOCAB_NAME)
|
||||
merges_file = os.path.join(pretrained_model_name_or_path, MERGES_NAME)
|
||||
special_tokens_file = os.path.join(
|
||||
pretrained_model_name_or_path, SPECIAL_TOKENS_NAME
|
||||
)
|
||||
if not os.path.exists(special_tokens_file):
|
||||
special_tokens_file = None
|
||||
else:
|
||||
logger.info(
|
||||
"loading special tokens file {}".format(special_tokens_file)
|
||||
)
|
||||
# redirect to the cache, if necessary
|
||||
try:
|
||||
from .file_utils import cached_path
|
||||
|
||||
resolved_vocab_file = cached_path(vocab_file, cache_dir=cache_dir)
|
||||
resolved_merges_file = cached_path(merges_file, cache_dir=cache_dir)
|
||||
except EnvironmentError:
|
||||
logger.error(
|
||||
"Model name '{}' was not found in model name list ({}). "
|
||||
"We assumed '{}' was a path or url but couldn't find files {} and {} "
|
||||
"at this path or url.".format(
|
||||
pretrained_model_name_or_path,
|
||||
", ".join(PRETRAINED_VOCAB_ARCHIVE_MAP.keys()),
|
||||
pretrained_model_name_or_path,
|
||||
vocab_file,
|
||||
merges_file,
|
||||
)
|
||||
)
|
||||
return None
|
||||
if resolved_vocab_file == vocab_file and resolved_merges_file == merges_file:
|
||||
logger.info("loading vocabulary file {}".format(vocab_file))
|
||||
logger.info("loading merges file {}".format(merges_file))
|
||||
else:
|
||||
logger.info(
|
||||
"loading vocabulary file {} from cache at {}".format(
|
||||
vocab_file, resolved_vocab_file
|
||||
)
|
||||
)
|
||||
logger.info(
|
||||
"loading merges file {} from cache at {}".format(
|
||||
merges_file, resolved_merges_file
|
||||
)
|
||||
)
|
||||
if (
|
||||
pretrained_model_name_or_path
|
||||
in PRETRAINED_VOCAB_POSITIONAL_EMBEDDINGS_SIZE_MAP
|
||||
):
|
||||
# if we're using a pretrained model, ensure the tokenizer wont index sequences longer
|
||||
# than the number of positional embeddings
|
||||
max_len = PRETRAINED_VOCAB_POSITIONAL_EMBEDDINGS_SIZE_MAP[
|
||||
pretrained_model_name_or_path
|
||||
]
|
||||
kwargs["max_len"] = min(kwargs.get("max_len", int(1e12)), max_len)
|
||||
# Instantiate tokenizer.
|
||||
if special_tokens_file and "special_tokens" not in kwargs:
|
||||
special_tokens = (
|
||||
open(special_tokens_file, encoding="utf-8").read().split("\n")[:-1]
|
||||
)
|
||||
else:
|
||||
special_tokens = kwargs.pop("special_tokens", [])
|
||||
tokenizer = cls(
|
||||
resolved_vocab_file,
|
||||
resolved_merges_file,
|
||||
special_tokens=special_tokens,
|
||||
*inputs,
|
||||
**kwargs
|
||||
)
|
||||
return tokenizer
|
||||
|
||||
def __init__(
|
||||
self,
|
||||
vocab_file,
|
||||
merges_file,
|
||||
errors="replace",
|
||||
special_tokens=None,
|
||||
max_len=None,
|
||||
):
|
||||
self.max_len = max_len if max_len is not None else int(1e12)
|
||||
self.encoder = json.load(open(vocab_file))
|
||||
self.decoder = {v: k for k, v in self.encoder.items()}
|
||||
self.errors = errors # how to handle errors in decoding
|
||||
self.byte_encoder = bytes_to_unicode()
|
||||
self.byte_decoder = {v: k for k, v in self.byte_encoder.items()}
|
||||
bpe_data = open(merges_file, encoding="utf-8").read().split("\n")[1:-1]
|
||||
bpe_merges = [tuple(merge.split()) for merge in bpe_data]
|
||||
self.bpe_ranks = dict(zip(bpe_merges, range(len(bpe_merges))))
|
||||
self.cache = {}
|
||||
|
||||
# Should haved added re.IGNORECASE so BPE merges can happen for
|
||||
# capitalized versions of contractions
|
||||
self.pat = re.compile(
|
||||
r"""'s|'t|'re|'ve|'m|'ll|'d| ?\p{L}+| ?\p{N}+| ?[^\s\p{L}\p{N}]+|\s+(?!\S)|\s+"""
|
||||
)
|
||||
|
||||
self.special_tokens = {}
|
||||
self.special_tokens_decoder = {}
|
||||
self.set_special_tokens(special_tokens)
|
||||
|
||||
def __len__(self):
|
||||
return len(self.encoder) + len(self.special_tokens)
|
||||
|
||||
def set_special_tokens(self, special_tokens):
|
||||
"""Add a list of additional tokens to the encoder.
|
||||
The additional tokens are indexed starting from the last index of the
|
||||
current vocabulary in the order of the `special_tokens` list.
|
||||
"""
|
||||
if not special_tokens:
|
||||
self.special_tokens = {}
|
||||
self.special_tokens_decoder = {}
|
||||
return
|
||||
self.special_tokens = dict(
|
||||
(tok, len(self.encoder) + i) for i, tok in enumerate(special_tokens)
|
||||
)
|
||||
self.special_tokens_decoder = {v: k for k, v in self.special_tokens.items()}
|
||||
logger.info("Special tokens {}".format(self.special_tokens))
|
||||
|
||||
def bpe(self, token):
|
||||
if token in self.cache:
|
||||
return self.cache[token]
|
||||
word = tuple(token)
|
||||
pairs = get_pairs(word)
|
||||
|
||||
if not pairs:
|
||||
return token
|
||||
|
||||
while True:
|
||||
bigram = min(pairs, key=lambda pair: self.bpe_ranks.get(pair, float("inf")))
|
||||
if bigram not in self.bpe_ranks:
|
||||
break
|
||||
first, second = bigram
|
||||
new_word = []
|
||||
i = 0
|
||||
while i < len(word):
|
||||
try:
|
||||
j = word.index(first, i)
|
||||
new_word.extend(word[i:j])
|
||||
i = j
|
||||
except BaseException:
|
||||
new_word.extend(word[i:])
|
||||
break
|
||||
|
||||
if word[i] == first and i < len(word) - 1 and word[i + 1] == second:
|
||||
new_word.append(first + second)
|
||||
i += 2
|
||||
else:
|
||||
new_word.append(word[i])
|
||||
i += 1
|
||||
new_word = tuple(new_word)
|
||||
word = new_word
|
||||
if len(word) == 1:
|
||||
break
|
||||
else:
|
||||
pairs = get_pairs(word)
|
||||
word = " ".join(word)
|
||||
self.cache[token] = word
|
||||
return word
|
||||
|
||||
def tokenize(self, text):
|
||||
"""Tokenize a string."""
|
||||
bpe_tokens = []
|
||||
for token in re.findall(self.pat, text):
|
||||
if sys.version_info[0] == 2:
|
||||
token = "".join(self.byte_encoder[ord(b)] for b in token)
|
||||
else:
|
||||
token = "".join(self.byte_encoder[b] for b in token.encode("utf-8"))
|
||||
bpe_tokens.extend(bpe_token for bpe_token in self.bpe(token).split(" "))
|
||||
return bpe_tokens
|
||||
|
||||
def convert_tokens_to_ids(self, tokens):
|
||||
"""Converts a sequence of tokens into ids using the vocab."""
|
||||
ids = []
|
||||
if isinstance(tokens, str) or (
|
||||
sys.version_info[0] == 2 and isinstance(tokens, unicode)
|
||||
):
|
||||
if tokens in self.special_tokens:
|
||||
return self.special_tokens[tokens]
|
||||
else:
|
||||
return self.encoder.get(tokens, 0)
|
||||
for token in tokens:
|
||||
if token in self.special_tokens:
|
||||
ids.append(self.special_tokens[token])
|
||||
else:
|
||||
ids.append(self.encoder.get(token, 0))
|
||||
if len(ids) > self.max_len:
|
||||
logger.warning(
|
||||
"Token indices sequence length is longer than the specified maximum "
|
||||
" sequence length for this OpenAI GPT model ({} > {}). Running this"
|
||||
" sequence through the model will result in indexing errors".format(
|
||||
len(ids), self.max_len
|
||||
)
|
||||
)
|
||||
return ids
|
||||
|
||||
def convert_ids_to_tokens(self, ids, skip_special_tokens=False):
|
||||
"""Converts a sequence of ids in BPE tokens using the vocab."""
|
||||
tokens = []
|
||||
for i in ids:
|
||||
if i in self.special_tokens_decoder:
|
||||
if not skip_special_tokens:
|
||||
tokens.append(self.special_tokens_decoder[i])
|
||||
else:
|
||||
tokens.append(self.decoder[i])
|
||||
return tokens
|
||||
|
||||
def encode(self, text):
|
||||
return self.convert_tokens_to_ids(self.tokenize(text))
|
||||
|
||||
def decode(self, tokens):
|
||||
text = "".join([self.decoder[token] for token in tokens])
|
||||
text = bytearray([self.byte_decoder[c] for c in text]).decode(
|
||||
"utf-8", errors=self.errors
|
||||
)
|
||||
return text
|
||||
|
||||
def save_vocabulary(self, vocab_path):
|
||||
"""Save the tokenizer vocabulary and merge files to a directory."""
|
||||
if not os.path.isdir(vocab_path):
|
||||
logger.error(
|
||||
"Vocabulary path ({}) should be a directory".format(vocab_path)
|
||||
)
|
||||
return
|
||||
vocab_file = os.path.join(vocab_path, VOCAB_NAME)
|
||||
merge_file = os.path.join(vocab_path, MERGES_NAME)
|
||||
special_tokens_file = os.path.join(vocab_path, SPECIAL_TOKENS_NAME)
|
||||
|
||||
with open(vocab_file, "w", encoding="utf-8") as f:
|
||||
f.write(json.dumps(self.encoder, ensure_ascii=False))
|
||||
|
||||
index = 0
|
||||
with open(merge_file, "w", encoding="utf-8") as writer:
|
||||
writer.write("#version: 0.2\n")
|
||||
for bpe_tokens, token_index in sorted(
|
||||
self.bpe_ranks.items(), key=lambda kv: kv[1]
|
||||
):
|
||||
if index != token_index:
|
||||
logger.warning(
|
||||
"Saving vocabulary to {}: BPE merge indices are not consecutive."
|
||||
" Please check that the tokenizer is not corrupted!".format(
|
||||
merge_file
|
||||
)
|
||||
)
|
||||
index = token_index
|
||||
writer.write(" ".join(bpe_tokens) + "\n")
|
||||
index += 1
|
||||
|
||||
index = len(self.encoder)
|
||||
with open(special_tokens_file, "w", encoding="utf-8") as writer:
|
||||
for token, token_index in sorted(
|
||||
self.special_tokens.items(), key=lambda kv: kv[1]
|
||||
):
|
||||
if index != token_index:
|
||||
logger.warning(
|
||||
"Saving special tokens vocabulary to {}: BPE indices are not consecutive."
|
||||
" Please check that the tokenizer is not corrupted!".format(
|
||||
special_tokens_file
|
||||
)
|
||||
)
|
||||
index = token_index
|
||||
writer.write(token + "\n")
|
||||
index += 1
|
||||
|
||||
return vocab_file, merge_file, special_tokens_file
|
@ -0,0 +1,280 @@
|
||||
# coding=utf-8
|
||||
# Copyright (c) 2020, NVIDIA CORPORATION. All rights reserved.
|
||||
#
|
||||
# Licensed under the Apache License, Version 2.0 (the "License");
|
||||
# you may not use this file except in compliance with the License.
|
||||
# You may obtain a copy of the License at
|
||||
#
|
||||
# http://www.apache.org/licenses/LICENSE-2.0
|
||||
#
|
||||
# Unless required by applicable law or agreed to in writing, software
|
||||
# distributed under the License is distributed on an "AS IS" BASIS,
|
||||
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
|
||||
# See the License for the specific language governing permissions and
|
||||
# limitations under the License.
|
||||
|
||||
"""CodeGeeX tokenizers."""
|
||||
|
||||
from abc import ABC
|
||||
from abc import abstractmethod
|
||||
|
||||
from .gpt2_tokenization import GPT2Tokenizer
|
||||
from transformers import AutoTokenizer
|
||||
|
||||
|
||||
def encode_whitespaces(text, start_extra_id: int, max_len: int):
|
||||
"""Encode whitespaces to extra tokens in GPT-J.
|
||||
|
||||
>>> encode_whitespaces('a\\n b\\n c', 10, 10)
|
||||
'a\\n<|extratoken_10|>b\\n<|extratoken_11|>c'
|
||||
"""
|
||||
|
||||
def push_acc_space(acc_len: int, text: str):
|
||||
if acc_len == 0:
|
||||
return text
|
||||
if acc_len == 1:
|
||||
return text + " "
|
||||
assert (
|
||||
acc_len <= max_len
|
||||
), f"Max whitespace run length {max_len}, but found {acc_len}"
|
||||
extra_id = start_extra_id - 2 + acc_len
|
||||
extra_token = f"<|extratoken_{extra_id}|>"
|
||||
return text + extra_token
|
||||
|
||||
acc_len = 0
|
||||
res = ""
|
||||
for ch in text:
|
||||
if ch == " ":
|
||||
acc_len += 1
|
||||
if acc_len == max_len:
|
||||
res = push_acc_space(acc_len, res)
|
||||
acc_len = 0
|
||||
else:
|
||||
res = push_acc_space(acc_len, res)
|
||||
acc_len = 0
|
||||
res = res + ch
|
||||
|
||||
res = push_acc_space(acc_len, res)
|
||||
|
||||
return res
|
||||
|
||||
|
||||
def decode_whitespaces(text: str, start_extra_id: int, max_len: int):
|
||||
"""Decode the whitespace-encoded strings produced by encode_whitespace.
|
||||
|
||||
>>> text = 'a\\n b\\n c'
|
||||
>>> s, l = 10, 10
|
||||
>>> text == decode_whitespaces(encode_whitespaces(text, s, l), s, l)
|
||||
True
|
||||
"""
|
||||
for l in range(2, max_len + 1):
|
||||
token_id = start_extra_id - 2 + l
|
||||
token = f"<|extratoken_{token_id}|>"
|
||||
text = text.replace(token, " " * l)
|
||||
return text
|
||||
|
||||
|
||||
def build_hgf_tokenizer(args):
|
||||
"""Initialize tokenizer."""
|
||||
tokenizer_path = args.tokenizer_path
|
||||
if args.rank == 0:
|
||||
print(f"> building huggingface tokenizer from {tokenizer_path} ...", flush=True)
|
||||
assert tokenizer_path is not None, "Tokenizer path must be provided."
|
||||
|
||||
tokenizer = AutoTokenizer.from_pretrained(tokenizer_path)
|
||||
if args.rank == 0:
|
||||
print(f" > eos_token = {tokenizer.eos_token}", flush=True)
|
||||
|
||||
ws_start_id = args.ws_encoding_start_id if "ws_encoding_start_id" in args else None
|
||||
ws_len = args.ws_encoding_length if "ws_encoding_length" in args else None
|
||||
|
||||
return HgfTokenizerWrapper(
|
||||
tokenizer, ws_start=ws_start_id, ws_len=ws_len
|
||||
)
|
||||
|
||||
|
||||
def build_tokenizer(args):
|
||||
"""Initialize tokenizer."""
|
||||
if "tokenizer_path" in args and args.tokenizer_path is not None:
|
||||
# build huggingface tokenizer
|
||||
tokenizer = build_hgf_tokenizer(args)
|
||||
args.padded_vocab_size = _vocab_size_with_padding(tokenizer.vocab_size, args)
|
||||
return tokenizer
|
||||
|
||||
if args.rank == 0:
|
||||
print("> building {} tokenizer ...".format(args.tokenizer_type), flush=True)
|
||||
|
||||
# Select and instantiate the tokenizer.
|
||||
assert args.vocab_file is not None
|
||||
if args.tokenizer_type == "GPT2BPETokenizer":
|
||||
assert args.merge_file is not None
|
||||
tokenizer = _GPT2BPETokenizer(args.vocab_file, args.merge_file)
|
||||
else:
|
||||
raise NotImplementedError(
|
||||
"{} tokenizer is not " "implemented.".format(args.tokenizer_type)
|
||||
)
|
||||
|
||||
# Add vocab size.
|
||||
args.padded_vocab_size = _vocab_size_with_padding(tokenizer.vocab_size, args)
|
||||
|
||||
return tokenizer
|
||||
|
||||
|
||||
def _vocab_size_with_padding(orig_vocab_size, args):
|
||||
"""Pad vocab size so it is divisible by model parallel size and
|
||||
still having GPU friendly size."""
|
||||
|
||||
after = orig_vocab_size
|
||||
multiple = args.make_vocab_size_divisible_by * args.tensor_model_parallel_size
|
||||
while (after % multiple) != 0:
|
||||
after += 1
|
||||
if args.rank == 0:
|
||||
print(
|
||||
" > padded vocab (size: {}) with {} dummy tokens "
|
||||
"(new size: {})".format(orig_vocab_size, after - orig_vocab_size, after),
|
||||
flush=True,
|
||||
)
|
||||
return after
|
||||
|
||||
|
||||
class AbstractTokenizer(ABC):
|
||||
"""Abstract class for tokenizer."""
|
||||
|
||||
def __init__(self, name):
|
||||
self.name = name
|
||||
super().__init__()
|
||||
|
||||
@property
|
||||
@abstractmethod
|
||||
def vocab_size(self):
|
||||
pass
|
||||
|
||||
@property
|
||||
@abstractmethod
|
||||
def vocab(self):
|
||||
"""Dictionary from vocab text token to id token."""
|
||||
pass
|
||||
|
||||
@property
|
||||
@abstractmethod
|
||||
def inv_vocab(self):
|
||||
"""Dictionary from vocab id token to text token."""
|
||||
pass
|
||||
|
||||
@abstractmethod
|
||||
def tokenize(self, text):
|
||||
pass
|
||||
|
||||
def detokenize(self, token_ids):
|
||||
raise NotImplementedError(
|
||||
"detokenizer is not implemented for {} " "tokenizer".format(self.name)
|
||||
)
|
||||
|
||||
@property
|
||||
def cls(self):
|
||||
raise NotImplementedError(
|
||||
"CLS is not provided for {} " "tokenizer".format(self.name)
|
||||
)
|
||||
|
||||
@property
|
||||
def sep(self):
|
||||
raise NotImplementedError(
|
||||
"SEP is not provided for {} " "tokenizer".format(self.name)
|
||||
)
|
||||
|
||||
@property
|
||||
def pad(self):
|
||||
raise NotImplementedError(
|
||||
"PAD is not provided for {} " "tokenizer".format(self.name)
|
||||
)
|
||||
|
||||
@property
|
||||
def eod(self):
|
||||
raise NotImplementedError(
|
||||
"EOD is not provided for {} " "tokenizer".format(self.name)
|
||||
)
|
||||
|
||||
@property
|
||||
def mask(self):
|
||||
raise NotImplementedError(
|
||||
"MASK is not provided for {} " "tokenizer".format(self.name)
|
||||
)
|
||||
|
||||
|
||||
class _GPT2BPETokenizer(AbstractTokenizer):
|
||||
"""Original GPT2 BPE tokenizer."""
|
||||
|
||||
def __init__(self, vocab_file, merge_file):
|
||||
name = "GPT2 BPE"
|
||||
super().__init__(name)
|
||||
|
||||
self.tokenizer = GPT2Tokenizer(
|
||||
vocab_file, merge_file, errors="replace", special_tokens=[], max_len=None
|
||||
)
|
||||
self.eod_id = self.tokenizer.encoder["<|endoftext|>"]
|
||||
|
||||
@property
|
||||
def vocab_size(self):
|
||||
return len(self.tokenizer.encoder)
|
||||
|
||||
@property
|
||||
def vocab(self):
|
||||
return self.tokenizer.encoder
|
||||
|
||||
@property
|
||||
def inv_vocab(self):
|
||||
return self.tokenizer.decoder
|
||||
|
||||
def tokenize(self, text):
|
||||
return self.tokenizer.encode(text)
|
||||
|
||||
def detokenize(self, token_ids):
|
||||
return self.tokenizer.decode(token_ids)
|
||||
|
||||
@property
|
||||
def eod(self):
|
||||
return self.eod_id
|
||||
|
||||
|
||||
class HgfTokenizerWrapper(AbstractTokenizer):
|
||||
"""Wrapper for Hugging Face tokenizer."""
|
||||
|
||||
def __init__(
|
||||
self,
|
||||
tokenizer,
|
||||
ws_start: int = None,
|
||||
ws_len: int = None,
|
||||
):
|
||||
super(HgfTokenizerWrapper, self).__init__(tokenizer.__class__.__name__)
|
||||
self.tokenizer = tokenizer
|
||||
self.ws_start = ws_start
|
||||
self.ws_len = ws_len
|
||||
|
||||
def tokenize(self, text):
|
||||
if self.ws_start:
|
||||
text = encode_whitespaces(text, self.ws_start, self.ws_len)
|
||||
input_ids = self.tokenizer(text, is_split_into_words=False).input_ids
|
||||
|
||||
return input_ids
|
||||
|
||||
def detokenize(self, token_ids):
|
||||
text = self.tokenizer.decode(token_ids, skip_special_tokens=False)
|
||||
if self.ws_start:
|
||||
text = decode_whitespaces(text, self.ws_start, self.ws_len)
|
||||
return text
|
||||
|
||||
@property
|
||||
def eod(self):
|
||||
return self.tokenizer.eos_token_id
|
||||
|
||||
@property
|
||||
def inv_vocab(self):
|
||||
return len(self.tokenizer.decoder)
|
||||
|
||||
@property
|
||||
def vocab(self):
|
||||
return self.tokenizer.vocab
|
||||
|
||||
@property
|
||||
def vocab_size(self):
|
||||
return len(self.vocab)
|
@ -0,0 +1,219 @@
|
||||
# coding=utf-8
|
||||
# Copyright (c) 2020, NVIDIA CORPORATION. All rights reserved.
|
||||
#
|
||||
# Licensed under the Apache License, Version 2.0 (the "License");
|
||||
# you may not use this file except in compliance with the License.
|
||||
# You may obtain a copy of the License at
|
||||
#
|
||||
# http://www.apache.org/licenses/LICENSE-2.0
|
||||
#
|
||||
# Unless required by applicable law or agreed to in writing, software
|
||||
# distributed under the License is distributed on an "AS IS" BASIS,
|
||||
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
|
||||
# See the License for the specific language governing permissions and
|
||||
# limitations under the License.
|
||||
|
||||
"""General utilities."""
|
||||
|
||||
import sys
|
||||
|
||||
import torch
|
||||
from torch.nn.parallel import DistributedDataParallel as torchDDP
|
||||
|
||||
from apex.multi_tensor_apply import multi_tensor_applier
|
||||
import amp_C
|
||||
|
||||
from codegeex.megatron import get_args
|
||||
from codegeex.megatron import print_rank_0
|
||||
from codegeex.megatron import get_adlr_autoresume
|
||||
from codegeex.megatron import mpu
|
||||
from codegeex.megatron.model.module import param_is_not_shared
|
||||
from codegeex.megatron.mpu.layers import param_is_not_tensor_parallel_duplicate
|
||||
|
||||
|
||||
def unwrap_model(model, module_instances=(torchDDP)):
|
||||
return_list = True
|
||||
if not isinstance(model, list):
|
||||
model = [model]
|
||||
return_list = False
|
||||
unwrapped_model = []
|
||||
for model_module in model:
|
||||
while isinstance(model_module, module_instances):
|
||||
model_module = model_module.module
|
||||
unwrapped_model.append(model_module)
|
||||
if not return_list:
|
||||
return unwrapped_model[0]
|
||||
return unwrapped_model
|
||||
|
||||
|
||||
def calc_params_l2_norm(model):
|
||||
"""Calculate l2 norm of parameters"""
|
||||
args = get_args()
|
||||
if not isinstance(model, list):
|
||||
model = [model]
|
||||
# Remove duplicate params.
|
||||
params_data = []
|
||||
for model_ in model:
|
||||
for param in model_.parameters():
|
||||
is_not_shared = param_is_not_shared(param)
|
||||
is_not_tp_duplicate = param_is_not_tensor_parallel_duplicate(param)
|
||||
if is_not_shared and is_not_tp_duplicate:
|
||||
if args.bf16:
|
||||
params_data.append(param.data.float())
|
||||
else:
|
||||
params_data.append(param.data)
|
||||
# Calculate norm
|
||||
dummy_overflow_buf = torch.cuda.IntTensor([0])
|
||||
norm, _ = multi_tensor_applier(
|
||||
amp_C.multi_tensor_l2norm,
|
||||
dummy_overflow_buf,
|
||||
[params_data],
|
||||
False, # no per-parameter norm
|
||||
)
|
||||
norm_2 = norm * norm
|
||||
# Sum across all model-parallel GPUs.
|
||||
torch.distributed.all_reduce(
|
||||
norm_2, op=torch.distributed.ReduceOp.SUM, group=mpu.get_model_parallel_group()
|
||||
)
|
||||
return norm_2.item() ** 0.5
|
||||
|
||||
|
||||
def average_losses_across_data_parallel_group(losses):
|
||||
"""Reduce a tensor of losses across all GPUs."""
|
||||
averaged_losses = torch.cat([loss.clone().detach().view(1) for loss in losses])
|
||||
torch.distributed.all_reduce(averaged_losses, group=mpu.get_data_parallel_group())
|
||||
averaged_losses = averaged_losses / torch.distributed.get_world_size(
|
||||
group=mpu.get_data_parallel_group()
|
||||
)
|
||||
|
||||
return averaged_losses
|
||||
|
||||
|
||||
def report_memory(name):
|
||||
"""Simple GPU memory report."""
|
||||
mega_bytes = 1024.0 * 1024.0
|
||||
string = name + " memory (MB)"
|
||||
string += " | allocated: {}".format(torch.cuda.memory_allocated() / mega_bytes)
|
||||
string += " | max allocated: {}".format(
|
||||
torch.cuda.max_memory_allocated() / mega_bytes
|
||||
)
|
||||
string += " | reserved: {}".format(torch.cuda.memory_reserved() / mega_bytes)
|
||||
string += " | max reserved: {}".format(
|
||||
torch.cuda.max_memory_reserved() / mega_bytes
|
||||
)
|
||||
if mpu.get_data_parallel_rank() == 0:
|
||||
print("[Rank {}] {}".format(torch.distributed.get_rank(), string), flush=True)
|
||||
|
||||
|
||||
def print_params_min_max_norm(optimizer, iteration):
|
||||
"""Print min, max, and norm of all parameters."""
|
||||
index = 0
|
||||
rank = torch.distributed.get_rank()
|
||||
string = "iteration, rank, index, tensor-model-parallel, min, max, norm\n"
|
||||
optimizer_ = optimizer.optimizer
|
||||
for param_group in optimizer_.param_groups:
|
||||
for param in param_group["params"]:
|
||||
index += 1
|
||||
min_ = param.data.min()
|
||||
max_ = param.data.max()
|
||||
norm = torch.linalg.norm(param.data)
|
||||
string += "{:7d}, {:4d}, {:4d}, {:2d}, ".format(
|
||||
iteration, rank, index, int(param.tensor_model_parallel)
|
||||
)
|
||||
string += "{:.6E}, {:.6E}, {:.6E}\n".format(min_, max_, norm)
|
||||
print(string, flush=True)
|
||||
|
||||
|
||||
def check_adlr_autoresume_termination(iteration, model, optimizer, lr_scheduler):
|
||||
"""Check for autoresume signal and exit if it is received."""
|
||||
from codegeex.megatron.checkpointing import save_checkpoint
|
||||
|
||||
args = get_args()
|
||||
autoresume = get_adlr_autoresume()
|
||||
# Add barrier to ensure consistnecy.
|
||||
torch.distributed.barrier()
|
||||
if autoresume.termination_requested():
|
||||
if args.save:
|
||||
save_checkpoint(iteration, model, optimizer, lr_scheduler)
|
||||
print_rank_0(">>> autoresume termination request found!")
|
||||
if torch.distributed.get_rank() == 0:
|
||||
autoresume.request_resume()
|
||||
print_rank_0(">>> training terminated. Returning")
|
||||
sys.exit(0)
|
||||
|
||||
|
||||
def get_ltor_masks_and_position_ids(
|
||||
data, eod_token, reset_position_ids, reset_attention_mask, eod_mask_loss
|
||||
):
|
||||
"""Build masks and position id for left to right model."""
|
||||
|
||||
# Extract batch size and sequence length.
|
||||
micro_batch_size, seq_length = data.size()
|
||||
|
||||
# Attention mask (lower triangular).
|
||||
if reset_attention_mask:
|
||||
att_mask_batch = micro_batch_size
|
||||
else:
|
||||
att_mask_batch = 1
|
||||
attention_mask = torch.tril(
|
||||
torch.ones((att_mask_batch, seq_length, seq_length), device=data.device)
|
||||
).view(att_mask_batch, 1, seq_length, seq_length)
|
||||
|
||||
# Loss mask.
|
||||
loss_mask = torch.ones(data.size(), dtype=torch.float, device=data.device)
|
||||
if eod_mask_loss:
|
||||
loss_mask[data == eod_token] = 0.0
|
||||
|
||||
# Position ids.
|
||||
position_ids = torch.arange(seq_length, dtype=torch.long, device=data.device)
|
||||
position_ids = position_ids.unsqueeze(0).expand_as(data)
|
||||
# We need to clone as the ids will be modifed based on batch index.
|
||||
if reset_position_ids:
|
||||
position_ids = position_ids.clone()
|
||||
|
||||
if reset_position_ids or reset_attention_mask:
|
||||
# Loop through the batches:
|
||||
for b in range(micro_batch_size):
|
||||
|
||||
# Find indecies where EOD token is.
|
||||
eod_index = position_ids[b, data[b] == eod_token]
|
||||
# Detach indecies from positions if going to modify positions.
|
||||
if reset_position_ids:
|
||||
eod_index = eod_index.clone()
|
||||
|
||||
# Loop through EOD indecies:
|
||||
prev_index = 0
|
||||
for j in range(eod_index.size()[0]):
|
||||
i = eod_index[j]
|
||||
# Mask attention loss.
|
||||
if reset_attention_mask:
|
||||
attention_mask[b, 0, (i + 1) :, : (i + 1)] = 0
|
||||
# Reset positions.
|
||||
if reset_position_ids:
|
||||
position_ids[b, (i + 1) :] -= i + 1 - prev_index
|
||||
prev_index = i + 1
|
||||
|
||||
# Convert attention mask to binary:
|
||||
attention_mask = attention_mask < 0.5
|
||||
|
||||
return attention_mask, loss_mask, position_ids
|
||||
|
||||
|
||||
def get_parameters_in_billions(model):
|
||||
gpus_per_model = torch.distributed.get_world_size(
|
||||
group=mpu.get_model_parallel_group()
|
||||
)
|
||||
|
||||
approx_parameters_in_billions = sum(
|
||||
[
|
||||
sum(
|
||||
[
|
||||
p.ds_numel if hasattr(p, "ds_id") else p.nelement()
|
||||
for p in model_module.parameters()
|
||||
]
|
||||
)
|
||||
for model_module in model
|
||||
]
|
||||
)
|
||||
|
||||
return approx_parameters_in_billions * gpus_per_model / (1e9)
|
@ -0,0 +1,16 @@
|
||||
# CodeGeeX-13B configuration
|
||||
|
||||
CHECKPOINT_PATH="<path where you put the checkpoint (e.g., XXX/codegeex_13b.pt)>"
|
||||
|
||||
MODEL_ARGS="--num-layers 39 \
|
||||
--hidden-size 5120 \
|
||||
--num-attention-heads 40 \
|
||||
--max-position-embeddings 2048 \
|
||||
--attention-softmax-in-fp32 \
|
||||
--load "$CHECKPOINT_PATH" \
|
||||
--layernorm-epsilon 1e-5 \
|
||||
--fp16 \
|
||||
--ws-encoding-start-id 10 \
|
||||
--ws-encoding-length 10 \
|
||||
--make-vocab-size-divisible-by 52224 \
|
||||
--seq-length 2048"
|
After Width: | Height: | Size: 482 KiB |
After Width: | Height: | Size: 282 KiB |
After Width: | Height: | Size: 234 KiB |
After Width: | Height: | Size: 394 KiB |
After Width: | Height: | Size: 431 KiB |
Before Width: | Height: | Size: 203 KiB |
Before Width: | Height: | Size: 1.1 MiB After Width: | Height: | Size: 846 KiB |
Before Width: | Height: | Size: 73 KiB After Width: | Height: | Size: 45 KiB |
Before Width: | Height: | Size: 195 KiB |
Before Width: | Height: | Size: 1.1 MiB After Width: | Height: | Size: 713 KiB |
@ -0,0 +1,28 @@
|
||||
# This script is used to convert mindspore checkpoint to the megatron format.
|
||||
|
||||
NPY_CKPT_PATH=$1 # Path to Mindspore exported weights in .npy format.
|
||||
SAVE_CKPT_PATH=$2 # Path to save the output .pt checkpoint.
|
||||
GPU=$3
|
||||
|
||||
SCRIPT_PATH=$(realpath "$0")
|
||||
SCRIPT_DIR=$(dirname "$SCRIPT_PATH")
|
||||
MAIN_DIR=$(dirname "$SCRIPT_DIR")
|
||||
TOKENIZER_PATH="$MAIN_DIR/codegeex/tokenizer/"
|
||||
|
||||
# export CUDA settings
|
||||
if [ -z "$GPU" ]; then
|
||||
GPU=0
|
||||
fi
|
||||
|
||||
export CUDA_HOME=/usr/local/cuda-11.1/
|
||||
export CUDA_VISIBLE_DEVICES=$GPU
|
||||
|
||||
|
||||
CMD="python $MAIN_DIR/codegeex/megatron/mindspore_to_megatron.py \
|
||||
--npy-ckpt-path $NPY_CKPT_PATH \
|
||||
--save-ckpt-path $SAVE_CKPT_PATH \
|
||||
--tokenizer-path $TOKENIZER_PATH \
|
||||
$MODEL_ARGS"
|
||||
|
||||
echo "$CMD"
|
||||
eval "$CMD"
|
@ -0,0 +1,95 @@
|
||||
# This script is used to generate solutions of HumanEval-X.
|
||||
|
||||
LANGUAGE=$1 # Target programming language, currently support one of ["python", "java", "cpp", "js", "go"]
|
||||
OUTPUT_PATH=$2 # Output path of the generated programs.
|
||||
HOSTLIST=$3 # Provide hostfile if generating distributedly
|
||||
|
||||
SCRIPT_PATH=$(realpath "$0")
|
||||
SCRIPT_DIR=$(dirname "$SCRIPT_PATH")
|
||||
MAIN_DIR=$(dirname "$SCRIPT_DIR")
|
||||
TOKENIZER_PATH="$MAIN_DIR/codegeex/tokenizer/"
|
||||
|
||||
# export CUDA settings
|
||||
export CUDA_HOME=/usr/local/cuda-11.1/
|
||||
|
||||
# import model configuration
|
||||
source "$MAIN_DIR/configs/codegeex_13b.sh"
|
||||
|
||||
# nccl options
|
||||
OPTIONS_NCCL="export NCCL_DEBUG=warn; export NCCL_IB_DISABLE=0; export NCCL_IB_GID_INDEX=3"
|
||||
OPTIONS_PATH="export PATH=$PATH; export LD_LIBRARY_PATH=$LD_LIBRARY_PATH"
|
||||
CWD=$(pwd)
|
||||
|
||||
# set master ip for zmq server
|
||||
if [ -z "$HOSTLIST" ]; then
|
||||
ZMQ_ADDR=$(hostname -i)
|
||||
echo "$ZMQ_ADDR" > "./hostfile"
|
||||
HOSTLIST="./hostfile"
|
||||
else
|
||||
ZMQ_ADDR=$(cat $HOSTLIST | head -n 1)
|
||||
fi
|
||||
echo "master_ip: $ZMQ_ADDR"
|
||||
|
||||
NUM_SAMPLES=1
|
||||
MICRO_BSZ=1
|
||||
WORLD_SIZE=1
|
||||
TEMP=0.8
|
||||
TOPP=0.95
|
||||
SEED=42
|
||||
DATASET=humaneval
|
||||
TODAY=$(date +%y%m%d)
|
||||
CHANNEL_PORT=$(expr $RANDOM + 5000)
|
||||
MASTER_PORT=$(expr $RANDOM + 8000)
|
||||
|
||||
# save log file
|
||||
LOG_DIR=$MAIN_DIR/log
|
||||
mkdir -p "$LOG_DIR"
|
||||
LOG_PATH="$LOG_DIR/$TODAY-generation.log"
|
||||
|
||||
if [ -z "$LANGUAGE" ]; then
|
||||
LANGUAGE=python
|
||||
fi
|
||||
|
||||
if [ -z "$INPUT_PATH" ]; then
|
||||
INPUT_PATH=$MAIN_DIR/codegeex/benchmark/humaneval-x/$LANGUAGE/data/humaneval_$LANGUAGE.jsonl.gz
|
||||
fi
|
||||
|
||||
if [ -z "$OUTPUT_PATH" ]; then
|
||||
OUTPUT_PATH=$MAIN_DIR/codegeex/benchmark/output/humaneval-x/codegeex/
|
||||
mkdir -p "$OUTPUT_PATH"
|
||||
fi
|
||||
|
||||
JOB_ID=codegeex-ns$NUM_SAMPLES-t$TEMP-topp$TOPP-seed$SEED-$LANGUAGE
|
||||
|
||||
RUN_CMD="python \
|
||||
$MAIN_DIR/codegeex/benchmark/humaneval-x/generate_humaneval_x.py \
|
||||
--hostfile $HOSTLIST \
|
||||
--channel-ip $ZMQ_ADDR \
|
||||
--channel-port $CHANNEL_PORT \
|
||||
--master-port $MASTER_PORT \
|
||||
--tokenizer-path $TOKENIZER_PATH \
|
||||
--load-deepspeed \
|
||||
--temperature $TEMP \
|
||||
--top-p $TOPP \
|
||||
--out-seq-length 1024 \
|
||||
--micro-batch-size $MICRO_BSZ \
|
||||
--samples-per-problem $NUM_SAMPLES \
|
||||
--language-type $LANGUAGE \
|
||||
--dataset $DATASET \
|
||||
--input-path $INPUT_PATH \
|
||||
--output-prefix $OUTPUT_PATH/$JOB_ID \
|
||||
--gen-node-world-size $WORLD_SIZE \
|
||||
--seed $SEED \
|
||||
$MODEL_ARGS"
|
||||
|
||||
RUN_CMD="$OPTIONS_NCCL; $OPTIONS_PATH; $RUN_CMD"
|
||||
RUN_CMD="cd $CWD; $RUN_CMD"
|
||||
|
||||
if (( WORLD_SIZE != 1 )); then
|
||||
RUN_CMD="pdsh -R ssh -w ^$HOSTLIST \"$RUN_CMD\""
|
||||
fi
|
||||
|
||||
echo "$RUN_CMD"
|
||||
echo "Writing log to $LOG_PATH"
|
||||
eval "$RUN_CMD" > "$LOG_PATH"
|
||||
bash $MAIN_DIR/scripts/gather_output.sh $OUTPUT_PATH $JOB_ID 1
|
@ -0,0 +1,39 @@
|
||||
# This script is used to test the inference of CodeGeeX.
|
||||
|
||||
GPU=$1
|
||||
PROMPT_FILE=$2
|
||||
|
||||
SCRIPT_PATH=$(realpath "$0")
|
||||
SCRIPT_DIR=$(dirname "$SCRIPT_PATH")
|
||||
MAIN_DIR=$(dirname "$SCRIPT_DIR")
|
||||
TOKENIZER_PATH="$MAIN_DIR/codegeex/tokenizer/"
|
||||
|
||||
# import model configuration
|
||||
source "$MAIN_DIR/configs/codegeex_13b.sh"
|
||||
|
||||
# export CUDA settings
|
||||
if [ -z "$GPU" ]; then
|
||||
GPU=0
|
||||
fi
|
||||
|
||||
export CUDA_HOME=/usr/local/cuda-11.1/
|
||||
export CUDA_VISIBLE_DEVICES=$GPU
|
||||
|
||||
if [ -z "$PROMPT_FILE" ]; then
|
||||
PROMPT_FILE=$MAIN_DIR/tests/test_prompt.txt
|
||||
fi
|
||||
|
||||
# remove --greedy if using sampling
|
||||
CMD="python $MAIN_DIR/tests/test_inference.py \
|
||||
--prompt-file $PROMPT_FILE \
|
||||
--tokenizer-path $TOKENIZER_PATH \
|
||||
--micro-batch-size 1 \
|
||||
--out-seq-length 1024 \
|
||||
--temperature 0.8 \
|
||||
--top-p 0.95 \
|
||||
--top-k 100 \
|
||||
--greedy \
|
||||
$MODEL_ARGS"
|
||||
|
||||
echo "$CMD"
|
||||
eval "$CMD"
|
@ -0,0 +1,110 @@
|
||||
# This script is used to translate solutions of HumanEval-X.
|
||||
|
||||
LANG_SRC_TYPE=$1 # Source programming language, currently support one of ["python", "java", "cpp", "js", "go"]
|
||||
LANG_TGT_TYPE=$2 # Target programming language, currently support one of ["python", "java", "cpp", "js", "go"]
|
||||
OUTPUT_PATH=$3 # Output path of the generated programs.
|
||||
HOSTLIST=$4 # Provide hostfile if generating distributedly
|
||||
|
||||
SCRIPT_PATH=$(realpath "$0")
|
||||
SCRIPT_DIR=$(dirname "$SCRIPT_PATH")
|
||||
MAIN_DIR=$(dirname "$SCRIPT_DIR")
|
||||
TOKENIZER_PATH="$MAIN_DIR/codegeex/tokenizer/"
|
||||
|
||||
# export CUDA settings
|
||||
export CUDA_HOME=/usr/local/cuda-11.1/
|
||||
|
||||
# import model configuration
|
||||
source "$MAIN_DIR/configs/codegeex_13b.sh"
|
||||
|
||||
# nccl options
|
||||
OPTIONS_NCCL="export NCCL_DEBUG=warn; export NCCL_IB_DISABLE=0; export NCCL_IB_GID_INDEX=3"
|
||||
OPTIONS_PATH="export PATH=$PATH; export LD_LIBRARY_PATH=$LD_LIBRARY_PATH"
|
||||
CWD=$(pwd)
|
||||
|
||||
# set master ip for zmq server
|
||||
if [ -z "$HOSTLIST" ]; then
|
||||
ZMQ_ADDR=$(hostname -i)
|
||||
echo "$ZMQ_ADDR" > "./hostfile"
|
||||
HOSTLIST="./hostfile"
|
||||
else
|
||||
ZMQ_ADDR=$(cat $HOSTLIST | head -n 1)
|
||||
fi
|
||||
echo "master_ip: $ZMQ_ADDR"
|
||||
|
||||
NUM_SAMPLES=1
|
||||
MICRO_BSZ=1
|
||||
WORLD_SIZE=1
|
||||
TEMP=0.8
|
||||
TOPP=0.95
|
||||
SEED=42
|
||||
DATASET=humaneval
|
||||
TODAY=$(date +%y%m%d)
|
||||
CHANNEL_PORT=$(expr $RANDOM + 5000)
|
||||
MASTER_PORT=$(expr $RANDOM + 8000)
|
||||
|
||||
# save log file
|
||||
LOG_DIR=$MAIN_DIR/log
|
||||
mkdir -p "$LOG_DIR"
|
||||
LOG_PATH="$LOG_DIR/$TODAY-translation.log"
|
||||
|
||||
if [ -z "$LANG_SRC_TYPE" ]
|
||||
then
|
||||
LANG_SRC_TYPE=python
|
||||
fi
|
||||
|
||||
if [ -z "$LANG_TGT_TYPE" ]
|
||||
then
|
||||
LANG_TGT_TYPE=java
|
||||
fi
|
||||
|
||||
if [ -z "$INPUT_SRC_PATH" ]
|
||||
then
|
||||
INPUT_SRC_PATH=$MAIN_DIR/codegeex/benchmark/humaneval-x/$LANG_SRC_TYPE/data/humaneval_$LANG_SRC_TYPE.jsonl.gz
|
||||
fi
|
||||
|
||||
if [ -z "$INPUT_TGT_PATH" ]
|
||||
then
|
||||
INPUT_TGT_PATH=$MAIN_DIR/codegeex/benchmark/humaneval-x/$LANG_TGT_TYPE/data/humaneval_$LANG_TGT_TYPE.jsonl.gz
|
||||
fi
|
||||
|
||||
if [ -z "$OUTPUT_PATH" ]; then
|
||||
OUTPUT_PATH=$MAIN_DIR/codegeex/benchmark/output/humaneval-x/codegeex/
|
||||
mkdir -p "$OUTPUT_PATH"
|
||||
fi
|
||||
|
||||
JOB_ID=codegeex-ns$NUM_SAMPLES-t$TEMP-topp$TOPP-seed$SEED-$LANGUAGE
|
||||
|
||||
RUN_CMD="python \
|
||||
$MAIN_DIR/codegeex/benchmark/humaneval-x/translate_humaneval_x.py \
|
||||
--hostfile $HOSTLIST \
|
||||
--channel-ip $ZMQ_ADDR \
|
||||
--channel-port $CHANNEL_PORT \
|
||||
--master-port $MASTER_PORT \
|
||||
--tokenizer-path $TOKENIZER_PATH \
|
||||
--load-deepspeed \
|
||||
--temperature $TEMP \
|
||||
--top-p $TOPP \
|
||||
--out-seq-length 1024 \
|
||||
--micro-batch-size $MICRO_BSZ \
|
||||
--samples-per-problem $NUM_SAMPLES \
|
||||
--language-src-type $LANG_SRC_TYPE \
|
||||
--language-tgt-type $LANG_TGT_TYPE \
|
||||
--src-path $INPUT_SRC_PATH \
|
||||
--tgt-path $INPUT_TGT_PATH \
|
||||
--dataset $DATASET \
|
||||
--output-prefix $OUTPUT_PATH/$JOB_ID \
|
||||
--gen-node-world-size $WORLD_SIZE \
|
||||
--seed $SEED \
|
||||
$MODEL_ARGS"
|
||||
|
||||
RUN_CMD="$OPTIONS_NCCL; $OPTIONS_PATH; $RUN_CMD"
|
||||
RUN_CMD="cd $CWD; $RUN_CMD"
|
||||
|
||||
if (( WORLD_SIZE != 1 )); then
|
||||
RUN_CMD="pdsh -R ssh -w ^$HOSTLIST \"$RUN_CMD\""
|
||||
fi
|
||||
|
||||
echo "$RUN_CMD"
|
||||
echo "Writing log to $LOG_PATH"
|
||||
eval "$RUN_CMD" > "$LOG_PATH"
|
||||
bash $MAIN_DIR/scripts/gather_output.sh $OUTPUT_PATH $JOB_ID 1
|
@ -0,0 +1,199 @@
|
||||
import os
|
||||
import copy
|
||||
import time
|
||||
import torch
|
||||
import random
|
||||
import numpy as np
|
||||
|
||||
from codegeex.megatron import get_tokenizer, get_args
|
||||
from codegeex.megatron.initialize import initialize_megatron
|
||||
from codegeex.megatron.model import CodeGeeXModel
|
||||
from codegeex.megatron.code_generation_utils import get_token_stream
|
||||
|
||||
torch.set_printoptions(precision=8)
|
||||
|
||||
|
||||
def set_random_seed(seed):
|
||||
np.random.seed(seed)
|
||||
torch.manual_seed(seed)
|
||||
torch.cuda.manual_seed_all(seed)
|
||||
|
||||
torch.backends.cudnn.deterministic = True
|
||||
torch.backends.cudnn.benchmark = False
|
||||
|
||||
|
||||
def model_provider():
|
||||
"""Build the model."""
|
||||
|
||||
model = CodeGeeXModel(num_tokentypes=0,
|
||||
parallel_output=False)
|
||||
|
||||
return model
|
||||
|
||||
|
||||
def add_code_generation_args(parser):
|
||||
"""Code generation arguments."""
|
||||
group = parser.add_argument_group(title="code generation")
|
||||
|
||||
group.add_argument(
|
||||
"--temperature",
|
||||
type=float,
|
||||
default=1.0,
|
||||
help="Sampling temperature.",
|
||||
)
|
||||
group.add_argument(
|
||||
"--greedy",
|
||||
action="store_true",
|
||||
default=False,
|
||||
help="Use greedy sampling.",
|
||||
)
|
||||
group.add_argument(
|
||||
"--top-p",
|
||||
type=float,
|
||||
default=0.0,
|
||||
help="Top p sampling.",
|
||||
)
|
||||
group.add_argument(
|
||||
"--top-k",
|
||||
type=int,
|
||||
default=0,
|
||||
help="Top k sampling.",
|
||||
)
|
||||
group.add_argument(
|
||||
"--out-seq-length",
|
||||
type=int,
|
||||
default=2048,
|
||||
help="Size of the output generated text.",
|
||||
)
|
||||
group.add_argument(
|
||||
"--recompute",
|
||||
action="store_true",
|
||||
help="During generation recompute all attention "
|
||||
"instead of using previously computed keys/values.",
|
||||
)
|
||||
group.add_argument(
|
||||
"--ws-encoding-start-id",
|
||||
type=int,
|
||||
default=10,
|
||||
help="Start id for whitespace encoding",
|
||||
)
|
||||
group.add_argument(
|
||||
"--ws-encoding-length",
|
||||
type=int,
|
||||
default=80,
|
||||
help="Length of whitespace encoding",
|
||||
)
|
||||
group.add_argument(
|
||||
"--n-generation",
|
||||
type=int,
|
||||
default=10,
|
||||
)
|
||||
group.add_argument(
|
||||
"--eos-id",
|
||||
type=int,
|
||||
default=50256,
|
||||
)
|
||||
group.add_argument(
|
||||
"--prompt-file",
|
||||
type=str,
|
||||
default="./test_prompt.txt",
|
||||
)
|
||||
group.add_argument(
|
||||
"--perf-file",
|
||||
type=str,
|
||||
default="./perf_out.txt",
|
||||
)
|
||||
group.add_argument(
|
||||
"--perf-trace",
|
||||
type=str,
|
||||
default="./perf_out.txt",
|
||||
)
|
||||
group.add_argument(
|
||||
"--use-torch-profile",
|
||||
action="store_true",
|
||||
)
|
||||
group.add_argument(
|
||||
"--ln-fp32",
|
||||
action="store_true",
|
||||
)
|
||||
group.add_argument(
|
||||
'--bad-ids',
|
||||
nargs="*",
|
||||
type=int,
|
||||
default=None,
|
||||
help='Identify the type of programming language to generate',
|
||||
)
|
||||
|
||||
return parser
|
||||
|
||||
|
||||
def main():
|
||||
os.environ["MASTER_ADDR"] = "localhost"
|
||||
os.environ["MASTER_PORT"] = str(random.randint(10000, 20000))
|
||||
|
||||
initialize_megatron(
|
||||
extra_args_provider=add_code_generation_args,
|
||||
)
|
||||
|
||||
args = get_args()
|
||||
set_random_seed(args.seed)
|
||||
|
||||
print("Loading tokenizer ...")
|
||||
tokenizer = get_tokenizer()
|
||||
|
||||
print("Loading state dict ...")
|
||||
state_dict = torch.load(args.load, map_location="cpu")
|
||||
state_dict = state_dict["module"]
|
||||
|
||||
print("Building CodeGeeX model ...")
|
||||
model = model_provider()
|
||||
model.load_state_dict(state_dict)
|
||||
model.eval()
|
||||
if args.fp16 and args.ln_fp16:
|
||||
model.half()
|
||||
model.cuda()
|
||||
|
||||
with open(args.prompt_file, "r") as f:
|
||||
prompt = f.readlines()
|
||||
prompt = "".join(prompt)
|
||||
|
||||
print("Generating ...")
|
||||
t0 = time.perf_counter()
|
||||
for prompt in [prompt]:
|
||||
tokens = tokenizer.tokenize(prompt)
|
||||
print(tokens)
|
||||
print("Current prompt:")
|
||||
print(prompt)
|
||||
n_token_prompt = len(tokens)
|
||||
print("N_token_prompt:", n_token_prompt)
|
||||
token_stream = get_token_stream(
|
||||
model,
|
||||
[copy.deepcopy(tokens) for _ in range(args.micro_batch_size)],
|
||||
micro_batch_size=args.micro_batch_size,
|
||||
bad_ids=args.bad_ids,
|
||||
)
|
||||
is_finished = [False for _ in range(args.micro_batch_size)]
|
||||
for i, generated in enumerate(token_stream):
|
||||
generated_tokens = generated[0]
|
||||
for j in range(args.micro_batch_size):
|
||||
if is_finished[j]:
|
||||
continue
|
||||
if generated_tokens[j].cpu().numpy()[-1] == tokenizer.eod or len(
|
||||
generated_tokens[j]) >= args.out_seq_length:
|
||||
is_finished[j] = True
|
||||
generated_tokens_ = generated_tokens[j].cpu().numpy().tolist()
|
||||
generated_code = tokenizer.detokenize(generated_tokens_[n_token_prompt:])
|
||||
t1 = time.perf_counter()
|
||||
print("Total generation time:", t1 - t0, "# Tokens:", len(generated_tokens_) - n_token_prompt)
|
||||
print(f"{(t1 - t0) / (len(generated_tokens_) - n_token_prompt)}s/token")
|
||||
print("================================= Generated code:")
|
||||
print(generated_code)
|
||||
t0 = time.perf_counter()
|
||||
if all(is_finished):
|
||||
break
|
||||
|
||||
print("Generation finished.")
|
||||
|
||||
|
||||
if __name__ == "__main__":
|
||||
main()
|
@ -0,0 +1,15 @@
|
||||
code translation
|
||||
Java:
|
||||
public class Solution {
|
||||
public static boolean hasCloseElements(int[] nums, int threshold) {
|
||||
for (int i = 0; i < nums.length - 1; i++) {
|
||||
for (int j = i + 1; j < nums.length; j++) {
|
||||
if (Math.abs(nums[i] - nums[j]) < threshold) {
|
||||
return true;
|
||||
}
|
||||
}
|
||||
}
|
||||
return false;
|
||||
}
|
||||
}
|
||||
Python:
|