mirror of https://github.com/THUDM/CodeGeeX.git
You cannot select more than 25 topics
Topics must start with a letter or number, can include dashes ('-') and can be up to 35 characters long.
1241 lines
46 KiB
Python
1241 lines
46 KiB
Python
# coding=utf-8
|
|
# Copyright (c) 2020, NVIDIA CORPORATION. All rights reserved.
|
|
#
|
|
# Licensed under the Apache License, Version 2.0 (the "License");
|
|
# you may not use this file except in compliance with the License.
|
|
# You may obtain a copy of the License at
|
|
#
|
|
# http://www.apache.org/licenses/LICENSE-2.0
|
|
#
|
|
# Unless required by applicable law or agreed to in writing, software
|
|
# distributed under the License is distributed on an "AS IS" BASIS,
|
|
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
|
|
# See the License for the specific language governing permissions and
|
|
# limitations under the License.
|
|
|
|
"""Utilities for generating text."""
|
|
|
|
import copy
|
|
import json
|
|
import os
|
|
import time
|
|
from typing import *
|
|
|
|
import torch
|
|
import torch.nn.functional as F
|
|
from dataclasses import dataclass
|
|
|
|
from codegeex.megatron import get_args
|
|
from codegeex.megatron import get_tokenizer
|
|
from codegeex.megatron import mpu
|
|
from codegeex.megatron.utils import get_ltor_masks_and_position_ids
|
|
|
|
|
|
def get_batch(context_tokens, micro_batch_size=None):
|
|
"""Generate batch from context tokens."""
|
|
args = get_args()
|
|
tokenizer = get_tokenizer()
|
|
|
|
# Move to GPU.
|
|
if micro_batch_size is None:
|
|
micro_batch_size = args.micro_batch_size
|
|
tokens = context_tokens.view(micro_batch_size, -1).contiguous().cuda()
|
|
# Get the attention mask and postition ids.
|
|
attention_mask, _, position_ids = get_ltor_masks_and_position_ids(
|
|
tokens,
|
|
tokenizer.eod,
|
|
args.reset_position_ids,
|
|
args.reset_attention_mask,
|
|
args.eod_mask_loss,
|
|
)
|
|
|
|
return tokens, attention_mask, position_ids
|
|
|
|
|
|
def get_batch_(context_tokens):
|
|
"""Generate batch from context tokens."""
|
|
args = get_args()
|
|
tokenizer = get_tokenizer()
|
|
|
|
# Move to GPU.
|
|
tokens = context_tokens.contiguous().cuda()
|
|
# Get the attention mask and postition ids.
|
|
attention_mask, _, position_ids = get_ltor_masks_and_position_ids(
|
|
tokens,
|
|
tokenizer.eod,
|
|
args.reset_position_ids,
|
|
args.reset_attention_mask,
|
|
args.eod_mask_loss,
|
|
)
|
|
|
|
return tokens, attention_mask, position_ids
|
|
|
|
|
|
def top_k_logits(logits, top_k=0, top_p=0.0, filter_value=-float("Inf")):
|
|
"""This function has been mostly taken from huggingface conversational
|
|
ai code at
|
|
https://medium.com/huggingface/how-to-build-a-state-of-the-art-
|
|
conversational-ai-with-transfer-learning-2d818ac26313"""
|
|
|
|
if top_k > 0:
|
|
# Remove all tokens with a probability less than the
|
|
# last token of the top-k
|
|
indices_to_remove = logits < torch.topk(logits, top_k)[0][..., -1, None]
|
|
logits[indices_to_remove] = filter_value
|
|
|
|
if top_p > 0.0:
|
|
# Cconvert to 1D
|
|
sorted_logits, sorted_indices = torch.sort(logits, descending=True, dim=-1)
|
|
cumulative_probs = torch.cumsum(F.softmax(sorted_logits, dim=-1), dim=-1)
|
|
|
|
# Remove tokens with cumulative probability above the threshold
|
|
sorted_indices_to_remove = cumulative_probs > top_p
|
|
# Shift the indices to the right to keep also the first token
|
|
# above the threshold
|
|
sorted_indices_to_remove[..., 1:] = sorted_indices_to_remove[..., :-1].clone()
|
|
sorted_indices_to_remove[..., 0] = 0
|
|
for i in range(sorted_indices.size(0)):
|
|
indices_to_remove = sorted_indices[i][sorted_indices_to_remove[i]]
|
|
logits[i][indices_to_remove] = filter_value
|
|
|
|
return logits
|
|
|
|
|
|
def generate_samples_input_from_file(model):
|
|
args = get_args()
|
|
tokenizer = get_tokenizer()
|
|
|
|
# Read the sample file and open the output file.
|
|
assert args.sample_input_file is not None, "sample input file is not provided."
|
|
if mpu.is_pipeline_first_stage() and mpu.get_tensor_model_parallel_rank() == 0:
|
|
fname = open(args.sample_input_file, "r")
|
|
all_raw_text = fname.readlines()
|
|
input_count = len(all_raw_text)
|
|
input_pos = 0
|
|
if args.sample_output_file is None:
|
|
sample_output_file = args.sample_input_file + ".out"
|
|
print(
|
|
"`sample-output-file` not specified, setting "
|
|
"it to {}".format(sample_output_file)
|
|
)
|
|
else:
|
|
sample_output_file = args.sample_output_file
|
|
fname_out = open(sample_output_file, "w+")
|
|
|
|
context_count = 0
|
|
model.eval()
|
|
with torch.no_grad():
|
|
while True:
|
|
terminate_runs = 0
|
|
raw_text_len = 0
|
|
|
|
if (
|
|
mpu.is_pipeline_first_stage()
|
|
and mpu.get_tensor_model_parallel_rank() == 0
|
|
):
|
|
raw_text = all_raw_text[input_pos]
|
|
input_pos += 1
|
|
if input_pos == input_count:
|
|
raw_text = "stop"
|
|
raw_text_len = len(raw_text)
|
|
|
|
if "stop" in raw_text:
|
|
terminate_runs = 1
|
|
else:
|
|
context_tokens = tokenizer.tokenize(raw_text)
|
|
context_length = len(context_tokens)
|
|
|
|
if context_length >= (args.seq_length // 2):
|
|
print(
|
|
"\nContext length",
|
|
context_length,
|
|
"\nPlease give smaller context (half of the "
|
|
"sequence length)!",
|
|
flush=True,
|
|
)
|
|
continue
|
|
else:
|
|
context_tokens = tokenizer.tokenize("EMPTY TEXT")
|
|
context_length = 0
|
|
|
|
input_info = [terminate_runs, raw_text_len, context_length]
|
|
input_info_tensor = torch.cuda.LongTensor(input_info)
|
|
torch.distributed.all_reduce(
|
|
input_info_tensor, group=mpu.get_model_parallel_group()
|
|
)
|
|
terminate_runs = input_info_tensor[0].item()
|
|
raw_text_len = input_info_tensor[1].item()
|
|
context_length = input_info_tensor[2].item()
|
|
|
|
if terminate_runs == 1:
|
|
return
|
|
|
|
# For pipeline parallel we send context tokens to other stages
|
|
# so they get the lengths correct
|
|
if (
|
|
mpu.get_tensor_model_parallel_rank() == 0
|
|
and args.pipeline_model_parallel_size > 1
|
|
):
|
|
if mpu.is_pipeline_first_stage():
|
|
src = mpu.get_pipeline_model_parallel_first_rank()
|
|
group = mpu.get_pipeline_model_parallel_group()
|
|
context_tokens_tensor = torch.cuda.LongTensor(context_tokens)
|
|
torch.distributed.broadcast(context_tokens_tensor, src, group)
|
|
else:
|
|
src = mpu.get_pipeline_model_parallel_first_rank()
|
|
group = mpu.get_pipeline_model_parallel_group()
|
|
context_tokens_tensor = torch.empty(
|
|
context_length, dtype=torch.int64, device=torch.device("cuda")
|
|
)
|
|
torch.distributed.broadcast(context_tokens_tensor, src, group)
|
|
context_tokens = context_tokens_tensor.cpu().numpy().tolist()
|
|
|
|
token_stream = get_token_stream(model, [context_tokens])
|
|
for _, decode_tokens in enumerate(token_stream):
|
|
pass
|
|
|
|
if mpu.get_tensor_model_parallel_rank() == 0:
|
|
if mpu.is_pipeline_first_stage():
|
|
os.system("clear")
|
|
print("\nContext:", raw_text, flush=True)
|
|
|
|
fname_out.write("\nContext:")
|
|
fname_out.write(raw_text)
|
|
|
|
decode_tokens, _ = decode_tokens
|
|
decode_tokens = decode_tokens[0].cpu().numpy().tolist()
|
|
trim_decode_tokens = tokenizer.detokenize(decode_tokens)[
|
|
raw_text_len:
|
|
]
|
|
print("\nMegatron-LM:", trim_decode_tokens, flush=True)
|
|
|
|
fname_out.write("\n\nMegatron-LM:")
|
|
fname_out.write(trim_decode_tokens)
|
|
fname_out.write("\n")
|
|
|
|
raw_text = None
|
|
context_count += 1
|
|
|
|
|
|
# We added this function to support the tasks evaluation such as squad
|
|
# and drop in the https://github.com/EleutherAI/lm-evaluation-harness
|
|
# codebase. The lm-evaluation-harness code can now call this function
|
|
# similar to their current generate function call used for gpt style models.
|
|
def generate_samples_eval(model, context, max_gen_length, eos_token_id):
|
|
# Generate samples for lm evaluation
|
|
# NEED TO THINK ABOUT eos token
|
|
|
|
args = get_args()
|
|
tokenizer = get_tokenizer()
|
|
|
|
raw_text_len = len(context)
|
|
model.eval()
|
|
|
|
context_tokens = tokenizer.tokenize(context)
|
|
args.out_seq_length = max_gen_length + len(context_tokens)
|
|
args.eos_id = eos_token_id
|
|
|
|
with torch.no_grad():
|
|
token_stream = get_token_stream(model, [context_tokens])
|
|
for counter, decode_tokens in enumerate(token_stream):
|
|
if counter == args.out_seq_length:
|
|
break
|
|
|
|
decode_tokens, _ = decode_tokens
|
|
decode_tokens = decode_tokens[0].cpu().numpy().tolist()
|
|
trim_decode_tokens = tokenizer.detokenize(decode_tokens)[raw_text_len:]
|
|
|
|
return trim_decode_tokens
|
|
|
|
|
|
def generate_samples_interactive_code_contest(model, print_frequency=10):
|
|
args = get_args()
|
|
tokenizer = get_tokenizer()
|
|
|
|
context_count = 0
|
|
model.eval()
|
|
with torch.no_grad():
|
|
while True:
|
|
terminate_runs = 0
|
|
raw_text_len = 0
|
|
|
|
if (
|
|
mpu.is_pipeline_first_stage()
|
|
and mpu.get_tensor_model_parallel_rank() == 0
|
|
):
|
|
# os.system("clear")
|
|
raw_text = []
|
|
input_line = input("\nContext prompt (EOF to exit) >>> ")
|
|
|
|
if input_line == ":recompute":
|
|
args.recompute = True
|
|
print(f"set recompute: {args.recompute}")
|
|
continue
|
|
|
|
if input_line == ":no-recompute":
|
|
args.recompute = False
|
|
print(f"set recompute: {args.recompute}")
|
|
continue
|
|
|
|
while input_line != "EOF":
|
|
raw_text.append(input_line)
|
|
input_line = input("\nContext prompt (EOF to exit) >>> ")
|
|
raw_text = "\n".join(raw_text)
|
|
|
|
raw_text_len = len(raw_text)
|
|
|
|
if "stop" in raw_text:
|
|
# terminate_runs = 1
|
|
pass
|
|
else:
|
|
context_tokens = tokenizer.tokenize(raw_text)
|
|
context_length = len(context_tokens)
|
|
|
|
if context_length >= (args.seq_length // 2):
|
|
print(
|
|
"\nContext length",
|
|
context_length,
|
|
"\nPlease give smaller context (half of the "
|
|
"sequence length)!",
|
|
flush=True,
|
|
)
|
|
continue
|
|
else:
|
|
context_tokens = tokenizer.tokenize("EMPTY TEXT")
|
|
context_length = 0
|
|
|
|
input_info = [terminate_runs, raw_text_len, context_length]
|
|
input_info_tensor = torch.cuda.LongTensor(input_info)
|
|
torch.distributed.all_reduce(
|
|
input_info_tensor, group=mpu.get_model_parallel_group()
|
|
)
|
|
terminate_runs = input_info_tensor[0].item()
|
|
raw_text_len = input_info_tensor[1].item()
|
|
context_length = input_info_tensor[2].item()
|
|
|
|
if terminate_runs == 1:
|
|
return
|
|
|
|
# For pipeline parallel we send context tokens to other stages
|
|
# so they get the lengths correct
|
|
if (
|
|
mpu.get_tensor_model_parallel_rank() == 0
|
|
and args.pipeline_model_parallel_size > 1
|
|
):
|
|
if mpu.is_pipeline_first_stage():
|
|
src = mpu.get_pipeline_model_parallel_first_rank()
|
|
group = mpu.get_pipeline_model_parallel_group()
|
|
context_tokens_tensor = torch.cuda.LongTensor(context_tokens)
|
|
torch.distributed.broadcast(context_tokens_tensor, src, group)
|
|
else:
|
|
src = mpu.get_pipeline_model_parallel_first_rank()
|
|
group = mpu.get_pipeline_model_parallel_group()
|
|
context_tokens_tensor = torch.empty(
|
|
context_length, dtype=torch.int64, device=torch.device("cuda")
|
|
)
|
|
torch.distributed.broadcast(context_tokens_tensor, src, group)
|
|
context_tokens = context_tokens_tensor.cpu().numpy().tolist()
|
|
|
|
token_stream = get_token_stream(model, [context_tokens for _ in range(args.micro_batch_size)])
|
|
|
|
for counter, decode_tokens in enumerate(token_stream):
|
|
if (
|
|
counter % print_frequency != 0
|
|
or mpu.get_tensor_model_parallel_rank() != 0
|
|
or not mpu.is_pipeline_first_stage()
|
|
):
|
|
continue
|
|
|
|
os.system("clear")
|
|
print("\nContext:", raw_text, flush=True)
|
|
|
|
decode_tokens, _ = decode_tokens
|
|
decode_tokens = decode_tokens[0].cpu().numpy().tolist()
|
|
trim_decode_tokens = tokenizer.detokenize(decode_tokens)[raw_text_len:]
|
|
print(f"\nMegatron-LM (gen len: {counter}):", trim_decode_tokens, flush=True)
|
|
|
|
if (
|
|
mpu.is_pipeline_first_stage()
|
|
and mpu.get_tensor_model_parallel_rank() == 0
|
|
):
|
|
os.system("clear")
|
|
print("\nContext:", raw_text, flush=True)
|
|
|
|
if not isinstance(decode_tokens, list):
|
|
decode_tokens, _ = decode_tokens
|
|
decode_tokens = decode_tokens[0].cpu().numpy().tolist()
|
|
trim_decode_tokens = tokenizer.detokenize(decode_tokens)[raw_text_len:]
|
|
print("\nMegatron-LM:", trim_decode_tokens, flush=True)
|
|
|
|
input("\nPress Enter to continue >>>")
|
|
|
|
raw_text = None
|
|
context_count += 1
|
|
|
|
|
|
def generate_samples_interactive(model, print_frequency=24):
|
|
args = get_args()
|
|
tokenizer = get_tokenizer()
|
|
|
|
context_count = 0
|
|
model.eval()
|
|
with torch.no_grad():
|
|
while True:
|
|
terminate_runs = 0
|
|
raw_text_len = 0
|
|
|
|
if (
|
|
mpu.is_pipeline_first_stage()
|
|
and mpu.get_tensor_model_parallel_rank() == 0
|
|
):
|
|
os.system("clear")
|
|
raw_text = input("\nContext prompt (stop to exit) >>> ")
|
|
while not raw_text:
|
|
print("Prompt should not be empty!")
|
|
raw_text = input("\nContext prompt (stop to exit) >>> ")
|
|
raw_text_len = len(raw_text)
|
|
|
|
if "stop" in raw_text:
|
|
terminate_runs = 1
|
|
else:
|
|
context_tokens = tokenizer.tokenize(raw_text)
|
|
context_length = len(context_tokens)
|
|
|
|
if context_length >= (args.seq_length // 2):
|
|
print(
|
|
"\nContext length",
|
|
context_length,
|
|
"\nPlease give smaller context (half of the "
|
|
"sequence length)!",
|
|
flush=True,
|
|
)
|
|
continue
|
|
else:
|
|
context_tokens = tokenizer.tokenize("EMPTY TEXT")
|
|
context_length = 0
|
|
|
|
input_info = [terminate_runs, raw_text_len, context_length]
|
|
input_info_tensor = torch.cuda.LongTensor(input_info)
|
|
torch.distributed.all_reduce(
|
|
input_info_tensor, group=mpu.get_model_parallel_group()
|
|
)
|
|
terminate_runs = input_info_tensor[0].item()
|
|
raw_text_len = input_info_tensor[1].item()
|
|
context_length = input_info_tensor[2].item()
|
|
|
|
if terminate_runs == 1:
|
|
return
|
|
|
|
# For pipeline parallel we send context tokens to other stages
|
|
# so they get the lengths correct
|
|
if (
|
|
mpu.get_tensor_model_parallel_rank() == 0
|
|
and args.pipeline_model_parallel_size > 1
|
|
):
|
|
if mpu.is_pipeline_first_stage():
|
|
src = mpu.get_pipeline_model_parallel_first_rank()
|
|
group = mpu.get_pipeline_model_parallel_group()
|
|
context_tokens_tensor = torch.cuda.LongTensor(context_tokens)
|
|
torch.distributed.broadcast(context_tokens_tensor, src, group)
|
|
else:
|
|
src = mpu.get_pipeline_model_parallel_first_rank()
|
|
group = mpu.get_pipeline_model_parallel_group()
|
|
context_tokens_tensor = torch.empty(
|
|
context_length, dtype=torch.int64, device=torch.device("cuda")
|
|
)
|
|
torch.distributed.broadcast(context_tokens_tensor, src, group)
|
|
context_tokens = context_tokens_tensor.cpu().numpy().tolist()
|
|
|
|
token_stream = get_token_stream(model, [context_tokens])
|
|
|
|
for counter, decode_tokens in enumerate(token_stream):
|
|
if (
|
|
counter % print_frequency != 0
|
|
or mpu.get_tensor_model_parallel_rank() != 0
|
|
or not mpu.is_pipeline_first_stage()
|
|
):
|
|
continue
|
|
|
|
os.system("clear")
|
|
print("\nContext:", raw_text, flush=True)
|
|
|
|
decode_tokens, _ = decode_tokens
|
|
decode_tokens = decode_tokens[0].cpu().numpy().tolist()
|
|
trim_decode_tokens = tokenizer.detokenize(decode_tokens)[raw_text_len:]
|
|
print("\nMegatron-LM:", trim_decode_tokens, flush=True)
|
|
|
|
if (
|
|
mpu.is_pipeline_first_stage()
|
|
and mpu.get_tensor_model_parallel_rank() == 0
|
|
):
|
|
os.system("clear")
|
|
print("\nContext:", raw_text, flush=True)
|
|
|
|
if not isinstance(decode_tokens, list):
|
|
decode_tokens, _ = decode_tokens
|
|
decode_tokens = decode_tokens[0].cpu().numpy().tolist()
|
|
trim_decode_tokens = tokenizer.detokenize(decode_tokens)[raw_text_len:]
|
|
print("\nMegatron-LM:", trim_decode_tokens, flush=True)
|
|
|
|
input("\nPress Enter to continue >>>")
|
|
|
|
raw_text = None
|
|
context_count += 1
|
|
|
|
|
|
def generate_samples_unconditional(model):
|
|
args = get_args()
|
|
tokenizer = get_tokenizer()
|
|
|
|
num_samples = args.num_samples
|
|
context_tokens = [[tokenizer.eod] for _ in range(args.micro_batch_size)]
|
|
ctr = 0
|
|
while True:
|
|
start_time = time.time()
|
|
for token_stream in get_token_stream(model, copy.deepcopy(context_tokens)):
|
|
pass
|
|
if mpu.is_pipeline_last_stage() and mpu.get_tensor_model_parallel_rank() == 0:
|
|
if ctr % args.log_interval == 0:
|
|
print(
|
|
"Avg s/batch:",
|
|
(time.time() - start_time) / min(args.log_interval, ctr + 1),
|
|
)
|
|
start_time = time.time()
|
|
length = len(token_stream)
|
|
token_batch = token_stream[0].cpu().numpy().tolist()
|
|
length_batch = token_stream[1].cpu().numpy().tolist()
|
|
assert len(length_batch) == args.micro_batch_size
|
|
for tokens, length in zip(token_batch, length_batch):
|
|
tokens = tokens[1: length - 1]
|
|
text = tokenizer.detokenize(tokens)
|
|
is_finished = length < args.seq_length - 1
|
|
datum = {"text": text, "length": length - 1, "finished": is_finished}
|
|
yield datum
|
|
ctr += 1
|
|
if ctr >= num_samples:
|
|
break
|
|
else:
|
|
for _ in range(args.micro_batch_size):
|
|
yield None
|
|
ctr += 1
|
|
if ctr >= num_samples:
|
|
break
|
|
if ctr >= num_samples:
|
|
break
|
|
|
|
|
|
def generate_and_write_samples_unconditional(model):
|
|
args = get_args()
|
|
assert args.genfile is not None
|
|
with open(args.genfile, "w") as f:
|
|
for datum in generate_samples_unconditional(model):
|
|
if (
|
|
mpu.is_pipeline_last_stage()
|
|
and mpu.get_tensor_model_parallel_rank() == 0
|
|
):
|
|
f.write(json.dumps(datum) + "\n")
|
|
|
|
|
|
def pad_batch(batch, pad_id, args):
|
|
context_lengths = []
|
|
for tokens in batch:
|
|
context_length = len(tokens)
|
|
if context_length < args.seq_length:
|
|
tokens.extend([pad_id] * (args.seq_length - context_length))
|
|
context_lengths.append(context_length)
|
|
return batch, context_lengths
|
|
|
|
|
|
def topk_sampling(logits: torch.FloatTensor, num_samples: int):
|
|
"""
|
|
Samples from a multinomial distribution using the top-k sampling strategy.
|
|
|
|
Args:
|
|
logits: A tensor of shape (batch_size, vocab_size) containing the logits.
|
|
num_samples: The number of samples to draw.
|
|
"""
|
|
log_prob = F.log_softmax(logits, dim=-1)
|
|
topk = torch.topk(log_prob, num_samples, dim=-1)
|
|
topk_tokens = topk.indices
|
|
topk_log_prob = topk.values
|
|
|
|
return topk_tokens, topk_log_prob
|
|
|
|
|
|
def nuclear_sampling(logits: torch.FloatTensor, temperature: float, top_p: float = None, top_k: int = None):
|
|
orig_log_probs = F.log_softmax(logits, dim=-1)
|
|
logits /= temperature
|
|
logits = top_k_logits(logits, top_k, top_p)
|
|
log_probs = F.softmax(logits, dim=-1)
|
|
tokens = torch.multinomial(log_probs, num_samples=1).view(-1)
|
|
|
|
indices = tokens.view(-1, 1)
|
|
new_scores = orig_log_probs.gather(1, indices).view(-1)
|
|
|
|
return tokens, new_scores
|
|
|
|
|
|
def sample_topk_tokens(model,
|
|
input_tokens, attention_mask, position_ids,
|
|
context_length: int, num_samples: int):
|
|
assert context_length < input_tokens.shape[-1], "context_length must be smaller than seq_length"
|
|
|
|
model.eval()
|
|
with torch.no_grad():
|
|
output = forward_step(
|
|
model,
|
|
input_tokens,
|
|
position_ids,
|
|
attention_mask,
|
|
tokentype_ids=None,
|
|
forward_method_parallel_output=False,
|
|
)
|
|
assert output is not None
|
|
logits = output[:, context_length - 1, :]
|
|
|
|
return topk_sampling(logits, num_samples)
|
|
|
|
|
|
def nuclear_sample_tokens(model,
|
|
input_tokens, attention_mask, position_ids,
|
|
context_length: int, temperature: float, top_p: float, top_k: int):
|
|
assert context_length < input_tokens.shape[-1], "context_length must be smaller than seq_length"
|
|
|
|
model.eval()
|
|
with torch.no_grad():
|
|
output = forward_step(
|
|
model,
|
|
input_tokens,
|
|
position_ids,
|
|
attention_mask,
|
|
tokentype_ids=None,
|
|
forward_method_parallel_output=False,
|
|
)
|
|
assert output is not None
|
|
logits = output[:, context_length - 1, :]
|
|
return nuclear_sampling(logits, temperature, top_p, top_k)
|
|
|
|
|
|
@dataclass
|
|
class Beam:
|
|
tokens: List[int]
|
|
score: float
|
|
|
|
def __repr__(self):
|
|
return f"<Beam {repr(get_tokenizer().detokenize(self.tokens))}, score={self.score}>"
|
|
|
|
def get_code(self):
|
|
return get_tokenizer().detokenize(self.tokens)
|
|
|
|
|
|
def expand_beams(beams: List[Beam], num_beams: int, model) -> List[Beam]:
|
|
args = get_args()
|
|
tokenizer = get_tokenizer()
|
|
|
|
context_tokens = [b.tokens.copy() for b in beams]
|
|
context_tokens, context_lengths = pad_batch(context_tokens, tokenizer.eod, args)
|
|
|
|
context_lengths = set(context_lengths)
|
|
assert len(context_lengths) == 1, "context_lengths must be the same"
|
|
context_length = list(context_lengths)[0]
|
|
|
|
context_tokens_tensor = torch.cuda.LongTensor(context_tokens)
|
|
tokens, attention_mask, position_ids = get_batch_(context_tokens_tensor)
|
|
tokens, scores = sample_topk_tokens(model, tokens, attention_mask, position_ids, context_length, num_beams)
|
|
tokens = tokens.detach().cpu().tolist()
|
|
scores = scores.detach().cpu().tolist()
|
|
assert len(tokens) == len(beams), "output tokens and input beams must have the same length"
|
|
|
|
all_beams = []
|
|
for i in range(len(beams)):
|
|
this_tokens = tokens[i]
|
|
this_scores = scores[i]
|
|
|
|
for token, score in zip(this_tokens, this_scores):
|
|
all_beams.append(Beam(beams[i].tokens + [token], beams[i].score + score))
|
|
|
|
return all_beams
|
|
|
|
|
|
def beam_search(model, context_tokens, num_beams: int):
|
|
"""Beam search.
|
|
|
|
Note that this function does not support model parallel!
|
|
"""
|
|
args = get_args()
|
|
tokenizer = get_tokenizer()
|
|
|
|
assert not isinstance(context_tokens[0], list), "batched beam search not supported"
|
|
|
|
initial_beam = Beam(context_tokens, 0.0)
|
|
context_len = len(context_tokens)
|
|
org_context_len = context_len
|
|
finished_beams = []
|
|
|
|
# first expansion
|
|
beams = expand_beams([initial_beam], num_beams, model)
|
|
context_len += 1
|
|
|
|
# print(f"initial beam: {initial_beam}")
|
|
|
|
while len(beams) > 0 and context_len < args.seq_length:
|
|
expanded_beams = expand_beams(beams, num_beams, model)
|
|
next_beams = []
|
|
for beam in expanded_beams:
|
|
if args.beam_warmup_length > 0:
|
|
if len(beam.tokens) >= org_context_len + args.beam_warmup_length or beam.tokens[-1] == tokenizer.eod:
|
|
finished_beams.append(beam)
|
|
else:
|
|
next_beams.append(beam)
|
|
else:
|
|
if beam.tokens[-1] == tokenizer.eod:
|
|
finished_beams.append(beam)
|
|
else:
|
|
next_beams.append(beam)
|
|
# only keep top-k beams
|
|
next_beams.sort(key=lambda b: b.score, reverse=True)
|
|
beams = next_beams[:num_beams]
|
|
context_len += 1
|
|
|
|
if len(finished_beams) >= num_beams:
|
|
# first, only keep top-k beams
|
|
finished_beams.sort(key=lambda b: b.score, reverse=True)
|
|
finished_beams = finished_beams[:num_beams]
|
|
return finished_beams # return finished beams with highest scores
|
|
# stop if all currently expanding beams has a score lower than the minimal score of finished ones
|
|
min_score = min([b.score for b in finished_beams])
|
|
if min_score >= beams[0].score:
|
|
break
|
|
else:
|
|
print(f"we have got enough finished beams, but the minimal score is {min_score}")
|
|
print(f"and the maximum searching score is {beams[0].score}")
|
|
|
|
# return top-k finished and unfinished beams
|
|
all_beams = finished_beams + beams
|
|
all_beams.sort(key=lambda b: b.score, reverse=True)
|
|
|
|
return all_beams[:num_beams]
|
|
|
|
|
|
@dataclass
|
|
class Handle:
|
|
tokens: List[int]
|
|
score: float
|
|
|
|
def __repr__(self):
|
|
return f"<Handle {repr(get_tokenizer().detokenize(self.tokens))}, score={self.score}>"
|
|
|
|
def is_finished(self):
|
|
return len(self.tokens) and self.tokens[-1] == get_tokenizer().eod
|
|
|
|
def derived(self, new_token: int, log_prob: float):
|
|
assert not self.is_finished(), "cannot derive from a finished handle"
|
|
return Handle(self.tokens + [new_token], self.score + log_prob)
|
|
|
|
|
|
def expand_handles(handles: List[Handle], temperature: float, top_p: float, top_k: int, model):
|
|
args = get_args()
|
|
tokenizer = get_tokenizer()
|
|
|
|
context_tokens = [b.tokens.copy() for b in handles]
|
|
context_tokens, context_lengths = pad_batch(context_tokens, tokenizer.eod, args)
|
|
|
|
context_lengths = set(context_lengths)
|
|
assert len(context_lengths) == 1, "context_lengths must be the same"
|
|
context_length = list(context_lengths)[0]
|
|
|
|
context_tokens_tensor = torch.cuda.LongTensor(context_tokens)
|
|
tokens, attention_mask, position_ids = get_batch_(context_tokens_tensor)
|
|
tokens, scores = nuclear_sample_tokens(model, tokens, attention_mask, position_ids, context_length, temperature,
|
|
top_p, top_k)
|
|
tokens = tokens.detach().cpu().tolist()
|
|
scores = scores.detach().cpu().tolist()
|
|
assert len(tokens) == len(handles), "output tokens and input must have the same length"
|
|
|
|
all_beams = []
|
|
for i in range(len(handles)):
|
|
this_tokens = tokens[i]
|
|
this_scores = scores[i]
|
|
|
|
all_beams.append(handles[i].derived(this_tokens, this_scores))
|
|
|
|
return all_beams
|
|
|
|
|
|
def generate_nuclear_sampling(model, context_tokens, num_samples: int, temperature: float, top_p: float, top_k: int):
|
|
"""Beam search.
|
|
|
|
Note that this function does not support model parallel!
|
|
"""
|
|
args = get_args()
|
|
tokenizer = get_tokenizer()
|
|
|
|
assert not isinstance(context_tokens[0], list), "batched beam search not supported"
|
|
|
|
handles = [Handle(tokens=context_tokens, score=0) for _ in range(num_samples)]
|
|
context_len = len(context_tokens)
|
|
finished_handles = []
|
|
|
|
while len(handles) > 0 and context_len < args.seq_length:
|
|
expanded_handles = expand_handles(handles, temperature, top_p, top_k, model)
|
|
|
|
new_handles = []
|
|
for h in expanded_handles:
|
|
if h.is_finished():
|
|
finished_handles.append(h)
|
|
else:
|
|
new_handles.append(h)
|
|
|
|
context_len += 1
|
|
handles = new_handles
|
|
|
|
return handles + finished_handles
|
|
|
|
|
|
def forward_step(
|
|
model,
|
|
tokens,
|
|
position_ids,
|
|
attention_mask,
|
|
tokentype_ids,
|
|
layer_past=None,
|
|
get_key_value=None,
|
|
forward_method_parallel_output=None,
|
|
prompt_length=None,
|
|
context_length=None,
|
|
):
|
|
# Hidden size changes when not using recompute, need to tell p2p_communicate
|
|
# functions the correct size
|
|
args = get_args()
|
|
orig_seq_length = args.seq_length
|
|
args.seq_length = tokens.shape[1]
|
|
|
|
# Forward pass through the model.
|
|
output_tensor = model(
|
|
tokens,
|
|
position_ids,
|
|
attention_mask,
|
|
tokentype_ids=tokentype_ids,
|
|
layer_past=layer_past,
|
|
get_key_value=get_key_value,
|
|
prompt_length=prompt_length,
|
|
context_length=context_length,
|
|
)
|
|
|
|
if get_key_value:
|
|
output_tensor, layer_past = output_tensor
|
|
|
|
args.seq_length = orig_seq_length
|
|
if get_key_value:
|
|
return output_tensor, layer_past
|
|
|
|
return output_tensor
|
|
|
|
|
|
def get_token_stream(
|
|
model,
|
|
context_tokens,
|
|
return_scores: bool = False,
|
|
prompt_length: int = None,
|
|
micro_batch_size: int = None,
|
|
bad_ids: List = None,
|
|
temperature: float = None,
|
|
topp: float = None,
|
|
topk: int = None,
|
|
beam_warmup: bool = False,
|
|
):
|
|
args = get_args()
|
|
tokenizer = get_tokenizer()
|
|
|
|
context_tokens, context_lengths = pad_batch(context_tokens, tokenizer.eod, args)
|
|
|
|
context_tokens_tensor = torch.cuda.LongTensor(context_tokens)
|
|
context_length_tensor = torch.cuda.LongTensor(context_lengths)
|
|
|
|
torch.distributed.broadcast(
|
|
context_length_tensor,
|
|
mpu.get_tensor_model_parallel_src_rank(),
|
|
group=mpu.get_tensor_model_parallel_group(),
|
|
)
|
|
torch.distributed.broadcast(
|
|
context_tokens_tensor,
|
|
mpu.get_tensor_model_parallel_src_rank(),
|
|
group=mpu.get_tensor_model_parallel_group(),
|
|
)
|
|
|
|
context_length = context_length_tensor.min().item()
|
|
tokens, attention_mask, position_ids = get_batch(context_tokens_tensor, micro_batch_size)
|
|
|
|
if beam_warmup:
|
|
batch_token_iterator = sample_sequence_batch_beam(
|
|
model,
|
|
context_tokens_tensor,
|
|
context_length_tensor,
|
|
attention_mask,
|
|
position_ids,
|
|
return_scores=return_scores,
|
|
prompt_length=prompt_length,
|
|
bad_ids=bad_ids,
|
|
temperature=temperature,
|
|
topp=topp,
|
|
topk=topk,
|
|
beam_warmup=True,
|
|
)
|
|
else:
|
|
batch_token_iterator = sample_sequence_batch(
|
|
model,
|
|
context_tokens_tensor,
|
|
context_length_tensor,
|
|
attention_mask,
|
|
position_ids,
|
|
return_scores=return_scores,
|
|
prompt_length=prompt_length,
|
|
bad_ids=bad_ids,
|
|
temperature=temperature,
|
|
topp=topp,
|
|
topk=topk,
|
|
)
|
|
|
|
for tokens, lengths in batch_token_iterator:
|
|
context_length += 1
|
|
if tokens is not None:
|
|
yield tokens[:, :context_length], lengths
|
|
else:
|
|
yield None, None
|
|
|
|
|
|
def switch(val1, val2, boolean):
|
|
boolean = boolean.type_as(val1)
|
|
return (1 - boolean) * val1 + boolean * val2
|
|
|
|
|
|
def sample_sequence_batch(
|
|
model,
|
|
context_tokens,
|
|
context_lengths,
|
|
attention_mask,
|
|
position_ids,
|
|
maxlen=None,
|
|
type_ids=None,
|
|
return_scores: bool = False,
|
|
prompt_length: int = None,
|
|
bad_ids: List = None,
|
|
temperature: float = None,
|
|
topp: float = None,
|
|
topk: int = None,
|
|
):
|
|
args = get_args()
|
|
tokenizer = get_tokenizer()
|
|
temperature = temperature if temperature is not None else args.temperature
|
|
topp = topp if topp is not None else args.top_p
|
|
topk = topk if topk is not None else args.top_k
|
|
|
|
model.eval()
|
|
with torch.no_grad():
|
|
context_length = context_lengths.min().item()
|
|
|
|
# added eos_id to support the function generate_samples_eval that passes
|
|
# eos_id as an argument and needs termination when that id id found.
|
|
if hasattr(args, "eos_id"):
|
|
eos_id = args.eos_id
|
|
else:
|
|
eos_id = tokenizer.eod
|
|
|
|
counter = 0
|
|
org_context_length = context_length
|
|
|
|
layer_past = None
|
|
batch_size = context_tokens.size(0)
|
|
is_done = torch.zeros([batch_size]).byte().cuda()
|
|
tokens = context_tokens
|
|
if maxlen is None:
|
|
maxlen = args.seq_length - 1
|
|
if maxlen > (org_context_length + args.out_seq_length):
|
|
maxlen = org_context_length + args.out_seq_length
|
|
|
|
lengths = torch.ones([batch_size]).long().cuda() * maxlen
|
|
if return_scores:
|
|
scores = torch.zeros([batch_size]).float().cuda()
|
|
|
|
while context_length <= (maxlen):
|
|
|
|
if args.recompute:
|
|
logits = model(tokens,
|
|
position_ids,
|
|
attention_mask,
|
|
tokentype_ids=type_ids,
|
|
forward_method_parallel_output=False,
|
|
prompt_length=prompt_length,
|
|
context_length=context_length,
|
|
)
|
|
logits = logits[:, context_length - 1, :]
|
|
else:
|
|
types2use = None
|
|
if counter == 0:
|
|
tokens2use = tokens[:, :context_length]
|
|
positions2use = position_ids[:, :context_length]
|
|
if type_ids is not None:
|
|
types2use = type_ids[:, :context_length]
|
|
else:
|
|
tokens2use = tokens[:, context_length - 1].view(
|
|
batch_size, -1)
|
|
positions2use = position_ids[:, context_length - 1].view(
|
|
batch_size, -1)
|
|
if type_ids is not None:
|
|
types2use = type_ids[:, context_length - 1].view(
|
|
batch_size, -1)
|
|
logits, layer_past = model(tokens2use,
|
|
positions2use,
|
|
attention_mask,
|
|
layer_past=layer_past,
|
|
get_key_value=True,
|
|
tokentype_ids=types2use,
|
|
forward_method_parallel_output=False,
|
|
prompt_length=prompt_length,
|
|
context_length=context_length,
|
|
)
|
|
logits = logits[:, -1].view(batch_size, -1).contiguous()
|
|
|
|
if mpu.is_pipeline_last_stage():
|
|
if bad_ids is not None:
|
|
for bad_id in bad_ids:
|
|
logits[:, bad_id] = -10000
|
|
if args.greedy:
|
|
prev = torch.argmax(logits, dim=-1).view(-1)
|
|
else:
|
|
logits = logits.float()
|
|
if return_scores:
|
|
orig_log_probs = torch.log_softmax(logits, dim=-1)
|
|
logits /= temperature
|
|
logits = top_k_logits(logits, top_k=topk, top_p=topp)
|
|
log_probs = F.softmax(logits, dim=-1)
|
|
prev = torch.multinomial(log_probs, num_samples=1).view(-1)
|
|
|
|
started = context_lengths <= context_length
|
|
|
|
new_tokens = switch(tokens[:, context_length].view(-1), prev, started)
|
|
|
|
if not args.greedy and return_scores:
|
|
indices = prev.view(-1, 1)
|
|
new_scores = orig_log_probs.gather(1, indices).view(-1)
|
|
new_scores = new_scores * started
|
|
new_scores = new_scores * is_done.bool().logical_not()
|
|
scores += new_scores
|
|
|
|
tokens[:, context_length] = new_tokens
|
|
src = mpu.get_pipeline_model_parallel_last_rank()
|
|
group = mpu.get_embedding_group()
|
|
torch.distributed.broadcast(new_tokens, src, group)
|
|
|
|
done_token = (prev == eos_id).byte() & started.byte()
|
|
just_finished = (done_token & ~is_done).bool()
|
|
lengths[just_finished.view(-1)] = context_length
|
|
is_done = is_done | done_token
|
|
|
|
done = torch.all(is_done)
|
|
src = mpu.get_pipeline_model_parallel_last_rank()
|
|
group = mpu.get_pipeline_model_parallel_group()
|
|
torch.distributed.broadcast(done, src, group)
|
|
|
|
if return_scores:
|
|
yield tokens, (lengths, scores)
|
|
else:
|
|
yield tokens, lengths
|
|
|
|
else:
|
|
if mpu.is_pipeline_first_stage():
|
|
src = mpu.get_pipeline_model_parallel_last_rank()
|
|
group = mpu.get_embedding_group()
|
|
new_tokens = torch.empty_like(tokens[:, context_length])
|
|
torch.distributed.broadcast(new_tokens, src, group)
|
|
tokens[:, context_length] = new_tokens
|
|
yield tokens, None
|
|
else:
|
|
yield None, None
|
|
|
|
done = torch.cuda.ByteTensor([0])
|
|
src = mpu.get_pipeline_model_parallel_last_rank()
|
|
group = mpu.get_pipeline_model_parallel_group()
|
|
torch.distributed.broadcast(done, src, group)
|
|
|
|
context_length += 1
|
|
counter += 1
|
|
if done:
|
|
break
|
|
|
|
|
|
def sample_sequence_batch_beam(
|
|
model,
|
|
context_tokens,
|
|
context_lengths,
|
|
attention_mask,
|
|
position_ids,
|
|
maxlen=None,
|
|
type_ids=None,
|
|
return_scores: bool = False,
|
|
prompt_length: int = None,
|
|
bad_ids: List = None,
|
|
temperature: float = None,
|
|
topp: float = None,
|
|
topk: int = None,
|
|
beam_warmup: bool = False,
|
|
):
|
|
args = get_args()
|
|
tokenizer = get_tokenizer()
|
|
temperature = temperature if temperature is not None else args.temperature
|
|
topp = topp if topp is not None else args.top_p
|
|
topk = topk if topk is not None else args.top_k
|
|
|
|
model.eval()
|
|
with torch.no_grad():
|
|
context_length = context_lengths.min().item()
|
|
|
|
# added eos_id to support the function generate_samples_eval that passes
|
|
# eos_id as an argument and needs termination when that id id found.
|
|
if hasattr(args, "eos_id"):
|
|
eos_id = args.eos_id
|
|
else:
|
|
eos_id = tokenizer.eod
|
|
|
|
counter = 0
|
|
org_context_length = context_length
|
|
|
|
layer_past = None
|
|
batch_size = context_tokens.size(0)
|
|
is_done = torch.zeros([batch_size]).byte().cuda()
|
|
tokens = context_tokens
|
|
if maxlen is None:
|
|
maxlen = args.seq_length - 1
|
|
if maxlen > (org_context_length + args.out_seq_length):
|
|
maxlen = org_context_length + args.out_seq_length
|
|
|
|
lengths = torch.ones([batch_size]).long().cuda() * maxlen
|
|
if return_scores:
|
|
scores = torch.zeros([batch_size]).float().cuda()
|
|
|
|
if beam_warmup:
|
|
beams = beam_search(model, context_tokens=tokens.cpu().numpy().tolist()[0][:context_length],
|
|
num_beams=args.num_beams)
|
|
beam = beams[0]
|
|
tokens_ = beam.tokens
|
|
tokens_ = (tokens_ if tokens_[-1] != tokenizer.eod else tokens_[:-1])
|
|
tokens_warmup = []
|
|
for i in range(batch_size):
|
|
tokens_warmup.append(tokens_.copy())
|
|
tokens, context_lengths = pad_batch(tokens_warmup, tokenizer.eod, args)
|
|
tokens = torch.cuda.LongTensor(tokens)
|
|
context_lengths = torch.cuda.LongTensor(context_lengths)
|
|
context_length = len(tokens_)
|
|
org_context_length = context_length
|
|
if maxlen is None:
|
|
maxlen = args.seq_length - 1
|
|
if maxlen > (org_context_length + args.out_seq_length):
|
|
maxlen = org_context_length + args.out_seq_length
|
|
lengths = torch.ones([batch_size]).long().cuda() * maxlen
|
|
tokens, attention_mask, position_ids = get_batch(tokens, batch_size)
|
|
|
|
while context_length <= (maxlen):
|
|
if args.recompute:
|
|
logits = model(tokens,
|
|
position_ids,
|
|
attention_mask,
|
|
tokentype_ids=type_ids,
|
|
forward_method_parallel_output=False,
|
|
prompt_length=prompt_length,
|
|
context_length=context_length,
|
|
)
|
|
logits = logits[:, context_length - 1, :]
|
|
else:
|
|
types2use = None
|
|
if counter == 0:
|
|
tokens2use = tokens[:, :context_length]
|
|
positions2use = position_ids[:, :context_length]
|
|
if type_ids is not None:
|
|
types2use = type_ids[:, :context_length]
|
|
else:
|
|
tokens2use = tokens[:, context_length - 1].view(
|
|
batch_size, -1)
|
|
positions2use = position_ids[:, context_length - 1].view(
|
|
batch_size, -1)
|
|
if type_ids is not None:
|
|
types2use = type_ids[:, context_length - 1].view(
|
|
batch_size, -1)
|
|
logits, layer_past = model(tokens2use,
|
|
positions2use,
|
|
attention_mask,
|
|
layer_past=layer_past,
|
|
get_key_value=True,
|
|
tokentype_ids=types2use,
|
|
forward_method_parallel_output=False,
|
|
prompt_length=prompt_length,
|
|
context_length=context_length,
|
|
)
|
|
logits = logits[:, -1].view(batch_size, -1).contiguous()
|
|
|
|
if mpu.is_pipeline_last_stage():
|
|
if bad_ids is not None:
|
|
for bad_id in bad_ids:
|
|
logits[:, bad_id] = -10000
|
|
if args.greedy:
|
|
prev = torch.argmax(logits, dim=-1).view(-1)
|
|
else:
|
|
logits = logits.float()
|
|
if return_scores:
|
|
orig_log_probs = torch.log_softmax(logits, dim=-1)
|
|
logits /= temperature
|
|
logits = top_k_logits(logits, top_k=topk, top_p=topp)
|
|
log_probs = F.softmax(logits, dim=-1)
|
|
prev = torch.multinomial(log_probs, num_samples=1).view(-1)
|
|
|
|
started = context_lengths <= context_length
|
|
|
|
new_tokens = switch(tokens[:, context_length].view(-1), prev, started)
|
|
|
|
if not args.greedy and return_scores:
|
|
indices = prev.view(-1, 1)
|
|
new_scores = orig_log_probs.gather(1, indices).view(-1)
|
|
new_scores = new_scores * started
|
|
new_scores = new_scores * is_done.bool().logical_not()
|
|
scores += new_scores
|
|
|
|
tokens[:, context_length] = new_tokens
|
|
src = mpu.get_pipeline_model_parallel_last_rank()
|
|
group = mpu.get_embedding_group()
|
|
torch.distributed.broadcast(new_tokens, src, group)
|
|
|
|
done_token = (prev == eos_id).byte() & started.byte()
|
|
just_finished = (done_token & ~is_done).bool()
|
|
lengths[just_finished.view(-1)] = context_length
|
|
is_done = is_done | done_token
|
|
|
|
done = torch.all(is_done)
|
|
src = mpu.get_pipeline_model_parallel_last_rank()
|
|
group = mpu.get_pipeline_model_parallel_group()
|
|
torch.distributed.broadcast(done, src, group)
|
|
|
|
if return_scores:
|
|
yield tokens, (lengths, scores)
|
|
else:
|
|
yield tokens, lengths
|
|
|
|
else:
|
|
if mpu.is_pipeline_first_stage():
|
|
src = mpu.get_pipeline_model_parallel_last_rank()
|
|
group = mpu.get_embedding_group()
|
|
new_tokens = torch.empty_like(tokens[:, context_length])
|
|
torch.distributed.broadcast(new_tokens, src, group)
|
|
tokens[:, context_length] = new_tokens
|
|
yield tokens, None
|
|
else:
|
|
yield None, None
|
|
|
|
done = torch.cuda.ByteTensor([0])
|
|
src = mpu.get_pipeline_model_parallel_last_rank()
|
|
group = mpu.get_pipeline_model_parallel_group()
|
|
torch.distributed.broadcast(done, src, group)
|
|
|
|
context_length += 1
|
|
counter += 1
|
|
if done:
|
|
break
|