mirror of https://github.com/THUDM/CodeGeeX.git
You cannot select more than 25 topics
Topics must start with a letter or number, can include dashes ('-') and can be up to 35 characters long.
327 lines
11 KiB
Python
327 lines
11 KiB
Python
import copy
|
|
import json
|
|
import os
|
|
import time
|
|
from typing import *
|
|
|
|
import paddle
|
|
import paddle.nn.functional as F
|
|
from dataclasses import dataclass
|
|
|
|
|
|
def get_ltor_masks_and_position_ids(
|
|
data,
|
|
eod_token,
|
|
reset_position_ids,
|
|
reset_attention_mask,
|
|
):
|
|
"""Build masks and position id for left to right model."""
|
|
|
|
# Extract batch size and sequence length.
|
|
micro_batch_size, seq_length = data.shape
|
|
|
|
# Attention mask (lower triangular).
|
|
if reset_attention_mask:
|
|
att_mask_batch = micro_batch_size
|
|
else:
|
|
att_mask_batch = 1
|
|
attention_mask = paddle.tril(
|
|
paddle.ones((att_mask_batch, seq_length, seq_length))
|
|
).reshape([att_mask_batch, 1, seq_length, seq_length])
|
|
|
|
# Position ids.
|
|
position_ids = paddle.arange(seq_length, dtype="int64")
|
|
position_ids = position_ids.unsqueeze(0).expand_as(data)
|
|
# We need to clone as the ids will be modifed based on batch index.
|
|
if reset_position_ids:
|
|
position_ids = position_ids.clone()
|
|
|
|
if reset_position_ids or reset_attention_mask:
|
|
# Loop through the batches:
|
|
for b in range(micro_batch_size):
|
|
|
|
# Find indecies where EOD token is.
|
|
eod_index = position_ids[b, data[b] == eod_token]
|
|
# Detach indecies from positions if going to modify positions.
|
|
if reset_position_ids:
|
|
eod_index = eod_index.clone()
|
|
|
|
# Loop through EOD indecies:
|
|
prev_index = 0
|
|
for j in range(eod_index.shape[0]):
|
|
i = eod_index[j]
|
|
# Mask attention loss.
|
|
if reset_attention_mask:
|
|
attention_mask[b, 0, (i + 1) :, : (i + 1)] = 0
|
|
# Reset positions.
|
|
if reset_position_ids:
|
|
position_ids[b, (i + 1) :] -= i + 1 - prev_index
|
|
prev_index = i + 1
|
|
|
|
# Convert attention mask to binary:
|
|
attention_mask = attention_mask < 0.5
|
|
|
|
return attention_mask, position_ids
|
|
|
|
|
|
def get_batch(
|
|
context_tokens,
|
|
micro_batch_size,
|
|
eod_token,
|
|
reset_position_ids=False,
|
|
reset_attention_mask=False,
|
|
):
|
|
"""Generate batch from context tokens."""
|
|
tokens = context_tokens.reshape([micro_batch_size, -1]).cuda()
|
|
# Get the attention mask and postition ids.
|
|
attention_mask, position_ids = get_ltor_masks_and_position_ids(
|
|
tokens,
|
|
eod_token,
|
|
reset_position_ids,
|
|
reset_attention_mask,
|
|
)
|
|
|
|
return tokens, attention_mask, position_ids
|
|
|
|
|
|
def top_k_logits(logits, top_k=0, top_p=0.0, filter_value=-float("Inf")):
|
|
"""This function has been mostly taken from huggingface conversational
|
|
ai code at
|
|
https://medium.com/huggingface/how-to-build-a-state-of-the-art-
|
|
conversational-ai-with-transfer-learning-2d818ac26313"""
|
|
|
|
if top_k > 0:
|
|
# Remove all tokens with a probability less than the
|
|
# last token of the top-k
|
|
indices_to_remove = logits < paddle.topk(logits, top_k)[0][..., -1, None]
|
|
logits[indices_to_remove] = filter_value
|
|
|
|
if top_p > 0.0:
|
|
# Cconvert to 1D
|
|
sorted_logits, sorted_indices = paddle.sort(logits, descending=True, axis=-1)
|
|
cumulative_probs = paddle.cumsum(F.softmax(sorted_logits, axis=-1), axis=-1)
|
|
|
|
# Remove tokens with cumulative probability above the threshold
|
|
sorted_indices_to_remove = cumulative_probs > top_p
|
|
# Shift the indices to the right to keep also the first token
|
|
# above the threshold
|
|
sorted_indices_to_remove[..., 1:] = sorted_indices_to_remove[..., :-1].clone()
|
|
sorted_indices_to_remove[..., 0] = 0
|
|
for i in range(sorted_indices.shape[0]):
|
|
indices_to_remove = sorted_indices[i][sorted_indices_to_remove[i]]
|
|
logits[i][indices_to_remove] = filter_value
|
|
|
|
return logits
|
|
|
|
|
|
def pad_batch(batch, pad_id, seq_length):
|
|
context_lengths = []
|
|
for tokens in batch:
|
|
context_length = len(tokens)
|
|
if context_length < seq_length:
|
|
tokens.extend([pad_id] * (seq_length - context_length))
|
|
context_lengths.append(context_length)
|
|
return batch, context_lengths
|
|
|
|
|
|
def forward_step(
|
|
model,
|
|
tokens,
|
|
seq_length,
|
|
position_ids,
|
|
attention_mask,
|
|
layer_past=None,
|
|
get_key_value=None,
|
|
prompt_length=None,
|
|
context_length=None,
|
|
):
|
|
# Forward pass through the model.
|
|
output_tensor = model(
|
|
tokens,
|
|
position_ids,
|
|
attention_mask,
|
|
layer_past=layer_past,
|
|
get_key_value=get_key_value,
|
|
prompt_length=prompt_length,
|
|
context_length=context_length,
|
|
)
|
|
|
|
if get_key_value:
|
|
output_tensor, layer_past = output_tensor
|
|
|
|
if get_key_value:
|
|
return output_tensor, layer_past
|
|
|
|
return output_tensor
|
|
|
|
|
|
def get_token_stream(
|
|
model,
|
|
tokenizer,
|
|
seq_length,
|
|
out_seq_length,
|
|
context_tokens,
|
|
return_scores: bool = False,
|
|
prompt_length: int = None,
|
|
micro_batch_size: int = None,
|
|
bad_ids: List = None,
|
|
temperature: float = 1.0,
|
|
topp: float = 1.0,
|
|
topk: int = 0.0,
|
|
greedy: bool = False,
|
|
recompute: bool = False,
|
|
):
|
|
context_tokens, context_lengths = pad_batch(context_tokens, tokenizer.eos_token_id, seq_length)
|
|
|
|
context_tokens_tensor = paddle.to_tensor(context_tokens, dtype="int64")
|
|
context_length_tensor = paddle.to_tensor(context_lengths, dtype="int64")
|
|
context_length = context_length_tensor.min().item()
|
|
tokens, attention_mask, position_ids = get_batch(
|
|
context_tokens_tensor,
|
|
micro_batch_size,
|
|
tokenizer.eos_token_id,
|
|
)
|
|
|
|
batch_token_iterator = sample_sequence_batch(
|
|
model,
|
|
tokenizer,
|
|
context_tokens_tensor,
|
|
context_length_tensor,
|
|
attention_mask,
|
|
position_ids,
|
|
seq_length=seq_length,
|
|
out_seq_length=out_seq_length,
|
|
return_scores=return_scores,
|
|
prompt_length=prompt_length,
|
|
bad_ids=bad_ids,
|
|
temperature=temperature,
|
|
topp=topp,
|
|
topk=topk,
|
|
greedy=greedy,
|
|
recompute=recompute,
|
|
)
|
|
|
|
for tokens, lengths in batch_token_iterator:
|
|
context_length += 1
|
|
if tokens is not None:
|
|
yield tokens[:, :context_length], lengths
|
|
else:
|
|
yield None, None
|
|
|
|
|
|
def switch(val1, val2, boolean):
|
|
boolean = boolean.cast(val1.dtype)
|
|
return (1 - boolean) * val1 + boolean * val2
|
|
|
|
|
|
def sample_sequence_batch(
|
|
model,
|
|
tokenizer,
|
|
context_tokens,
|
|
context_lengths,
|
|
attention_mask,
|
|
position_ids,
|
|
seq_length,
|
|
out_seq_length,
|
|
maxlen=None,
|
|
return_scores: bool = False,
|
|
prompt_length: int = None,
|
|
bad_ids: List = None,
|
|
temperature: float = 1.0,
|
|
topp: float = 1.0,
|
|
topk: int = 0.0,
|
|
recompute: bool = False,
|
|
greedy: bool = False,
|
|
):
|
|
model.eval()
|
|
with paddle.no_grad():
|
|
context_length = context_lengths.min().item()
|
|
eos_id = tokenizer.eos_token_id
|
|
|
|
counter = 0
|
|
org_context_length = context_length
|
|
|
|
layer_past = None
|
|
batch_size = context_tokens.shape[0]
|
|
is_done = paddle.zeros([batch_size]).cast("uint8").cuda()
|
|
tokens = context_tokens
|
|
if maxlen is None:
|
|
maxlen = seq_length - 1
|
|
if maxlen > (org_context_length + out_seq_length):
|
|
maxlen = org_context_length + out_seq_length
|
|
|
|
lengths = paddle.ones([batch_size]).cast("int64").cuda() * maxlen
|
|
if return_scores:
|
|
scores = paddle.zeros([batch_size]).cast("float32").cuda()
|
|
|
|
while context_length <= (maxlen):
|
|
|
|
if recompute:
|
|
logits = model(tokens,
|
|
position_ids,
|
|
attention_mask,
|
|
prompt_length=prompt_length,
|
|
context_length=context_length,
|
|
)
|
|
logits = logits[:, context_length - 1, :]
|
|
else:
|
|
if counter == 0:
|
|
tokens2use = tokens[:, :context_length]
|
|
positions2use = position_ids[:, :context_length]
|
|
else:
|
|
tokens2use = tokens[:, context_length - 1].reshape([
|
|
batch_size, -1])
|
|
positions2use = position_ids[:, context_length - 1].reshape([
|
|
batch_size, -1])
|
|
logits, layer_past = model(tokens2use,
|
|
positions2use,
|
|
attention_mask,
|
|
layer_past=layer_past,
|
|
get_key_value=True,
|
|
prompt_length=prompt_length,
|
|
context_length=context_length,
|
|
)
|
|
logits = logits[:, -1].reshape([batch_size, -1])
|
|
|
|
if bad_ids is not None:
|
|
for bad_id in bad_ids:
|
|
logits[:, bad_id] = -10000
|
|
if greedy:
|
|
prev = paddle.argmax(logits, axis=-1).reshape([-1])
|
|
else:
|
|
logits = logits.cast("float32")
|
|
if return_scores:
|
|
orig_log_probs = paddle.log_softmax(logits, axis=-1)
|
|
logits /= temperature
|
|
logits = top_k_logits(logits, top_k=topk, top_p=topp)
|
|
log_probs = F.softmax(logits, axis=-1)
|
|
prev = paddle.multinomial(log_probs, num_samples=1).reshape([-1])
|
|
|
|
started = context_lengths <= context_length
|
|
|
|
new_tokens = switch(tokens[:, context_length].reshape([-1]), prev, started)
|
|
|
|
if not greedy and return_scores:
|
|
indices = prev.reshape([-1, 1])
|
|
new_scores = orig_log_probs.gather(1, indices).reshape([-1])
|
|
new_scores = new_scores * started
|
|
new_scores = new_scores * is_done.cast("bool").logical_not()
|
|
scores += new_scores
|
|
|
|
tokens[:, context_length] = new_tokens
|
|
done_token = (prev == eos_id).cast("uint8") & started.cast("uint8")
|
|
just_finished = (done_token & ~is_done).cast("bool")
|
|
lengths[just_finished.reshape([-1])] = context_length
|
|
is_done = is_done | done_token
|
|
done = paddle.all(is_done.cast("bool"))
|
|
|
|
if return_scores:
|
|
yield tokens, (lengths, scores)
|
|
else:
|
|
yield tokens, lengths
|
|
|
|
context_length += 1
|
|
counter += 1
|
|
if done:
|
|
break
|