mirror of https://github.com/THUDM/CodeGeeX.git
You cannot select more than 25 topics
Topics must start with a letter or number, can include dashes ('-') and can be up to 35 characters long.
56 lines
1.2 KiB
Python
56 lines
1.2 KiB
Python
2 years ago
|
import argparse
|
||
|
import paddle
|
||
|
import torch
|
||
|
|
||
|
linear_layer = [
|
||
|
"mlp.dense_h_to_4h",
|
||
|
"mlp.dense_4h_to_h",
|
||
|
"attention.query",
|
||
|
"attention.key",
|
||
|
"attention.value",
|
||
|
"attention.dense",
|
||
|
]
|
||
|
|
||
|
|
||
|
def WalkDict(x):
|
||
|
for i in x:
|
||
|
if isinstance(x[i], dict):
|
||
|
WalkDict(x[i])
|
||
|
elif isinstance(x[i], torch.Tensor):
|
||
|
print(f"Converting '{i}' from 'torch.Tensor' to 'numpy.ndarray'.")
|
||
|
npy = x[i].cpu().numpy()
|
||
|
if any([f".{layer}.weight" in i for layer in linear_layer]):
|
||
|
print(f"Transposing linear layer weight '{i}'.")
|
||
|
x[i] = npy.T
|
||
|
else:
|
||
|
x[i] = npy
|
||
|
|
||
|
|
||
|
def parse_opt():
|
||
|
parser = argparse.ArgumentParser()
|
||
|
parser.add_argument(
|
||
|
"--pt",
|
||
|
type=str,
|
||
|
required=True,
|
||
|
help="Path to pt checkpoint."
|
||
|
)
|
||
|
parser.add_argument(
|
||
|
"--pdparams",
|
||
|
type=str,
|
||
|
required=True,
|
||
|
help="Path to pdparams checkpoint."
|
||
|
)
|
||
|
opt = parser.parse_args()
|
||
|
return opt
|
||
|
|
||
|
|
||
|
def main(opt):
|
||
|
state_dict = torch.load(opt.pt)
|
||
|
WalkDict(state_dict)
|
||
|
paddle.save(state_dict, opt.pdparams)
|
||
|
|
||
|
|
||
|
if __name__ == "__main__":
|
||
|
opt = parse_opt()
|
||
|
main(opt)
|