mirror of https://github.com/THUDM/CodeGeeX.git
Refactor
parent
965abd81b4
commit
76c550d876
@ -0,0 +1,50 @@
|
||||
import os
|
||||
import sys
|
||||
import torch
|
||||
import random
|
||||
import argparse
|
||||
import numpy as np
|
||||
|
||||
|
||||
def main():
|
||||
parser = argparse.ArgumentParser()
|
||||
parser.add_argument("--load-path",
|
||||
type=str,
|
||||
default="/zhangpai24/workspace/ckpt_ms/ckpt_ms_213000_fp32_52224.pt")
|
||||
parser.add_argument("--save-path",
|
||||
type=str,
|
||||
default="/zhangpai24/workspace/ckpt_ms/ckpt_ms_213000_qkv.pt")
|
||||
|
||||
args, _ = parser.parse_known_args()
|
||||
|
||||
state_dict_path = args.load_path
|
||||
print("Loading state dict ...")
|
||||
sd = torch.load(state_dict_path, map_location="cpu")
|
||||
|
||||
for i in range(40):
|
||||
if i < 39:
|
||||
query_weight = sd['module']['language_model']['transformer'].pop(f'layers.{i}.attention.query.weight', None)
|
||||
query_bias = sd['module']['language_model']['transformer'].pop(f'layers.{i}.attention.query.bias', None)
|
||||
key_weight = sd['module']['language_model']['transformer'].pop(f'layers.{i}.attention.key.weight', None)
|
||||
key_bias = sd['module']['language_model']['transformer'].pop(f'layers.{i}.attention.key.bias', None)
|
||||
value_weight = sd['module']['language_model']['transformer'].pop(f'layers.{i}.attention.value.weight', None)
|
||||
value_bias = sd['module']['language_model']['transformer'].pop(f'layers.{i}.attention.value.bias', None)
|
||||
qkv_weight = torch.cat([query_weight, key_weight, value_weight], dim=0)
|
||||
qkv_bias = torch.cat([query_bias, key_bias, value_bias])
|
||||
sd['module']['language_model']['transformer'][f'layers.{i}.attention.query_key_value.weight'] = qkv_weight
|
||||
sd['module']['language_model']['transformer'][f'layers.{i}.attention.query_key_value.bias'] = qkv_bias
|
||||
else:
|
||||
tq_key_weight = sd['module']['language_model']['transformer'].pop('topQueryLayer.attention.key.weight', None)
|
||||
tq_key_bias = sd['module']['language_model']['transformer'].pop('topQueryLayer.attention.key.bias', None)
|
||||
tq_value_weight = sd['module']['language_model']['transformer'].pop('topQueryLayer.attention.value.weight', None)
|
||||
tq_value_bias = sd['module']['language_model']['transformer'].pop('topQueryLayer.attention.value.bias', None)
|
||||
tq_kv_weight = torch.cat([tq_key_weight, tq_value_weight], dim=0)
|
||||
tq_kv_bias = torch.cat([tq_key_bias, tq_value_bias])
|
||||
sd['module']['language_model']['transformer']['topQueryLayer.attention.key_value.weight'] = tq_kv_weight
|
||||
sd['module']['language_model']['transformer']['topQueryLayer.attention.key_value.bias'] = tq_kv_bias
|
||||
|
||||
save_ckpt_path = args.save_path
|
||||
torch.save(sd, save_ckpt_path)
|
||||
|
||||
if __name__ == '__main__':
|
||||
main()
|
Loading…
Reference in New Issue