You cannot select more than 25 topics Topics must start with a letter or number, can include dashes ('-') and can be up to 35 characters long.

51 lines
2.7 KiB
Python

2 years ago
import os
import sys
import torch
import random
import argparse
import numpy as np
def main():
parser = argparse.ArgumentParser()
parser.add_argument("--load-path",
type=str,
default="/zhangpai24/workspace/ckpt_ms/ckpt_ms_213000_fp32_52224.pt")
parser.add_argument("--save-path",
type=str,
default="/zhangpai24/workspace/ckpt_ms/ckpt_ms_213000_qkv.pt")
args, _ = parser.parse_known_args()
state_dict_path = args.load_path
print("Loading state dict ...")
sd = torch.load(state_dict_path, map_location="cpu")
for i in range(40):
if i < 39:
query_weight = sd['module']['language_model']['transformer'].pop(f'layers.{i}.attention.query.weight', None)
query_bias = sd['module']['language_model']['transformer'].pop(f'layers.{i}.attention.query.bias', None)
key_weight = sd['module']['language_model']['transformer'].pop(f'layers.{i}.attention.key.weight', None)
key_bias = sd['module']['language_model']['transformer'].pop(f'layers.{i}.attention.key.bias', None)
value_weight = sd['module']['language_model']['transformer'].pop(f'layers.{i}.attention.value.weight', None)
value_bias = sd['module']['language_model']['transformer'].pop(f'layers.{i}.attention.value.bias', None)
qkv_weight = torch.cat([query_weight, key_weight, value_weight], dim=0)
qkv_bias = torch.cat([query_bias, key_bias, value_bias])
sd['module']['language_model']['transformer'][f'layers.{i}.attention.query_key_value.weight'] = qkv_weight
sd['module']['language_model']['transformer'][f'layers.{i}.attention.query_key_value.bias'] = qkv_bias
else:
tq_key_weight = sd['module']['language_model']['transformer'].pop('topQueryLayer.attention.key.weight', None)
tq_key_bias = sd['module']['language_model']['transformer'].pop('topQueryLayer.attention.key.bias', None)
tq_value_weight = sd['module']['language_model']['transformer'].pop('topQueryLayer.attention.value.weight', None)
tq_value_bias = sd['module']['language_model']['transformer'].pop('topQueryLayer.attention.value.bias', None)
tq_kv_weight = torch.cat([tq_key_weight, tq_value_weight], dim=0)
tq_kv_bias = torch.cat([tq_key_bias, tq_value_bias])
sd['module']['language_model']['transformer']['topQueryLayer.attention.key_value.weight'] = tq_kv_weight
sd['module']['language_model']['transformer']['topQueryLayer.attention.key_value.bias'] = tq_kv_bias
save_ckpt_path = args.save_path
torch.save(sd, save_ckpt_path)
if __name__ == '__main__':
main()