optimize and support int8

pull/87/head
BBuf 2 years ago
parent e97663b388
commit 63e77fae38

@ -2,7 +2,7 @@ import math
import oneflow as torch
import oneflow.nn.functional as F
from oneflow.nn.parameter import Parameter
from ..quantization import QuantizedLinear
def fast_gelu(x):
"""Mindspore's fast gelu implementation."""
@ -13,7 +13,6 @@ def fast_gelu(x):
class MLP(torch.nn.Module):
"""MLP.
MLP will take the input with h hidden state, project it to 4*h
hidden dimension, perform nonlinear transformation, and project the
state back into h hidden dimension. At the end, dropout is also
@ -52,7 +51,6 @@ class MLP(torch.nn.Module):
class SelfAttention(torch.nn.Module):
"""self-attention layer abstract class.
Self-attention layer takes input with size [b, s, h]
and returns output of the same size.
"""
@ -100,7 +98,7 @@ class SelfAttention(torch.nn.Module):
# Query, Key, and Value
# =====================
if hasattr(torch._C, 'grouped_matmul_bias'):
if hasattr(torch._C, 'grouped_matmul_bias') and not isinstance(self.query, QuantizedLinear):
query_layer, key_layer, value_layer = torch._C.grouped_matmul_bias([hidden_states, hidden_states, hidden_states],
[self.query.weight, self.key.weight, self.value.weight],
[self.query.bias, self.key.bias, self.value.bias])
@ -108,49 +106,41 @@ class SelfAttention(torch.nn.Module):
query_layer = self.query(hidden_states)
key_layer = self.key(hidden_states)
value_layer = self.value(hidden_states)
fallback = not hasattr(torch._C, 'fused_multi_head_attention_inference_v2')
new_query_layer_shape = query_layer.size()[:-1] + \
(self.num_attention_heads,
self.hidden_size_per_attention_head)
query_layer = query_layer.view(*new_query_layer_shape)
new_query_layer_shape = key_layer.size()[:-1] + \
(self.num_attention_heads,
self.hidden_size_per_attention_head)
key_layer = key_layer.view(*new_query_layer_shape)
if fallback:
if hasattr(torch._C, 'fused_codegeex_qkv_reshape'):
query_layer, key_layer, value_layer = torch._C.fused_codegeex_qkv_reshape(query_layer, key_layer, value_layer, self.num_attention_heads)
else:
new_query_layer_shape = query_layer.size()[:-1] + \
(self.num_attention_heads,
self.hidden_size_per_attention_head)
query_layer = query_layer.view(*new_query_layer_shape)
new_query_layer_shape = value_layer.size()[:-1] + \
(self.num_attention_heads,
self.hidden_size_per_attention_head)
value_layer = value_layer.view(*new_query_layer_shape)
new_query_layer_shape = key_layer.size()[:-1] + \
(self.num_attention_heads,
self.hidden_size_per_attention_head)
key_layer = key_layer.view(*new_query_layer_shape)
# ==================================
# Adjust key and value for inference
# ==================================
new_query_layer_shape = value_layer.size()[:-1] + \
(self.num_attention_heads,
self.hidden_size_per_attention_head)
value_layer = value_layer.view(*new_query_layer_shape)
if layer_past is not None:
past_key, past_value = layer_past
key_layer = torch.cat((past_key.type_as(key_layer),
key_layer), dim=0)
value_layer = torch.cat((past_value.type_as(value_layer),
value_layer), dim=0)
if get_key_value:
present = (key_layer, value_layer)
origin_query_layer = query_layer
origin_key_layer = key_layer
origin_value_layer = value_layer
# ==================================
# Adjust key and value for inference
# ==================================
if hasattr(torch._C, 'fused_multi_head_attention_inference'):
if layer_past is not None:
context_layer = torch._C.fused_multi_head_attention_inference(
origin_query_layer.view(query_layer.size()[1], query_layer.size()[0], -1), origin_key_layer.view(key_layer.size()[1], key_layer.size()[0], -1), origin_value_layer.view(value_layer.size()[1], value_layer.size()[0], -1), self.num_attention_heads, causal=False
).transpose(0, 1)
else:
context_layer = torch._C.fused_multi_head_attention_inference(
origin_query_layer.view(query_layer.size()[1], query_layer.size()[0], -1), origin_key_layer.view(key_layer.size()[1], key_layer.size()[0], -1), origin_value_layer.view(value_layer.size()[1], value_layer.size()[0], -1), self.num_attention_heads, causal=True
).transpose(0, 1)
else:
past_key, past_value = layer_past
key_layer = torch.cat((past_key.type_as(key_layer),
key_layer), dim=0)
value_layer = torch.cat((past_value.type_as(value_layer),
value_layer), dim=0)
if get_key_value:
present = (key_layer, value_layer)
# ===================================
# Raw attention scores. [b, np, sq, sk]
# ===================================
@ -167,7 +157,7 @@ class SelfAttention(torch.nn.Module):
# Raw attention scores. [b * np, sq, sk]
matmul_result = torch.matmul(query_layer.transpose(0, 1),
key_layer.permute(1, 2, 0)) / self.norm_factor
key_layer.transpose(0, 1).transpose(1, 2)) / self.norm_factor
# change view to [b, np, sq, sk]
attention_scores = matmul_result.view(*output_size)
@ -194,11 +184,22 @@ class SelfAttention(torch.nn.Module):
attention_mask = torch.clone(attention_mask)
attention_mask[:, :, context_length:, :] = True
attention_scores = attention_scores - attention_mask * 10000.0
if self.attention_softmax_in_fp32:
attention_probs = self.softmax(attention_scores.float()).half()
attention_mask = ~attention_mask
attention_mask = attention_mask.contiguous()
# attention scores and attention mask [b, np, sq, sk]
# attention_scores = attention_mask_func(attention_scores, attention_mask)
if hasattr(torch._C, 'fused_scale_mask_softmax'):
if self.attention_softmax_in_fp32:
attention_probs = torch._C.fused_scale_mask_softmax(attention_scores.float(), attention_mask, fill_value=-10000.0, scale=1.0).half()
else:
attention_probs = torch._C.fused_scale_mask_softmax(attention_scores, attention_mask, fill_value=-10000.0, scale=1.0)
else:
attention_probs = self.softmax(attention_scores)
attention_scores = attention_scores - attention_mask * 10000.0
if self.attention_softmax_in_fp32:
attention_probs = self.softmax(attention_scores.float()).half()
else:
attention_probs = self.softmax(attention_scores)
# =========================
# Context layer. [sq, b, hp]
@ -220,7 +221,7 @@ class SelfAttention(torch.nn.Module):
attention_probs = attention_probs.view(output_size[0] * output_size[1],
output_size[2], -1)
context_layer = torch.bmm(attention_probs, value_layer.transpose(0, 1))
context_layer = torch.bmm(attention_probs, value_layer.unsqueeze(0).transpose(1, 2).squeeze(0))
# change view [b, np, sq, hn]
context_layer = context_layer.view(*output_size)
@ -232,10 +233,40 @@ class SelfAttention(torch.nn.Module):
new_context_layer_shape = context_layer.size()[:-2] + \
(self.hidden_size,)
context_layer = context_layer.view(*new_context_layer_shape)
# =================
# Output. [sq, b, h]
# =================
else:
if layer_past is not None:
past_key, past_value = layer_past
key_layer, value_layer = torch._C.fused_attention_concat_past_key_value(
past_key=past_key,
past_key_layout="MB(HK)",
past_value=past_value,
past_value_layout="MB(HK)",
key=key_layer,
key_layout="MB(HK)",
value=value_layer,
value_layout="MB(HK)",
key_head_size=self.hidden_size_per_attention_head,
)
if get_key_value:
present = (key_layer, value_layer)
context_layer = torch._C.fused_multi_head_attention_inference_v2(
query=query_layer,
key=key_layer,
value=value_layer,
query_head_size=self.hidden_size_per_attention_head,
causal=True,
causal_diagonal_offset=key_layer.shape[0]-query_layer.shape[0],
query_layout="MB(HK)",
key_layout="MB(HK)",
value_layout="MB(HK)",
output_layout="MB(HK)",
)
# =================
# Output. [sq, b, h]
# =================
output = self.dense(context_layer)
@ -247,7 +278,6 @@ class SelfAttention(torch.nn.Module):
class TopQuerySelfAttention(torch.nn.Module):
"""Top query self-attention layer abstract class.
Self-attention layer takes input with size [b, s, h]
and returns output of the same size.
"""
@ -291,7 +321,7 @@ class TopQuerySelfAttention(torch.nn.Module):
):
# hidden_states: [sq, b, h]
if hasattr(torch._C, 'grouped_matmul_bias'):
if hasattr(torch._C, 'grouped_matmul_bias') and not isinstance(self.query, QuantizedLinear):
query_layer, key_layer, value_layer = torch._C.grouped_matmul_bias([query_hidden_state, hidden_states, hidden_states],
[self.query.weight, self.key.weight, self.value.weight],
[self.query.bias, self.key.bias, self.value.bias])
@ -299,49 +329,41 @@ class TopQuerySelfAttention(torch.nn.Module):
query_layer = self.query(query_hidden_state)
key_layer = self.key(hidden_states)
value_layer = self.value(hidden_states)
fallback = not hasattr(torch._C, 'fused_multi_head_attention_inference_v2')
new_query_layer_shape = query_layer.size()[:-1] + \
(self.num_attention_heads,
self.hidden_size_per_attention_head)
query_layer = query_layer.view(*new_query_layer_shape)
new_query_layer_shape = key_layer.size()[:-1] + \
(self.num_attention_heads,
self.hidden_size_per_attention_head)
key_layer = key_layer.view(*new_query_layer_shape)
new_query_layer_shape = value_layer.size()[:-1] + \
(self.num_attention_heads,
self.hidden_size_per_attention_head)
value_layer = value_layer.view(*new_query_layer_shape)
if fallback:
if hasattr(torch._C, 'fused_codegeex_qkv_reshape'):
query_layer, key_layer, value_layer = torch._C.fused_codegeex_qkv_reshape(query_layer, key_layer, value_layer, self.num_attention_heads)
else:
new_query_layer_shape = query_layer.size()[:-1] + \
(self.num_attention_heads,
self.hidden_size_per_attention_head)
query_layer = query_layer.view(*new_query_layer_shape)
# ==================================
# Adjust key and value for inference
# ==================================
new_query_layer_shape = key_layer.size()[:-1] + \
(self.num_attention_heads,
self.hidden_size_per_attention_head)
key_layer = key_layer.view(*new_query_layer_shape)
if layer_past is not None:
past_key, past_value = layer_past
key_layer = torch.cat((past_key.type_as(key_layer),
key_layer), dim=0)
value_layer = torch.cat((past_value.type_as(value_layer),
value_layer), dim=0)
if get_key_value:
present = (key_layer, value_layer)
new_query_layer_shape = value_layer.size()[:-1] + \
(self.num_attention_heads,
self.hidden_size_per_attention_head)
value_layer = value_layer.view(*new_query_layer_shape)
origin_query_layer = query_layer
origin_key_layer = key_layer
origin_value_layer = value_layer
# ==================================
# Adjust key and value for inference
# ==================================
if hasattr(torch._C, 'fused_multi_head_attention_inference'):
if layer_past is not None:
context_layer = torch._C.fused_multi_head_attention_inference(
origin_query_layer.view(query_layer.size()[1], query_layer.size()[0], -1), origin_key_layer.view(key_layer.size()[1], key_layer.size()[0], -1), origin_value_layer.view(value_layer.size()[1], value_layer.size()[0], -1), self.num_attention_heads, causal=False
).transpose(0, 1)
else:
context_layer = torch._C.fused_multi_head_attention_inference(
origin_query_layer.view(query_layer.size()[1], query_layer.size()[0], -1), origin_key_layer.view(key_layer.size()[1], key_layer.size()[0], -1), origin_value_layer.view(value_layer.size()[1], value_layer.size()[0], -1), self.num_attention_heads, causal=True
).transpose(0, 1)
else:
past_key, past_value = layer_past
key_layer = torch.cat((past_key.type_as(key_layer),
key_layer), dim=0)
value_layer = torch.cat((past_value.type_as(value_layer),
value_layer), dim=0)
if get_key_value:
present = (key_layer, value_layer)
# ===================================
# Raw attention scores. [b, np, sq, sk]
# ===================================
@ -386,18 +408,11 @@ class TopQuerySelfAttention(torch.nn.Module):
# attention scores and attention mask [b, np, sq, sk]
# attention_scores = attention_mask_func(attention_scores, attention_mask)
if hasattr(torch._C, 'fused_scale_mask_softmax'):
attention_mask = ~attention_mask
if self.attention_softmax_in_fp32:
attention_probs = torch._C.fused_scale_mask_softmax(attention_scores.float(), attention_mask, fill_value=-10000.0, scale=1.0).half()
else:
attention_probs = torch._C.fused_scale_mask_softmax(attention_scores, attention_mask, fill_value=-10000.0, scale=1.0)
attention_scores = attention_scores - attention_mask * 10000.0
if self.attention_softmax_in_fp32:
attention_probs = self.softmax(attention_scores.float()).half()
else:
attention_scores = attention_scores - attention_mask * 10000.0
if self.attention_softmax_in_fp32:
attention_probs = self.softmax(attention_scores.float()).half()
else:
attention_probs = self.softmax(attention_scores)
attention_probs = self.softmax(attention_scores)
# =========================
# Context layer. [sq, b, hp]
@ -433,9 +448,40 @@ class TopQuerySelfAttention(torch.nn.Module):
(self.hidden_size,)
context_layer = context_layer.view(*new_context_layer_shape)
# =================
# Output. [sq, b, h]
# =================
else:
if layer_past is not None:
past_key, past_value = layer_past
key_layer, value_layer = torch._C.fused_attention_concat_past_key_value(
past_key=past_key,
past_key_layout="MB(HK)",
past_value=past_value,
past_value_layout="MB(HK)",
key=key_layer,
key_layout="MB(HK)",
value=value_layer,
value_layout="MB(HK)",
key_head_size=self.hidden_size_per_attention_head,
)
if get_key_value:
present = (key_layer, value_layer)
if hasattr(torch._C, 'fused_multi_head_attention_inference_v2'):
context_layer = torch._C.fused_multi_head_attention_inference_v2(
query=query_layer,
key=key_layer,
value=value_layer,
query_head_size=self.hidden_size_per_attention_head,
causal=True,
causal_diagonal_offset=key_layer.shape[0]-query_layer.shape[0],
query_layout="MB(HK)",
key_layout="MB(HK)",
value_layout="MB(HK)",
output_layout="MB(HK)",
)
# =================
# Output. [sq, b, h]
# =================
output = self.dense(context_layer)
@ -447,7 +493,6 @@ class TopQuerySelfAttention(torch.nn.Module):
class TransformerLayer(torch.nn.Module):
"""A single transformer layer.
Transformore layer takes input with size [b, s, h] and returns an
output of the same size.
"""
@ -527,7 +572,6 @@ class TransformerLayer(torch.nn.Module):
class TopQueryLayer(torch.nn.Module):
"""A single top query layer.
Top query layer takes input with size [b, s, h] and returns an
output of the same size.
"""
@ -728,7 +772,6 @@ class Transformer(torch.nn.Module):
class Embedding(torch.nn.Module):
"""Language model embeddings.
Arguments:
hidden_size: hidden size
vocab_size: vocabulary size
@ -808,7 +851,6 @@ class Embedding(torch.nn.Module):
class QueryEmbedding(torch.nn.Module):
"""Language model embeddings.
Arguments:
hidden_size: hidden size
vocab_size: vocabulary size
@ -868,7 +910,6 @@ class QueryEmbedding(torch.nn.Module):
class TransformerLanguageModel(torch.nn.Module):
"""Transformer language model.
Arguments:
transformer_hparams: transformer hyperparameters
attention_mask_func: a function that takes `unmaksed-attention-scores`

