pull/69/head
Stanislas0 2 years ago
parent 1051b384a2
commit 4198b21b9e

@ -415,6 +415,10 @@ def _add_network_size_args(parser):
help="Disable BERT binary head.",
dest="bert_binary_head",
)
group.add_argument(
"--compress",
action="store_true",
)
return parser

@ -61,14 +61,19 @@ class ParallelMLP(MegatronModule):
applied.
"""
def __init__(self, init_method, output_layer_init_method):
def __init__(
self,
init_method,
output_layer_init_method,
scale: int = 4,
):
super(ParallelMLP, self).__init__()
args = get_args()
# Project to 4h.
self.dense_h_to_4h = mpu.ColumnParallelLinear(
args.hidden_size,
4 * args.hidden_size,
scale * args.hidden_size,
gather_output=False,
init_method=init_method,
# skip_bias_add=True,
@ -78,7 +83,7 @@ class ParallelMLP(MegatronModule):
# Project back to h.
self.dense_4h_to_h = mpu.RowParallelLinear(
4 * args.hidden_size,
scale * args.hidden_size,
args.hidden_size,
input_is_parallel=True if args.tensor_model_parallel_size > 1 else False,
init_method=output_layer_init_method,
@ -264,7 +269,7 @@ class ParallelSelfAttention(MegatronModule):
if self.attention_softmax_in_fp32:
attention_probs = self.softmax(attention_scores.float()).half()
else:
attention_probs = self.softmax(attention_scores)
attention_probs = self.softmax(attention_scores.half())
# This is actually dropping out entire tokens to attend to, which might
# seem a bit unusual, but is taken from the original Transformer paper.
@ -485,7 +490,7 @@ class ParallelTopQuerySelfAttention(MegatronModule):
if self.attention_softmax_in_fp32:
attention_probs = self.softmax(attention_scores.float()).half()
else:
attention_probs = self.softmax(attention_scores)
attention_probs = self.softmax(attention_scores.half())
# This is actually dropping out entire tokens to attend to, which might
# seem a bit unusual, but is taken from the original Transformer paper.
@ -607,7 +612,8 @@ class ParallelTransformerLayer(MegatronModule):
self.ln_fp16 = False
# MLP
self.mlp = ParallelMLP(init_method,
output_layer_init_method)
output_layer_init_method,
scale=2 if args.compress else 4)
def forward(
self,

@ -65,12 +65,6 @@ except ImportError:
from filelock import FileLock
import pathlib
try:
import bmcook
from bmcook import Config
except ImportError:
print("bmcook not imported.")
bmcook = None
def print_datetime(string):
@ -80,11 +74,6 @@ def print_datetime(string):
print_rank_0("[" + string + "] datetime: {} ".format(time_str))
def compress_setup(args, model, optimizer):
teacher = get_model(args)
cook_config = ConfigParser(args.cook_config)
CPMAntTrainer.set_compression(cook_config, model, optimizer, teacher=teacher, remove_ckptblock=False, target_linear=Linear)
def pretrain(
train_valid_test_dataset_provider,
model_provider,

Loading…
Cancel
Save