mirror of https://github.com/THUDM/CodeGeeX.git
You cannot select more than 25 topics
Topics must start with a letter or number, can include dashes ('-') and can be up to 35 characters long.
198 lines
6.8 KiB
Python
198 lines
6.8 KiB
Python
import json
|
|
import torch
|
|
import argparse
|
|
import gradio as gr
|
|
|
|
import codegeex
|
|
from codegeex.torch import CodeGeeXModel
|
|
from codegeex.tokenizer import CodeGeeXTokenizer
|
|
from codegeex.quantization import quantize
|
|
from codegeex.data.data_utils import LANGUAGE_TAG
|
|
from codegeex.megatron.inference import set_random_seed
|
|
|
|
|
|
def model_provider(args):
|
|
"""Build the model."""
|
|
|
|
model = CodeGeeXModel(
|
|
args.hidden_size,
|
|
args.num_layers,
|
|
args.num_attention_heads,
|
|
args.padded_vocab_size,
|
|
args.max_position_embeddings
|
|
)
|
|
return model
|
|
|
|
|
|
def add_code_generation_args(parser):
|
|
group = parser.add_argument_group(title="code generation")
|
|
group.add_argument(
|
|
"--num-layers",
|
|
type=int,
|
|
default=39,
|
|
)
|
|
group.add_argument(
|
|
"--hidden-size",
|
|
type=int,
|
|
default=5120,
|
|
)
|
|
group.add_argument(
|
|
"--num-attention-heads",
|
|
type=int,
|
|
default=40,
|
|
)
|
|
group.add_argument(
|
|
"--padded-vocab-size",
|
|
type=int,
|
|
default=52224,
|
|
)
|
|
group.add_argument(
|
|
"--max-position-embeddings",
|
|
type=int,
|
|
default=2048,
|
|
)
|
|
group.add_argument(
|
|
"--tokenizer-path",
|
|
type=str,
|
|
default="./tokenizer",
|
|
)
|
|
group.add_argument(
|
|
"--example-path",
|
|
type=str,
|
|
default="./",
|
|
)
|
|
group.add_argument(
|
|
"--load",
|
|
type=str,
|
|
)
|
|
group.add_argument(
|
|
"--state-dict-path",
|
|
type=str,
|
|
)
|
|
group.add_argument(
|
|
"--micro-batch-size",
|
|
type=int,
|
|
default=1,
|
|
)
|
|
group.add_argument(
|
|
"--quantize",
|
|
action="store_true",
|
|
)
|
|
|
|
return parser
|
|
|
|
|
|
def main():
|
|
parser = argparse.ArgumentParser()
|
|
parser = add_code_generation_args(parser)
|
|
args, _ = parser.parse_known_args()
|
|
|
|
print("Loading tokenizer ...")
|
|
tokenizer = CodeGeeXTokenizer(
|
|
tokenizer_path=args.tokenizer_path,
|
|
mode="codegeex-13b")
|
|
|
|
print("Loading state dict ...")
|
|
state_dict = torch.load(args.load, map_location="cpu")
|
|
state_dict = state_dict["module"]
|
|
|
|
print("Building CodeGeeX model ...")
|
|
model = model_provider(args)
|
|
model.load_state_dict(state_dict)
|
|
model.eval()
|
|
model.half()
|
|
if args.quantize:
|
|
model = quantize(model, weight_bit_width=8, backend="torch")
|
|
model.cuda()
|
|
|
|
def predict(
|
|
prompt,
|
|
lang,
|
|
seed,
|
|
out_seq_length,
|
|
temperature,
|
|
top_k,
|
|
top_p,
|
|
):
|
|
set_random_seed(seed)
|
|
if lang.lower() in LANGUAGE_TAG:
|
|
prompt = LANGUAGE_TAG[lang.lower()] + "\n" + prompt
|
|
|
|
generated_code = codegeex.generate(
|
|
model,
|
|
tokenizer,
|
|
prompt,
|
|
out_seq_length=out_seq_length,
|
|
seq_length=args.max_position_embeddings,
|
|
top_k=top_k,
|
|
top_p=top_p,
|
|
temperature=temperature,
|
|
micro_batch_size=args.micro_batch_size,
|
|
backend="megatron",
|
|
verbose=True,
|
|
)
|
|
return prompt + generated_code
|
|
|
|
examples = []
|
|
with open(args.example_path, "r") as f:
|
|
for line in f:
|
|
examples.append(list(json.loads(line).values()))
|
|
|
|
with gr.Blocks() as demo:
|
|
gr.Markdown(
|
|
"""
|
|
<img src="https://raw.githubusercontent.com/THUDM/CodeGeeX/main/resources/logo/codegeex_logo.png">
|
|
""")
|
|
gr.Markdown(
|
|
"""
|
|
<p align="center">
|
|
🏠 <a href="https://codegeex.cn" target="_blank">Homepage</a> | 📖 <a href="http://keg.cs.tsinghua.edu.cn/codegeex/" target="_blank">Blog</a> | 🪧 <a href="https://codegeex.cn/playground" target="_blank">DEMO</a> | 🛠 <a href="https://marketplace.visualstudio.com/items?itemName=aminer.codegeex" target="_blank">VS Code</a> or <a href="https://plugins.jetbrains.com/plugin/20587-codegeex" target="_blank">Jetbrains</a> Extensions | 💻 <a href="https://github.com/THUDM/CodeGeeX" target="_blank">Source code</a> | 🤖 <a href="https://models.aminer.cn/codegeex/download/request" target="_blank">Download Model</a>
|
|
</p>
|
|
""")
|
|
gr.Markdown(
|
|
"""
|
|
We introduce CodeGeeX, a large-scale multilingual code generation model with 13 billion parameters, pre-trained on a large code corpus of more than 20 programming languages. CodeGeeX supports 15+ programming languages for both code generation and translation. CodeGeeX is open source, please refer to our [GitHub](https://github.com/THUDM/CodeGeeX) for more details. This is a minimal-functional DEMO, for other DEMOs like code translation, please visit our [Homepage](https://codegeex.cn). We also offer free [VS Code](https://marketplace.visualstudio.com/items?itemName=aminer.codegeex) or [Jetbrains](https://plugins.jetbrains.com/plugin/20587-codegeex) extensions for full functionality.
|
|
""")
|
|
|
|
with gr.Row():
|
|
with gr.Column():
|
|
prompt = gr.Textbox(lines=13, placeholder='Please enter the description or select an example input below.',label='Input')
|
|
with gr.Row():
|
|
gen = gr.Button("Generate")
|
|
clr = gr.Button("Clear")
|
|
|
|
outputs = gr.Textbox(lines=15, label='Output')
|
|
|
|
gr.Markdown(
|
|
"""
|
|
Generation Parameter
|
|
""")
|
|
with gr.Row():
|
|
with gr.Column():
|
|
lang = gr.Radio(
|
|
choices=["C++", "C", "C#", "Python", "Java", "HTML", "PHP", "JavaScript", "TypeScript", "Go",
|
|
"Rust",
|
|
"SQL", "Kotlin", "R", "Fortran"], value='lang', label='Programming Language',
|
|
default="Python")
|
|
with gr.Column():
|
|
seed = gr.Slider(maximum=10000, value=8888, step=1, label='Seed')
|
|
with gr.Row():
|
|
out_seq_length = gr.Slider(maximum=1024, value=128, minimum=1, step=1, label='Output Sequence Length')
|
|
temperature = gr.Slider(maximum=1, value=0.2, minimum=0, label='Temperature')
|
|
with gr.Row():
|
|
top_k = gr.Slider(maximum=40, value=0, minimum=0, step=1, label='Top K')
|
|
top_p = gr.Slider(maximum=1, value=1.0, minimum=0, label='Top P')
|
|
|
|
inputs = [prompt, lang, seed, out_seq_length, temperature, top_k, top_p]
|
|
gen.click(fn=predict, inputs=inputs, outputs=outputs)
|
|
clr.click(fn=lambda value: gr.update(value=""), inputs=clr, outputs=prompt)
|
|
|
|
gr_examples = gr.Examples(examples=examples, inputs=[prompt, lang],
|
|
label="Example Inputs (Click to insert an examplet it into the input box)",
|
|
examples_per_page=20)
|
|
|
|
demo.launch(server_port=6007)
|
|
|
|
if __name__ == '__main__':
|
|
with torch.no_grad():
|
|
main() |