mirror of https://github.com/THUDM/CodeGeeX.git
You cannot select more than 25 topics
Topics must start with a letter or number, can include dashes ('-') and can be up to 35 characters long.
937 lines
45 KiB
Python
937 lines
45 KiB
Python
# Copyright 2019 Huawei Technologies Co., Ltd
|
|
#
|
|
# Licensed under the Apache License, Version 2.0 (the "License");
|
|
# you may not use this file except in compliance with the License.
|
|
# You may obtain a copy of the License at
|
|
#
|
|
# http://www.apache.org/licenses/LICENSE-2.0
|
|
#
|
|
# Unless required by applicable law or agreed to in writing, software
|
|
# distributed under the License is distributed on an "AS IS" BASIS,
|
|
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
|
|
# See the License for the specific language governing permissions and
|
|
# limitations under the License.
|
|
# ============================================================================
|
|
"""
|
|
layer_norm
|
|
"""
|
|
from copy import deepcopy
|
|
|
|
import impl.dynamic as dyn
|
|
import impl.layer_norm_unify as layer_norm_unify
|
|
import te.lang.cce as tbe
|
|
from impl.common_util import constant
|
|
from impl.layer_norm_tik import if_tik_support
|
|
from impl.layer_norm_tik import layer_normalize
|
|
from impl.util import util_select_op_base
|
|
from impl.util.platform_adapter import tbe_context
|
|
from impl.util.util_select_op_base import SplitInput
|
|
from impl.util.util_select_op_base import SplitOutput
|
|
from impl.util.util_select_op_base import get_op_cal_info
|
|
from tbe.common.platform import SOC_VERSION
|
|
from tbe.common.platform.platform_info import get_soc_spec
|
|
from tbe.dsl.compute.layer_norm_cube import LayerNormCube
|
|
from te import platform as tbe_platform
|
|
from te import tvm
|
|
from te.utils import para_check
|
|
from te.utils import shape_util
|
|
from te.utils.error_manager import error_manager_vector
|
|
|
|
|
|
# 'pylint: disable = unused-argument
|
|
# 'pylint: disable=too-many-arguments,too-many-locals
|
|
def get_op_support_info(input_x, input_gamma, input_beta,
|
|
output_y, output_mean, output_variance,
|
|
begin_norm_axis, begin_params_axis,
|
|
epsilon=1e-12, kernel_name="layer_norm",
|
|
impl_mode="high_performance"):
|
|
"""
|
|
get_op_support_info
|
|
"""
|
|
format_x = input_x.get("format").upper()
|
|
shape_x = input_x.get("shape")
|
|
ori_shape_x = input_x.get("ori_shape")
|
|
begin_norm_axis = shape_util.axis_check(len(shape_x), begin_norm_axis)
|
|
begin_params_axis = shape_util.axis_check(len(shape_x), begin_params_axis)
|
|
axis_split_matrix = []
|
|
|
|
if format_x in ("ND", "NCHW", "NHWC", "NC1HWC0"):
|
|
if begin_params_axis == 0:
|
|
for i in range(begin_norm_axis):
|
|
split_0 = [SplitInput([0, [i], [-1], [-1]], [1, [i], [-1], [-1]], [2, [i], [-1], [-1]]),
|
|
SplitOutput([0, [i]], [1, [i]], [2, [i]])]
|
|
axis_split_matrix.append(split_0)
|
|
else:
|
|
if begin_norm_axis <= begin_params_axis:
|
|
for i in range(begin_norm_axis):
|
|
split_0 = [SplitInput([0, [i], [-1], [-1]]),
|
|
SplitOutput([0, [i]], [1, [i]], [2, [i]])]
|
|
axis_split_matrix.append(split_0)
|
|
else:
|
|
for i in range(begin_params_axis):
|
|
split_0 = [SplitInput([0, [i], [-1], [-1]]),
|
|
SplitOutput([0, [i]], [1, [i]], [2, [i]])]
|
|
axis_split_matrix.append(split_0)
|
|
|
|
elif format_x == "FRACTAL_NZ":
|
|
index_list = tuple(index for index, _ in enumerate(ori_shape_x))
|
|
start_axis = min(begin_norm_axis, begin_params_axis)
|
|
|
|
no_split_axis = index_list[start_axis:]
|
|
no_split_axis = to_frac_z_axis(ori_shape_x, no_split_axis)
|
|
for i in range(len(shape_x)):
|
|
if i not in no_split_axis:
|
|
split_0 = [SplitInput([0, [i], [-1], [-1]]),
|
|
SplitOutput([0, [i]], [1, [i]], [2, [i]])]
|
|
axis_split_matrix.append(split_0)
|
|
|
|
else:
|
|
axis_split_matrix = None
|
|
axis_reduce_list = None
|
|
op_cal_info_in_json = get_op_cal_info(axis_split_matrix, axis_reduce_list, 0, 0)
|
|
return op_cal_info_in_json
|
|
|
|
|
|
# 'pylint: disable=locally-disabled,too-many-arguments,unused-argument
|
|
# 'pylint: disable=too-many-locals,too-many-statements,too-many-branches
|
|
def _division_sixteen(shape, begin_norm_axis):
|
|
"""
|
|
division_sixteen
|
|
"""
|
|
if len(shape) < 2:
|
|
if shape[-1] == 0:
|
|
error_detail = "value of shape_x is illegal"
|
|
error_manager_vector.raise_err_input_shape_invalid("layer_norm", "input_x",
|
|
error_detail)
|
|
return False
|
|
|
|
if shape[-1] == 0 or shape[-2] == 0:
|
|
error_detail = "value of shape_x is illegal"
|
|
error_manager_vector.raise_err_input_shape_invalid("layer_norm", "input_x",
|
|
error_detail)
|
|
|
|
is_reduce_last = begin_norm_axis in (-1, len(shape) - 1)
|
|
# if shape[-2] % constant.C0_SIZE == 0:
|
|
# if shape[-1] % constant.C0_SIZE == 0 or (shape[-1] % constant.C0_SIZE != 0 and is_reduce_last):
|
|
# return True
|
|
return False
|
|
|
|
|
|
# 'pylint: disable=too-many-statements,too-many-branches
|
|
def op_select_format(input_x, input_gamma, input_beta,
|
|
output_y, output_mean, output_variance,
|
|
begin_norm_axis, begin_params_axis,
|
|
kernel_name="layer_norm"):
|
|
"""
|
|
select format dynamically
|
|
"""
|
|
shape_x = input_x.get("ori_shape")
|
|
shape_x = shape_util.scalar2tensor_one(shape_x)
|
|
shape_gamma = input_gamma.get("ori_shape")
|
|
shape_gamma = shape_util.scalar2tensor_one(shape_gamma)
|
|
|
|
if begin_params_axis == 0:
|
|
if len(shape_gamma) >= 2 or (not _division_sixteen(shape_x, begin_norm_axis)):
|
|
input0 = util_select_op_base.gen_param(classify="input0", name="x",
|
|
datatype="float16,float16,float16,float16,"
|
|
"float,float,float,float",
|
|
format="NCHW,NC1HWC0,NHWC,ND,NCHW,NC1HWC0,NHWC,ND")
|
|
|
|
input1 = util_select_op_base.gen_param(classify="input1", name="gamma",
|
|
datatype="float16,float16,float16,float16,float,"
|
|
"float,float,float",
|
|
format="NCHW,NC1HWC0,NHWC,ND,NCHW,NC1HWC0,NHWC,ND")
|
|
|
|
input2 = util_select_op_base.gen_param(classify="input2", name="beta",
|
|
datatype="float16,float16,float16,float16,float,"
|
|
"float,float,float",
|
|
format="NCHW,NC1HWC0,NHWC,ND,NCHW,NC1HWC0,NHWC,ND")
|
|
|
|
output0 = util_select_op_base.gen_param(classify="output0", name="y",
|
|
datatype="float16,float16,float16,float16,float,"
|
|
"float,float,float",
|
|
format="NCHW,NC1HWC0,NHWC,ND,NCHW,NC1HWC0,NHWC,ND")
|
|
|
|
output1 = util_select_op_base.gen_param(classify="output1", name="mean",
|
|
datatype="float16,float16,float16,float16,float,"
|
|
"float,float,float",
|
|
format="NCHW,NC1HWC0,NHWC,ND,NCHW,NC1HWC0,NHWC,ND")
|
|
|
|
output2 = util_select_op_base.gen_param(classify="output2", name="variance",
|
|
datatype="float16,float16,float16,float16,float,"
|
|
"float,float,float",
|
|
format="NCHW,NC1HWC0,NHWC,ND,NCHW,NC1HWC0,NHWC,ND")
|
|
else:
|
|
input0 = util_select_op_base.gen_param(classify="input0", name="x",
|
|
datatype="float16,float,float16,float16,float16,"
|
|
"float16,float,float,float,float",
|
|
format="FRACTAL_NZ,FRACTAL_NZ,NCHW,NC1HWC0,NHWC,"
|
|
"ND,NCHW,NC1HWC0,NHWC,ND")
|
|
|
|
input1 = util_select_op_base.gen_param(classify="input1", name="gamma",
|
|
datatype="float16,float,float16,float16,float16,"
|
|
"float16,float,float,float,float",
|
|
format="ND,ND,NCHW,NC1HWC0,NHWC,ND,NCHW,NC1HWC0,"
|
|
"NHWC,ND")
|
|
|
|
input2 = util_select_op_base.gen_param(classify="input2", name="beta",
|
|
datatype="float16,float,float16,float16,float16,"
|
|
"float16,float,float,float,float",
|
|
format="ND,ND,NCHW,NC1HWC0,NHWC,ND,NCHW,NC1HWC0,"
|
|
"NHWC,ND")
|
|
|
|
output0 = util_select_op_base.gen_param(classify="output0", name="y",
|
|
datatype="float16,float,float16,float16,float16,"
|
|
"float16,float,float,float,float",
|
|
format="FRACTAL_NZ,FRACTAL_NZ,NCHW,NC1HWC0,NHWC,ND,"
|
|
"NCHW,NC1HWC0,NHWC,ND")
|
|
|
|
output1 = util_select_op_base.gen_param(classify="output1", name="mean",
|
|
datatype="float16,float,float16,float16,float16,"
|
|
"float16,float,float,float,float",
|
|
format="ND,ND,NCHW,NC1HWC0,NHWC,ND,NCHW,NC1HWC0,"
|
|
"NHWC,ND")
|
|
|
|
output2 = util_select_op_base.gen_param(classify="output2", name="variance",
|
|
datatype="float16,float,float16,float16,float16,"
|
|
"float16,float,float,float,float",
|
|
format="ND,ND,NCHW,NC1HWC0,NHWC,ND,NCHW,NC1HWC0,"
|
|
"NHWC,ND")
|
|
else:
|
|
if len(shape_gamma) >= 2 or (not _division_sixteen(shape_x, begin_norm_axis)):
|
|
input0 = util_select_op_base.gen_param(classify="input0", name="x",
|
|
datatype="float16,float16,float16,"
|
|
"float,float,float",
|
|
format="NCHW,NHWC,ND,NCHW,NHWC,ND")
|
|
|
|
input1 = util_select_op_base.gen_param(classify="input1", name="gamma",
|
|
datatype="float16,float16,float16,"
|
|
"float,float,float",
|
|
format="NCHW,NHWC,ND,NCHW,NHWC,ND")
|
|
|
|
input2 = util_select_op_base.gen_param(classify="input2", name="beta",
|
|
datatype="float16,float16,float16,"
|
|
"float,float,float",
|
|
format="NCHW,NHWC,ND,NCHW,NHWC,ND")
|
|
|
|
output0 = util_select_op_base.gen_param(classify="output0", name="y",
|
|
datatype="float16,float16,float16,"
|
|
"float,float,float",
|
|
format="NCHW,NHWC,ND,NCHW,NHWC,ND")
|
|
|
|
output1 = util_select_op_base.gen_param(classify="output1", name="mean",
|
|
datatype="float16,float16,float16,"
|
|
"float,float,float",
|
|
format="NCHW,NHWC,ND,NCHW,NHWC,ND")
|
|
|
|
output2 = util_select_op_base.gen_param(classify="output2", name="variance",
|
|
datatype="float16,float16,float16,"
|
|
"float,float,float",
|
|
format="NCHW,NHWC,ND,NCHW,NHWC,ND")
|
|
else:
|
|
input0 = util_select_op_base.gen_param(classify="input0", name="x",
|
|
datatype="float16,float,float16,float16,"
|
|
"float16,float,float,float",
|
|
format="FRACTAL_NZ,FRACTAL_NZ,NCHW,NHWC,"
|
|
"ND,NCHW,NHWC,ND")
|
|
|
|
input1 = util_select_op_base.gen_param(classify="input1", name="gamma",
|
|
datatype="float16,float,float16,float16,"
|
|
"float16,float,float,float",
|
|
format="ND,ND,NCHW,NHWC,ND,NCHW,"
|
|
"NHWC,ND")
|
|
|
|
input2 = util_select_op_base.gen_param(classify="input2", name="beta",
|
|
datatype="float16,float,float16,float16,"
|
|
"float16,float,float,float",
|
|
format="ND,ND,NCHW,NHWC,ND,NCHW,"
|
|
"NHWC,ND")
|
|
|
|
output0 = util_select_op_base.gen_param(classify="output0", name="y",
|
|
datatype="float16,float,float16,float16,"
|
|
"float16,float,float,float",
|
|
format="FRACTAL_NZ,FRACTAL_NZ,NCHW,NHWC,ND,"
|
|
"NCHW,NHWC,ND")
|
|
|
|
output1 = util_select_op_base.gen_param(classify="output1", name="mean",
|
|
datatype="float16,float,float16,float16,"
|
|
"float16,float,float,float",
|
|
format="ND,ND,NCHW,NHWC,ND,NCHW,"
|
|
"NHWC,ND")
|
|
|
|
output2 = util_select_op_base.gen_param(classify="output2", name="variance",
|
|
datatype="float16,float,float16,float16,"
|
|
"float16,float,float,float",
|
|
format="ND,ND,NCHW,NHWC,ND,NCHW,"
|
|
"NHWC,ND")
|
|
|
|
param_list = [input0, input1, input2, output0, output1, output2]
|
|
param_dynamic_in_json = util_select_op_base.get_dynamic_param_in_json(param_list)
|
|
return param_dynamic_in_json
|
|
|
|
|
|
def to_frac_z_axis(ori_shape, ori_axis):
|
|
"""
|
|
judge the format is fractal NZ
|
|
|
|
Parameters
|
|
----------
|
|
ori_shape: list or tuple
|
|
original shape of input
|
|
ori_axis: list or tuple
|
|
original axis of original shape to operate
|
|
|
|
Returns
|
|
-------
|
|
output: list
|
|
axis of the fractal Nz shape
|
|
"""
|
|
|
|
frac_z_axis = list(ori_axis)
|
|
shape_len = len(ori_shape)
|
|
axis_count = len(frac_z_axis)
|
|
axis_negative_1 = shape_len - 1
|
|
axis_negative_2 = shape_len - 2
|
|
for i in range(axis_count):
|
|
axis_index = (frac_z_axis[i] + shape_len) % shape_len
|
|
if axis_index == axis_negative_1:
|
|
if frac_z_axis[i] > shape_len - 2:
|
|
frac_z_axis[i] = axis_index - 1
|
|
frac_z_axis.append(axis_index + 1)
|
|
else:
|
|
frac_z_axis[i] = axis_index - 1
|
|
frac_z_axis.append(axis_index + 2)
|
|
elif axis_index == axis_negative_2:
|
|
frac_z_axis[i] = axis_index + 1
|
|
frac_z_axis.append(axis_index + 2)
|
|
else:
|
|
frac_z_axis[i] = axis_index
|
|
return frac_z_axis
|
|
|
|
|
|
def _broadcast_nz(tensor, shape):
|
|
"""
|
|
broadcast_nz
|
|
"""
|
|
broadcast_axes = []
|
|
src_shape = shape_util.shape_to_list(tensor.shape)
|
|
for i, _ in enumerate(shape):
|
|
if shape[i] != src_shape[i]:
|
|
broadcast_axes.append(i)
|
|
if len(broadcast_axes) == 2 and \
|
|
broadcast_axes[1] - broadcast_axes[0] != 1 and \
|
|
broadcast_axes[1] + 1 == len(shape):
|
|
temp_shape = src_shape[:-1] + [shape[-1]]
|
|
tensor = tbe.broadcast(tensor, temp_shape)
|
|
tensor = tbe.broadcast(tensor, shape)
|
|
return tensor
|
|
|
|
|
|
def _check_vector_to_cube(dtype, ori_shape_x, shape_x, begin_norm_axis, impl_mode):
|
|
"""
|
|
judge case using cube to handle reducesum
|
|
only supported follow case in Ascend910 and Ascend710:
|
|
ori_shape: ((batch), m, 1024(768)), "shape": ((batch), 64(48), m//16, 16, 16), "dtype": fp16
|
|
"""
|
|
|
|
def _check_shape_and_dtype():
|
|
if dtype != "float16":
|
|
return False
|
|
if len(ori_shape_x) not in (2, 3) or ori_shape_x[-1] not in (1024, 768, 96, 384, 192, 128, 512, 256):
|
|
return False
|
|
if len(shape_x) not in (4, 5) or shape_x[-4] not in (64, 48, 6, 12, 24, 16, 32):
|
|
return False
|
|
if "Ascend910" not in get_soc_spec(SOC_VERSION) and "Ascend710" not in get_soc_spec(SOC_VERSION):
|
|
return False
|
|
if begin_norm_axis != (len(ori_shape_x) - 1):
|
|
return False
|
|
return True
|
|
|
|
return impl_mode == "high_performance" and _check_shape_and_dtype()
|
|
|
|
|
|
# 'pylint: disable=too-many-locals,too-many-statements,too-many-branches
|
|
def nz_non_aligned(input_x, input_gamma, input_beta,
|
|
output_y, output_mean, output_variance,
|
|
begin_norm_axis, begin_params_axis,
|
|
ori_shape, epsilon, kernel_name="layer_norm",
|
|
impl_mode="high_performance"):
|
|
"""
|
|
DSL description of the layernorm operator's mathematical calculation process for non_aligned scene
|
|
"""
|
|
shape_x = shape_util.shape_to_list(input_x.shape)
|
|
dtype = input_x.dtype.lower()
|
|
cast_dtype = "float16"
|
|
if dtype == "float16" and \
|
|
((tbe_platform.cce_conf.api_check_support
|
|
("te.lang.cce.vexp", "float32") and
|
|
impl_mode == "high_performance") or
|
|
impl_mode == "high_precision"):
|
|
cast_dtype = "float32"
|
|
input_x = tbe.cast_to(input_x, "float32")
|
|
input_gamma = tbe.cast_to(input_gamma, "float32")
|
|
input_beta = tbe.cast_to(input_beta, "float32")
|
|
else:
|
|
input_x = tbe.vadds(input_x, 0)
|
|
|
|
# Calculate the scaling ratio of the average
|
|
reduce_elts = 1.0
|
|
index_list = tuple(index for index, _ in enumerate(ori_shape))
|
|
reduce_axis = index_list[begin_norm_axis:]
|
|
for i in reduce_axis:
|
|
reduce_elts *= ori_shape[i]
|
|
reduce_axis = to_frac_z_axis(ori_shape, reduce_axis)
|
|
mean_cof = reduce_elts ** (-1)
|
|
|
|
# DSL description of the mean calculation process
|
|
with tvm.tag_scope("tail_block_pretreatment"):
|
|
lambda_func = lambda *indice: tvm.const(0, input_x.dtype)
|
|
temp = tvm.compute(input_x.shape, lambda_func, name="tail_block_pretreatment")
|
|
|
|
input_x = tbe.vadd(input_x, temp)
|
|
mean_muls = tbe.vmuls(input_x, mean_cof)
|
|
mean = tbe.sum(mean_muls, axis=reduce_axis, keepdims=True)
|
|
|
|
mean_square = tbe.vmul(mean, mean)
|
|
x_square = tbe.vmul(input_x, input_x)
|
|
x_square = tbe.vmuls(x_square, mean_cof)
|
|
x_square_mean = tbe.sum(x_square, axis=reduce_axis, keepdims=True)
|
|
variance = tbe.vsub(x_square_mean, mean_square)
|
|
|
|
# DSL description of the normalize calculation process
|
|
mean_normalize_broadcast = _broadcast_nz(mean, shape_x)
|
|
normalize_sub = tbe.vsub(input_x, mean_normalize_broadcast)
|
|
epsilon = tvm.const(epsilon, dtype=cast_dtype)
|
|
|
|
normalize_add = tbe.vadds(variance, epsilon)
|
|
normalize_log = tbe.vlog(normalize_add)
|
|
normalize_log_mul = \
|
|
tbe.vmuls(normalize_log, tvm.const(-0.5, dtype=cast_dtype))
|
|
normalize_exp = tbe.vexp(normalize_log_mul)
|
|
variance_normalize_broadcast = _broadcast_nz(normalize_exp, shape_x)
|
|
normalize_mul = tbe.vmul(normalize_sub, variance_normalize_broadcast)
|
|
|
|
# DSL description of the scale and translate calculation process
|
|
gamma_broadcast = _broadcast_nz(input_gamma, shape_x)
|
|
beta_broadcast = _broadcast_nz(input_beta, shape_x)
|
|
scale_mul = tbe.vmul(gamma_broadcast, normalize_mul)
|
|
res = tbe.vadd(scale_mul, beta_broadcast)
|
|
|
|
if dtype == "float16" and \
|
|
((tbe_platform.cce_conf.api_check_support
|
|
("te.lang.cce.vexp", "float32") and
|
|
impl_mode == "high_performance") or
|
|
impl_mode == "high_precision"):
|
|
mean = tbe.cast_to(mean, "float16")
|
|
variance = tbe.cast_to(variance, "float16")
|
|
res = tbe.cast_to(res, "float16")
|
|
|
|
return mean, variance, res
|
|
|
|
|
|
# 'pylint: disable=too-many-statements,too-many-branches
|
|
def layer_norm_compute_nz(input_x, input_gamma, input_beta,
|
|
output_y, output_mean, output_variance,
|
|
begin_norm_axis, begin_params_axis,
|
|
ori_shape, epsilon, kernel_name="layer_norm",
|
|
impl_mode="high_performance"):
|
|
"""
|
|
DSL description of the layernorm operator's mathematical calculation process
|
|
|
|
Parameters
|
|
----------
|
|
input_x: TVM tensor
|
|
the placeholder of x input data
|
|
input_gamma: TVM tensor
|
|
the placeholder of gamma input data
|
|
input_beta: TVM tensor
|
|
the placeholder of beta input data
|
|
output_data: dict
|
|
shape and dtype of output
|
|
begin_norm_axis: int
|
|
The first normalization dimension: normalization will be
|
|
performed along dimensions `begin_norm_axis : rank(inputs)`
|
|
begin_params_axis: int
|
|
The first parameter (beta, gamma) dimension: scale
|
|
and centering parameters will have dimensions
|
|
`begin_params_axis : rank(inputs)` and will be broadcast with the
|
|
normalized inputs accordingly.
|
|
epsilon: float,
|
|
Minimum positive number greater than 0
|
|
kernel_name: str
|
|
cce kernel name, default value is "cce_layernorm"
|
|
|
|
Returns
|
|
-------
|
|
res_tuple: tuple
|
|
(mean, variance, result)
|
|
"""
|
|
shape_x = shape_util.shape_to_list(input_x.shape)
|
|
dtype = input_x.dtype.lower()
|
|
cast_dtype, cast_fp16_dtype = "float16", "float16"
|
|
cast_dtype_precision = dtype
|
|
if dtype == "float16" and \
|
|
((tbe_platform.cce_conf.api_check_support
|
|
("te.lang.cce.vexp", "float32") and
|
|
impl_mode == "high_performance") or
|
|
impl_mode == "high_precision"):
|
|
cast_dtype = "float32"
|
|
cast_dtype_precision = "float32"
|
|
input_x = tbe.cast_to(input_x, "float32")
|
|
input_gamma = tbe.cast_to(input_gamma, "float32")
|
|
input_beta = tbe.cast_to(input_beta, "float32")
|
|
|
|
# Calculate the scaling ratio of the average
|
|
reduce_elts = 1.0
|
|
index_list = tuple(index for index, _ in enumerate(ori_shape))
|
|
reduce_axis = index_list[begin_norm_axis:]
|
|
for i in reduce_axis:
|
|
reduce_elts *= ori_shape[i]
|
|
reduce_axis = to_frac_z_axis(ori_shape, reduce_axis)
|
|
mean_cof = reduce_elts ** (-1)
|
|
if impl_mode != "keep_fp16":
|
|
# DSL description of the mean calculation process
|
|
mean_muls = tbe.vmuls(input_x, mean_cof)
|
|
mean = tbe.sum(mean_muls, axis=reduce_axis, keepdims=True)
|
|
# DSL description of the variance calculation process
|
|
mean_variance_broadcast = _broadcast_nz(mean, shape_x)
|
|
variance_sub = tbe.vsub(input_x, mean_variance_broadcast)
|
|
variance_mul = tbe.vmul(variance_sub, variance_sub)
|
|
variance_muls = tbe.vmuls(variance_mul, mean_cof)
|
|
variance = tbe.sum(variance_muls, axis=reduce_axis, keepdims=True)
|
|
else:
|
|
# DSL description of the mean calculation process
|
|
x_sum = tbe.sum(input_x, axis=reduce_axis, keepdims=True)
|
|
mean = tbe.vmuls(x_sum, mean_cof)
|
|
# DSL description of the variance calculation process
|
|
mean_variance_broadcast = _broadcast_nz(mean, shape_x)
|
|
variance_sub = tbe.vsub(input_x, mean_variance_broadcast)
|
|
variance_mul = tbe.vmul(variance_sub, variance_sub)
|
|
variance_sum = tbe.sum(variance_mul, axis=reduce_axis, keepdims=True)
|
|
variance = tbe.vmuls(variance_sum, mean_cof)
|
|
|
|
# DSL description of the normalize calculation process
|
|
if impl_mode == "high_performance":
|
|
mean_normalize_broadcast = _broadcast_nz(mean, shape_x)
|
|
normalize_sub = tbe.vsub(input_x, mean_normalize_broadcast)
|
|
epsilon = tvm.const(epsilon, dtype=cast_dtype)
|
|
variance_normalize_broadcast = _broadcast_nz(variance, shape_x)
|
|
normalize_add = tbe.vadds(variance_normalize_broadcast, epsilon)
|
|
normalize_log = tbe.vlog(normalize_add)
|
|
normalize_log_mul = \
|
|
tbe.vmuls(normalize_log, tvm.const(-0.5, dtype=cast_dtype))
|
|
normalize_exp = tbe.vexp(normalize_log_mul)
|
|
normalize_mul = tbe.vmul(normalize_sub, normalize_exp)
|
|
elif impl_mode == "high_precision":
|
|
tesor_one = tbe.broadcast(tvm.const
|
|
(1, cast_dtype_precision),
|
|
shape_x)
|
|
mean_normalize_broadcast = _broadcast_nz(mean, shape_x)
|
|
normalize_sub = tbe.vsub(input_x, mean_normalize_broadcast)
|
|
variance_normalize_broadcast = _broadcast_nz(variance, shape_x)
|
|
epsilon = tvm.const(epsilon, dtype=cast_dtype_precision)
|
|
normalize_add = tbe.vadds(variance_normalize_broadcast, epsilon)
|
|
normalize_sqrt = tbe.vsqrt(normalize_add, 0)
|
|
normalize_rsqrt = tbe.vdiv(tesor_one, normalize_sqrt)
|
|
normalize_mul = tbe.vmul(normalize_sub, normalize_rsqrt)
|
|
else:
|
|
epsilon = tvm.const(epsilon, dtype=cast_fp16_dtype)
|
|
normalize_add = tbe.vadds(variance, epsilon)
|
|
normalize_log = tbe.vlog(normalize_add)
|
|
normalize_log_mul = \
|
|
tbe.vmuls(normalize_log, tvm.const(-0.5, dtype=cast_fp16_dtype))
|
|
normalize_exp = tbe.vexp(normalize_log_mul)
|
|
variance_normalize_broadcast = _broadcast_nz(normalize_exp, shape_x)
|
|
normalize_mul = tbe.vmul(variance_sub, variance_normalize_broadcast)
|
|
|
|
# DSL description of the scale and translate calculation process
|
|
if begin_params_axis == 0:
|
|
scale_mul = tbe.vmul(input_gamma, normalize_mul)
|
|
res = tbe.vadd(scale_mul, input_beta)
|
|
else:
|
|
gamma_broadcast = _broadcast_nz(input_gamma, shape_x)
|
|
beta_broadcast = _broadcast_nz(input_beta, shape_x)
|
|
scale_mul = tbe.vmul(gamma_broadcast, normalize_mul)
|
|
res = tbe.vadd(scale_mul, beta_broadcast)
|
|
|
|
if dtype == "float16" and \
|
|
((tbe_platform.cce_conf.api_check_support
|
|
("te.lang.cce.vexp", "float32") and
|
|
impl_mode == "high_performance") or
|
|
impl_mode == "high_precision"):
|
|
mean = tbe.cast_to(mean, "float16")
|
|
variance = tbe.cast_to(variance, "float16")
|
|
res = tbe.cast_to(res, "float16")
|
|
|
|
return mean, variance, res
|
|
|
|
|
|
# 'pylint: disable=too-many-statements,too-many-branches
|
|
@tbe_platform.fusion_manager.fusion_manager.register("layer_norm")
|
|
def layer_norm_compute(input_x, input_gamma, input_beta,
|
|
output_y, output_mean, output_variance,
|
|
begin_norm_axis, begin_params_axis,
|
|
epsilon, kernel_name="layer_norm",
|
|
impl_mode="high_performance"):
|
|
"""
|
|
DSL description of the layernorm operator's mathematical calculation process
|
|
|
|
Parameters
|
|
----------
|
|
input_x: TVM tensor
|
|
the placeholder of x input data
|
|
input_gamma: TVM tensor
|
|
the placeholder of gamma input data
|
|
input_beta: TVM tensor
|
|
the placeholder of beta input data
|
|
output_data: dict
|
|
shape and dtype of output
|
|
begin_norm_axis: int
|
|
The first normalization dimension: normalization will be
|
|
performed along dimensions `begin_norm_axis : rank(inputs)`
|
|
begin_params_axis: int
|
|
The first parameter (beta, gamma) dimension: scale
|
|
and centering parameters will have dimensions
|
|
`begin_params_axis : rank(inputs)` and will be broadcast with the
|
|
normalized inputs accordingly.
|
|
epsilon: float,
|
|
Minimum positive number greater than 0
|
|
kernel_name: str
|
|
cce kernel name, default value is "cce_layernorm"
|
|
|
|
Returns
|
|
-------
|
|
res_tuple: tuple
|
|
(mean, variance, result)
|
|
"""
|
|
shape_x = shape_util.shape_to_list(input_x.shape)
|
|
dtype = input_x.dtype.lower()
|
|
cast_dtype, cast_fp16_dtype = "float16", "float16"
|
|
cast_dtype_precision = dtype
|
|
if dtype == "float16" and \
|
|
((tbe_platform.cce_conf.api_check_support
|
|
("te.lang.cce.vexp", "float32") and
|
|
impl_mode == "high_performance") or
|
|
impl_mode == "high_precision"):
|
|
cast_dtype = "float32"
|
|
cast_dtype_precision = "float32"
|
|
input_x = tbe.cast_to(input_x, "float32")
|
|
input_gamma = tbe.cast_to(input_gamma, "float32")
|
|
input_beta = tbe.cast_to(input_beta, "float32")
|
|
|
|
# Calculate the scaling ratio of the average
|
|
index_list = tuple(index for index, _ in enumerate(shape_x))
|
|
reduce_axis = index_list[begin_norm_axis:]
|
|
|
|
reduce_elts = 1.0
|
|
for i in reduce_axis:
|
|
reduce_elts *= shape_x[i]
|
|
mean_cof = reduce_elts ** (-1)
|
|
|
|
if impl_mode != "keep_fp16":
|
|
# DSL description of the mean calculation process
|
|
mean_muls = tbe.vmuls(input_x, mean_cof)
|
|
mean = tbe.sum(mean_muls, axis=reduce_axis, keepdims=True)
|
|
|
|
# DSL description of the variance calculation process
|
|
mean_variance_broadcast = tbe.broadcast(mean, shape_x)
|
|
variance_sub = tbe.vsub(input_x, mean_variance_broadcast)
|
|
variance_mul = tbe.vmul(variance_sub, variance_sub)
|
|
variance_muls = tbe.vmuls(variance_mul, mean_cof)
|
|
variance = tbe.sum(variance_muls, axis=reduce_axis, keepdims=True)
|
|
else:
|
|
# DSL description of the mean calculation process
|
|
x_sum = tbe.sum(input_x, axis=reduce_axis, keepdims=True)
|
|
mean = tbe.vmuls(x_sum, mean_cof)
|
|
# DSL description of the variance calculation process
|
|
mean_variance_broadcast = tbe.broadcast(mean, shape_x)
|
|
variance_sub = tbe.vsub(input_x, mean_variance_broadcast)
|
|
variance_mul = tbe.vmul(variance_sub, variance_sub)
|
|
variance_sum = tbe.sum(variance_mul, axis=reduce_axis, keepdims=True)
|
|
variance = tbe.vmuls(variance_sum, mean_cof)
|
|
|
|
# DSL description of the normalize calculation process
|
|
if impl_mode == "high_performance":
|
|
mean_normalize_broadcast = tbe.broadcast(mean, shape_x)
|
|
normalize_sub = tbe.vsub(input_x, mean_normalize_broadcast)
|
|
epsilon = tvm.const(epsilon, dtype=cast_dtype)
|
|
variance_normalize_broadcast = tbe.broadcast(variance, shape_x)
|
|
normalize_add = tbe.vadds(variance_normalize_broadcast, epsilon)
|
|
normalize_log = tbe.vlog(normalize_add)
|
|
normalize_log_mul = \
|
|
tbe.vmuls(normalize_log, tvm.const(-0.5, dtype=cast_dtype))
|
|
normalize_exp = tbe.vexp(normalize_log_mul)
|
|
normalize_mul = tbe.vmul(normalize_sub, normalize_exp)
|
|
elif impl_mode == "high_precision":
|
|
tesor_one = tbe.broadcast(tvm.const
|
|
(1, cast_dtype_precision),
|
|
shape_x)
|
|
mean_normalize_broadcast = tbe.broadcast(mean, shape_x)
|
|
normalize_sub = tbe.vsub(input_x, mean_normalize_broadcast)
|
|
variance_normalize_broadcast = tbe.broadcast(variance, shape_x)
|
|
epsilon = tvm.const(epsilon, dtype=cast_dtype_precision)
|
|
normalize_add = tbe.vadds(variance_normalize_broadcast, epsilon)
|
|
normalize_sqrt = tbe.vsqrt(normalize_add, 0)
|
|
normalize_rsqrt = tbe.vdiv(tesor_one, normalize_sqrt)
|
|
normalize_mul = tbe.vmul(normalize_sub, normalize_rsqrt)
|
|
else:
|
|
epsilon = tvm.const(epsilon, dtype=cast_fp16_dtype)
|
|
normalize_add = tbe.vadds(variance, epsilon)
|
|
normalize_log = tbe.vlog(normalize_add)
|
|
normalize_log_mul = \
|
|
tbe.vmuls(normalize_log, tvm.const(-0.5, dtype=cast_fp16_dtype))
|
|
normalize_exp = tbe.vexp(normalize_log_mul)
|
|
variance_normalize_broadcast = tbe.broadcast(normalize_exp, shape_x)
|
|
normalize_mul = tbe.vmul(variance_sub, variance_normalize_broadcast)
|
|
|
|
# DSL description of the scale and translate calculation process
|
|
if begin_params_axis == 0:
|
|
scale_mul = tbe.vmul(input_gamma, normalize_mul)
|
|
res = tbe.vadd(scale_mul, input_beta)
|
|
else:
|
|
gamma_broadcast = tbe.broadcast(input_gamma, shape_x)
|
|
beta_broadcast = tbe.broadcast(input_beta, shape_x)
|
|
scale_mul = tbe.vmul(gamma_broadcast, normalize_mul)
|
|
res = tbe.vadd(scale_mul, beta_broadcast)
|
|
|
|
if dtype == "float16" and \
|
|
((tbe_platform.cce_conf.api_check_support
|
|
("te.lang.cce.vexp", "float32") and
|
|
impl_mode == "high_performance") or
|
|
impl_mode == "high_precision"):
|
|
mean = tbe.cast_to(mean, "float16")
|
|
variance = tbe.cast_to(variance, "float16")
|
|
res = tbe.cast_to(res, "float16")
|
|
|
|
return mean, variance, res
|
|
|
|
|
|
def is_support_nz_non_aligned(ori_shape_x, begin_params_axis, impl_mode):
|
|
"""
|
|
is_support_nz_non_aligned
|
|
"""
|
|
if ori_shape_x[-1] % constant.C0_SIZE != 0:
|
|
if begin_params_axis != 0:
|
|
return True
|
|
|
|
return False
|
|
|
|
|
|
@para_check.check_op_params(para_check.REQUIRED_INPUT, para_check.REQUIRED_INPUT,
|
|
para_check.REQUIRED_INPUT, para_check.REQUIRED_OUTPUT,
|
|
para_check.REQUIRED_OUTPUT, para_check.REQUIRED_OUTPUT,
|
|
para_check.REQUIRED_ATTR_INT, para_check.REQUIRED_ATTR_INT,
|
|
para_check.OPTION_ATTR_FLOAT, para_check.KERNEL_NAME,
|
|
para_check.OPTION_ATTR_STR)
|
|
def layer_norm(input_x, input_gamma, input_beta,
|
|
output_y, output_mean, output_variance,
|
|
begin_norm_axis, begin_params_axis,
|
|
epsilon=1e-12, kernel_name="layer_norm",
|
|
impl_mode="high_performance"):
|
|
"""
|
|
layernorm operator interface implementation
|
|
calculating: x, gamma, beta
|
|
mean = np.mean(x, reduce_axis, keepdims=True)
|
|
variance = np.mean(np.power((x - mean),2), reduce_axis, keepdims=True)
|
|
result = gamma*((x - mean) / np.sqrt(variance + 0.001)) + beta
|
|
|
|
Parameters
|
|
----------
|
|
input_x : dict
|
|
shape and dtype of input x, only support float16, float32
|
|
input_gamma: dict
|
|
shape and dtype of input gamma, only support float16, float32
|
|
input_beta: dict
|
|
shape and dtype of input beta, only support float16, float32
|
|
output_y: dict
|
|
shape and dtype of output, only support float16, float32
|
|
begin_norm_axis: int
|
|
The first normalization dimension: normalization will be
|
|
performed along dimensions `begin_norm_axis : rank(inputs)`
|
|
begin_params_axis: int
|
|
The first parameter (beta, gamma) dimension: scale
|
|
and centering parameters will have dimensions
|
|
`begin_params_axis : rank(inputs)` and will be broadcast with the
|
|
normalized inputs accordingly.
|
|
epsilon: float,
|
|
Minimum positive number greater than 0
|
|
kernel_name: str
|
|
cce kernel name, default value is "layernorm"
|
|
|
|
Returns
|
|
-------
|
|
None
|
|
"""
|
|
shape_x = list(input_x.get("shape"))
|
|
input_gamma_shape = input_gamma.get("shape")
|
|
input_beta_shape = input_beta.get("shape")
|
|
ori_shape_x = list(input_x.get("ori_shape"))
|
|
input_format = input_x.get("format").upper()
|
|
input_gamma_format = input_gamma.get("format").upper()
|
|
input_beta_format = input_beta.get("format").upper()
|
|
|
|
para_check.check_shape(input_gamma_shape, param_name="input_gamma")
|
|
para_check.check_shape(input_beta_shape, param_name="input_beta")
|
|
para_check.check_shape(shape_x, param_name="input_x")
|
|
|
|
check_list = ("float16", "float32")
|
|
dtype = input_x.get("dtype").lower()
|
|
dtype_gamma = input_gamma.get("dtype").lower()
|
|
dtype_beta = input_gamma.get("dtype").lower()
|
|
para_check.check_dtype(dtype, check_list, param_name="input_x")
|
|
para_check.check_dtype(dtype_gamma, check_list, param_name="input_gamma")
|
|
para_check.check_dtype(dtype_beta, check_list, param_name="input_gamma")
|
|
|
|
shape_gamma = list(input_gamma.get("shape"))
|
|
shape_beta = list(input_beta.get("shape"))
|
|
|
|
flag_vector2cube = False
|
|
tik_support = if_tik_support(input_x, input_gamma, input_beta, output_y, output_mean,
|
|
output_variance, begin_norm_axis, begin_params_axis, epsilon)
|
|
if tik_support:
|
|
layer_normalize(input_x, input_gamma, input_beta,
|
|
output_y, output_mean, output_variance,
|
|
begin_norm_axis, begin_params_axis,
|
|
epsilon, kernel_name)
|
|
else:
|
|
if input_format == "FRACTAL_NZ":
|
|
begin_norm_axis = shape_util.axis_check(len(ori_shape_x), begin_norm_axis)
|
|
begin_params_axis = shape_util.axis_check(len(ori_shape_x), begin_params_axis)
|
|
|
|
flag_vector2cube = _check_vector_to_cube(dtype, ori_shape_x, shape_x, begin_norm_axis, impl_mode)
|
|
if input_gamma_format == "FRACTAL_NZ" or input_beta_format == "FRACTAL_NZ":
|
|
error_detail = "gamma and beta not support Nz in bert"
|
|
error_manager_vector.raise_err_two_input_format_invalid(kernel_name, "input_gamma",
|
|
"input_beta", error_detail)
|
|
if shape_gamma != shape_beta:
|
|
error_detail = "gamma and beta's shape must be same."
|
|
error_manager_vector.raise_err_two_input_shape_invalid(kernel_name, "input_gamma",
|
|
"input_beta", error_detail)
|
|
if ori_shape_x[begin_params_axis:] != shape_gamma:
|
|
error_detail = "x or gamma or begin_params_axis is wrong."
|
|
error_manager_vector.raise_err_two_input_shape_invalid(kernel_name, "x",
|
|
"input_gamma", error_detail)
|
|
if len(shape_gamma) > 1:
|
|
error_detail = "shape of gamma or beta only support 1D in bert"
|
|
error_manager_vector.raise_err_input_shape_invalid(kernel_name, "input_gamma", error_detail)
|
|
|
|
# make shape_x,shape_gamma,shape_beta dim same in vector case
|
|
if not flag_vector2cube:
|
|
if begin_params_axis != 0:
|
|
for i in range(begin_params_axis):
|
|
shape_gamma.insert(i, 1)
|
|
shape_gamma[-2] = shape_x[-4]
|
|
shape_gamma[-1] = 1
|
|
shape_gamma.append(1)
|
|
shape_gamma.append(shape_x[-1])
|
|
if begin_params_axis > len(ori_shape_x) - 2:
|
|
shape_x[-3:] = [shape_x[-3] * shape_x[-2], shape_x[-1]]
|
|
shape_gamma[-3:] = [shape_gamma[-3] * shape_gamma[-2], shape_gamma[-1]]
|
|
shape_beta = shape_gamma
|
|
else:
|
|
begin_norm_axis = shape_util.axis_check(len(shape_x), begin_norm_axis)
|
|
begin_params_axis = shape_util.axis_check(len(shape_x), begin_params_axis)
|
|
|
|
if shape_gamma != shape_beta:
|
|
error_detail = "gamma and beta's shape must be same."
|
|
error_manager_vector.raise_err_two_input_shape_invalid(kernel_name, "input_gamma",
|
|
"input_beta", error_detail)
|
|
no_need_fix_gamma = False
|
|
no_need_fix_beta = False
|
|
if shape_x[begin_params_axis:] != shape_gamma:
|
|
if len(shape_x) == len(shape_gamma):
|
|
no_need_fix_gamma = True
|
|
else:
|
|
error_detail = "x or gamma or begin_params_axis is wrong."
|
|
error_manager_vector.raise_err_two_input_shape_invalid(kernel_name, "x",
|
|
"input_gamma", error_detail)
|
|
if shape_x[begin_params_axis:] != shape_beta:
|
|
if len(shape_x) == len(shape_beta):
|
|
no_need_fix_beta = True
|
|
else:
|
|
error_detail = "x or gamma or begin_params_axis is wrong."
|
|
error_manager_vector.raise_err_two_input_shape_invalid(kernel_name, "x",
|
|
"input_beta", error_detail)
|
|
# make shape_x,shape_gamma,shape_beta dim same
|
|
if begin_params_axis != 0 and not no_need_fix_gamma:
|
|
for i in range(begin_params_axis):
|
|
shape_gamma.insert(i, 1)
|
|
if begin_params_axis != 0 and not no_need_fix_beta:
|
|
for i in range(begin_params_axis):
|
|
shape_beta.insert(i, 1)
|
|
|
|
attr = {"ori_shape": ori_shape_x}
|
|
data_x = tvm.placeholder(shape_x, name="x", dtype=dtype, attrs=attr)
|
|
data_gamma = tvm.placeholder(shape_gamma, name="gamma", dtype=dtype)
|
|
data_beta = tvm.placeholder(shape_beta, name="beta", dtype=dtype)
|
|
|
|
if input_format == "FRACTAL_NZ":
|
|
dyn_input_x = deepcopy(input_x)
|
|
dyn_input_x["shape"] = shape_x
|
|
if flag_vector2cube:
|
|
layer_norm_cube = LayerNormCube({"ori_shape": ori_shape_x,
|
|
"epsilon" : epsilon})
|
|
mean, variance, res = \
|
|
layer_norm_cube.layer_norm_cube_compute(data_x, data_gamma, data_beta)
|
|
elif is_support_nz_non_aligned(ori_shape_x, begin_params_axis, impl_mode):
|
|
mean, variance, res = \
|
|
nz_non_aligned(data_x, data_gamma, data_beta,
|
|
output_y, output_mean, output_variance,
|
|
begin_norm_axis, begin_params_axis,
|
|
ori_shape_x, epsilon, kernel_name, impl_mode)
|
|
elif layer_norm_unify.is_special_cases(dyn_input_x, input_gamma, input_beta, begin_norm_axis, impl_mode):
|
|
__dynamic_template_api(input_x, input_gamma, input_beta, output_y, output_mean, output_variance,
|
|
begin_norm_axis, begin_params_axis, epsilon, kernel_name, impl_mode)
|
|
return
|
|
|
|
else:
|
|
mean, variance, res = \
|
|
layer_norm_compute_nz(data_x, data_gamma, data_beta,
|
|
output_y, output_mean, output_variance,
|
|
begin_norm_axis, begin_params_axis,
|
|
ori_shape_x, epsilon, kernel_name, impl_mode)
|
|
else:
|
|
if layer_norm_unify.is_special_cases(input_x, input_gamma, input_beta, begin_norm_axis, impl_mode):
|
|
__dynamic_template_api(input_x, input_gamma, input_beta, output_y, output_mean, output_variance,
|
|
begin_norm_axis, begin_params_axis, epsilon, kernel_name, impl_mode)
|
|
return
|
|
else:
|
|
mean, variance, res = \
|
|
layer_norm_compute(data_x, data_gamma, data_beta,
|
|
output_y, output_mean,
|
|
output_variance,
|
|
begin_norm_axis, begin_params_axis,
|
|
epsilon, kernel_name, impl_mode)
|
|
|
|
with tvm.target.cce():
|
|
sch = tbe.auto_schedule([res, mean, variance])
|
|
|
|
config = {"print_ir" : False,
|
|
"name" : kernel_name,
|
|
"tensor_list": [data_x, data_gamma,
|
|
data_beta, res, mean, variance]}
|
|
|
|
tbe.cce_build_code(sch, config)
|
|
|
|
|
|
def __dynamic_template_api(input_x, input_gamma, input_beta, output_y, output_mean, output_variance,
|
|
begin_norm_axis, begin_params_axis, epsilon, kernel_name, impl_mode):
|
|
# when all reduce axis, or reduce axis non aligned or reduced mte data less one block etc. single-core cases will
|
|
# transfer dynamic template to use multi-core
|
|
input_x, input_gamma, input_beta = layer_norm_unify.set_range(input_x, input_gamma, input_beta)
|
|
context_ops = tbe_context.op_context.get_context()
|
|
if context_ops is not None:
|
|
context_ops.set_op_mode("static")
|
|
context_ops.add_addition("is_static", True)
|
|
dyn.layer_norm(input_x, input_gamma, input_beta,
|
|
output_y, output_mean, output_variance,
|
|
begin_norm_axis, begin_params_axis,
|
|
epsilon, kernel_name, impl_mode)
|
|
else:
|
|
with tbe_context.op_context.OpContext("static"):
|
|
tbe_context.op_context.get_context().add_addition("is_static", True)
|
|
dyn.layer_norm(input_x, input_gamma, input_beta,
|
|
output_y, output_mean, output_variance,
|
|
begin_norm_axis, begin_params_axis,
|
|
epsilon, kernel_name, impl_mode)
|