Add `pt` to `pdparams` convert script

pull/30/head
saas1600 2 years ago
parent 82d2c020b0
commit c2db09b320

@ -0,0 +1,55 @@
import argparse
import paddle
import torch
linear_layer = [
"mlp.dense_h_to_4h",
"mlp.dense_4h_to_h",
"attention.query",
"attention.key",
"attention.value",
"attention.dense",
]
def WalkDict(x):
for i in x:
if isinstance(x[i], dict):
WalkDict(x[i])
elif isinstance(x[i], torch.Tensor):
print(f"Converting '{i}' from 'torch.Tensor' to 'numpy.ndarray'.")
npy = x[i].cpu().numpy()
if any([f".{layer}.weight" in i for layer in linear_layer]):
print(f"Transposing linear layer weight '{i}'.")
x[i] = npy.T
else:
x[i] = npy
def parse_opt():
parser = argparse.ArgumentParser()
parser.add_argument(
"--pt",
type=str,
required=True,
help="Path to pt checkpoint."
)
parser.add_argument(
"--pdparams",
type=str,
required=True,
help="Path to pdparams checkpoint."
)
opt = parser.parse_args()
return opt
def main(opt):
state_dict = torch.load(opt.pt)
WalkDict(state_dict)
paddle.save(state_dict, opt.pdparams)
if __name__ == "__main__":
opt = parse_opt()
main(opt)
Loading…
Cancel
Save