Merge pull request #65 from BBuf/main

CodeGeeX inference support oneflow backend
pull/69/head^2
Qinkai 2 years ago committed by GitHub
commit d980c9d03e
No known key found for this signature in database
GPG Key ID: 4AEE18F83AFDEB23

2
.gitignore vendored

@ -0,0 +1,2 @@
__pycache__/
codegeex.egg-info/

@ -0,0 +1 @@
from .codegeex_model import CodeGeeXModel

File diff suppressed because it is too large Load Diff

@ -0,0 +1,326 @@
import copy
import json
import os
import time
from typing import *
import oneflow as torch
import oneflow.nn.functional as F
from dataclasses import dataclass
def get_ltor_masks_and_position_ids(
data,
eod_token,
reset_position_ids,
reset_attention_mask,
):
"""Build masks and position id for left to right model."""
# Extract batch size and sequence length.
micro_batch_size, seq_length = data.size()
# Attention mask (lower triangular).
if reset_attention_mask:
att_mask_batch = micro_batch_size
else:
att_mask_batch = 1
attention_mask = torch.tril(
torch.ones((att_mask_batch, seq_length, seq_length), device=data.device)
).view(att_mask_batch, 1, seq_length, seq_length)
# Position ids.
position_ids = torch.arange(seq_length, dtype=torch.long, device=data.device)
position_ids = position_ids.unsqueeze(0).expand_as(data)
# We need to clone as the ids will be modifed based on batch index.
if reset_position_ids:
position_ids = position_ids.clone()
if reset_position_ids or reset_attention_mask:
# Loop through the batches:
for b in range(micro_batch_size):
# Find indecies where EOD token is.
eod_index = position_ids[b, data[b] == eod_token]
# Detach indecies from positions if going to modify positions.
if reset_position_ids:
eod_index = eod_index.clone()
# Loop through EOD indecies:
prev_index = 0
for j in range(eod_index.size()[0]):
i = eod_index[j]
# Mask attention loss.
if reset_attention_mask:
attention_mask[b, 0, (i + 1) :, : (i + 1)] = 0
# Reset positions.
if reset_position_ids:
position_ids[b, (i + 1) :] -= i + 1 - prev_index
prev_index = i + 1
# Convert attention mask to binary:
attention_mask = attention_mask < 0.5
return attention_mask, position_ids
def get_batch(
context_tokens,
micro_batch_size,
eod_token,
reset_position_ids=False,
reset_attention_mask=False,
):
"""Generate batch from context tokens."""
tokens = context_tokens.view(micro_batch_size, -1).contiguous().cuda()
# Get the attention mask and postition ids.
attention_mask, position_ids = get_ltor_masks_and_position_ids(
tokens,
eod_token,
reset_position_ids,
reset_attention_mask,
)
return tokens, attention_mask, position_ids
def top_k_logits(logits, top_k=0, top_p=0.0, filter_value=-float("Inf")):
"""This function has been mostly taken from huggingface conversational
ai code at
https://medium.com/huggingface/how-to-build-a-state-of-the-art-
conversational-ai-with-transfer-learning-2d818ac26313"""
if top_k > 0:
# Remove all tokens with a probability less than the
# last token of the top-k
indices_to_remove = logits < torch.topk(logits, top_k)[0][..., -1, None]
logits[indices_to_remove] = filter_value
if top_p > 0.0:
# Cconvert to 1D
sorted_logits, sorted_indices = torch.sort(logits, descending=True, dim=-1)
cumulative_probs = torch.cumsum(F.softmax(sorted_logits, dim=-1), dim=-1)
# Remove tokens with cumulative probability above the threshold
sorted_indices_to_remove = cumulative_probs > top_p
# Shift the indices to the right to keep also the first token
# above the threshold
sorted_indices_to_remove[..., 1:] = sorted_indices_to_remove[..., :-1].clone()
sorted_indices_to_remove[..., 0] = 0
for i in range(sorted_indices.size(0)):
indices_to_remove = sorted_indices[i][sorted_indices_to_remove[i]]
logits[i][indices_to_remove] = filter_value
return logits
def pad_batch(batch, pad_id, seq_length):
context_lengths = []
for tokens in batch:
context_length = len(tokens)
if context_length < seq_length:
tokens.extend([pad_id] * (seq_length - context_length))
context_lengths.append(context_length)
return batch, context_lengths
def forward_step(
model,
tokens,
seq_length,
position_ids,
attention_mask,
layer_past=None,
get_key_value=None,
prompt_length=None,
context_length=None,
):
# Forward pass through the model.
output_tensor = model(
tokens,
position_ids,
attention_mask,
layer_past=layer_past,
get_key_value=get_key_value,
prompt_length=prompt_length,
context_length=context_length,
)
if get_key_value:
output_tensor, layer_past = output_tensor
if get_key_value:
return output_tensor, layer_past
return output_tensor
def get_token_stream(
model,
tokenizer,
seq_length,
out_seq_length,
context_tokens,
return_scores: bool = False,
prompt_length: int = None,
micro_batch_size: int = None,
bad_ids: List = None,
temperature: float = 1.0,
topp: float = 1.0,
topk: int = 0.0,
greedy: bool = False,
recompute: bool = False,
):
context_tokens, context_lengths = pad_batch(context_tokens, tokenizer.eos_token_id, seq_length)
context_tokens_tensor = torch.cuda.LongTensor(context_tokens)
context_length_tensor = torch.cuda.LongTensor(context_lengths)
context_length = context_length_tensor.min().item()
tokens, attention_mask, position_ids = get_batch(
context_tokens_tensor,
micro_batch_size,
tokenizer.eos_token_id,
)
batch_token_iterator = sample_sequence_batch(
model,
tokenizer,
context_tokens_tensor,
context_length_tensor,
attention_mask,
position_ids,
seq_length=seq_length,
out_seq_length=out_seq_length,
return_scores=return_scores,
prompt_length=prompt_length,
bad_ids=bad_ids,
temperature=temperature,
topp=topp,
topk=topk,
greedy=greedy,
recompute=recompute,
)
for tokens, lengths in batch_token_iterator:
context_length += 1
if tokens is not None:
yield tokens[:, :context_length], lengths
else:
yield None, None
def switch(val1, val2, boolean):
boolean = boolean.type_as(val1)
return (1 - boolean) * val1 + boolean * val2
def sample_sequence_batch(
model,
tokenizer,
context_tokens,
context_lengths,
attention_mask,
position_ids,
seq_length,
out_seq_length,
maxlen=None,
return_scores: bool = False,
prompt_length: int = None,
bad_ids: List = None,
temperature: float = 1.0,
topp: float = 1.0,
topk: int = 0.0,
recompute: bool = False,
greedy: bool = False,
):
model.eval()
with torch.no_grad():
context_length = context_lengths.min().item()
eos_id = tokenizer.eos_token_id
counter = 0
org_context_length = context_length
layer_past = None
batch_size = context_tokens.size(0)
is_done = torch.zeros([batch_size]).byte().cuda()
tokens = context_tokens
if maxlen is None:
maxlen = seq_length - 1
if maxlen > (org_context_length + out_seq_length):
maxlen = org_context_length + out_seq_length
lengths = torch.ones([batch_size]).long().cuda() * maxlen
if return_scores:
scores = torch.zeros([batch_size]).float().cuda()
while context_length <= (maxlen):
if recompute:
logits = model(tokens,
position_ids,
attention_mask,
prompt_length=prompt_length,
context_length=context_length,
)
logits = logits[:, context_length - 1, :]
else:
if counter == 0:
tokens2use = tokens[:, :context_length]
positions2use = position_ids[:, :context_length]
else:
tokens2use = tokens[:, context_length - 1].view(
batch_size, -1)
positions2use = position_ids[:, context_length - 1].view(
batch_size, -1)
logits, layer_past = model(tokens2use,
positions2use,
attention_mask,
layer_past=layer_past,
get_key_value=True,
prompt_length=prompt_length,
context_length=context_length,
)
logits = logits[:, -1].view(batch_size, -1).contiguous()
if bad_ids is not None:
for bad_id in bad_ids:
logits[:, bad_id] = -10000
if greedy:
prev = torch.argmax(logits, dim=-1).view(-1)
else:
logits = logits.float()
if return_scores:
orig_log_probs = torch.log_softmax(logits, dim=-1)
logits /= temperature
logits = top_k_logits(logits, top_k=topk, top_p=topp)
log_probs = F.softmax(logits, dim=-1)
prev = torch.multinomial(log_probs, num_samples=1).view(-1)
started = context_lengths <= context_length
new_tokens = switch(tokens[:, context_length].view(-1), prev, started)
if not greedy and return_scores:
indices = prev.view(-1, 1)
new_scores = orig_log_probs.gather(1, indices).view(-1)
new_scores = new_scores * started
new_scores = new_scores * is_done.bool().logical_not()
scores += new_scores
tokens[:, context_length] = new_tokens
done_token = (prev == eos_id).byte() & started.byte()
just_finished = (done_token & ~is_done).bool()
lengths[just_finished.view(-1)] = context_length
is_done = is_done | done_token
done = torch.all(is_done)
if return_scores:
yield tokens, (lengths, scores)
else:
yield tokens, lengths
context_length += 1
counter += 1
if done:
break

