mirror of https://github.com/THUDM/CodeGeeX.git
Add int8 quantization
parent
601b3aa6eb
commit
a2b6a15321
@ -0,0 +1,99 @@
|
||||
import pkg_resources
|
||||
import torch
|
||||
import ctypes
|
||||
|
||||
from typing import List
|
||||
from cpm_kernels.kernels.base import LazyKernelCModule, KernelFunction, round_up
|
||||
|
||||
RESOURCE_PACKAGE_NAME = __name__
|
||||
|
||||
|
||||
class Kernel:
|
||||
def __init__(self, filename: str, function_names: List[str]):
|
||||
filename = filename + ".fatbin"
|
||||
if not pkg_resources.resource_exists(RESOURCE_PACKAGE_NAME, filename):
|
||||
raise RuntimeError("File `%s` not found in `%s`" % (filename, RESOURCE_PACKAGE_NAME))
|
||||
self.filename = filename
|
||||
self.code = pkg_resources.resource_string(RESOURCE_PACKAGE_NAME, filename)
|
||||
self._function_names = function_names
|
||||
self._cmodule = LazyKernelCModule(self.code)
|
||||
|
||||
for name in self._function_names:
|
||||
setattr(self, name, KernelFunction(self._cmodule, name))
|
||||
|
||||
|
||||
kernels = Kernel(
|
||||
"quantization",
|
||||
[
|
||||
"int4WeightCompression",
|
||||
"int4WeightExtractionFloat",
|
||||
"int4WeightExtractionHalf",
|
||||
"int8WeightExtractionFloat",
|
||||
"int8WeightExtractionHalf",
|
||||
],
|
||||
)
|
||||
|
||||
|
||||
def compress_int4_weight(weight: torch.Tensor): # (n, m)
|
||||
with torch.cuda.device(weight.device):
|
||||
n, m = weight.size(0), weight.size(1)
|
||||
assert m % 2 == 0
|
||||
m = m // 2
|
||||
out = torch.empty(n, m, dtype=torch.int8, device="cuda")
|
||||
stream = torch.cuda.current_stream()
|
||||
|
||||
gridDim = (n, 1, 1)
|
||||
blockDim = (min(round_up(m, 32), 1024), 1, 1)
|
||||
|
||||
kernels.int4WeightCompression(
|
||||
gridDim,
|
||||
blockDim,
|
||||
0,
|
||||
stream,
|
||||
[ctypes.c_void_p(weight.data_ptr()), ctypes.c_void_p(out.data_ptr()), ctypes.c_int32(n), ctypes.c_int32(m)],
|
||||
)
|
||||
return out
|
||||
|
||||
|
||||
def extract_weight_to_half(weight: torch.Tensor, scale_list: torch.Tensor, source_bit_width: int):
|
||||
if source_bit_width == 8:
|
||||
func = kernels.int8WeightExtractionHalf
|
||||
elif source_bit_width == 4:
|
||||
func = kernels.int4WeightExtractionHalf
|
||||
else:
|
||||
assert False, "Unsupported bit-width"
|
||||
|
||||
with torch.cuda.device(weight.device):
|
||||
n, m = weight.size(0), weight.size(1)
|
||||
out = torch.empty(n, m * (8 // source_bit_width), dtype=torch.half, device="cuda")
|
||||
stream = torch.cuda.current_stream()
|
||||
|
||||
gridDim = (n, 1, 1)
|
||||
blockDim = (min(round_up(m, 32), 1024), 1, 1)
|
||||
|
||||
func(
|
||||
gridDim,
|
||||
blockDim,
|
||||
0,
|
||||
stream,
|
||||
[
|
||||
ctypes.c_void_p(weight.data_ptr()),
|
||||
ctypes.c_void_p(scale_list.data_ptr()),
|
||||
ctypes.c_void_p(out.data_ptr()),
|
||||
ctypes.c_int32(n),
|
||||
ctypes.c_int32(m),
|
||||
],
|
||||
)
|
||||
return out
|
||||
|
||||
|
||||
if __name__ == "__main__":
|
||||
weight = torch.randn(4, 32).to(torch.int8).cuda()
|
||||
scale = torch.ones(weight.size(0)).to(torch.half).cuda()
|
||||
|
||||
print(weight)
|
||||
b = compress_int4_weight(weight)
|
||||
print(b)
|
||||
|
||||
a = extract_weight_to_half(b, scale, source_bit_width=4)
|
||||
print(a)
|
Binary file not shown.
@ -0,0 +1 @@
|
||||
from .quantize import quantize
|
@ -0,0 +1,139 @@
|
||||
import torch
|
||||
|
||||
from torch.nn.parameter import Parameter
|
||||
from codegeex.kernels import extract_weight_to_half
|
||||
|
||||
|
||||
class W8A16Linear(torch.autograd.Function):
|
||||
@staticmethod
|
||||
def forward(ctx, inp: torch.Tensor, quant_w: torch.Tensor, scale_w: torch.Tensor, weight_bit_width):
|
||||
ctx.inp_shape = inp.size()
|
||||
ctx.weight_shape = quant_w.size()
|
||||
ctx.weight_bit_width = weight_bit_width
|
||||
out_features = quant_w.size(0)
|
||||
inp = inp.contiguous().view(-1, inp.size(-1))
|
||||
weight = extract_weight_to_half(quant_w, scale_w, weight_bit_width)
|
||||
output = inp.mm(weight.t())
|
||||
ctx.save_for_backward(inp, quant_w, scale_w)
|
||||
return output.view(*(ctx.inp_shape[:-1] + (out_features,)))
|
||||
|
||||
@staticmethod
|
||||
def backward(ctx, grad_output: torch.Tensor):
|
||||
inp, quant_w, scale_w = ctx.saved_tensors
|
||||
weight = extract_weight_to_half(quant_w, scale_w, ctx.weight_bit_width)
|
||||
grad_output = grad_output.contiguous().view(-1, weight.size(0))
|
||||
grad_input = grad_output.mm(weight)
|
||||
grad_weight = grad_output.t().mm(inp)
|
||||
return grad_input.view(ctx.inp_shape), grad_weight.view(ctx.weight_shape), None
|
||||
|
||||
|
||||
class QuantizedLinear(torch.nn.Module):
|
||||
def __init__(
|
||||
self,
|
||||
in_features: int,
|
||||
out_features: int,
|
||||
weight_bit_width: int,
|
||||
weight: torch.Tensor = None,
|
||||
bias: torch.Tensor = None,
|
||||
*args,
|
||||
**kwargs
|
||||
):
|
||||
super(QuantizedLinear, self).__init__()
|
||||
|
||||
self.in_features = in_features
|
||||
self.out_features = out_features
|
||||
self.weight_bit_width = weight_bit_width
|
||||
|
||||
if weight is None:
|
||||
self.weight = torch.empty(
|
||||
shape[0], shape[1] * weight_bit_width // 8, dtype=torch.int8, device=kwargs["device"]
|
||||
)
|
||||
self.weight_scale = torch.empty(shape[0], dtype=kwargs["params_dtype"], device=kwargs["device"])
|
||||
else:
|
||||
self.weight_scale = (weight.abs().max(dim=-1).values / ((2 ** (weight_bit_width - 1)) - 1)).half()
|
||||
self.weight = torch.round(weight / self.weight_scale[:, None]).to(torch.int8)
|
||||
if weight_bit_width == 4:
|
||||
self.weight = compress_int4_weight(self.weight)
|
||||
|
||||
if bias is None:
|
||||
self.register_parameter('bias', None)
|
||||
else:
|
||||
self.bias = bias
|
||||
|
||||
self.weight = Parameter(self.weight.to(kwargs["device"]), requires_grad=False)
|
||||
self.weight_scale = Parameter(self.weight_scale.to(kwargs["device"]), requires_grad=False)
|
||||
|
||||
def forward(self, input_):
|
||||
# Matrix multiply.
|
||||
output = W8A16Linear.apply(input_, self.weight, self.weight_scale, self.weight_bit_width)
|
||||
if self.bias is not None:
|
||||
output = output + self.bias
|
||||
|
||||
return output
|
||||
|
||||
|
||||
def quantize(model, weight_bit_width):
|
||||
"""Replace fp16 linear with quantized linear"""
|
||||
|
||||
for i in range(len(model.language_model.transformer.layers) + 1):
|
||||
if i == len(model.language_model.transformer.layers):
|
||||
layer = model.language_model.transformer.topQueryLayer
|
||||
else:
|
||||
layer = model.language_model.transformer.layers[i]
|
||||
|
||||
layer.attention.query = QuantizedLinear(
|
||||
in_features=layer.attention.query.weight.shape[0],
|
||||
out_features=layer.attention.query.weight.shape[1],
|
||||
weight_bit_width=weight_bit_width,
|
||||
weight=layer.attention.query.weight.to(torch.cuda.current_device()),
|
||||
bias=layer.attention.query.bias.to(torch.cuda.current_device()),
|
||||
params_dtype=torch.half,
|
||||
device=layer.attention.query.weight.device,
|
||||
)
|
||||
layer.attention.value = QuantizedLinear(
|
||||
in_features=layer.attention.value.weight.shape[0],
|
||||
out_features=layer.attention.value.weight.shape[1],
|
||||
weight_bit_width=weight_bit_width,
|
||||
weight=layer.attention.value.weight.to(torch.cuda.current_device()),
|
||||
bias=layer.attention.value.bias.to(torch.cuda.current_device()),
|
||||
params_dtype=torch.half,
|
||||
device=layer.attention.value.weight.device,
|
||||
)
|
||||
layer.attention.key = QuantizedLinear(
|
||||
in_features=layer.attention.key.weight.shape[0],
|
||||
out_features=layer.attention.key.weight.shape[1],
|
||||
weight_bit_width=weight_bit_width,
|
||||
weight=layer.attention.key.weight.to(torch.cuda.current_device()),
|
||||
bias=layer.attention.key.bias.to(torch.cuda.current_device()),
|
||||
params_dtype=torch.half,
|
||||
device=layer.attention.key.weight.device,
|
||||
)
|
||||
layer.attention.dense = QuantizedLinear(
|
||||
in_features=layer.attention.dense.weight.shape[0],
|
||||
out_features=layer.attention.dense.weight.shape[1],
|
||||
weight_bit_width=weight_bit_width,
|
||||
weight=layer.attention.dense.weight.to(torch.cuda.current_device()),
|
||||
bias=layer.attention.dense.bias.to(torch.cuda.current_device()),
|
||||
params_dtype=torch.half,
|
||||
device=layer.attention.dense.weight.device,
|
||||
)
|
||||
layer.mlp.dense_h_to_4h = QuantizedLinear(
|
||||
in_features=layer.mlp.dense_h_to_4h.weight.shape[0],
|
||||
out_features=layer.mlp.dense_h_to_4h.weight.shape[1],
|
||||
weight_bit_width=weight_bit_width,
|
||||
weight=layer.mlp.dense_h_to_4h.weight.to(torch.cuda.current_device()),
|
||||
bias=layer.mlp.dense_h_to_4h.bias.to(torch.cuda.current_device()),
|
||||
params_dtype=torch.half,
|
||||
device=layer.mlp.dense_h_to_4h.weight.device,
|
||||
)
|
||||
layer.mlp.dense_4h_to_h = QuantizedLinear(
|
||||
in_features=layer.mlp.dense_4h_to_h.weight.shape[0],
|
||||
out_features=layer.mlp.dense_4h_to_h.weight.shape[1],
|
||||
weight_bit_width=weight_bit_width,
|
||||
weight=layer.mlp.dense_4h_to_h.weight.to(torch.cuda.current_device()),
|
||||
bias=layer.mlp.dense_4h_to_h.bias.to(torch.cuda.current_device()),
|
||||
params_dtype=torch.half,
|
||||
device=layer.mlp.dense_4h_to_h.weight.device,
|
||||
)
|
||||
|
||||
return model
|
@ -0,0 +1 @@
|
||||
from .tokenizer import CodeGeeXTokenizer
|
@ -0,0 +1,87 @@
|
||||
import torch
|
||||
from typing import *
|
||||
from transformers import AutoTokenizer
|
||||
from transformers.models.gpt2 import GPT2TokenizerFast
|
||||
|
||||
|
||||
def encode_whitespaces(text, start_extra_id: int, max_len: int):
|
||||
""" Encode whitespaces to extra tokens in GPT-J.
|
||||
|
||||
>>> encode_whitespaces('a\\n b\\n c', 10, 10)
|
||||
'a\\n<|extratoken_10|>b\\n<|extratoken_11|>c'
|
||||
"""
|
||||
|
||||
def push_acc_space(acc_len: int, text: str):
|
||||
if acc_len == 0:
|
||||
return text
|
||||
if acc_len == 1:
|
||||
return text + ' '
|
||||
assert acc_len <= max_len, f'Max whitespace run length {max_len}, but found {acc_len}'
|
||||
extra_id = start_extra_id - 2 + acc_len
|
||||
extra_token = f'<|extratoken_{extra_id}|>'
|
||||
return text + extra_token
|
||||
|
||||
acc_len = 0
|
||||
res = ''
|
||||
for ch in text:
|
||||
if ch == ' ':
|
||||
acc_len += 1
|
||||
if acc_len == max_len:
|
||||
res = push_acc_space(acc_len, res)
|
||||
acc_len = 0
|
||||
else:
|
||||
res = push_acc_space(acc_len, res)
|
||||
acc_len = 0
|
||||
res = res + ch
|
||||
|
||||
res = push_acc_space(acc_len, res)
|
||||
|
||||
return res
|
||||
|
||||
|
||||
def decode_whitespaces(text: str, start_extra_id: int, max_len: int):
|
||||
""" Decode the whitespace-encoded strings produced by encode_whitespace.
|
||||
|
||||
>>> text = 'a\\n b\\n c'
|
||||
>>> s, l = 10, 10
|
||||
>>> text == decode_whitespaces(encode_whitespaces(text, s, l), s, l)
|
||||
True
|
||||
"""
|
||||
for l in range(2, max_len + 1):
|
||||
token_id = start_extra_id - 2 + l
|
||||
token = f'<|extratoken_{token_id}|>'
|
||||
text = text.replace(token, ' ' * l)
|
||||
return text
|
||||
|
||||
|
||||
class CodeGeeXTokenizer(object):
|
||||
def __init__(
|
||||
self,
|
||||
tokenizer: GPT2TokenizerFast = None,
|
||||
tokenizer_path: str = "EleutherAI/gpt-j-6B",
|
||||
start_extra_id: int = 10,
|
||||
max_len : int = 10,
|
||||
mode='codegeex-13b',
|
||||
dict_file: str = None,
|
||||
):
|
||||
self.tokenizer = tokenizer if tokenizer is not None else AutoTokenizer.from_pretrained(tokenizer_path)
|
||||
if mode not in ['codegeex-13b']:
|
||||
raise ValueError(f"Invalid mode {mode}, choose from ['codegeex-13b']")
|
||||
self.start_extra_id = start_extra_id
|
||||
self.max_len = max_len
|
||||
self.mode = mode
|
||||
self.eos_token_id = self.tokenizer.eos_token_id
|
||||
|
||||
def encode_code(self, code: str):
|
||||
if self.mode == 'codegeex-13b':
|
||||
code = encode_whitespaces(code, self.start_extra_id, self.max_len)
|
||||
input_ids = self.tokenizer(code, is_split_into_words=False).input_ids
|
||||
|
||||
return input_ids
|
||||
|
||||
def decode_code(self, input_ids):
|
||||
if self.mode == 'codegeex-13b':
|
||||
text = self.tokenizer.decode(input_ids, skip_special_tokens=False)
|
||||
output_code = decode_whitespaces(text, self.start_extra_id, self.max_len)
|
||||
|
||||
return output_code
|
@ -0,0 +1 @@
|
||||
from .codegeex_model import CodeGeeXModel
|
File diff suppressed because it is too large
Load Diff
@ -0,0 +1,324 @@
|
||||
import copy
|
||||
import json
|
||||
import os
|
||||
import time
|
||||
from typing import *
|
||||
|
||||
import torch
|
||||
import torch.nn.functional as F
|
||||
from dataclasses import dataclass
|
||||
|
||||
|
||||
def get_ltor_masks_and_position_ids(
|
||||
data,
|
||||
eod_token,
|
||||
reset_position_ids,
|
||||
reset_attention_mask,
|
||||
):
|
||||
"""Build masks and position id for left to right model."""
|
||||
|
||||
# Extract batch size and sequence length.
|
||||
micro_batch_size, seq_length = data.size()
|
||||
|
||||
# Attention mask (lower triangular).
|
||||
if reset_attention_mask:
|
||||
att_mask_batch = micro_batch_size
|
||||
else:
|
||||
att_mask_batch = 1
|
||||
attention_mask = torch.tril(
|
||||
torch.ones((att_mask_batch, seq_length, seq_length), device=data.device)
|
||||
).view(att_mask_batch, 1, seq_length, seq_length)
|
||||
|
||||
# Position ids.
|
||||
position_ids = torch.arange(seq_length, dtype=torch.long, device=data.device)
|
||||
position_ids = position_ids.unsqueeze(0).expand_as(data)
|
||||
# We need to clone as the ids will be modifed based on batch index.
|
||||
if reset_position_ids:
|
||||
position_ids = position_ids.clone()
|
||||
|
||||
if reset_position_ids or reset_attention_mask:
|
||||
# Loop through the batches:
|
||||
for b in range(micro_batch_size):
|
||||
|
||||
# Find indecies where EOD token is.
|
||||
eod_index = position_ids[b, data[b] == eod_token]
|
||||
# Detach indecies from positions if going to modify positions.
|
||||
if reset_position_ids:
|
||||
eod_index = eod_index.clone()
|
||||
|
||||
# Loop through EOD indecies:
|
||||
prev_index = 0
|
||||
for j in range(eod_index.size()[0]):
|
||||
i = eod_index[j]
|
||||
# Mask attention loss.
|
||||
if reset_attention_mask:
|
||||
attention_mask[b, 0, (i + 1) :, : (i + 1)] = 0
|
||||
# Reset positions.
|
||||
if reset_position_ids:
|
||||
position_ids[b, (i + 1) :] -= i + 1 - prev_index
|
||||
prev_index = i + 1
|
||||
|
||||
# Convert attention mask to binary:
|
||||
attention_mask = attention_mask < 0.5
|
||||
|
||||
return attention_mask, position_ids
|
||||
|
||||
|
||||
def get_batch(
|
||||
context_tokens,
|
||||
micro_batch_size,
|
||||
eod_token,
|
||||
reset_position_ids=False,
|
||||
reset_attention_mask=False,
|
||||
):
|
||||
"""Generate batch from context tokens."""
|
||||
tokens = context_tokens.view(micro_batch_size, -1).contiguous().cuda()
|
||||
# Get the attention mask and postition ids.
|
||||
attention_mask, position_ids = get_ltor_masks_and_position_ids(
|
||||
tokens,
|
||||
eod_token,
|
||||
reset_position_ids,
|
||||
reset_attention_mask,
|
||||
)
|
||||
|
||||
return tokens, attention_mask, position_ids
|
||||
|
||||
|
||||
def top_k_logits(logits, top_k=0, top_p=0.0, filter_value=-float("Inf")):
|
||||
"""This function has been mostly taken from huggingface conversational
|
||||
ai code at
|
||||
https://medium.com/huggingface/how-to-build-a-state-of-the-art-
|
||||
conversational-ai-with-transfer-learning-2d818ac26313"""
|
||||
|
||||
if top_k > 0:
|
||||
# Remove all tokens with a probability less than the
|
||||
# last token of the top-k
|
||||
indices_to_remove = logits < torch.topk(logits, top_k)[0][..., -1, None]
|
||||
logits[indices_to_remove] = filter_value
|
||||
|
||||
if top_p > 0.0:
|
||||
# Cconvert to 1D
|
||||
sorted_logits, sorted_indices = torch.sort(logits, descending=True, dim=-1)
|
||||
cumulative_probs = torch.cumsum(F.softmax(sorted_logits, dim=-1), dim=-1)
|
||||
|
||||
# Remove tokens with cumulative probability above the threshold
|
||||
sorted_indices_to_remove = cumulative_probs > top_p
|
||||
# Shift the indices to the right to keep also the first token
|
||||
# above the threshold
|
||||
sorted_indices_to_remove[..., 1:] = sorted_indices_to_remove[..., :-1].clone()
|
||||
sorted_indices_to_remove[..., 0] = 0
|
||||
for i in range(sorted_indices.size(0)):
|
||||
indices_to_remove = sorted_indices[i][sorted_indices_to_remove[i]]
|
||||
logits[i][indices_to_remove] = filter_value
|
||||
|
||||
return logits
|
||||
|
||||
|
||||
def pad_batch(batch, pad_id, seq_length):
|
||||
context_lengths = []
|
||||
for tokens in batch:
|
||||
context_length = len(tokens)
|
||||
if context_length < seq_length:
|
||||
tokens.extend([pad_id] * (seq_length - context_length))
|
||||
context_lengths.append(context_length)
|
||||
return batch, context_lengths
|
||||
|
||||
|
||||
def forward_step(
|
||||
model,
|
||||
tokens,
|
||||
seq_length,
|
||||
position_ids,
|
||||
attention_mask,
|
||||
layer_past=None,
|
||||
get_key_value=None,
|
||||
prompt_length=None,
|
||||
context_length=None,
|
||||
):
|
||||
# Forward pass through the model.
|
||||
output_tensor = model(
|
||||
tokens,
|
||||
position_ids,
|
||||
attention_mask,
|
||||
layer_past=layer_past,
|
||||
get_key_value=get_key_value,
|
||||
prompt_length=prompt_length,
|
||||
context_length=context_length,
|
||||
)
|
||||
|
||||
if get_key_value:
|
||||
output_tensor, layer_past = output_tensor
|
||||
|
||||
if get_key_value:
|
||||
return output_tensor, layer_past
|
||||
|
||||
return output_tensor
|
||||
|
||||
|
||||
def get_token_stream(
|
||||
model,
|
||||
tokenizer,
|
||||
seq_length,
|
||||
out_seq_length,
|
||||
context_tokens,
|
||||
return_scores: bool = False,
|
||||
prompt_length: int = None,
|
||||
micro_batch_size: int = None,
|
||||
bad_ids: List = None,
|
||||
temperature: float = 1.0,
|
||||
topp: float = 1.0,
|
||||
topk: int = 0.0,
|
||||
greedy: bool = False,
|
||||
):
|
||||
context_tokens, context_lengths = pad_batch(context_tokens, tokenizer.eos_token_id, seq_length)
|
||||
|
||||
context_tokens_tensor = torch.cuda.LongTensor(context_tokens)
|
||||
context_length_tensor = torch.cuda.LongTensor(context_lengths)
|
||||
context_length = context_length_tensor.min().item()
|
||||
tokens, attention_mask, position_ids = get_batch(
|
||||
context_tokens_tensor,
|
||||
micro_batch_size,
|
||||
tokenizer.eos_token_id,
|
||||
)
|
||||
|
||||
batch_token_iterator = sample_sequence_batch(
|
||||
model,
|
||||
tokenizer,
|
||||
context_tokens_tensor,
|
||||
context_length_tensor,
|
||||
attention_mask,
|
||||
position_ids,
|
||||
seq_length=seq_length,
|
||||
out_seq_length=out_seq_length,
|
||||
return_scores=return_scores,
|
||||
prompt_length=prompt_length,
|
||||
bad_ids=bad_ids,
|
||||
temperature=temperature,
|
||||
topp=topp,
|
||||
topk=topk,
|
||||
greedy=greedy,
|
||||
)
|
||||
|
||||
for tokens, lengths in batch_token_iterator:
|
||||
context_length += 1
|
||||
if tokens is not None:
|
||||
yield tokens[:, :context_length], lengths
|
||||
else:
|
||||
yield None, None
|
||||
|
||||
|
||||
def switch(val1, val2, boolean):
|
||||
boolean = boolean.type_as(val1)
|
||||
return (1 - boolean) * val1 + boolean * val2
|
||||
|
||||
|
||||
def sample_sequence_batch(
|
||||
model,
|
||||
tokenizer,
|
||||
context_tokens,
|
||||
context_lengths,
|
||||
attention_mask,
|
||||
position_ids,
|
||||
seq_length,
|
||||
out_seq_length,
|
||||
maxlen=None,
|
||||
return_scores: bool = False,
|
||||
prompt_length: int = None,
|
||||
bad_ids: List = None,
|
||||
temperature: float = 1.0,
|
||||
topp: float = 1.0,
|
||||
topk: int = 0.0,
|
||||
recompute: bool = False,
|
||||
greedy: bool = False,
|
||||
):
|
||||
model.eval()
|
||||
with torch.no_grad():
|
||||
context_length = context_lengths.min().item()
|
||||
eos_id = tokenizer.eos_token_id
|
||||
|
||||
counter = 0
|
||||
org_context_length = context_length
|
||||
|
||||
layer_past = None
|
||||
batch_size = context_tokens.size(0)
|
||||
is_done = torch.zeros([batch_size]).byte().cuda()
|
||||
tokens = context_tokens
|
||||
if maxlen is None:
|
||||
maxlen = seq_length - 1
|
||||
if maxlen > (org_context_length + out_seq_length):
|
||||
maxlen = org_context_length + out_seq_length
|
||||
|
||||
lengths = torch.ones([batch_size]).long().cuda() * maxlen
|
||||
if return_scores:
|
||||
scores = torch.zeros([batch_size]).float().cuda()
|
||||
|
||||
while context_length <= (maxlen):
|
||||
|
||||
if recompute:
|
||||
logits = model(tokens,
|
||||
position_ids,
|
||||
attention_mask,
|
||||
prompt_length=prompt_length,
|
||||
context_length=context_length,
|
||||
)
|
||||
logits = logits[:, context_length - 1, :]
|
||||
else:
|
||||
if counter == 0:
|
||||
tokens2use = tokens[:, :context_length]
|
||||
positions2use = position_ids[:, :context_length]
|
||||
else:
|
||||
tokens2use = tokens[:, context_length - 1].view(
|
||||
batch_size, -1)
|
||||
positions2use = position_ids[:, context_length - 1].view(
|
||||
batch_size, -1)
|
||||
logits, layer_past = model(tokens2use,
|
||||
positions2use,
|
||||
attention_mask,
|
||||
layer_past=layer_past,
|
||||
get_key_value=True,
|
||||
prompt_length=prompt_length,
|
||||
context_length=context_length,
|
||||
)
|
||||
logits = logits[:, -1].view(batch_size, -1).contiguous()
|
||||
|
||||
if bad_ids is not None:
|
||||
for bad_id in bad_ids:
|
||||
logits[:, bad_id] = -10000
|
||||
if greedy:
|
||||
prev = torch.argmax(logits, dim=-1).view(-1)
|
||||
else:
|
||||
logits = logits.float()
|
||||
if return_scores:
|
||||
orig_log_probs = torch.log_softmax(logits, dim=-1)
|
||||
logits /= temperature
|
||||
logits = top_k_logits(logits, top_k=topk, top_p=topp)
|
||||
log_probs = F.softmax(logits, dim=-1)
|
||||
prev = torch.multinomial(log_probs, num_samples=1).view(-1)
|
||||
|
||||
started = context_lengths <= context_length
|
||||
|
||||
new_tokens = switch(tokens[:, context_length].view(-1), prev, started)
|
||||
|
||||
if not greedy and return_scores:
|
||||
indices = prev.view(-1, 1)
|
||||
new_scores = orig_log_probs.gather(1, indices).view(-1)
|
||||
new_scores = new_scores * started
|
||||
new_scores = new_scores * is_done.bool().logical_not()
|
||||
scores += new_scores
|
||||
|
||||
tokens[:, context_length] = new_tokens
|
||||
done_token = (prev == eos_id).byte() & started.byte()
|
||||
just_finished = (done_token & ~is_done).bool()
|
||||
lengths[just_finished.view(-1)] = context_length
|
||||
is_done = is_done | done_token
|
||||
done = torch.all(is_done)
|
||||
|
||||
if return_scores:
|
||||
yield tokens, (lengths, scores)
|
||||
else:
|
||||
yield tokens, lengths
|
||||
|
||||
context_length += 1
|
||||
counter += 1
|
||||
if done:
|
||||
break
|
@ -0,0 +1,38 @@
|
||||
# This script is used to test the inference of CodeGeeX.
|
||||
|
||||
GPU=$1
|
||||
PROMPT_FILE=$2
|
||||
|
||||
SCRIPT_PATH=$(realpath "$0")
|
||||
SCRIPT_DIR=$(dirname "$SCRIPT_PATH")
|
||||
MAIN_DIR=$(dirname "$SCRIPT_DIR")
|
||||
TOKENIZER_PATH="$MAIN_DIR/codegeex/tokenizer/"
|
||||
|
||||
# import model configuration
|
||||
source "$MAIN_DIR/configs/codegeex_13b.sh"
|
||||
|
||||
# export CUDA settings
|
||||
if [ -z "$GPU" ]; then
|
||||
GPU=0
|
||||
fi
|
||||
|
||||
export CUDA_HOME=/usr/local/cuda-11.1/
|
||||
export CUDA_VISIBLE_DEVICES=$GPU
|
||||
|
||||
if [ -z "$PROMPT_FILE" ]; then
|
||||
PROMPT_FILE=$MAIN_DIR/tests/test_prompt.txt
|
||||
fi
|
||||
|
||||
# remove --greedy if using sampling
|
||||
CMD="python $MAIN_DIR/tests/test_inference_quantized.py \
|
||||
--prompt-file $PROMPT_FILE \
|
||||
--tokenizer-path $TOKENIZER_PATH \
|
||||
--micro-batch-size 1 \
|
||||
--out-seq-length 1024 \
|
||||
--temperature 0.2 \
|
||||
--top-p 0.95 \
|
||||
--top-k 0 \
|
||||
$MODEL_ARGS"
|
||||
|
||||
echo "$CMD"
|
||||
eval "$CMD"
|
@ -0,0 +1,201 @@
|
||||
|
||||
import os
|
||||
import copy
|
||||
import time
|
||||
import torch
|
||||
import random
|
||||
import argparse
|
||||
import numpy as np
|
||||
|
||||
from codegeex.torch.inference import get_token_stream
|
||||
from codegeex.torch import CodeGeeXModel
|
||||
from codegeex.tokenizer import CodeGeeXTokenizer
|
||||
from codegeex.quantization import quantize
|
||||
|
||||
|
||||
def model_provider(args):
|
||||
"""Build the model."""
|
||||
|
||||
model = CodeGeeXModel(
|
||||
args.hidden_size,
|
||||
args.num_layers,
|
||||
args.num_attention_heads,
|
||||
args.padded_vocab_size,
|
||||
args.max_position_embeddings
|
||||
)
|
||||
|
||||
return model
|
||||
|
||||
|
||||
def add_code_generation_args(parser):
|
||||
group = parser.add_argument_group(title="code generation")
|
||||
group.add_argument(
|
||||
"--num-layers",
|
||||
type=int,
|
||||
default=39,
|
||||
)
|
||||
group.add_argument(
|
||||
"--hidden-size",
|
||||
type=int,
|
||||
default=5120,
|
||||
)
|
||||
group.add_argument(
|
||||
"--num-attention-heads",
|
||||
type=int,
|
||||
default=40,
|
||||
)
|
||||
group.add_argument(
|
||||
"--padded-vocab-size",
|
||||
type=int,
|
||||
default=52224,
|
||||
)
|
||||
group.add_argument(
|
||||
"--max-position-embeddings",
|
||||
type=int,
|
||||
default=2048,
|
||||
)
|
||||
group.add_argument(
|
||||
"--temperature",
|
||||
type=float,
|
||||
default=1.0,
|
||||
help="Sampling temperature.",
|
||||
)
|
||||
group.add_argument(
|
||||
"--greedy",
|
||||
action="store_true",
|
||||
default=False,
|
||||
help="Use greedy sampling.",
|
||||
)
|
||||
group.add_argument(
|
||||
"--top-p",
|
||||
type=float,
|
||||
default=0.0,
|
||||
help="Top p sampling.",
|
||||
)
|
||||
group.add_argument(
|
||||
"--top-k",
|
||||
type=int,
|
||||
default=0,
|
||||
help="Top k sampling.",
|
||||
)
|
||||
group.add_argument(
|
||||
"--out-seq-length",
|
||||
type=int,
|
||||
default=2048,
|
||||
help="Size of the output generated text.",
|
||||
)
|
||||
group.add_argument(
|
||||
"--prompt-file",
|
||||
type=str,
|
||||
default="./test_prompt.txt",
|
||||
)
|
||||
group.add_argument(
|
||||
"--tokenizer-path",
|
||||
type=str,
|
||||
default="./tokenizer",
|
||||
)
|
||||
group.add_argument(
|
||||
"--load",
|
||||
type=str,
|
||||
)
|
||||
group.add_argument(
|
||||
"--state-dict-path",
|
||||
type=str,
|
||||
)
|
||||
group.add_argument(
|
||||
"--micro-batch-size",
|
||||
type=int,
|
||||
default=1,
|
||||
)
|
||||
|
||||
return parser
|
||||
|
||||
|
||||
def main():
|
||||
parser = argparse.ArgumentParser()
|
||||
parser = add_code_generation_args(parser)
|
||||
args, _ = parser.parse_known_args()
|
||||
|
||||
print("Building CodeGeeX model ...")
|
||||
model = model_provider(args)
|
||||
|
||||
print("Loading tokenizer ...")
|
||||
tokenizer = CodeGeeXTokenizer(
|
||||
tokenizer_path=args.tokenizer_path,
|
||||
mode="codegeex-13b")
|
||||
|
||||
print("Loading state dict ...")
|
||||
state_dict = torch.load(args.load, map_location="cpu")
|
||||
state_dict = state_dict["module"]
|
||||
|
||||
print("Building CodeGeeX model ...")
|
||||
model = model_provider(args)
|
||||
model.load_state_dict(state_dict)
|
||||
model.eval()
|
||||
model.half()
|
||||
model = quantize(model, weight_bit_width=8)
|
||||
model.cuda()
|
||||
|
||||
with open(args.prompt_file, "r") as f:
|
||||
prompt = f.readlines()
|
||||
prompt = "".join(prompt)
|
||||
|
||||
times = {}
|
||||
out_seq_lengths = [args.out_seq_length]
|
||||
micro_batch_size = args.micro_batch_size
|
||||
seq_length = args.max_position_embeddings
|
||||
for out_seq_length in out_seq_lengths:
|
||||
print(f"Generating with out_seq_len {out_seq_length}...")
|
||||
|
||||
times[out_seq_length] = []
|
||||
for prompt in [prompt]:
|
||||
t0 = time.perf_counter()
|
||||
tokens = tokenizer.encode_code(prompt)
|
||||
print(tokens)
|
||||
print("Current prompt:")
|
||||
print(prompt)
|
||||
n_token_prompt = len(tokens)
|
||||
print("N_token_prompt:", n_token_prompt)
|
||||
token_stream = get_token_stream(
|
||||
model,
|
||||
tokenizer,
|
||||
seq_length,
|
||||
out_seq_length,
|
||||
[copy.deepcopy(tokens) for _ in range(micro_batch_size)],
|
||||
micro_batch_size=micro_batch_size,
|
||||
topk=args.top_k,
|
||||
topp=args.top_p,
|
||||
temperature=args.temperature,
|
||||
greedy=args.greedy,
|
||||
)
|
||||
is_finished = [False for _ in range(micro_batch_size)]
|
||||
for i, generated in enumerate(token_stream):
|
||||
generated_tokens = generated[0]
|
||||
for j in range(micro_batch_size):
|
||||
if is_finished[j]:
|
||||
continue
|
||||
if generated_tokens[j].cpu().numpy()[-1] == tokenizer.eos_token_id or len(
|
||||
generated_tokens[j]) >= out_seq_length:
|
||||
is_finished[j] = True
|
||||
generated_tokens_ = generated_tokens[j].cpu().numpy().tolist()
|
||||
generated_code = tokenizer.decode_code(generated_tokens_[n_token_prompt:])
|
||||
generated_code = "".join(generated_code)
|
||||
t1 = time.perf_counter()
|
||||
print("Total generation time:", t1 - t0, "# Tokens:", len(generated_tokens_) - n_token_prompt)
|
||||
print(f"{(t1 - t0) / (len(generated_tokens_) - n_token_prompt)}s/token")
|
||||
times[out_seq_length].append(t1 - t0)
|
||||
print("================================= Generated code:")
|
||||
print(generated_code)
|
||||
|
||||
if all(is_finished):
|
||||
break
|
||||
|
||||
print(times)
|
||||
for out_seq_length in times.keys():
|
||||
print(out_seq_length, np.mean(times[out_seq_length]))
|
||||
|
||||
print("Generation finished.")
|
||||
|
||||
|
||||
if __name__ == "__main__":
|
||||
main()
|
Loading…
Reference in New Issue