diff --git a/codegeex/megatron/convert_ckpt_parallel.py b/codegeex/megatron/convert_ckpt_parallel.py new file mode 100644 index 0000000..c6adcdc --- /dev/null +++ b/codegeex/megatron/convert_ckpt_parallel.py @@ -0,0 +1,145 @@ +"""Get model parallel partitions.""" + +import os +import re +import random +import sys + +import numpy as np +import torch + +sys.path.append(os.path.abspath(os.path.join(os.path.dirname(__file__), + os.path.pardir))) + +from codegeex.megatron import get_args +from codegeex.megatron.model import CodeGeeXModel +from codegeex.megatron.initialize import initialize_megatron +from codegeex.megatron.checkpointing import ensure_directory_exists + + +def get_change_ckpt_args(parser): + """Provide extra arguments required for merging.""" + group = parser.add_argument_group(title='Mindspore to megatron') + group.add_argument( + '--load-ckpt-path', + type=str, + required=True, + help='path to load ".pt" checkpoint.', + ) + group.add_argument( + '--save-ckpt-path', + type=str, + required=True, + help='dir to save converted checkpoints.', + ) + group.add_argument( + '--target-tensor-model-parallel-size', + type=int, + default=2, + help='target tensor model parallel size', + ) + + return parser + + +def get_element_from_dict_by_path(d, path): + """ + Get element from dictionary by path. If element is not present, recursively add empty dictionaries. + Args: + d (dict): the dictionary to get the element from + path (list): the path to the element which is delimited by "." + """ + path = path.split(".") + for k in path: + if k not in d: + d[k] = {} + d = d[k] + return d + + +def main(): + os.environ["MASTER_ADDR"] = "localhost" + os.environ["MASTER_PORT"] = str(random.randint(10000, 20000)) + + initialize_megatron( + extra_args_provider=get_change_ckpt_args, + args_defaults={ + "tokenizer_type": "GPT2BPETokenizer", + "no_load_rng" : True, + "no_load_optim" : True, + }, + ) + + args = get_args() + print(f"Load ckpt from {args.load_ckpt_path}...") + state_dict = torch.load(args.load_ckpt_path, map_location="cpu") + + print(f"Spliting ckpt into {args.target_tensor_model_parallel_size} parts...") + output_state_dict = [] + for i in range(args.target_tensor_model_parallel_size): + output_state_dict.append({}) + + print("Converting Embedding layers...") + word_embeddings = state_dict['module']['language_model']['embedding']['word_embeddings']['weight'] + position_embeddings = state_dict['module']['language_model']['embedding']['position_embeddings']['weight'] + out_word_embeddings = torch.chunk(word_embeddings, args.target_tensor_model_parallel_size, dim=0) + + for i in range(args.target_tensor_model_parallel_size): + pos_emb_dict = get_element_from_dict_by_path( + output_state_dict[i], "module.language_model.embedding.position_embeddings" + ) + pos_emb_dict["weight"] = position_embeddings + + word_emb_dict = get_element_from_dict_by_path( + output_state_dict[i], "module.language_model.embedding.word_embeddings" + ) + word_emb_dict["weight"] = out_word_embeddings[i] + + print("Converting QueryEmbedding layers...") + query_embeddings = state_dict['module']['language_model']['topQueryEmbedding']['top_query_embeddings']['weight'] + out_query_embeddings = torch.chunk(query_embeddings, args.target_tensor_model_parallel_size, dim=0) + + for i in range(args.target_tensor_model_parallel_size): + query_emb_dict = get_element_from_dict_by_path( + output_state_dict[i], "module.language_model.topQueryEmbedding.top_query_embeddings" + ) + query_emb_dict["weight"] = out_query_embeddings[i] + + print("Converting Transformer layers...") + for layer_name in state_dict['module']['language_model']['transformer'].keys(): + params = state_dict['module']['language_model']['transformer'][layer_name] + if "layernorm" in layer_name: + pass + elif "attention" in layer_name and "weight" in layer_name: + if "dense" in layer_name: + params = torch.chunk(params, args.target_tensor_model_parallel_size, dim=1) + else: + params = torch.chunk(params, args.target_tensor_model_parallel_size, dim=0) + elif "weight" in layer_name and "dense" in layer_name: + if "h_to_4h" in layer_name: + params = torch.chunk(params, args.target_tensor_model_parallel_size, dim=0) + else: + params = torch.chunk(params, args.target_tensor_model_parallel_size, dim=1) + elif "bias" in layer_name: + if "dense" not in layer_name or "mlp" in layer_name: + if "4h_to_h" in layer_name: + pass + else: + params = torch.chunk(params, args.target_tensor_model_parallel_size, dim=0) + + for i in range(args.target_tensor_model_parallel_size): + params_dict = get_element_from_dict_by_path(output_state_dict[i], "module.language_model.transformer") + if type(params) is tuple: + params_dict[layer_name] = params[i] + else: + params_dict[layer_name] = params + + os.makedirs(args.save_ckpt_path, exist_ok=True) + for rank in range(args.target_tensor_model_parallel_size): + save_ckpt_path = os.path.join(args.save_ckpt_path, f"mp_rank_{rank:02d}_model_states.pt") + torch.save(output_state_dict[rank], save_ckpt_path) + print(f"Converted checkpoint saved in {save_ckpt_path}.") + + +if __name__ == '__main__': + main() diff --git a/scripts/test_inference_quantized.sh b/scripts/test_inference_quantized.sh index 2f83b24..c672e82 100644 --- a/scripts/test_inference_quantized.sh +++ b/scripts/test_inference_quantized.sh @@ -24,7 +24,7 @@ if [ -z "$PROMPT_FILE" ]; then fi # remove --greedy if using sampling -CMD="python $MAIN_DIR/tests/test_inference_quantized.py \ +CMD="python $MAIN_DIR/tests/test_inference.py \ --prompt-file $PROMPT_FILE \ --tokenizer-path $TOKENIZER_PATH \ --micro-batch-size 1 \ @@ -32,6 +32,7 @@ CMD="python $MAIN_DIR/tests/test_inference_quantized.py \ --temperature 0.2 \ --top-p 0.95 \ --top-k 0 \ + --quantize \ $MODEL_ARGS" echo "$CMD" diff --git a/tests/test_inference_quantized.py b/tests/test_inference_quantized.py deleted file mode 100644 index ad3be1d..0000000 --- a/tests/test_inference_quantized.py +++ /dev/null @@ -1,201 +0,0 @@ - -import os -import copy -import time -import torch -import random -import argparse -import numpy as np - -from codegeex.torch.inference import get_token_stream -from codegeex.torch import CodeGeeXModel -from codegeex.tokenizer import CodeGeeXTokenizer -from codegeex.quantization import quantize - - -def model_provider(args): - """Build the model.""" - - model = CodeGeeXModel( - args.hidden_size, - args.num_layers, - args.num_attention_heads, - args.padded_vocab_size, - args.max_position_embeddings - ) - - return model - - -def add_code_generation_args(parser): - group = parser.add_argument_group(title="code generation") - group.add_argument( - "--num-layers", - type=int, - default=39, - ) - group.add_argument( - "--hidden-size", - type=int, - default=5120, - ) - group.add_argument( - "--num-attention-heads", - type=int, - default=40, - ) - group.add_argument( - "--padded-vocab-size", - type=int, - default=52224, - ) - group.add_argument( - "--max-position-embeddings", - type=int, - default=2048, - ) - group.add_argument( - "--temperature", - type=float, - default=1.0, - help="Sampling temperature.", - ) - group.add_argument( - "--greedy", - action="store_true", - default=False, - help="Use greedy sampling.", - ) - group.add_argument( - "--top-p", - type=float, - default=0.0, - help="Top p sampling.", - ) - group.add_argument( - "--top-k", - type=int, - default=0, - help="Top k sampling.", - ) - group.add_argument( - "--out-seq-length", - type=int, - default=2048, - help="Size of the output generated text.", - ) - group.add_argument( - "--prompt-file", - type=str, - default="./test_prompt.txt", - ) - group.add_argument( - "--tokenizer-path", - type=str, - default="./tokenizer", - ) - group.add_argument( - "--load", - type=str, - ) - group.add_argument( - "--state-dict-path", - type=str, - ) - group.add_argument( - "--micro-batch-size", - type=int, - default=1, - ) - - return parser - - -def main(): - parser = argparse.ArgumentParser() - parser = add_code_generation_args(parser) - args, _ = parser.parse_known_args() - - print("Building CodeGeeX model ...") - model = model_provider(args) - - print("Loading tokenizer ...") - tokenizer = CodeGeeXTokenizer( - tokenizer_path=args.tokenizer_path, - mode="codegeex-13b") - - print("Loading state dict ...") - state_dict = torch.load(args.load, map_location="cpu") - state_dict = state_dict["module"] - - print("Building CodeGeeX model ...") - model = model_provider(args) - model.load_state_dict(state_dict) - model.eval() - model.half() - model = quantize(model, weight_bit_width=8) - model.cuda() - - with open(args.prompt_file, "r") as f: - prompt = f.readlines() - prompt = "".join(prompt) - - times = {} - out_seq_lengths = [args.out_seq_length] - micro_batch_size = args.micro_batch_size - seq_length = args.max_position_embeddings - for out_seq_length in out_seq_lengths: - print(f"Generating with out_seq_len {out_seq_length}...") - - times[out_seq_length] = [] - for prompt in [prompt]: - t0 = time.perf_counter() - tokens = tokenizer.encode_code(prompt) - print(tokens) - print("Current prompt:") - print(prompt) - n_token_prompt = len(tokens) - print("N_token_prompt:", n_token_prompt) - token_stream = get_token_stream( - model, - tokenizer, - seq_length, - out_seq_length, - [copy.deepcopy(tokens) for _ in range(micro_batch_size)], - micro_batch_size=micro_batch_size, - topk=args.top_k, - topp=args.top_p, - temperature=args.temperature, - greedy=args.greedy, - ) - is_finished = [False for _ in range(micro_batch_size)] - for i, generated in enumerate(token_stream): - generated_tokens = generated[0] - for j in range(micro_batch_size): - if is_finished[j]: - continue - if generated_tokens[j].cpu().numpy()[-1] == tokenizer.eos_token_id or len( - generated_tokens[j]) >= out_seq_length: - is_finished[j] = True - generated_tokens_ = generated_tokens[j].cpu().numpy().tolist() - generated_code = tokenizer.decode_code(generated_tokens_[n_token_prompt:]) - generated_code = "".join(generated_code) - t1 = time.perf_counter() - print("Total generation time:", t1 - t0, "# Tokens:", len(generated_tokens_) - n_token_prompt) - print(f"{(t1 - t0) / (len(generated_tokens_) - n_token_prompt)}s/token") - times[out_seq_length].append(t1 - t0) - print("================================= Generated code:") - print(generated_code) - - if all(is_finished): - break - - print(times) - for out_seq_length in times.keys(): - print(out_seq_length, np.mean(times[out_seq_length])) - - print("Generation finished.") - - -if __name__ == "__main__": - main()