@ -0,0 +1,39 @@
# This script is used to test the inference of CodeGeeX.
GPU=$1
PROMPT_FILE=$2
SCRIPT_PATH=$(realpath "$0")
SCRIPT_DIR=$(dirname "$SCRIPT_PATH")
MAIN_DIR=$(dirname "$SCRIPT_DIR")
TOKENIZER_PATH="$MAIN_DIR/codegeex/tokenizer/"
# import model configuration
source "$MAIN_DIR/configs/codegeex_13b.sh"
# export CUDA settings
if [ -z "$GPU" ]; then
GPU=0
fi
export CUDA_HOME=/usr/local/cuda-11.1/
export CUDA_VISIBLE_DEVICES=$GPU
if [ -z "$PROMPT_FILE" ]; then
PROMPT_FILE=$MAIN_DIR/tests/test_prompt.txt
fi
# remove --greedy if using sampling
CMD="python $MAIN_DIR/tests/test_inference_oneflow.py \
--prompt-file $PROMPT_FILE \
--tokenizer-path $TOKENIZER_PATH \
--micro-batch-size 1 \
--out-seq-length 1024 \
--temperature 0.8 \
--top-p 0.95 \
--top-k 0 \
--greedy \
$MODEL_ARGS"
echo "$CMD"
eval "$CMD"

