mirror of https://github.com/THUDM/CodeGeeX.git
Add `pt` to `pdparams` convert script
parent
82d2c020b0
commit
c2db09b320
@ -0,0 +1,55 @@
|
||||
import argparse
|
||||
import paddle
|
||||
import torch
|
||||
|
||||
linear_layer = [
|
||||
"mlp.dense_h_to_4h",
|
||||
"mlp.dense_4h_to_h",
|
||||
"attention.query",
|
||||
"attention.key",
|
||||
"attention.value",
|
||||
"attention.dense",
|
||||
]
|
||||
|
||||
|
||||
def WalkDict(x):
|
||||
for i in x:
|
||||
if isinstance(x[i], dict):
|
||||
WalkDict(x[i])
|
||||
elif isinstance(x[i], torch.Tensor):
|
||||
print(f"Converting '{i}' from 'torch.Tensor' to 'numpy.ndarray'.")
|
||||
npy = x[i].cpu().numpy()
|
||||
if any([f".{layer}.weight" in i for layer in linear_layer]):
|
||||
print(f"Transposing linear layer weight '{i}'.")
|
||||
x[i] = npy.T
|
||||
else:
|
||||
x[i] = npy
|
||||
|
||||
|
||||
def parse_opt():
|
||||
parser = argparse.ArgumentParser()
|
||||
parser.add_argument(
|
||||
"--pt",
|
||||
type=str,
|
||||
required=True,
|
||||
help="Path to pt checkpoint."
|
||||
)
|
||||
parser.add_argument(
|
||||
"--pdparams",
|
||||
type=str,
|
||||
required=True,
|
||||
help="Path to pdparams checkpoint."
|
||||
)
|
||||
opt = parser.parse_args()
|
||||
return opt
|
||||
|
||||
|
||||
def main(opt):
|
||||
state_dict = torch.load(opt.pt)
|
||||
WalkDict(state_dict)
|
||||
paddle.save(state_dict, opt.pdparams)
|
||||
|
||||
|
||||
if __name__ == "__main__":
|
||||
opt = parse_opt()
|
||||
main(opt)
|
Loading…
Reference in New Issue