|
|
|
@ -61,14 +61,19 @@ class ParallelMLP(MegatronModule):
|
|
|
|
|
applied.
|
|
|
|
|
"""
|
|
|
|
|
|
|
|
|
|
def __init__(self, init_method, output_layer_init_method):
|
|
|
|
|
def __init__(
|
|
|
|
|
self,
|
|
|
|
|
init_method,
|
|
|
|
|
output_layer_init_method,
|
|
|
|
|
scale: int = 4,
|
|
|
|
|
):
|
|
|
|
|
super(ParallelMLP, self).__init__()
|
|
|
|
|
args = get_args()
|
|
|
|
|
|
|
|
|
|
# Project to 4h.
|
|
|
|
|
self.dense_h_to_4h = mpu.ColumnParallelLinear(
|
|
|
|
|
args.hidden_size,
|
|
|
|
|
4 * args.hidden_size,
|
|
|
|
|
scale * args.hidden_size,
|
|
|
|
|
gather_output=False,
|
|
|
|
|
init_method=init_method,
|
|
|
|
|
# skip_bias_add=True,
|
|
|
|
@ -78,7 +83,7 @@ class ParallelMLP(MegatronModule):
|
|
|
|
|
|
|
|
|
|
# Project back to h.
|
|
|
|
|
self.dense_4h_to_h = mpu.RowParallelLinear(
|
|
|
|
|
4 * args.hidden_size,
|
|
|
|
|
scale * args.hidden_size,
|
|
|
|
|
args.hidden_size,
|
|
|
|
|
input_is_parallel=True if args.tensor_model_parallel_size > 1 else False,
|
|
|
|
|
init_method=output_layer_init_method,
|
|
|
|
@ -264,7 +269,7 @@ class ParallelSelfAttention(MegatronModule):
|
|
|
|
|
if self.attention_softmax_in_fp32:
|
|
|
|
|
attention_probs = self.softmax(attention_scores.float()).half()
|
|
|
|
|
else:
|
|
|
|
|
attention_probs = self.softmax(attention_scores)
|
|
|
|
|
attention_probs = self.softmax(attention_scores.half())
|
|
|
|
|
|
|
|
|
|
# This is actually dropping out entire tokens to attend to, which might
|
|
|
|
|
# seem a bit unusual, but is taken from the original Transformer paper.
|
|
|
|
@ -485,7 +490,7 @@ class ParallelTopQuerySelfAttention(MegatronModule):
|
|
|
|
|
if self.attention_softmax_in_fp32:
|
|
|
|
|
attention_probs = self.softmax(attention_scores.float()).half()
|
|
|
|
|
else:
|
|
|
|
|
attention_probs = self.softmax(attention_scores)
|
|
|
|
|
attention_probs = self.softmax(attention_scores.half())
|
|
|
|
|
|
|
|
|
|
# This is actually dropping out entire tokens to attend to, which might
|
|
|
|
|
# seem a bit unusual, but is taken from the original Transformer paper.
|
|
|
|
@ -607,7 +612,8 @@ class ParallelTransformerLayer(MegatronModule):
|
|
|
|
|
self.ln_fp16 = False
|
|
|
|
|
# MLP
|
|
|
|
|
self.mlp = ParallelMLP(init_method,
|
|
|
|
|
output_layer_init_method)
|
|
|
|
|
output_layer_init_method,
|
|
|
|
|
scale=2 if args.compress else 4)
|
|
|
|
|
|
|
|
|
|
def forward(
|
|
|
|
|
self,
|
|
|
|
|