@ -0,0 +1,204 @@
import os
import copy
import time
import oneflow as torch
import random
import argparse
import numpy as np
from codegeex.oneflow.inference import get_token_stream
from codegeex.oneflow import CodeGeeXModel
from codegeex.tokenizer import CodeGeeXTokenizer
from codegeex.quantization import quantize
os.environ["ONEFLOW_KERNEL_ENABLE_FUSED_LINEAR"] = "1"
def model_provider(args):
"""Build the model."""
model = CodeGeeXModel(
args.hidden_size,
args.num_layers,
args.num_attention_heads,
args.padded_vocab_size,
args.max_position_embeddings
)
return model
def add_code_generation_args(parser):
group = parser.add_argument_group(title="code generation")
group.add_argument(
"--num-layers",
type=int,
default=39,
)
group.add_argument(
"--hidden-size",
type=int,
default=5120,
)
group.add_argument(
"--num-attention-heads",
type=int,
default=40,
)
group.add_argument(
"--padded-vocab-size",
type=int,
default=52224,
)
group.add_argument(
"--max-position-embeddings",
type=int,
default=2048,
)
group.add_argument(
"--temperature",
type=float,
default=1.0,
help="Sampling temperature.",
)
group.add_argument(
"--greedy",
action="store_true",
default=False,
help="Use greedy sampling.",
)
group.add_argument(
"--top-p",
type=float,
default=0.0,
help="Top p sampling.",
)
group.add_argument(
"--top-k",
type=int,
default=0,
help="Top k sampling.",
)
group.add_argument(
"--out-seq-length",
type=int,
default=2048,
help="Size of the output generated text.",
)
group.add_argument(
"--prompt-file",
type=str,
default="./test_prompt.txt",
)
group.add_argument(
"--tokenizer-path",
type=str,
default="./tokenizer",
)
group.add_argument(
"--load",
type=str,
)
group.add_argument(
"--state-dict-path",
type=str,
)
group.add_argument(
"--micro-batch-size",
type=int,
default=1,
)
group.add_argument(
"--quantize",
action="store_true",
)
return parser
def main():
parser = argparse.ArgumentParser()
parser = add_code_generation_args(parser)
args, _ = parser.parse_known_args()
print("Loading tokenizer ...")
tokenizer = CodeGeeXTokenizer(
tokenizer_path=args.tokenizer_path,
mode="codegeex-13b")
print("Loading state dict ...")
state_dict = torch.load(args.load, map_location="cpu")
state_dict = state_dict["module"]
print("Building CodeGeeX model ...")
model = model_provider(args)
model.load_state_dict(state_dict)
model.eval()
model.half()
if args.quantize:
model = quantize(model, weight_bit_width=8, backend="torch")
model.cuda()
torch.cuda.synchronize()
with open(args.prompt_file, "r") as f:
prompt = f.readlines()
prompt = "".join(prompt)
times = {}
out_seq_lengths = [args.out_seq_length]
micro_batch_size = args.micro_batch_size
seq_length = args.max_position_embeddings
for out_seq_length in out_seq_lengths:
print(f"Generating with out_seq_len {out_seq_length}...")
times[out_seq_length] = []
for prompt in [prompt]:
t0 = time.perf_counter()
tokens = tokenizer.encode_code(prompt)
print(tokens)
print("Current prompt:")
print(prompt)
n_token_prompt = len(tokens)
print("N_token_prompt:", n_token_prompt)
token_stream = get_token_stream(
model,
tokenizer,
seq_length,
out_seq_length,
[copy.deepcopy(tokens) for _ in range(micro_batch_size)],
micro_batch_size=micro_batch_size,
topk=args.top_k,
topp=args.top_p,
temperature=args.temperature,
greedy=args.greedy,
)
is_finished = [False for _ in range(micro_batch_size)]
for i, generated in enumerate(token_stream):
generated_tokens = generated[0]
for j in range(micro_batch_size):
if is_finished[j]:
continue
generated_token_numpy = generated_tokens[j].numpy()
if generated_token_numpy[-1] == tokenizer.eos_token_id or len(
generated_tokens[j]) >= out_seq_length:
is_finished[j] = True
generated_tokens_ = generated_token_numpy.tolist()
generated_code = tokenizer.decode_code(generated_tokens_[n_token_prompt:])
generated_code = "".join(generated_code)
t1 = time.perf_counter()
print("Total generation time:", t1 - t0, "# Tokens:", len(generated_tokens_) - n_token_prompt)
print(f"{(t1 - t0) / (len(generated_tokens_) - n_token_prompt)}s/token")
times[out_seq_length].append(t1 - t0)
print("================================= Generated code:")
print(generated_code)
if all(is_finished):
break
print(times)
for out_seq_length in times.keys():
print(out_seq_length, np.mean(times[out_seq_length]))
print("Generation finished.")
if __name__ == "__main__":
main()
Loading…
Cancel
Save