@ -1 +1,3 @@
from .quantize import quantize
from .quantize import quantize
from .quantize_oneflow import quantize_oneflow
from .quantize_oneflow import QuantizedLinear

@ -0,0 +1,168 @@
import numpy as np
import oneflow as torch
from oneflow.nn.parameter import Parameter
def _pack_int8_to_int4(x):
np_x = x.numpy()
l = np_x[..., 0::2]
r = np_x[..., 1::2]
l = np.left_shift(l, 4)
if x.dtype is np.int8:
even = np.bitwise_and(r, np.int8(0xF))
packed = torch.tensor(np.bitwise_or(l, r), device=x.device)
return packed
def _quantize(num_bits, symmetric, x, group_dim, group_size, quant_type):
x_float = x.float()
x_reshaped = x_float.reshape(
x.shape[:group_dim]
+ (x.shape[group_dim] // group_size, group_size)
+ x.shape[group_dim + 1 :]
)
if symmetric:
signed_max = float(2 ** (num_bits - 1)) - 1
offset = signed_max if quant_type is torch.uint8 else 0.0
scale_float = (
x_reshaped.abs().max(dim=group_dim + 1, keepdim=True).values / signed_max
)
quantized = (
torch.round(x_reshaped / scale_float + offset)
.reshape(x.shape)
.to(quant_type)
)
if num_bits == 4:
quantized = _pack_int8_to_int4(quantized)
return (quantized, scale_float.squeeze(group_dim + 1).to(x.dtype), None)
else:
unsigned_max = float(2 ** num_bits) - 1
mn = x_reshaped.min(dim=group_dim + 1, keepdim=True).values
mx = x_reshaped.max(dim=group_dim + 1, keepdim=True).values
scale_float = (mx - mn) / unsigned_max
quantized = (
torch.round((x_reshaped - mn) / scale_float).reshape(x.shape).to(torch.uint8)
)
if num_bits == 4:
quantized = _pack_int8_to_int4(quantized)
return (
quantized,
scale_float.squeeze(group_dim + 1).to(x.dtype),
mn.squeeze(group_dim + 1).to(x.dtype),
)
class QuantizedLinear(torch.nn.Module):
def __init__(
self,
in_features: int,
out_features: int,
weight_bit_width: int,
weight: torch.Tensor = None,
bias: torch.Tensor = None,
*args,
**kwargs
):
super(QuantizedLinear, self).__init__()
self.in_features = in_features
self.out_features = out_features
self.weight_bit_width = weight_bit_width
self.symmetric = True
self.group_dim = 1
self.group_size = in_features
self.weight, self.weight_scale, self.weight_zero = _quantize(
self.weight_bit_width, self.symmetric, weight, self.group_dim, self.group_size, torch.int8
)
if bias is None:
self.register_parameter('bias', None)
else:
self.bias = bias
self.bias = self.bias.to(kwargs["device"])
self.weight = Parameter(self.weight.to(kwargs["device"]), requires_grad=False)
self.weight_scale = Parameter(self.weight_scale.to(kwargs["device"]), requires_grad=False)
if self.bias is not None:
self.bias = Parameter(self.bias.to(kwargs["device"]), requires_grad=False)
if self.weight_zero is not None:
self.weight_zero = Parameter(self.weight_zero.to(kwargs["device"]), requires_grad=False)
def forward(self, input_):
# Matrix multiply.
output = torch._C.fused_linear_with_groupwise_quantized_weight(input_,
w=self.weight,
w_scale=self.weight_scale,
w_zero=self.weight_zero,
b=self.bias if self.bias is not None else None,
num_bits=self.weight_bit_width,
symmetric=self.symmetric,
group_dim=self.group_dim,
group_size=self.group_size)
return output
def quantize_oneflow(model, weight_bit_width):
"""Replace fp16 linear with quantized linear"""
for i in range(len(model.language_model.transformer.layers) + 1):
if i == len(model.language_model.transformer.layers):
layer = model.language_model.transformer.topQueryLayer
else:
layer = model.language_model.transformer.layers[i]
layer.attention.query = QuantizedLinear(
in_features=layer.attention.query.in_features,
out_features=layer.attention.query.out_features,
weight_bit_width=weight_bit_width,
weight=layer.attention.query.weight.to(torch.cuda.current_device()),
bias=layer.attention.query.bias.to(torch.cuda.current_device()),
params_dtype=torch.half,
device=layer.attention.query.weight.device,
)
layer.attention.value = QuantizedLinear(
in_features=layer.attention.value.in_features,
out_features=layer.attention.value.out_features,
weight_bit_width=weight_bit_width,
weight=layer.attention.value.weight.to(torch.cuda.current_device()),
bias=layer.attention.value.bias.to(torch.cuda.current_device()),
params_dtype=torch.half,
device=layer.attention.value.weight.device,
)
layer.attention.key = QuantizedLinear(
in_features=layer.attention.key.in_features,
out_features=layer.attention.key.out_features,
weight_bit_width=weight_bit_width,
weight=layer.attention.key.weight.to(torch.cuda.current_device()),
bias=layer.attention.key.bias.to(torch.cuda.current_device()),
params_dtype=torch.half,
device=layer.attention.key.weight.device,
)
layer.attention.dense = QuantizedLinear(
in_features=layer.attention.dense.in_features,
out_features=layer.attention.dense.out_features,
weight_bit_width=weight_bit_width,
weight=layer.attention.dense.weight.to(torch.cuda.current_device()),
bias=layer.attention.dense.bias.to(torch.cuda.current_device()),
params_dtype=torch.half,
device=layer.attention.dense.weight.device,
)
layer.mlp.dense_h_to_4h = QuantizedLinear(
in_features=layer.mlp.dense_h_to_4h.in_features,
out_features=layer.mlp.dense_h_to_4h.out_features,
weight_bit_width=weight_bit_width,
weight=layer.mlp.dense_h_to_4h.weight.to(torch.cuda.current_device()),
bias=layer.mlp.dense_h_to_4h.bias.to(torch.cuda.current_device()),
params_dtype=torch.half,
device=layer.mlp.dense_h_to_4h.weight.device,
)
layer.mlp.dense_4h_to_h = QuantizedLinear(
in_features=layer.mlp.dense_4h_to_h.in_features,
out_features=layer.mlp.dense_4h_to_h.out_features,
weight_bit_width=weight_bit_width,
weight=layer.mlp.dense_4h_to_h.weight.to(torch.cuda.current_device()),
bias=layer.mlp.dense_4h_to_h.bias.to(torch.cuda.current_device()),
params_dtype=torch.half,
device=layer.mlp.dense_4h_to_h.weight.device,
)
return model

@ -0,0 +1,39 @@
# This script is used to test the inference of CodeGeeX.
GPU=$1
PROMPT_FILE=$2
SCRIPT_PATH=$(realpath "$0")
SCRIPT_DIR=$(dirname "$SCRIPT_PATH")
MAIN_DIR=$(dirname "$SCRIPT_DIR")
TOKENIZER_PATH="$MAIN_DIR/codegeex/tokenizer/"
# import model configuration
source "$MAIN_DIR/configs/codegeex_13b.sh"
# export CUDA settings
if [ -z "$GPU" ]; then
GPU=1
fi
export CUDA_HOME=/usr/local/cuda-11.1/
export CUDA_VISIBLE_DEVICES=$GPU
if [ -z "$PROMPT_FILE" ]; then
PROMPT_FILE=$MAIN_DIR/tests/test_prompt.txt
fi
# remove --greedy if using sampling
CMD="python $MAIN_DIR/tests/test_inference_oneflow.py \
--prompt-file $PROMPT_FILE \
--tokenizer-path $TOKENIZER_PATH \
--micro-batch-size 1 \
--out-seq-length 1024 \
--temperature 0.2 \
--top-p 0.95 \
--top-k 0 \
--quantize \
$MODEL_ARGS"
echo "$CMD"
eval "$CMD"

@ -10,8 +10,9 @@ import numpy as np
from codegeex.oneflow.inference import get_token_stream
from codegeex.oneflow import CodeGeeXModel
from codegeex.tokenizer import CodeGeeXTokenizer
from codegeex.quantization import quantize
from codegeex.quantization import quantize_oneflow
os.environ["ONEFLOW_KERNEL_ENABLE_FUSED_LINEAR"] = "1"
os.environ["ONEFLOW_LINEAR_EMBEDDING_SKIP_INIT"] = "1"
def model_provider(args):
"""Build the model."""
@ -135,7 +136,7 @@ def main():
model.eval()
model.half()
if args.quantize:
model = quantize(model, weight_bit_width=8, backend="torch")
model = quantize_oneflow(model, weight_bit_width=8)
model.cuda()
torch.cuda.synchronize()
with open(args.prompt_file, "r") as f:

Loading…
Cancel
Save