diff --git a/deployment/server_gradio.py b/deployment/server_gradio.py
new file mode 100644
index 0000000..96eb770
--- /dev/null
+++ b/deployment/server_gradio.py
@@ -0,0 +1,198 @@
+import json
+import torch
+import argparse
+import gradio as gr
+
+import codegeex
+from codegeex.torch import CodeGeeXModel
+from codegeex.tokenizer import CodeGeeXTokenizer
+from codegeex.quantization import quantize
+from codegeex.data.data_utils import LANGUAGE_TAG
+from codegeex.megatron.inference import set_random_seed
+
+
+def model_provider(args):
+ """Build the model."""
+
+ model = CodeGeeXModel(
+ args.hidden_size,
+ args.num_layers,
+ args.num_attention_heads,
+ args.padded_vocab_size,
+ args.max_position_embeddings
+ )
+ return model
+
+
+def add_code_generation_args(parser):
+ group = parser.add_argument_group(title="code generation")
+ group.add_argument(
+ "--num-layers",
+ type=int,
+ default=39,
+ )
+ group.add_argument(
+ "--hidden-size",
+ type=int,
+ default=5120,
+ )
+ group.add_argument(
+ "--num-attention-heads",
+ type=int,
+ default=40,
+ )
+ group.add_argument(
+ "--padded-vocab-size",
+ type=int,
+ default=52224,
+ )
+ group.add_argument(
+ "--max-position-embeddings",
+ type=int,
+ default=2048,
+ )
+ group.add_argument(
+ "--tokenizer-path",
+ type=str,
+ default="./tokenizer",
+ )
+ group.add_argument(
+ "--example-path",
+ type=str,
+ default="./",
+ )
+ group.add_argument(
+ "--load",
+ type=str,
+ )
+ group.add_argument(
+ "--state-dict-path",
+ type=str,
+ )
+ group.add_argument(
+ "--micro-batch-size",
+ type=int,
+ default=1,
+ )
+ group.add_argument(
+ "--quantize",
+ action="store_true",
+ )
+
+ return parser
+
+
+def main():
+ parser = argparse.ArgumentParser()
+ parser = add_code_generation_args(parser)
+ args, _ = parser.parse_known_args()
+
+ print("Loading tokenizer ...")
+ tokenizer = CodeGeeXTokenizer(
+ tokenizer_path=args.tokenizer_path,
+ mode="codegeex-13b")
+
+ print("Loading state dict ...")
+ state_dict = torch.load(args.load, map_location="cpu")
+ state_dict = state_dict["module"]
+
+ print("Building CodeGeeX model ...")
+ model = model_provider(args)
+ model.load_state_dict(state_dict)
+ model.eval()
+ model.half()
+ if args.quantize:
+ model = quantize(model, weight_bit_width=8, backend="torch")
+ model.cuda()
+
+ def predict(
+ prompt,
+ lang,
+ seed,
+ out_seq_length,
+ temperature,
+ top_k,
+ top_p,
+ ):
+ set_random_seed(seed)
+ if lang.lower() in LANGUAGE_TAG:
+ prompt = LANGUAGE_TAG[lang.lower()] + "\n" + prompt
+
+ generated_code = codegeex.generate(
+ model,
+ tokenizer,
+ prompt,
+ out_seq_length=out_seq_length,
+ seq_length=args.max_position_embeddings,
+ top_k=top_k,
+ top_p=top_p,
+ temperature=temperature,
+ micro_batch_size=args.micro_batch_size,
+ backend="megatron",
+ verbose=True,
+ )
+ return prompt + generated_code
+
+ examples = []
+ with open(args.example_path, "r") as f:
+ for line in f:
+ examples.append(list(json.loads(line).values()))
+
+ with gr.Blocks() as demo:
+ gr.Markdown(
+ """
+
+ """)
+ gr.Markdown(
+ """
+
+ 🏠 Homepage | 📖 Blog | 🪧 DEMO | 🛠 VS Code or Jetbrains Extensions | 💻 Source code | 🤖 Download Model +
+ """) + gr.Markdown( + """ + We introduce CodeGeeX, a large-scale multilingual code generation model with 13 billion parameters, pre-trained on a large code corpus of more than 20 programming languages. CodeGeeX supports 15+ programming languages for both code generation and translation. CodeGeeX is open source, please refer to our [GitHub](https://github.com/THUDM/CodeGeeX) for more details. This is a minimal-functional DEMO, for other DEMOs like code translation, please visit our [Homepage](https://codegeex.cn). We also offer free [VS Code](https://marketplace.visualstudio.com/items?itemName=aminer.codegeex) or [Jetbrains](https://plugins.jetbrains.com/plugin/20587-codegeex) extensions for full functionality. + """) + + with gr.Row(): + with gr.Column(): + prompt = gr.Textbox(lines=13, placeholder='Please enter the description or select an example input below.',label='Input') + with gr.Row(): + gen = gr.Button("Generate") + clr = gr.Button("Clear") + + outputs = gr.Textbox(lines=15, label='Output') + + gr.Markdown( + """ + Generation Parameter + """) + with gr.Row(): + with gr.Column(): + lang = gr.Radio( + choices=["C++", "C", "C#", "Python", "Java", "HTML", "PHP", "JavaScript", "TypeScript", "Go", + "Rust", + "SQL", "Kotlin", "R", "Fortran"], value='lang', label='Programming Language', + default="Python") + with gr.Column(): + seed = gr.Slider(maximum=10000, value=8888, step=1, label='Seed') + with gr.Row(): + out_seq_length = gr.Slider(maximum=1024, value=128, minimum=1, step=1, label='Output Sequence Length') + temperature = gr.Slider(maximum=1, value=0.2, minimum=0, label='Temperature') + with gr.Row(): + top_k = gr.Slider(maximum=40, value=0, minimum=0, step=1, label='Top K') + top_p = gr.Slider(maximum=1, value=1.0, minimum=0, label='Top P') + + inputs = [prompt, lang, seed, out_seq_length, temperature, top_k, top_p] + gen.click(fn=predict, inputs=inputs, outputs=outputs) + clr.click(fn=lambda value: gr.update(value=""), inputs=clr, outputs=prompt) + + gr_examples = gr.Examples(examples=examples, inputs=[prompt, lang], + label="Example Inputs (Click to insert an examplet it into the input box)", + examples_per_page=20) + + demo.launch(server_port=6007) + +if __name__ == '__main__': + with torch.no_grad(): + main() \ No newline at end of file