Add int8 quantization

pull/27/head
Stanislas0 2 years ago
parent 601b3aa6eb
commit a2b6a15321

@ -0,0 +1,99 @@
import pkg_resources
import torch
import ctypes
from typing import List
from cpm_kernels.kernels.base import LazyKernelCModule, KernelFunction, round_up
RESOURCE_PACKAGE_NAME = __name__
class Kernel:
def __init__(self, filename: str, function_names: List[str]):
filename = filename + ".fatbin"
if not pkg_resources.resource_exists(RESOURCE_PACKAGE_NAME, filename):
raise RuntimeError("File `%s` not found in `%s`" % (filename, RESOURCE_PACKAGE_NAME))
self.filename = filename
self.code = pkg_resources.resource_string(RESOURCE_PACKAGE_NAME, filename)
self._function_names = function_names
self._cmodule = LazyKernelCModule(self.code)
for name in self._function_names:
setattr(self, name, KernelFunction(self._cmodule, name))
kernels = Kernel(
"quantization",
[
"int4WeightCompression",
"int4WeightExtractionFloat",
"int4WeightExtractionHalf",
"int8WeightExtractionFloat",
"int8WeightExtractionHalf",
],
)
def compress_int4_weight(weight: torch.Tensor): # (n, m)
with torch.cuda.device(weight.device):
n, m = weight.size(0), weight.size(1)
assert m % 2 == 0
m = m // 2
out = torch.empty(n, m, dtype=torch.int8, device="cuda")
stream = torch.cuda.current_stream()
gridDim = (n, 1, 1)
blockDim = (min(round_up(m, 32), 1024), 1, 1)
kernels.int4WeightCompression(
gridDim,
blockDim,
0,
stream,
[ctypes.c_void_p(weight.data_ptr()), ctypes.c_void_p(out.data_ptr()), ctypes.c_int32(n), ctypes.c_int32(m)],
)
return out
def extract_weight_to_half(weight: torch.Tensor, scale_list: torch.Tensor, source_bit_width: int):
if source_bit_width == 8:
func = kernels.int8WeightExtractionHalf
elif source_bit_width == 4:
func = kernels.int4WeightExtractionHalf
else:
assert False, "Unsupported bit-width"
with torch.cuda.device(weight.device):
n, m = weight.size(0), weight.size(1)
out = torch.empty(n, m * (8 // source_bit_width), dtype=torch.half, device="cuda")
stream = torch.cuda.current_stream()
gridDim = (n, 1, 1)
blockDim = (min(round_up(m, 32), 1024), 1, 1)
func(
gridDim,
blockDim,
0,
stream,
[
ctypes.c_void_p(weight.data_ptr()),
ctypes.c_void_p(scale_list.data_ptr()),
ctypes.c_void_p(out.data_ptr()),
ctypes.c_int32(n),
ctypes.c_int32(m),
],
)
return out
if __name__ == "__main__":
weight = torch.randn(4, 32).to(torch.int8).cuda()
scale = torch.ones(weight.size(0)).to(torch.half).cuda()
print(weight)
b = compress_int4_weight(weight)
print(b)
a = extract_weight_to_half(b, scale, source_bit_width=4)
print(a)

@ -0,0 +1 @@
from .quantize import quantize

@ -0,0 +1,139 @@
import torch
from torch.nn.parameter import Parameter
from codegeex.kernels import extract_weight_to_half
class W8A16Linear(torch.autograd.Function):
@staticmethod
def forward(ctx, inp: torch.Tensor, quant_w: torch.Tensor, scale_w: torch.Tensor, weight_bit_width):
ctx.inp_shape = inp.size()
ctx.weight_shape = quant_w.size()
ctx.weight_bit_width = weight_bit_width
out_features = quant_w.size(0)
inp = inp.contiguous().view(-1, inp.size(-1))
weight = extract_weight_to_half(quant_w, scale_w, weight_bit_width)
output = inp.mm(weight.t())
ctx.save_for_backward(inp, quant_w, scale_w)
return output.view(*(ctx.inp_shape[:-1] + (out_features,)))
@staticmethod
def backward(ctx, grad_output: torch.Tensor):
inp, quant_w, scale_w = ctx.saved_tensors
weight = extract_weight_to_half(quant_w, scale_w, ctx.weight_bit_width)
grad_output = grad_output.contiguous().view(-1, weight.size(0))
grad_input = grad_output.mm(weight)
grad_weight = grad_output.t().mm(inp)
return grad_input.view(ctx.inp_shape), grad_weight.view(ctx.weight_shape), None
class QuantizedLinear(torch.nn.Module):
def __init__(
self,
in_features: int,
out_features: int,
weight_bit_width: int,
weight: torch.Tensor = None,
bias: torch.Tensor = None,
*args,
**kwargs
):
super(QuantizedLinear, self).__init__()
self.in_features = in_features
self.out_features = out_features
self.weight_bit_width = weight_bit_width
if weight is None:
self.weight = torch.empty(
shape[0], shape[1] * weight_bit_width // 8, dtype=torch.int8, device=kwargs["device"]
)
self.weight_scale = torch.empty(shape[0], dtype=kwargs["params_dtype"], device=kwargs["device"])
else:
self.weight_scale = (weight.abs().max(dim=-1).values / ((2 ** (weight_bit_width - 1)) - 1)).half()
self.weight = torch.round(weight / self.weight_scale[:, None]).to(torch.int8)
if weight_bit_width == 4:
self.weight = compress_int4_weight(self.weight)
if bias is None:
self.register_parameter('bias', None)
else:
self.bias = bias
self.weight = Parameter(self.weight.to(kwargs["device"]), requires_grad=False)
self.weight_scale = Parameter(self.weight_scale.to(kwargs["device"]), requires_grad=False)
def forward(self, input_):
# Matrix multiply.
output = W8A16Linear.apply(input_, self.weight, self.weight_scale, self.weight_bit_width)
if self.bias is not None:
output = output + self.bias
return output
def quantize(model, weight_bit_width):
"""Replace fp16 linear with quantized linear"""
for i in range(len(model.language_model.transformer.layers) + 1):
if i == len(model.language_model.transformer.layers):
layer = model.language_model.transformer.topQueryLayer
else:
layer = model.language_model.transformer.layers[i]
layer.attention.query = QuantizedLinear(
in_features=layer.attention.query.weight.shape[0],
out_features=layer.attention.query.weight.shape[1],
weight_bit_width=weight_bit_width,
weight=layer.attention.query.weight.to(torch.cuda.current_device()),
bias=layer.attention.query.bias.to(torch.cuda.current_device()),
params_dtype=torch.half,
device=layer.attention.query.weight.device,
)
layer.attention.value = QuantizedLinear(
in_features=layer.attention.value.weight.shape[0],
out_features=layer.attention.value.weight.shape[1],
weight_bit_width=weight_bit_width,
weight=layer.attention.value.weight.to(torch.cuda.current_device()),
bias=layer.attention.value.bias.to(torch.cuda.current_device()),
params_dtype=torch.half,
device=layer.attention.value.weight.device,
)
layer.attention.key = QuantizedLinear(
in_features=layer.attention.key.weight.shape[0],
out_features=layer.attention.key.weight.shape[1],
weight_bit_width=weight_bit_width,
weight=layer.attention.key.weight.to(torch.cuda.current_device()),
bias=layer.attention.key.bias.to(torch.cuda.current_device()),
params_dtype=torch.half,
device=layer.attention.key.weight.device,
)
layer.attention.dense = QuantizedLinear(
in_features=layer.attention.dense.weight.shape[0],
out_features=layer.attention.dense.weight.shape[1],
weight_bit_width=weight_bit_width,
weight=layer.attention.dense.weight.to(torch.cuda.current_device()),
bias=layer.attention.dense.bias.to(torch.cuda.current_device()),
params_dtype=torch.half,
device=layer.attention.dense.weight.device,
)
layer.mlp.dense_h_to_4h = QuantizedLinear(
in_features=layer.mlp.dense_h_to_4h.weight.shape[0],
out_features=layer.mlp.dense_h_to_4h.weight.shape[1],
weight_bit_width=weight_bit_width,
weight=layer.mlp.dense_h_to_4h.weight.to(torch.cuda.current_device()),
bias=layer.mlp.dense_h_to_4h.bias.to(torch.cuda.current_device()),
params_dtype=torch.half,
device=layer.mlp.dense_h_to_4h.weight.device,
)
layer.mlp.dense_4h_to_h = QuantizedLinear(
in_features=layer.mlp.dense_4h_to_h.weight.shape[0],
out_features=layer.mlp.dense_4h_to_h.weight.shape[1],
weight_bit_width=weight_bit_width,
weight=layer.mlp.dense_4h_to_h.weight.to(torch.cuda.current_device()),
bias=layer.mlp.dense_4h_to_h.bias.to(torch.cuda.current_device()),
params_dtype=torch.half,
device=layer.mlp.dense_4h_to_h.weight.device,
)
return model

@ -0,0 +1 @@
from .tokenizer import CodeGeeXTokenizer

@ -0,0 +1,87 @@
import torch
from typing import *
from transformers import AutoTokenizer
from transformers.models.gpt2 import GPT2TokenizerFast
def encode_whitespaces(text, start_extra_id: int, max_len: int):
""" Encode whitespaces to extra tokens in GPT-J.
>>> encode_whitespaces('a\\n b\\n c', 10, 10)
'a\\n<|extratoken_10|>b\\n<|extratoken_11|>c'
"""
def push_acc_space(acc_len: int, text: str):
if acc_len == 0:
return text
if acc_len == 1:
return text + ' '
assert acc_len <= max_len, f'Max whitespace run length {max_len}, but found {acc_len}'
extra_id = start_extra_id - 2 + acc_len
extra_token = f'<|extratoken_{extra_id}|>'
return text + extra_token
acc_len = 0
res = ''
for ch in text:
if ch == ' ':
acc_len += 1
if acc_len == max_len:
res = push_acc_space(acc_len, res)
acc_len = 0
else:
res = push_acc_space(acc_len, res)
acc_len = 0
res = res + ch
res = push_acc_space(acc_len, res)
return res
def decode_whitespaces(text: str, start_extra_id: int, max_len: int):
""" Decode the whitespace-encoded strings produced by encode_whitespace.
>>> text = 'a\\n b\\n c'
>>> s, l = 10, 10
>>> text == decode_whitespaces(encode_whitespaces(text, s, l), s, l)
True
"""
for l in range(2, max_len + 1):
token_id = start_extra_id - 2 + l
token = f'<|extratoken_{token_id}|>'
text = text.replace(token, ' ' * l)
return text
class CodeGeeXTokenizer(object):
def __init__(
self,
tokenizer: GPT2TokenizerFast = None,
tokenizer_path: str = "EleutherAI/gpt-j-6B",
start_extra_id: int = 10,
max_len : int = 10,
mode='codegeex-13b',
dict_file: str = None,
):
self.tokenizer = tokenizer if tokenizer is not None else AutoTokenizer.from_pretrained(tokenizer_path)
if mode not in ['codegeex-13b']:
raise ValueError(f"Invalid mode {mode}, choose from ['codegeex-13b']")
self.start_extra_id = start_extra_id
self.max_len = max_len
self.mode = mode
self.eos_token_id = self.tokenizer.eos_token_id
def encode_code(self, code: str):
if self.mode == 'codegeex-13b':
code = encode_whitespaces(code, self.start_extra_id, self.max_len)
input_ids = self.tokenizer(code, is_split_into_words=False).input_ids
return input_ids
def decode_code(self, input_ids):
if self.mode == 'codegeex-13b':
text = self.tokenizer.decode(input_ids, skip_special_tokens=False)
output_code = decode_whitespaces(text, self.start_extra_id, self.max_len)
return output_code

@ -0,0 +1 @@
from .codegeex_model import CodeGeeXModel

File diff suppressed because it is too large Load Diff

@ -0,0 +1,324 @@
import copy
import json
import os
import time
from typing import *
import torch
import torch.nn.functional as F
from dataclasses import dataclass
def get_ltor_masks_and_position_ids(
data,
eod_token,
reset_position_ids,
reset_attention_mask,
):
"""Build masks and position id for left to right model."""
# Extract batch size and sequence length.
micro_batch_size, seq_length = data.size()
# Attention mask (lower triangular).
if reset_attention_mask:
att_mask_batch = micro_batch_size
else:
att_mask_batch = 1
attention_mask = torch.tril(
torch.ones((att_mask_batch, seq_length, seq_length), device=data.device)
).view(att_mask_batch, 1, seq_length, seq_length)
# Position ids.
position_ids = torch.arange(seq_length, dtype=torch.long, device=data.device)
position_ids = position_ids.unsqueeze(0).expand_as(data)
# We need to clone as the ids will be modifed based on batch index.
if reset_position_ids:
position_ids = position_ids.clone()
if reset_position_ids or reset_attention_mask:
# Loop through the batches:
for b in range(micro_batch_size):
# Find indecies where EOD token is.
eod_index = position_ids[b, data[b] == eod_token]
# Detach indecies from positions if going to modify positions.
if reset_position_ids:
eod_index = eod_index.clone()
# Loop through EOD indecies:
prev_index = 0
for j in range(eod_index.size()[0]):
i = eod_index[j]
# Mask attention loss.
if reset_attention_mask:
attention_mask[b, 0, (i + 1) :, : (i + 1)] = 0
# Reset positions.
if reset_position_ids:
position_ids[b, (i + 1) :] -= i + 1 - prev_index
prev_index = i + 1
# Convert attention mask to binary:
attention_mask = attention_mask < 0.5
return attention_mask, position_ids
def get_batch(
context_tokens,
micro_batch_size,
eod_token,
reset_position_ids=False,
reset_attention_mask=False,
):
"""Generate batch from context tokens."""
tokens = context_tokens.view(micro_batch_size, -1).contiguous().cuda()
# Get the attention mask and postition ids.
attention_mask, position_ids = get_ltor_masks_and_position_ids(
tokens,
eod_token,
reset_position_ids,
reset_attention_mask,
)
return tokens, attention_mask, position_ids
def top_k_logits(logits, top_k=0, top_p=0.0, filter_value=-float("Inf")):
"""This function has been mostly taken from huggingface conversational
ai code at
https://medium.com/huggingface/how-to-build-a-state-of-the-art-
conversational-ai-with-transfer-learning-2d818ac26313"""
if top_k > 0:
# Remove all tokens with a probability less than the
# last token of the top-k
indices_to_remove = logits < torch.topk(logits, top_k)[0][..., -1, None]
logits[indices_to_remove] = filter_value
if top_p > 0.0:
# Cconvert to 1D
sorted_logits, sorted_indices = torch.sort(logits, descending=True, dim=-1)
cumulative_probs = torch.cumsum(F.softmax(sorted_logits, dim=-1), dim=-1)
# Remove tokens with cumulative probability above the threshold
sorted_indices_to_remove = cumulative_probs > top_p
# Shift the indices to the right to keep also the first token
# above the threshold
sorted_indices_to_remove[..., 1:] = sorted_indices_to_remove[..., :-1].clone()
sorted_indices_to_remove[..., 0] = 0
for i in range(sorted_indices.size(0)):
indices_to_remove = sorted_indices[i][sorted_indices_to_remove[i]]
logits[i][indices_to_remove] = filter_value
return logits
def pad_batch(batch, pad_id, seq_length):
context_lengths = []
for tokens in batch:
context_length = len(tokens)
if context_length < seq_length:
tokens.extend([pad_id] * (seq_length - context_length))
context_lengths.append(context_length)
return batch, context_lengths
def forward_step(
model,
tokens,
seq_length,
position_ids,
attention_mask,
layer_past=None,
get_key_value=None,
prompt_length=None,
context_length=None,
):
# Forward pass through the model.
output_tensor = model(
tokens,
position_ids,
attention_mask,
layer_past=layer_past,
get_key_value=get_key_value,
prompt_length=prompt_length,
context_length=context_length,
)
if get_key_value:
output_tensor, layer_past = output_tensor
if get_key_value:
return output_tensor, layer_past
return output_tensor
def get_token_stream(
model,
tokenizer,
seq_length,
out_seq_length,
context_tokens,
return_scores: bool = False,
prompt_length: int = None,
micro_batch_size: int = None,
bad_ids: List = None,
temperature: float = 1.0,
topp: float = 1.0,
topk: int = 0.0,
greedy: bool = False,
):
context_tokens, context_lengths = pad_batch(context_tokens, tokenizer.eos_token_id, seq_length)
context_tokens_tensor = torch.cuda.LongTensor(context_tokens)
context_length_tensor = torch.cuda.LongTensor(context_lengths)
context_length = context_length_tensor.min().item()
tokens, attention_mask, position_ids = get_batch(
context_tokens_tensor,
micro_batch_size,
tokenizer.eos_token_id,
)
batch_token_iterator = sample_sequence_batch(
model,
tokenizer,
context_tokens_tensor,
context_length_tensor,
attention_mask,
position_ids,
seq_length=seq_length,
out_seq_length=out_seq_length,
return_scores=return_scores,
prompt_length=prompt_length,
bad_ids=bad_ids,
temperature=temperature,
topp=topp,
topk=topk,
greedy=greedy,
)
for tokens, lengths in batch_token_iterator:
context_length += 1
if tokens is not None:
yield tokens[:, :context_length], lengths
else:
yield None, None
def switch(val1, val2, boolean):
boolean = boolean.type_as(val1)
return (1 - boolean) * val1 + boolean * val2
def sample_sequence_batch(
model,
tokenizer,
context_tokens,
context_lengths,
attention_mask,
position_ids,
seq_length,
out_seq_length,
maxlen=None,
return_scores: bool = False,
prompt_length: int = None,
bad_ids: List = None,
temperature: float = 1.0,
topp: float = 1.0,
topk: int = 0.0,
recompute: bool = False,
greedy: bool = False,
):
model.eval()
with torch.no_grad():
context_length = context_lengths.min().item()
eos_id = tokenizer.eos_token_id
counter = 0
org_context_length = context_length
layer_past = None
batch_size = context_tokens.size(0)
is_done = torch.zeros([batch_size]).byte().cuda()
tokens = context_tokens
if maxlen is None:
maxlen = seq_length - 1
if maxlen > (org_context_length + out_seq_length):
maxlen = org_context_length + out_seq_length
lengths = torch.ones([batch_size]).long().cuda() * maxlen
if return_scores:
scores = torch.zeros([batch_size]).float().cuda()
while context_length <= (maxlen):
if recompute:
logits = model(tokens,
position_ids,
attention_mask,
prompt_length=prompt_length,
context_length=context_length,
)
logits = logits[:, context_length - 1, :]
else:
if counter == 0:
tokens2use = tokens[:, :context_length]
positions2use = position_ids[:, :context_length]
else:
tokens2use = tokens[:, context_length - 1].view(
batch_size, -1)
positions2use = position_ids[:, context_length - 1].view(
batch_size, -1)
logits, layer_past = model(tokens2use,
positions2use,
attention_mask,
layer_past=layer_past,
get_key_value=True,
prompt_length=prompt_length,
context_length=context_length,
)
logits = logits[:, -1].view(batch_size, -1).contiguous()
if bad_ids is not None:
for bad_id in bad_ids:
logits[:, bad_id] = -10000
if greedy:
prev = torch.argmax(logits, dim=-1).view(-1)
else:
logits = logits.float()
if return_scores:
orig_log_probs = torch.log_softmax(logits, dim=-1)
logits /= temperature
logits = top_k_logits(logits, top_k=topk, top_p=topp)
log_probs = F.softmax(logits, dim=-1)
prev = torch.multinomial(log_probs, num_samples=1).view(-1)
started = context_lengths <= context_length
new_tokens = switch(tokens[:, context_length].view(-1), prev, started)
if not greedy and return_scores:
indices = prev.view(-1, 1)
new_scores = orig_log_probs.gather(1, indices).view(-1)
new_scores = new_scores * started
new_scores = new_scores * is_done.bool().logical_not()
scores += new_scores
tokens[:, context_length] = new_tokens
done_token = (prev == eos_id).byte() & started.byte()
just_finished = (done_token & ~is_done).bool()
lengths[just_finished.view(-1)] = context_length
is_done = is_done | done_token
done = torch.all(is_done)
if return_scores:
yield tokens, (lengths, scores)
else:
yield tokens, lengths
context_length += 1
counter += 1
if done:
break

@ -0,0 +1,38 @@
# This script is used to test the inference of CodeGeeX.
GPU=$1
PROMPT_FILE=$2
SCRIPT_PATH=$(realpath "$0")
SCRIPT_DIR=$(dirname "$SCRIPT_PATH")
MAIN_DIR=$(dirname "$SCRIPT_DIR")
TOKENIZER_PATH="$MAIN_DIR/codegeex/tokenizer/"
# import model configuration
source "$MAIN_DIR/configs/codegeex_13b.sh"
# export CUDA settings
if [ -z "$GPU" ]; then
GPU=0
fi
export CUDA_HOME=/usr/local/cuda-11.1/
export CUDA_VISIBLE_DEVICES=$GPU
if [ -z "$PROMPT_FILE" ]; then
PROMPT_FILE=$MAIN_DIR/tests/test_prompt.txt
fi
# remove --greedy if using sampling
CMD="python $MAIN_DIR/tests/test_inference_quantized.py \
--prompt-file $PROMPT_FILE \
--tokenizer-path $TOKENIZER_PATH \
--micro-batch-size 1 \
--out-seq-length 1024 \
--temperature 0.2 \
--top-p 0.95 \
--top-k 0 \
$MODEL_ARGS"
echo "$CMD"
eval "$CMD"

@ -0,0 +1,201 @@
import os
import copy
import time
import torch
import random
import argparse
import numpy as np
from codegeex.torch.inference import get_token_stream
from codegeex.torch import CodeGeeXModel
from codegeex.tokenizer import CodeGeeXTokenizer
from codegeex.quantization import quantize
def model_provider(args):
"""Build the model."""
model = CodeGeeXModel(
args.hidden_size,
args.num_layers,
args.num_attention_heads,
args.padded_vocab_size,
args.max_position_embeddings
)
return model
def add_code_generation_args(parser):
group = parser.add_argument_group(title="code generation")
group.add_argument(
"--num-layers",
type=int,
default=39,
)
group.add_argument(
"--hidden-size",
type=int,
default=5120,
)
group.add_argument(
"--num-attention-heads",
type=int,
default=40,
)
group.add_argument(
"--padded-vocab-size",
type=int,
default=52224,
)
group.add_argument(
"--max-position-embeddings",
type=int,
default=2048,
)
group.add_argument(
"--temperature",
type=float,
default=1.0,
help="Sampling temperature.",
)
group.add_argument(
"--greedy",
action="store_true",
default=False,
help="Use greedy sampling.",
)
group.add_argument(
"--top-p",
type=float,
default=0.0,
help="Top p sampling.",
)
group.add_argument(
"--top-k",
type=int,
default=0,
help="Top k sampling.",
)
group.add_argument(
"--out-seq-length",
type=int,
default=2048,
help="Size of the output generated text.",
)
group.add_argument(
"--prompt-file",
type=str,
default="./test_prompt.txt",
)
group.add_argument(
"--tokenizer-path",
type=str,
default="./tokenizer",
)
group.add_argument(
"--load",
type=str,
)
group.add_argument(
"--state-dict-path",
type=str,
)
group.add_argument(
"--micro-batch-size",
type=int,
default=1,
)
return parser
def main():
parser = argparse.ArgumentParser()
parser = add_code_generation_args(parser)
args, _ = parser.parse_known_args()
print("Building CodeGeeX model ...")
model = model_provider(args)
print("Loading tokenizer ...")
tokenizer = CodeGeeXTokenizer(
tokenizer_path=args.tokenizer_path,
mode="codegeex-13b")
print("Loading state dict ...")
state_dict = torch.load(args.load, map_location="cpu")
state_dict = state_dict["module"]
print("Building CodeGeeX model ...")
model = model_provider(args)
model.load_state_dict(state_dict)
model.eval()
model.half()
model = quantize(model, weight_bit_width=8)
model.cuda()
with open(args.prompt_file, "r") as f:
prompt = f.readlines()
prompt = "".join(prompt)
times = {}
out_seq_lengths = [args.out_seq_length]
micro_batch_size = args.micro_batch_size
seq_length = args.max_position_embeddings
for out_seq_length in out_seq_lengths:
print(f"Generating with out_seq_len {out_seq_length}...")
times[out_seq_length] = []
for prompt in [prompt]:
t0 = time.perf_counter()
tokens = tokenizer.encode_code(prompt)
print(tokens)
print("Current prompt:")
print(prompt)
n_token_prompt = len(tokens)
print("N_token_prompt:", n_token_prompt)
token_stream = get_token_stream(
model,
tokenizer,
seq_length,
out_seq_length,
[copy.deepcopy(tokens) for _ in range(micro_batch_size)],
micro_batch_size=micro_batch_size,
topk=args.top_k,
topp=args.top_p,
temperature=args.temperature,
greedy=args.greedy,
)
is_finished = [False for _ in range(micro_batch_size)]
for i, generated in enumerate(token_stream):
generated_tokens = generated[0]
for j in range(micro_batch_size):
if is_finished[j]:
continue
if generated_tokens[j].cpu().numpy()[-1] == tokenizer.eos_token_id or len(
generated_tokens[j]) >= out_seq_length:
is_finished[j] = True
generated_tokens_ = generated_tokens[j].cpu().numpy().tolist()
generated_code = tokenizer.decode_code(generated_tokens_[n_token_prompt:])
generated_code = "".join(generated_code)
t1 = time.perf_counter()
print("Total generation time:", t1 - t0, "# Tokens:", len(generated_tokens_) - n_token_prompt)
print(f"{(t1 - t0) / (len(generated_tokens_) - n_token_prompt)}s/token")
times[out_seq_length].append(t1 - t0)
print("================================= Generated code:")
print(generated_code)
if all(is_finished):
break
print(times)
for out_seq_length in times.keys():
print(out_seq_length, np.mean(times[out_seq_length]))
print("Generation finished.")
if __name__ == "__main__":
main()
Loading…
Cancel
Save