From c2db09b320c45f55857d22620cd21b386b05a6e7 Mon Sep 17 00:00:00 2001 From: saas1600 <111279477+saas1600@users.noreply.github.com> Date: Fri, 9 Dec 2022 15:40:11 +0800 Subject: [PATCH] Add `pt` to `pdparams` convert script --- codegeex/paddle/pt_to_pdparams.py | 55 +++++++++++++++++++++++++++++++ 1 file changed, 55 insertions(+) create mode 100644 codegeex/paddle/pt_to_pdparams.py diff --git a/codegeex/paddle/pt_to_pdparams.py b/codegeex/paddle/pt_to_pdparams.py new file mode 100644 index 0000000..c2638f4 --- /dev/null +++ b/codegeex/paddle/pt_to_pdparams.py @@ -0,0 +1,55 @@ +import argparse +import paddle +import torch + +linear_layer = [ + "mlp.dense_h_to_4h", + "mlp.dense_4h_to_h", + "attention.query", + "attention.key", + "attention.value", + "attention.dense", +] + + +def WalkDict(x): + for i in x: + if isinstance(x[i], dict): + WalkDict(x[i]) + elif isinstance(x[i], torch.Tensor): + print(f"Converting '{i}' from 'torch.Tensor' to 'numpy.ndarray'.") + npy = x[i].cpu().numpy() + if any([f".{layer}.weight" in i for layer in linear_layer]): + print(f"Transposing linear layer weight '{i}'.") + x[i] = npy.T + else: + x[i] = npy + + +def parse_opt(): + parser = argparse.ArgumentParser() + parser.add_argument( + "--pt", + type=str, + required=True, + help="Path to pt checkpoint." + ) + parser.add_argument( + "--pdparams", + type=str, + required=True, + help="Path to pdparams checkpoint." + ) + opt = parser.parse_args() + return opt + + +def main(opt): + state_dict = torch.load(opt.pt) + WalkDict(state_dict) + paddle.save(state_dict, opt.pdparams) + + +if __name__ == "__main__": + opt = parse_opt() + main(opt)