add valid forward

pull/69/head
Stanislas0 2 years ago
parent 843a946f41
commit df2b2a95a0

@ -78,6 +78,7 @@ def pretrain(
train_valid_test_dataset_provider,
model_provider,
forward_step_func,
valid_forward_step_func=None,
extra_args_provider=None,
args_defaults={},
):
@ -176,6 +177,7 @@ def pretrain(
if args.do_train and args.train_iters > 0:
iteration = train(
forward_step_func,
valid_forward_step_func,
model,
optimizer,
lr_scheduler,
@ -189,11 +191,11 @@ def pretrain(
if args.co_evaluation:
for key, value in valid_data_iterator.items():
evaluate_and_print_results(
prefix, forward_step_func, value, model, iteration, False, tag=key
prefix, valid_forward_step_func, value, model, iteration, False, tag=key
)
else:
evaluate_and_print_results(
prefix, forward_step_func, valid_data_iterator, model, iteration, False
prefix, valid_forward_step_func, valid_data_iterator, model, iteration, False
)
if args.save and iteration != 0:
@ -879,6 +881,7 @@ def save_checkpoint_and_time(iteration, model, optimizer, lr_scheduler):
def train(
forward_step_func,
valid_forward_step_func,
model,
optimizer,
lr_scheduler,
@ -976,11 +979,15 @@ def train(
if args.co_evaluation:
for key, value in valid_data_iterator.items():
evaluate_and_print_results(
prefix, forward_step_func, value, model, iteration, False, tag=key
prefix, valid_forward_step_func, value, model, iteration, False, tag=key
)
else:
if args.gold:
evaluate_and_print_results_gold(
prefix, forward_step_func, valid_data_iterator, model, iteration, False
)
evaluate_and_print_results(
prefix, forward_step_func, valid_data_iterator, model, iteration, False
prefix, valid_forward_step_func, valid_data_iterator, model, iteration, False
)
# Checkpointing
@ -1183,16 +1190,6 @@ def evaluate_and_print_results_gold(
total_loss_dict[key].item(),
iteration,
)
# writer.add_scalar(
# f"lm-loss-validation/{display_key} validation vs samples",
# total_loss_dict[key].item(),
# args.consumed_train_samples,
# )
# writer.add_scalar(
# f"lm-loss-validation/{display_key} validation vs tokens",
# total_loss_dict[key].item(),
# args.consumed_train_tokens,
# )
if args.log_validation_ppl_to_tensorboard:
writer.add_scalar(
f"lm-loss-validation/{display_key} validation ppl", ppl, iteration

Loading…
Cancel
Save