diff --git a/README.md b/README.md
index 900c0a1..70a959f 100644
--- a/README.md
+++ b/README.md
@@ -1,37 +1,71 @@
- 🏠 Homepage | 📖 Blog | 🪧 DEMO | 🛠 VS Code Extension | 📃 Paper(Coming soon!) | 🌐 中文
+ 🏠 Homepage | 📖 Blog | 🪧 DEMO | 🤖 Download Model | 📃 Paper(Coming soon!) |
+
+
+ 🛠 VS Code Extension | 🌐 中文
+
+
+
+
+
-# CodeGeeX: A Multilingual Code Generative Model
-We introduce CodeGeeX, a large-scale multilingual code generative model with 13 billion parameters, pre-trained on a large code corpus of more than 20 programming languages. As of **June 22**, 2022, CodeGeeX has been trained on more than 850 billion tokens on a cluster of 1,536 [Ascend 910 AI Processors](https://e.huawei.com/en/products/servers/ascend). CodeGeeX has several unique features:
+# CodeGeeX: A Multilingual Code Generation Model
+
+We introduce CodeGeeX, a large-scale multilingual code generation model with 13 billion parameters, pre-trained on a large code corpus of more than 20 programming languages. As of **June 22**, 2022, CodeGeeX has been trained on more than 850 billion tokens on a cluster of 1,536 [Ascend 910 AI Processors](https://e.huawei.com/en/products/servers/ascend). CodeGeeX has several unique features:
* **Multilingual Code Generation**: CodeGeeX has good performance for generating executable programs in several mainstream programming languages, including Python, C++, Java, JavaScript, Go, etc. [DEMO](https://models.aminer.cn/codegeex)
* **Crosslingual Code Translation**: CodeGeeX supports the translation of code snippets between different languages. Simply by one click, CodeGeeX can transform a program into any expected language with a high accuracy. [DEMO](https://models.aminer.cn/codegeex/codeTranslator)
* **Customizable Programming Assistant**: CodeGeeX is available in the VS Code extension marketplace **for free**. It supports code completion, explanation, summarization and more, which empower users with a better coding experience. [VS Code Extension](https://marketplace.visualstudio.com/items?itemName=aminer.codegeex)
-* **Open-Source and Cross-Platform**: All codes and model weights will be made publicly available for research purposes. We have also been working on the adaptation to other GPU platforms, which will be ready soon.
+* **Open-Source and Cross-Platform**: All codes and model weights are publicly available for research purposes. CodeGeeX supports both Ascend and NVIDIA platforms. It supports inference in a single Ascend 910, NVIDIA V100 or A100. [Apply Model Weights](https://models.aminer.cn/codegeex/download/request)
-**HumanEval-X for Realistic Multilingual Benchmarking.** To help standardize the evaluation of multilingual code generation and translation, we develop and release the **HumanEval-X** Benchmark. HumanEval-X is a new multilingual benchmark that contains **820 human-crafted** coding problems in **5** programming languages (Python, C++, Java, JavaScript, and Go), each of these problems is associated with tests and solutions. [Usage](codegeex/benchmark/README.md)
+**HumanEval-X for Realistic Multilingual Benchmarking.** To help standardize the evaluation of multilingual code generation and translation, we develop and release the **HumanEval-X** Benchmark. HumanEval-X is a new multilingual benchmark that contains **820 human-crafted** coding problems in **5** programming languages (Python, C++, Java, JavaScript, and Go), each of these problems is associated with tests and solutions. [Usage](codegeex/benchmark/README.md) [🤗 Available in HuggingFace](https://huggingface.co/datasets/THUDM/humaneval-x)
CodeGeeX achieves the highest average performance compared with other open-sourced multilingual baselines.
+## News
+
+* **2022-09-30**: We release the cross-platform source code and models weghts for both Ascend and NVIDIA platforms.
+
## Getting Started
+CodeGeeX is initially implemented in Mindspore and trained Ascend 910 AI Processors. We provide a torch-compatible version based on [Megatron-LM](https://github.com/NVIDIA/Megatron-LM) to facilitate usage on GPU platforms.
### Installation
-Download and install ``codegeex`` via:
+Python 3.7+ / CUDA 11+ / PyTorch 1.10+ / DeepSpeed 0.6+ are required. Install ``codegeex`` package via:
```bash
git clone git@github.com:THUDM/CodeGeeX.git
cd CodeGeeX
pip install -e .
```
+### Model Weights
+
+Apply and download model weights through this [link](https://models.aminer.cn/codegeex/download/request). You'll receive by mail ```urls.txt``` that contains temporary download links. We recommend you to use [aria2](https://aria2.github.io/) to download it via the following command (Please make sure you have enough disk space to download the checkpoint (~26GB)):
+```bash
+aria2c -x 16 -s 16 -j 4 --continue=true -i urls.txt
+```
+Run the following command to get the full model weights:
+```bash
+cat codegeex_13b.tar.gz.part.* > codegeex_13b.tar
+tar xvf codegeex_13b.tar.gz
+```
+
### Inference on GPUs
-CodeGeeX is initially implemented in Mindspore and trained on Ascend 910 AI Processors. In addition to the support on the Ascend platform, we have been also working on adapting the model to other GPU platforms, which will be ready soon.
+Have a try on generating the first program with CodeGeeX. First, specify the path of the model weights in ``configs/codegeex_13b.sh``. Second, write the prompt (natural language description or code snippet) into a file, e.g., ``tests/test_prompt.txt``, then run the following script:
+```bash
+bash ./scripts/test_inference.sh ./tests/test_prompt.txt
+```
+
+### VS Code Extension Guidance
+
+Based on CodeGeeX, we also develop a free VS Code extention, search "codegeex" in Marketplace or install it [here](https://marketplace.visualstudio.com/items?itemName=aminer.codegeex). Detailed instructions can be found in
+[CodeGeeX Extension Guidance](vscode-extension/README.md).
## CodeGeeX: Architecture, Code Corpus, and Implementation
@@ -46,7 +80,7 @@ CodeGeeX is initially implemented in Mindspore and trained on Ascend 910 AI Proc
**Training**: We implement CodeGeeX in [Mindspore 1.7](https://www.mindspore.cn/) and train it on 1,536 Ascend 910 AI Processor (32GB). The model weights are under FP16 format, except that we use FP32 for layer-norm and softmax for higher precision and stability. The entire model consumes about 27GB of memory. To increase the training efficiency, we adopt an 8-way model parallel training together with 192-way data parallel training, with ZeRO-2 optimizer enabled. The micro-batch size is 16 and the global batch size reaches 3,072. Moreover, we adopt techniques to further boost the training efficiency including the element-wise operator fusion, fast gelu activation, matrix multiplication dimension optimization, etc. The entire training process takes nearly two months, spanning from April 18 to June 22, 2022, during which 850B tokens were passed for training, i.e., 5+ epochs.
## HumanEval-X: A new benchmark for Multilingual Program Synthesis
-To better evaluate the multilingual ability of code generative models, we propose a new benchmark HumanEval-X. While previous works evaluate multilingual program synthesis under semantic similarity (e.g., [CodeBLEU](https://arxiv.org/abs/2009.10297)) which is often misleading, HumanEval-X evaluates the functional correctness of the generated programs. HumanEval-X consists of 820 high-quality human-crafted data samples (each with test cases) in Python, C++, Java, JavaScript, and Go, and can be used for various tasks.
+To better evaluate the multilingual ability of code generation models, we propose a new benchmark HumanEval-X. While previous works evaluate multilingual program synthesis under semantic similarity (e.g., [CodeBLEU](https://arxiv.org/abs/2009.10297)) which is often misleading, HumanEval-X evaluates the functional correctness of the generated programs. HumanEval-X consists of 820 high-quality human-crafted data samples (each with test cases) in Python, C++, Java, JavaScript, and Go, and can be used for various tasks.
@@ -60,7 +94,7 @@ In HumanEval-X, every sample in each language contains declaration, docstring, a
Left: the detailed pass@k (k=1,10,100) performance on code generation task for five languages in HumanEval-X. Right: the average performance of all languages of each model. CodeGeeX achieves the highest average performance compared with InCoder-6.7B, CodeGen-Multi-6B and CodeGen-Multi-16B.
-We compare CodeGeeX with two other open-sourced code generative models, [InCoder](https://github.com/dpfried/incoder) (from Meta) and [CodeGen](https://github.com/salesforce/CodeGen) (from Salesforce). Specifically, InCoder-6.7B, CodeGen-Multi-6B and CodeGen-Multi-16B are considered. CodeGeeX significantly outperforms models with smaller scales (by 7.5%~16.3%) and is competitive with CodeGen-Multi-16B with a larger scale (average performance 54.76% vs. 54.39%). CodeGeeX achieves the best average performance across languages. We further investigate the effect of distributing sampling budgets to different languages. Using a simple heuristic that distributes budgets weighted by the training data distribution, CodeGeeX achieves a higher pass rate than any single language (indicated by the red dotted circle).
+We compare CodeGeeX with two other open-sourced code generation models, [InCoder](https://github.com/dpfried/incoder) (from Meta) and [CodeGen](https://github.com/salesforce/CodeGen) (from Salesforce). Specifically, InCoder-6.7B, CodeGen-Multi-6B and CodeGen-Multi-16B are considered. CodeGeeX significantly outperforms models with smaller scales (by 7.5%~16.3%) and is competitive with CodeGen-Multi-16B with a larger scale (average performance 54.76% vs. 54.39%). CodeGeeX achieves the best average performance across languages.
### Crosslingual Code Translation
diff --git a/README_zh.md b/README_zh.md
index 17ad6f8..30fb02a 100644
--- a/README_zh.md
+++ b/README_zh.md
@@ -1,16 +1,25 @@
- 🏠 主页 | 📖 博客 | 🪧 示例 | 🛠 VS Code插件 | 📃 论文(即将推出!)| 🌐 English
+ 🏠 主页 | 📖 博客 | 🪧 示例 | 🤖 模型下载 | 📃 论文(即将推出!)
+
+
+ 🛠 VS Code插件 | 📒 API申请 | ⌨️ 加入开发者群 | 🌐 English
+
+
+
+
+
+
# CodeGeeX: 多语言代码生成模型
CodeGeeX是一个具有130亿参数的多编程语言代码生成预训练模型。CodeGeeX采用华为MindSpore框架实现,在鹏城实验室“鹏城云脑II”中的192个节点(共1536个国产[昇腾910 AI处理器](https://e.huawei.com/cn/products/servers/ascend))上训练而成。截至2022年6月22日,CodeGeeX历时两个月在20多种编程语言的代码语料库(>8500亿Token)上预训练得到。CodeGeeX有以下特点:
* **高精度代码生成**:支持生成Python、C++、Java、JavaScript和Go等多种主流编程语言的代码,在HumanEval-X代码生成任务上取得47%~60%求解率,较其他开源基线模型有更佳的平均性能。[代码生成示例](https://models.aminer.cn/codegeex/zh-CN)
* **跨语言代码翻译**:支持代码片段在不同编程语言间进行自动翻译转换,翻译结果正确率高,在HumanEval-X代码翻译任务上超越了其它基线模型。[代码翻译示例](https://models.aminer.cn/codegeex/zh-CN/codeTranslator)
* **自动编程插件**:CodeGeeX插件现已上架VSCode插件市场(完全免费),用户可以通过其强大的少样本生成能力,自定义代码生成风格和能力,更好辅助代码编写。[插件下载](https://marketplace.visualstudio.com/items?itemName=aminer.codegeex)
-* **模型跨平台开源**: 所有代码和模型权重将会开源,用作研究用途。我们正在适配除昇腾外的其它平台,并将在短期内开源。
+* **模型跨平台开源**: 所有代码和模型权重开源开放,用作研究用途。CodeGeeX同时支持昇腾和英伟达平台,可在单张昇腾910或英伟达V100/A100上实现推理。[申请模型权重](https://models.aminer.cn/codegeex/download/request)
**全新多编程语言评测基准HumanEval-X**:HumanEval-X是第一个支持功能正确性评测的多语言、多任务的基准,包含820个人工编写的高质量代码生成题目、测试用例与参考答案,覆盖5种编程语言(Python、C++、Java、JavaScript、Go),支持代码生成与代码翻译能力的评测。[如何使用](codegeex/benchmark/README_zh.md)
@@ -18,20 +27,44 @@ CodeGeeX是一个具有130亿参数的多编程语言代码生成预训练模型
在HumanEval-X代码生成任务上,与其它开源基线模型相比,CodeGeeX取得了最佳的平均性能。
+## 新闻
+
+* **2022-09-30**: 我们开源了跨平台代码和模型权重,同时支持昇腾和英伟达平台。
## 使用指南
+CodeGeeX最初使用Mindspore框架实现,并在昇腾910AI芯片上进行训练。为适配更多平台,我们将其转换到[Megatron-LM](https://github.com/NVIDIA/Megatron-LM)框架,支持Pytorch+GPU环境。
### 安装
-通过以下命令安装 ``codegeex``:
+需要Python 3.7+ / CUDA 11+ / PyTorch 1.10+ / DeepSpeed 0.6+,通过以下命令安装 ``codegeex``:
```bash
git clone git@github.com:THUDM/CodeGeeX.git
cd CodeGeeX
pip install -e .
```
-### 在GPU上进行推理
+### 模型权重
+
+通过[该链接](https://models.aminer.cn/codegeex/download/request)申请权重,您将收到一个包含临时下载链接文件```urls.txt```的邮件。推荐使用[aria2](https://aria2.github.io/)通过以下命令快速下载(请保证有足够的硬盘空间存放权重(~26GB)):
+```bash
+aria2c -x 16 -s 16 -j 4 --continue=true -i urls.txt
+```
+使用以下命令合并得到完整的权重:
+```bash
+cat codegeex_13b.tar.gz.part.* > codegeex_13b.tar
+tar xvf codegeex_13b.tar.gz
+```
+
+### 用GPU进行推理
+
+尝试使用CodeGeeX模型生成第一个程序吧!首先,在配置文件``configs/codegeex_13b.sh``中写明存放权重的路径。其次,将提示(可以是任意描述或代码片段)写入文件``tests/test_prompt.txt``,运行以下脚本即可开始推理(需指定GPU序号):
+```bash
+bash ./scripts/test_inference.sh ./tests/test_prompt.txt
+```
+
+### VS Code插件使用指南
+
+基于CodeGeeX,我们开发了一款免费的VS Code插件,在应用市场搜索“codegeex”或通过[该链接](https://marketplace.visualstudio.com/items?itemName=aminer.codegeex)安装。详细的使用指南在[CodeGeeX插件使用指南](vscode-extension/README_zh.md).
-CodeGeeX最初使用Mindspore框架实现,并在昇腾910AI芯片上进行训练。为了支持更多的平台,我们正在迁移代码和模型,并将在近期内开源。
## CodeGeeX: 多语言代码生成模型
@@ -65,7 +98,7 @@ HumanEval-X中每个语言的样本,包含了声明、描述和解答,它们
左侧: HumanEval-X中五种语言具体的pass@k(k=1,10,100)性能。右侧: 模型在所有语言上的平均性能。CodeGeeX的平均表现优于InCoder-6.7B和CodeGen-Multi-6B/16B。
-我们将CodeGeeX与另外两个开源代码生成模型进行比较,分别为Meta的[InCoder](https://github.com/dpfried/incoder)与Salesforce的[CodeGen](https://github.com/salesforce/CodeGen),选取InCoder-6.7B、CodeGen-Multi-6B 与 CodeGen-Multi-16B。CodeGeeX能获得最佳的平均性能,显著超越了参数量更小的模型(7.5%~16.3%的提升),与参数量更大的模型CodeGen-Multi-16B表现相当(平均性能 54.76% vs. 54.39%)。除此之外,我们还探索了将生成次数分配给不同语言的效果,在按照训练数据比例分配生成次数时,CodeGeeX的正确率比其在任一单语言下的正确率都更高(如红色虚线圈所示)。
+我们将CodeGeeX与另外两个开源代码生成模型进行比较,分别为Meta的[InCoder](https://github.com/dpfried/incoder)与Salesforce的[CodeGen](https://github.com/salesforce/CodeGen),选取InCoder-6.7B、CodeGen-Multi-6B 与 CodeGen-Multi-16B。CodeGeeX能获得最佳的平均性能,显著超越了参数量更小的模型(7.5%~16.3%的提升),与参数量更大的模型CodeGen-Multi-16B表现相当(平均性能 54.76% vs. 54.39%)。
### 跨语言代码翻译
@@ -123,3 +156,5 @@ HumanEval-X中每个语言的样本,包含了声明、描述和解答,它们
[唐杰](http://keg.cs.tsinghua.edu.cn/jietang/)(清华大学知识工程实验室 & 北京智源人工智能研究院)
+
+如果遇到问题或有任何建议,欢迎通过邮件与我们联系[codegeex@aminer.cn](mailto:codegeex@aminer.cn).
\ No newline at end of file
diff --git a/api/README_zh.md b/api/README_zh.md
new file mode 100644
index 0000000..5414654
--- /dev/null
+++ b/api/README_zh.md
@@ -0,0 +1,57 @@
+
+
+# 创建CodeGeeX API
+
+使用[天启 · API开放平台](https://tianqi.aminer.cn/open/)申请CodeGeeX API:
+
+
+
+点击首页中的天启平台体验入口:
+
+点击API应用:
+
+输入任意名称,创建API应用。创建后会得到API Key/Secret,用于调用API:
+
+
+在API信息中,可以查看代码生成/代码翻译的请求地址和使用文档:
+
+
+根据文档中的描述使用API,参考文件``api/generation_example.py``:
+
+```python
+# encoding:utf-8
+
+import requests
+import json
+
+'''
+Code Generation
+'''
+API_KEY = "" # Get from Tianqi console. 从控制台获取
+API_SECRET = "" # Get from Tianqi console. 从控制台获取
+PROMPT = "from typing import List\n\ndef has_close_elements(numbers: List[float], threshold: float) -> bool:\n \"\"\" Check if in given list of numbers, are any two numbers closer to each other than\n given threshold.\n >>> has_close_elements([1.0, 2.0, 3.0], 0.5)\n False\n >>> has_close_elements([1.0, 2.8, 3.0, 4.0, 5.0, 2.0], 0.3)\n True\n \"\"\"\n"
+NUMBER = 3
+LANG = "Python"
+request_url = "https://tianqi.aminer.cn/api/v2/"
+api = 'multilingual_code_generate'
+
+# Request is in json format. 指定请求参数格式为json
+headers = {'Content-Type': 'application/json'}
+request_url = request_url + api
+data = {
+ "apikey": API_KEY,
+ "apisecret": API_SECRET,
+ "prompt":PROMPT,
+ "n":NUMBER,
+ "lang":LANG
+}
+
+def main():
+ response = requests.post(request_url, headers=headers, data=json.dumps(data))
+ if response:
+ print(response.json())
+
+if __name__ == '__main__':
+ main()
+```
+
diff --git a/api/generation_example.py b/api/generation_example.py
new file mode 100644
index 0000000..8b8e6b7
--- /dev/null
+++ b/api/generation_example.py
@@ -0,0 +1,34 @@
+# encoding:utf-8
+
+import requests
+import json
+
+'''
+Code Generation
+'''
+API_KEY = "" # Get from Tianqi console. 从控制台获取
+API_SECRET = "" # Get from Tianqi console. 从控制台获取
+PROMPT = "from typing import List\n\ndef has_close_elements(numbers: List[float], threshold: float) -> bool:\n \"\"\" Check if in given list of numbers, are any two numbers closer to each other than\n given threshold.\n >>> has_close_elements([1.0, 2.0, 3.0], 0.5)\n False\n >>> has_close_elements([1.0, 2.8, 3.0, 4.0, 5.0, 2.0], 0.3)\n True\n \"\"\"\n"
+NUMBER = 3
+LANG = "Python"
+request_url = "https://tianqi.aminer.cn/api/v2/"
+api = 'multilingual_code_generate'
+
+# Request is in json format. 指定请求参数格式为json
+headers = {'Content-Type': 'application/json'}
+request_url = request_url + api
+data = {
+ "apikey": API_KEY,
+ "apisecret": API_SECRET,
+ "prompt":PROMPT,
+ "n":NUMBER,
+ "lang":LANG
+}
+
+def main():
+ response = requests.post(request_url, headers=headers, data=json.dumps(data))
+ if response:
+ print(response.json())
+
+if __name__ == '__main__':
+ main()
diff --git a/codegeex/benchmark/README.md b/codegeex/benchmark/README.md
index 5c4dc67..ec0391c 100644
--- a/codegeex/benchmark/README.md
+++ b/codegeex/benchmark/README.md
@@ -2,7 +2,7 @@
🌐 中文
-HumanEval-X is a new benchmark for better evaluating the multilingual ability of code generative models. While previous works evaluate multilingual program synthesis under semantic similarity (e.g., [CodeBLEU](https://arxiv.org/abs/2009.10297)) which is often misleading, HumanEval-X evaluates the functional correctness of the generated programs. HumanEval-X consists of 820 high-quality human-crafted data samples (each with test cases) in Python, C++, Java, JavaScript, and Go, and can be used for various tasks.
+HumanEval-X is a new benchmark for better evaluating the multilingual ability of code generation models. While previous works evaluate multilingual program synthesis under semantic similarity (e.g., [CodeBLEU](https://arxiv.org/abs/2009.10297)) which is often misleading, HumanEval-X evaluates the functional correctness of the generated programs. HumanEval-X consists of 820 high-quality human-crafted data samples (each with test cases) in Python, C++, Java, JavaScript, and Go, and can be used for various tasks.
diff --git a/codegeex/benchmark/humaneval-x/budget_distribution/evaluate_budget_distribution.py b/codegeex/benchmark/humaneval-x/budget_distribution/evaluate_budget_distribution.py
deleted file mode 100644
index 25e5ec8..0000000
--- a/codegeex/benchmark/humaneval-x/budget_distribution/evaluate_budget_distribution.py
+++ /dev/null
@@ -1,139 +0,0 @@
-# This file is for evaluating the budget distribution method
-import json
-import numpy as np
-
-w = json.load(open("solve_rate_final.jsonl", 'r'))
-
-
-def build_chart():
- fa = np.ones((201, 201))
-
- for i in range(1, 201):
- for j in range(201):
- fa[j, i] = fa[j, i - 1] * (201 - j - i) / (201 - i)
-
- return fa
-
-
-languages = ['cpp', 'go', 'java', 'python', 'js']
-models = ['codegeex', 'codegen16b', 'codegen6b', 'incoder']
-fa = build_chart()
-
-
-def compute(l, dist):
- budgets = []
- alldists = []
- for i in range(2, 41):
- budgets.append(i * 5)
- alldists.append(distribute(dist, i * 5))
- # print(alldists)
- sums = np.zeros(39)
- sumdists = np.zeros(39)
- sumop = np.zeros((39, 5))
- summax = np.zeros(39)
- for i in range(164):
- currents = np.ones(39)
- currentdists = np.ones(39)
- currentops = np.ones((39, 5))
-
- for w in range(5):
- num = int(l[w][i])
- for j in range(39):
- currents[j] = currents[j] * fa[j + 2, num]
-
- currentdists[j] = currentdists[j] * fa[alldists[j][w], num]
-
- currentops[j, w] = fa[(j + 2) * 5, num]
-
- sums = sums + (1 - currents)
- sumdists = sumdists + (1 - currentdists)
- sumop = sumop + (1 - currentops)
- summax = summax + (1 - np.min(currentops, axis=1))
-
- sumop = np.max(sumop, axis=1)
- return sums / 164, sumdists / 164, sumop / 164, summax / 164
-
-
-def distribute(distribution, budget):
- sum = np.sum(distribution)
- di = np.array(distribution) / sum * budget
- dis = []
- diff = []
- for i in range(len(di)):
- dis.append(int(di[i]))
- diff.append(dis[i] - di[i])
- # overflow assignment
- need = np.sum(dis) - budget
- while need > 0:
- g = np.argmax(diff)
- dis[g] -= 1
- diff[g] -= 1
- need -= 1
- while need < 0:
- g = np.argmin(diff)
- dis[g] += 1
- diff[g] += 1
- need += 1
- return dis
-
-
-names = []
-for i in range(39):
- names.append(str((i + 2) * 5) + " uniform")
-for i in range(39):
- names.append(str((i + 2) * 5) + " weighted")
-for i in range(39):
- names.append(str((i + 2) * 5) + " best")
-for i in range(39):
- names.append(str((i + 2) * 5) + " max")
-
-out = open("solution_output.txt", 'w')
-for model in models:
- if 'codegeex' in model:
- dist = [33, 6, 20, 32, 9]
- if 'codegen' in model:
- dist = [38, 8, 29, 17, 8]
- if 'incoder' in model:
- dist = [12, 4, 5, 45, 34]
- avi_list = {}
- for pp in w:
- if (np.sum(w[pp]) > 1500):
- if model in pp:
- for l in languages:
- if l in pp.replace('javascript', 'js'):
- if l in avi_list:
- avi_list[l].append(pp)
- else:
- avi_list[l] = [pp]
- # print(avi_list)
- maxsums = np.zeros(len(names))
- maxsumscomb = np.zeros((len(names), 5))
- current_marker = [0, 0, 0, 0, 0]
- while current_marker[0] < len(avi_list[languages[0]]):
- aclist = []
- for i in range(5):
- aclist.append(w[avi_list[languages[i]][current_marker[i]]])
- sums, sumdists, sumop, summax = compute(aclist, dist)
- things = np.concatenate((sums, sumdists, sumop, summax))
- for i in range(len(names)):
- if (things[i] > maxsums[i]):
- # print(names[i],things[i],current_marker)
- maxsums[i] = things[i]
- maxsumscomb[i] = current_marker
-
- current_marker[-1] += 1
- p = 4
- while (current_marker[p] >= len(avi_list[languages[p]]) and p > 0):
- current_marker[p] = 0
- current_marker[p - 1] += 1
- p -= 1
-
- print(model)
- print(model, file=out)
- for i in range(len(names)):
- print(names[i], maxsums[i], maxsumscomb[i])
- print(names[i], maxsums[i], file=out)
- # use the best of mix100 for further purposes
- for i in range(5):
- print(languages[i], avi_list[languages[i]][int(maxsumscomb[2, i])])
-out.close()
diff --git a/codegeex/benchmark/humaneval-x/budget_distribution/extract_solverate.py b/codegeex/benchmark/humaneval-x/budget_distribution/extract_solverate.py
deleted file mode 100644
index 6078461..0000000
--- a/codegeex/benchmark/humaneval-x/budget_distribution/extract_solverate.py
+++ /dev/null
@@ -1,37 +0,0 @@
-# This file is for gathering the solve rates from generated files
-import json
-import os
-
-import numpy as np
-
-language = ['cpp', 'java', 'js', 'python', 'go']
-repo = ""
-
-all_reps = os.listdir(repo)
-
-# choose the ones
-all_passes = {}
-assignment = [33, 6, 20, 32, 9]
-assignment = [38, 8, 29, 17, 8]
-assignment = [12, 4, 5, 45, 34]
-for folder in all_reps:
- if not ("." in folder):
- q = os.listdir(repo + '/' + folder)
- for f in q:
- if 'result' in f and not ('example' in f):
- passed = np.zeros(164)
- all_p = 0
- fi = open(repo + '/' + folder + '/' + f, 'r')
- t = fi.readlines()
- for l in t:
- if len(l.strip()) == 0:
- continue
- qq = json.loads(l)
- if qq['passed'] == True:
- id = int(qq['task_id'].split('/')[1])
- passed[id] += 1
- all_p += 1
- all_passes[f] = list(passed)
- print(f, all_p)
-
-json.dump(all_passes, open('solve_rate_final.jsonl', 'w'))
diff --git a/codegeex/megatron/__init__.py b/codegeex/megatron/__init__.py
new file mode 100644
index 0000000..b068896
--- /dev/null
+++ b/codegeex/megatron/__init__.py
@@ -0,0 +1,47 @@
+# coding=utf-8
+# Copyright (c) 2020, NVIDIA CORPORATION. All rights reserved.
+#
+# Licensed under the Apache License, Version 2.0 (the "License");
+# you may not use this file except in compliance with the License.
+# You may obtain a copy of the License at
+#
+# http://www.apache.org/licenses/LICENSE-2.0
+#
+# Unless required by applicable law or agreed to in writing, software
+# distributed under the License is distributed on an "AS IS" BASIS,
+# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
+# See the License for the specific language governing permissions and
+# limitations under the License.
+import torch
+
+from .global_vars import get_args
+from .global_vars import get_current_global_batch_size
+from .global_vars import get_num_microbatches
+from .global_vars import update_num_microbatches
+from .global_vars import get_tokenizer
+from .global_vars import get_tensorboard_writer
+from .global_vars import get_adlr_autoresume
+from .global_vars import get_timers
+from .initialize import initialize_megatron
+
+
+def print_rank_0(message):
+ """If distributed is initialized, print only on rank 0."""
+ if torch.distributed.is_initialized():
+ if torch.distributed.get_rank() == 0:
+ print(message, flush=True)
+ else:
+ print(message, flush=True)
+
+
+def is_last_rank():
+ return torch.distributed.get_rank() == (torch.distributed.get_world_size() - 1)
+
+
+def print_rank_last(message):
+ """If distributed is initialized, print only on last rank."""
+ if torch.distributed.is_initialized():
+ if is_last_rank():
+ print(message, flush=True)
+ else:
+ print(message, flush=True)
diff --git a/codegeex/megatron/arguments.py b/codegeex/megatron/arguments.py
new file mode 100644
index 0000000..e63b7be
--- /dev/null
+++ b/codegeex/megatron/arguments.py
@@ -0,0 +1,1528 @@
+# coding=utf-8
+# Copyright (c) 2020, NVIDIA CORPORATION. All rights reserved.
+#
+# Licensed under the Apache License, Version 2.0 (the "License");
+# you may not use this file except in compliance with the License.
+# You may obtain a copy of the License at
+#
+# http://www.apache.org/licenses/LICENSE-2.0
+#
+# Unless required by applicable law or agreed to in writing, software
+# distributed under the License is distributed on an "AS IS" BASIS,
+# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
+# See the License for the specific language governing permissions and
+# limitations under the License.
+
+"""Megatron arguments."""
+
+import argparse
+import os
+
+import torch
+import deepspeed
+
+
+def parse_args(extra_args_provider=None, defaults={}, ignore_unknown_args=False):
+ """Parse all arguments."""
+ parser = argparse.ArgumentParser(
+ description="Megatron-LM Arguments", allow_abbrev=False
+ )
+
+ # Standard arguments.
+ parser = _add_network_size_args(parser)
+ parser = _add_regularization_args(parser)
+ parser = _add_training_args(parser)
+ parser = _add_initialization_args(parser)
+ parser = _add_learning_rate_args(parser)
+ parser = _add_checkpointing_args(parser)
+ parser = _add_mixed_precision_args(parser)
+ parser = _add_distributed_args(parser)
+ parser = _add_validation_args(parser)
+ parser = _add_data_args(parser)
+ parser = _add_autoresume_args(parser)
+ parser = _add_biencoder_args(parser)
+ parser = _add_vit_args(parser)
+ parser = _add_logging_args(parser)
+ parser = _add_zero_args(parser)
+ parser = _add_memoryopt_args(parser)
+ parser = _add_activation_checkpoint_args(parser)
+ parser = _add_inference_args(parser)
+
+ # Custom arguments.
+ if extra_args_provider is not None:
+ parser = extra_args_provider(parser)
+
+ parser = deepspeed.add_config_arguments(parser)
+
+ # Parse.
+ if ignore_unknown_args:
+ args, _ = parser.parse_known_args()
+ else:
+ args = parser.parse_args()
+
+ # helper argument to set deepspeed pipeline parallel or not
+ args.ds_pipeline_enabled = not args.no_pipeline_parallel
+
+ # Distributed args.
+ args.rank = int(os.getenv("RANK", "0"))
+ args.world_size = int(os.getenv("WORLD_SIZE", "1"))
+ # Tensor model parallel size.
+ args.tensor_model_parallel_size = min(
+ args.tensor_model_parallel_size, args.world_size
+ )
+ assert (
+ args.world_size % args.tensor_model_parallel_size == 0
+ ), "world size" " ({}) is not divisible by tensor model parallel size ({})".format(
+ args.world_size, args.tensor_model_parallel_size
+ )
+ # Pipeline model parallel size.
+ args.pipeline_model_parallel_size = min(
+ args.pipeline_model_parallel_size,
+ (args.world_size // args.tensor_model_parallel_size),
+ )
+ # Checks.
+ if args.no_pipeline_parallel:
+ assert (
+ args.pipeline_model_parallel_size == 1
+ ), "pipeline_model_parallel_size must be 1 if pipeline parallel is disabled"
+ model_parallel_size = (
+ args.pipeline_model_parallel_size * args.tensor_model_parallel_size
+ )
+ assert args.world_size % model_parallel_size == 0, (
+ "world size is not"
+ " divisible by tensor parallel size ({}) times pipeline parallel "
+ "size ({})".format(
+ args.world_size,
+ args.tensor_model_parallel_size,
+ args.pipeline_model_parallel_size,
+ )
+ )
+ args.data_parallel_size = args.world_size // model_parallel_size
+ if args.rank == 0:
+ print(
+ "using world size: {}, data-parallel-size: {}, "
+ "tensor-model-parallel size: {}, "
+ "pipeline-model-parallel size: {} ".format(
+ args.world_size,
+ args.data_parallel_size,
+ args.tensor_model_parallel_size,
+ args.pipeline_model_parallel_size,
+ ),
+ flush=True,
+ )
+
+ # Deprecated arguments
+ assert args.batch_size is None, (
+ "--batch-size argument is no longer " "valid, use --micro-batch-size instead"
+ )
+ del args.batch_size
+ assert args.warmup is None, (
+ "--warmup argument is no longer valid, use " "--lr-warmup-fraction instead"
+ )
+ del args.warmup
+ assert args.model_parallel_size is None, (
+ "--model-parallel-size is no "
+ "longer valid, use --tensor-model-parallel-size instead"
+ )
+ del args.model_parallel_size
+
+ # Set input defaults.
+ for key in defaults:
+ # For default to be valid, it should not be provided in the
+ # arguments that are passed to the program. We check this by
+ # ensuring the arg is set to None.
+ if getattr(args, key) is not None:
+ if args.force_default:
+ print(
+ "WARNING: overriding arguments for {key}:{v2} \
+ with default {key}:{v}".format(
+ key=key, v=defaults[key], v2=getattr(args, key)
+ ),
+ flush=True,
+ )
+ setattr(args, key, defaults[key])
+ else:
+ if args.rank == 0:
+ print(
+ "WARNING: overriding default arguments for {key}:{v} \
+ with {key}:{v2}".format(
+ key=key, v=defaults[key], v2=getattr(args, key)
+ ),
+ flush=True,
+ )
+ else:
+ setattr(args, key, defaults[key])
+
+ # Batch size.
+ assert args.micro_batch_size is not None
+ assert args.micro_batch_size > 0
+ if args.global_batch_size is None:
+ args.global_batch_size = args.micro_batch_size * args.data_parallel_size
+ if args.rank == 0:
+ print(
+ "setting global batch size to {}".format(args.global_batch_size),
+ flush=True,
+ )
+ assert args.global_batch_size > 0
+ if args.num_layers_per_virtual_pipeline_stage is not None:
+ assert args.pipeline_model_parallel_size > 2, (
+ "pipeline-model-parallel size should be greater than 2 with "
+ "interleaved schedule"
+ )
+ assert args.num_layers % args.num_layers_per_virtual_pipeline_stage == 0, (
+ "number of layers is not divisible by number of layers per virtual "
+ "pipeline stage"
+ )
+ args.virtual_pipeline_model_parallel_size = (
+ args.num_layers // args.pipeline_model_parallel_size
+ ) // args.num_layers_per_virtual_pipeline_stage
+ else:
+ args.virtual_pipeline_model_parallel_size = None
+
+ # Parameters dtype.
+ args.params_dtype = torch.float
+ if args.fp16:
+ assert not args.bf16
+ args.params_dtype = torch.half
+ if args.bf16:
+ assert not args.fp16
+ args.params_dtype = torch.bfloat16
+ # bfloat16 requires gradient accumulation and all-reduce to
+ # be done in fp32.
+ if not args.accumulate_allreduce_grads_in_fp32:
+ args.accumulate_allreduce_grads_in_fp32 = True
+ if args.rank == 0:
+ print(
+ "accumulate and all-reduce gradients in fp32 for "
+ "bfloat16 data type.",
+ flush=True,
+ )
+
+ if args.rank == 0:
+ print("using {} for parameters ...".format(args.params_dtype), flush=True)
+
+ # If we do accumulation and all-reduces in fp32, we need to have
+ # local DDP and we should set the use-contiguous-buffers-in-ddp.
+ if args.accumulate_allreduce_grads_in_fp32:
+ assert args.DDP_impl == "local"
+ args.use_contiguous_buffers_in_ddp = True
+
+ if args.dataloader_type is None:
+ args.dataloader_type = "single"
+
+ # Consumed tokens.
+ args.consumed_train_samples = 0
+ args.consumed_valid_samples = 0
+ args.consumed_train_tokens = 0
+
+ # Iteration-based training.
+ if args.train_iters:
+ # If we use iteration-based training, make sure the
+ # sample-based options are off.
+ assert args.train_samples is None, "expected iteration-based training"
+ assert (
+ args.lr_decay_samples is None
+ ), "expected iteration-based learning rate decay"
+ assert (
+ args.lr_warmup_samples == 0
+ ), "expected iteration-based learning rate warmup"
+ assert (
+ args.rampup_batch_size is None
+ ), "expected no batch-size rampup for iteration-based training"
+ if args.lr_warmup_fraction is not None:
+ assert (
+ args.lr_warmup_iters == 0
+ ), "can only specify one of lr-warmup-fraction and lr-warmup-iters"
+
+ # Sample-based training.
+ if args.train_samples:
+ # If we use sample-based training, make sure the
+ # iteration-based options are off.
+ assert args.train_iters is None, "expected sample-based training"
+ assert args.lr_decay_iters is None, "expected sample-based learning rate decay"
+ assert args.lr_warmup_iters == 0, "expected sample-based learnig rate warmup"
+ if args.lr_warmup_fraction is not None:
+ assert args.lr_warmup_samples == 0, (
+ "can only specify one of lr-warmup-fraction " "and lr-warmup-samples"
+ )
+
+ # Check required arguments.
+ required_args = [
+ "num_layers",
+ "hidden_size",
+ "num_attention_heads",
+ "max_position_embeddings",
+ ]
+ for req_arg in required_args:
+ _check_arg_is_not_none(args, req_arg)
+
+ # args.learned_position_embeddings = args.learned_position_embeddings > 0
+
+ # Checks.
+ if args.ffn_hidden_size is None:
+ args.ffn_hidden_size = 4 * args.hidden_size
+
+ if args.kv_channels is None:
+ assert args.hidden_size % args.num_attention_heads == 0
+ args.kv_channels = args.hidden_size // args.num_attention_heads
+
+ if args.seq_length is not None:
+ assert args.encoder_seq_length is None
+ args.encoder_seq_length = args.seq_length
+ else:
+ assert args.encoder_seq_length is not None
+ args.seq_length = args.encoder_seq_length
+
+ if args.seq_length is not None:
+ assert args.max_position_embeddings >= args.seq_length
+ if args.decoder_seq_length is not None:
+ assert args.max_position_embeddings >= args.decoder_seq_length
+ if args.lr is not None:
+ assert args.min_lr <= args.lr
+ if args.save is not None:
+ assert args.save_interval is not None
+ # Mixed precision checks.
+ if args.fp16_lm_cross_entropy:
+ assert args.fp16, "lm cross entropy in fp16 only support in fp16 mode."
+ if args.fp32_residual_connection:
+ assert (
+ args.fp16 or args.bf16
+ ), "residual connection in fp32 only supported when using fp16 or bf16."
+ # Activation checkpointing.
+ if args.distribute_checkpointed_activations:
+ assert args.checkpoint_activations, (
+ "for distribute-checkpointed-activations to work you "
+ "need to enable checkpoint-activations"
+ )
+
+ _print_args(args)
+ return args
+
+
+def _print_args(args):
+ """Print arguments."""
+ if args.rank == 0:
+ print("------------------------ arguments ------------------------", flush=True)
+ str_list = []
+ for arg in vars(args):
+ dots = "." * (48 - len(arg))
+ str_list.append(" {} {} {}".format(arg, dots, getattr(args, arg)))
+ for arg in sorted(str_list, key=lambda x: x.lower()):
+ print(arg, flush=True)
+ print("-------------------- end of arguments ---------------------", flush=True)
+
+
+def _check_arg_is_not_none(args, arg):
+ assert getattr(args, arg) is not None, "{} argument is None".format(arg)
+
+
+def _add_network_size_args(parser):
+ group = parser.add_argument_group(title="network size")
+
+ group.add_argument(
+ "--num-layers",
+ type=int,
+ default=None,
+ help="Number of transformer layers.",
+ )
+ group.add_argument(
+ "--hidden-size",
+ type=int,
+ default=None,
+ help="Transformer hidden size.",
+ )
+ group.add_argument(
+ "--reward-growth",
+ type=str,
+ default="constant",
+ choices=["constant", "linear", "quadratic"],
+ help="Reward growth function.",
+ )
+ group.add_argument(
+ "--ffn-hidden-size",
+ type=int,
+ default=None,
+ help="Transformer Feed-Forward Network hidden size. "
+ "This is set to 4*hidden-size if not provided",
+ )
+ group.add_argument(
+ "--num-attention-heads",
+ type=int,
+ default=None,
+ help="Number of transformer attention heads.",
+ )
+ group.add_argument(
+ "--kv-channels",
+ type=int,
+ default=None,
+ help="Projection weights dimension in multi-head "
+ "attention. This is set to "
+ " args.hidden_size // args.num_attention_heads "
+ "if not provided.",
+ )
+ group.add_argument(
+ "--scale-embeddings",
+ action="store_true",
+ help="Scale embeddings by sqrt(d_model).",
+ )
+ group.add_argument(
+ "--max-position-embeddings",
+ type=int,
+ default=None,
+ help="Maximum number of position embeddings to use. "
+ "This is the size of position embedding.",
+ )
+ group.add_argument(
+ "--no-learned-position-embeddings",
+ action="store_true",
+ help="Do not learn position embeddings. ",
+ )
+ group.add_argument(
+ "--make-vocab-size-divisible-by",
+ type=int,
+ default=128,
+ help="Pad the vocab size to be divisible by this value."
+ "This is added for computational efficieny reasons.",
+ )
+ group.add_argument(
+ "--layernorm-epsilon", type=float, default=1e-5, help="Layer norm epsilon."
+ )
+ group.add_argument(
+ "--apply-residual-connection-post-layernorm",
+ action="store_true",
+ help="If set, use original BERT residula connection " "ordering.",
+ )
+ group.add_argument('--scaled-upper-triang-masked-softmax-fusion',
+ action='store_true',
+ help='Enable fusion of query_key_value_scaling '
+ 'time (upper diagonal) masking, softmax.')
+ group.add_argument(
+ "--openai-gelu",
+ action="store_true",
+ help="Use OpenAIs GeLU implementation. This option"
+ "should not be used unless for backward compatibility"
+ "reasons.",
+ )
+ group.add_argument(
+ "--onnx-safe",
+ type=bool,
+ required=False,
+ help="Use workarounds for known problems with " "Torch ONNX exporter",
+ )
+ group.add_argument(
+ "--bert-no-binary-head",
+ action="store_false",
+ help="Disable BERT binary head.",
+ dest="bert_binary_head",
+ )
+
+ return parser
+
+
+def _add_logging_args(parser):
+ group = parser.add_argument_group(title="logging")
+
+ group.add_argument(
+ "--log-params-norm",
+ action="store_true",
+ help="If set, calculate and log parameters norm.",
+ )
+ group.add_argument(
+ "--log-num-zeros-in-grad",
+ action="store_true",
+ help="If set, calculate and log the number of zeros in gradient.",
+ )
+ group.add_argument(
+ "--tensorboard-log-interval",
+ type=int,
+ default=1,
+ help="Report to tensorboard interval.",
+ )
+ group.add_argument(
+ "--tensorboard-queue-size",
+ type=int,
+ default=1000,
+ help="Size of the tensorboard queue for pending events "
+ "and summaries before one of the ‘add’ calls forces a "
+ "flush to disk.",
+ )
+ group.add_argument(
+ "--log-timers-to-tensorboard",
+ action="store_true",
+ help="If set, write timers to tensorboard.",
+ )
+ group.add_argument(
+ "--log-batch-size-to-tensorboard",
+ action="store_true",
+ help="If set, write batch-size to tensorboard.",
+ )
+ group.add_argument(
+ "--no-log-learnig-rate-to-tensorboard",
+ action="store_false",
+ help="Disable learning rate logging to tensorboard.",
+ dest="log_learning_rate_to_tensorboard",
+ )
+ group.add_argument(
+ "--no-log-loss-scale-to-tensorboard",
+ action="store_false",
+ help="Disable loss-scale logging to tensorboard.",
+ dest="log_loss_scale_to_tensorboard",
+ )
+ group.add_argument(
+ "--log-validation-ppl-to-tensorboard",
+ action="store_true",
+ help="If set, write validation perplexity to " "tensorboard.",
+ )
+ group.add_argument(
+ "--wandb-logging",
+ action="store_true",
+ help="If set, log training progress to wandb.",
+ )
+ group.add_argument(
+ "--wandb-log-interval",
+ type=int,
+ default=1,
+ help="Log to wandb every N steps.",
+ )
+
+ return parser
+
+
+def _add_regularization_args(parser):
+ group = parser.add_argument_group(title="regularization")
+
+ group.add_argument(
+ "--attention-dropout",
+ type=float,
+ default=0.1,
+ help="Post attention dropout probability.",
+ )
+ group.add_argument(
+ "--hidden-dropout",
+ type=float,
+ default=0.1,
+ help="Dropout probability for hidden state transformer.",
+ )
+ group.add_argument(
+ "--weight-decay",
+ type=float,
+ default=0.01,
+ help="Weight decay coefficient for L2 regularization.",
+ )
+ group.add_argument(
+ "--tempering",
+ type=float,
+ default=None,
+ help="Tempering coefficient for the model.",
+ )
+ group.add_argument(
+ "--gold",
+ action="store_true",
+ help="If set, use gold regularization.",
+ )
+ group.add_argument(
+ "--gold-beta",
+ type=float,
+ default=0.05,
+ help="Beta for GOLD tempering.",
+ )
+ group.add_argument(
+ "--play-tau",
+ type=float,
+ default=2.0
+ )
+ group.add_argument(
+ "--clip-grad",
+ type=float,
+ default=1.0,
+ help="Gradient clipping based on global L2 norm.",
+ )
+ group.add_argument(
+ "--adam-beta1",
+ type=float,
+ default=0.9,
+ help="First coefficient for computing running averages "
+ "of gradient and its square",
+ )
+ group.add_argument(
+ "--adam-beta2",
+ type=float,
+ default=0.999,
+ help="Second coefficient for computing running averages "
+ "of gradient and its square",
+ )
+ group.add_argument(
+ "--adam-eps",
+ type=float,
+ default=1e-08,
+ help="Term added to the denominator to improve" "numerical stability",
+ )
+ group.add_argument(
+ "--sgd-momentum", type=float, default=0.9, help="Momentum factor for sgd"
+ )
+
+ return parser
+
+
+def _add_training_args(parser):
+ group = parser.add_argument_group(title="training")
+
+ group.add_argument(
+ "--micro-batch-size",
+ type=int,
+ default=None,
+ help="Batch size per model instance (local batch size). "
+ "Global batch size is local batch size times data "
+ "parallel size times number of micro batches.",
+ )
+ group.add_argument(
+ "--batch-size",
+ type=int,
+ default=None,
+ help="Old batch size parameter, do not use. " "Use --micro-batch-size instead",
+ )
+ group.add_argument(
+ "--global-batch-size",
+ type=int,
+ default=None,
+ help="Training batch size. If set, it should be a "
+ "multiple of micro-batch-size times data-parallel-size. "
+ "If this value is None, then "
+ "use micro-batch-size * data-parallel-size as the "
+ "global batch size. This choice will result in 1 for "
+ "number of micro-batches.",
+ )
+ group.add_argument(
+ "--rampup-batch-size",
+ nargs="*",
+ default=None,
+ help="Batch size ramp up with the following values:"
+ " --rampup-batch-size "
+ " "
+ " "
+ "For example:"
+ " --rampup-batch-size 16 8 300000 \ "
+ " --global-batch-size 1024"
+ "will start with global batch size 16 and over "
+ " (1024 - 16) / 8 = 126 intervals will increase"
+ "the batch size linearly to 1024. In each interval"
+ "we will use approximately 300000 / 126 = 2380 samples.",
+ )
+ group.add_argument(
+ "--checkpoint-activations",
+ action="store_true",
+ help="Checkpoint activation to allow for training "
+ "with larger models, sequences, and batch sizes.",
+ )
+ group.add_argument(
+ "--distribute-checkpointed-activations",
+ action="store_true",
+ help="If set, distribute checkpointed activations "
+ "across model parallel group.",
+ )
+ group.add_argument(
+ "--checkpoint-num-layers",
+ type=int,
+ default=1,
+ help="chunk size (number of layers) for checkpointing.",
+ )
+ group.add_argument(
+ "--train-iters",
+ type=int,
+ default=None,
+ help="Total number of iterations to train over all "
+ "training runs. Note that either train-iters or "
+ "train-samples should be provided.",
+ )
+ group.add_argument(
+ "--train-samples",
+ type=int,
+ default=None,
+ help="Total number of samples to train over all "
+ "training runs. Note that either train-iters or "
+ "train-samples should be provided.",
+ )
+ group.add_argument(
+ "--train-tokens",
+ type=int,
+ default=None,
+ help="Total number of tokens to train over all " "training runs.",
+ )
+ group.add_argument(
+ "--log-interval", type=int, default=100, help="Report loss and timing interval."
+ )
+ group.add_argument(
+ "--exit-interval",
+ type=int,
+ default=None,
+ help="Exit the program after the iteration is divisible " "by this value.",
+ )
+ group.add_argument(
+ "--exit-duration-in-mins",
+ type=int,
+ default=None,
+ help="Exit the program after this many minutes.",
+ )
+ group.add_argument(
+ "--tensorboard-dir",
+ type=str,
+ default=None,
+ help="Write TensorBoard logs to this directory.",
+ )
+ group.add_argument(
+ "--no-masked-softmax-fusion",
+ action="store_false",
+ help="Disable fusion of query_key_value scaling, " "masking, and softmax.",
+ dest="masked_softmax_fusion",
+ )
+ group.add_argument(
+ "--no-bias-gelu-fusion",
+ action="store_false",
+ help="Disable bias and gelu fusion.",
+ dest="bias_gelu_fusion",
+ )
+ group.add_argument(
+ "--no-bias-dropout-fusion",
+ action="store_false",
+ help="Disable bias and dropout fusion.",
+ dest="bias_dropout_fusion",
+ )
+ group.add_argument(
+ "--optimizer",
+ type=str,
+ default="adam",
+ choices=["adam", "sgd"],
+ help="Optimizer function",
+ )
+ group.add_argument(
+ "--dataloader-type",
+ type=str,
+ default=None,
+ choices=["single", "cyclic"],
+ help="Single pass vs multiple pass data loader",
+ )
+ group.add_argument(
+ "--cpu-optimizer", action="store_true", help="Run optimizer on CPU"
+ )
+ group.add_argument(
+ "--cpu_torch_adam",
+ action="store_true",
+ help="Use Torch Adam as optimizer on CPU.",
+ )
+ group.add_argument(
+ "--no-pipeline-parallel",
+ action="store_true",
+ help="Disable pipeline parallelism",
+ )
+ group.add_argument(
+ "--ms-model",
+ action="store_true",
+ help="use model converted from Mindspore",
+ )
+
+ return parser
+
+
+def _add_initialization_args(parser):
+ group = parser.add_argument_group(title="initialization")
+
+ group.add_argument(
+ "--seed",
+ type=int,
+ default=1234,
+ help="Random seed used for python, numpy, " "pytorch, and cuda.",
+ )
+ group.add_argument(
+ "--init-method-std",
+ type=float,
+ default=0.02,
+ help="Standard deviation of the zero mean normal "
+ "distribution used for weight initialization.",
+ )
+ group.add_argument(
+ "--init-method-xavier-uniform",
+ action="store_true",
+ help="Enable Xavier uniform parameter initialization",
+ )
+
+ return parser
+
+
+def _add_inference_args(parser):
+ group = parser.add_argument_group(title="initialization")
+
+ group.add_argument(
+ '--beam-warmup',
+ action="store_true",
+ )
+ group.add_argument(
+ '--beam-warmup-length',
+ type=int,
+ default=0,
+ )
+ group.add_argument(
+ '--beam-search',
+ action="store_true",
+ )
+ group.add_argument(
+ '--beam-search-nucleus',
+ action="store_true",
+ )
+ group.add_argument(
+ '--num-beams',
+ type=int,
+ default=4,
+ )
+
+ return parser
+
+
+def _add_learning_rate_args(parser):
+ group = parser.add_argument_group(title="learning rate")
+
+ group.add_argument(
+ "--lr",
+ type=float,
+ default=None,
+ help="Initial learning rate. Depending on decay style "
+ "and initial warmup, the learing rate at each "
+ "iteration would be different.",
+ )
+ group.add_argument(
+ "--lr-decay-style",
+ type=str,
+ default="linear",
+ choices=["constant", "linear", "cosine"],
+ help="Learning rate decay function.",
+ )
+ group.add_argument(
+ "--lr-decay-iters",
+ type=int,
+ default=None,
+ help="number of iterations to decay learning rate over,"
+ " If None defaults to `--train-iters`",
+ )
+ group.add_argument(
+ "--lr-decay-samples",
+ type=int,
+ default=None,
+ help="number of samples to decay learning rate over,"
+ " If None defaults to `--train-samples`",
+ )
+ group.add_argument(
+ "--lr-decay-tokens",
+ type=int,
+ default=None,
+ help="number of tokens to decay learning rate over,"
+ " If not None will override iter/sample-based decay",
+ )
+ group.add_argument(
+ "--lr-warmup-fraction",
+ type=float,
+ default=None,
+ help="fraction of lr-warmup-(iters/samples) to use " "for warmup (as a float)",
+ )
+ group.add_argument(
+ "--lr-warmup-iters",
+ type=int,
+ default=0,
+ help="number of iterations to linearly warmup " "learning rate over.",
+ )
+ group.add_argument(
+ "--lr-warmup-samples",
+ type=int,
+ default=0,
+ help="number of samples to linearly warmup " "learning rate over.",
+ )
+ group.add_argument(
+ "--warmup",
+ type=int,
+ default=None,
+ help="Old lr warmup argument, do not use. Use one of the"
+ "--lr-warmup-* arguments above",
+ )
+ group.add_argument(
+ "--min-lr",
+ type=float,
+ default=0.0,
+ help="Minumum value for learning rate. The scheduler"
+ "clip values below this threshold.",
+ )
+ group.add_argument(
+ "--override-lr-scheduler",
+ action="store_true",
+ help="Reset the values of the scheduler (learning rate,"
+ "warmup iterations, minimum learning rate, maximum "
+ "number of iterations, and decay style from input "
+ "arguments and ignore values from checkpoints. Note"
+ "that all the above values will be reset.",
+ )
+ group.add_argument(
+ "--use-checkpoint-lr-scheduler",
+ action="store_true",
+ help="Use checkpoint to set the values of the scheduler "
+ "(learning rate, warmup iterations, minimum learning "
+ "rate, maximum number of iterations, and decay style "
+ "from checkpoint and ignore input arguments.",
+ )
+
+ return parser
+
+
+def _add_checkpointing_args(parser):
+ group = parser.add_argument_group(title="checkpointing")
+
+ group.add_argument(
+ "--save",
+ type=str,
+ default=None,
+ help="Output directory to save checkpoints to.",
+ )
+ group.add_argument(
+ "--save-interval",
+ type=int,
+ default=None,
+ help="Number of iterations between checkpoint saves.",
+ )
+ group.add_argument(
+ "--no-save-optim",
+ action="store_true",
+ default=None,
+ help="Do not save current optimizer.",
+ )
+ group.add_argument(
+ "--no-save-rng",
+ action="store_true",
+ default=None,
+ help="Do not save current rng state.",
+ )
+ group.add_argument(
+ "--load",
+ type=str,
+ default=None,
+ help="Directory containing a model checkpoint.",
+ )
+ group.add_argument(
+ "--low-memory-load",
+ action="store_true",
+ default=None,
+ help="Load model checkpoint in low memory mode."
+ "On each machine, workers load the checkpoint one at a time."
+ )
+ group.add_argument(
+ "--dist-timeout",
+ type=int,
+ default=30,
+ help="Timeout for Pytorch Distributed backend (in minutes).",
+ )
+ group.add_argument(
+ "--load-state",
+ type=str,
+ default=None,
+ help="Start training from a existing model state.",
+ )
+ group.add_argument(
+ "--no-load-optim",
+ action="store_true",
+ default=None,
+ help="Do not load optimizer when loading checkpoint.",
+ )
+ group.add_argument(
+ "--no-load-rng",
+ action="store_true",
+ default=None,
+ help="Do not load rng state when loading checkpoint.",
+ )
+ group.add_argument(
+ "--finetune",
+ action="store_true",
+ help="Load model for finetuning. Do not load optimizer "
+ "or rng state from checkpoint and set iteration to 0. "
+ "Assumed when loading a release checkpoint.",
+ )
+
+ return parser
+
+
+def _add_mixed_precision_args(parser):
+ group = parser.add_argument_group(title="mixed precision")
+
+ group.add_argument("--fp16", action="store_true", help="Run model in fp16 mode.")
+ group.add_argument("--ln-fp16", action="store_true", help="Run layernorm in fp16 mode.")
+ group.add_argument(
+ "--bf16", action="store_true", help="Run model in bfloat16 mode."
+ )
+ group.add_argument(
+ "--loss-scale",
+ type=float,
+ default=None,
+ help="Static loss scaling, positive power of 2 "
+ "values can improve fp16 convergence. If None, dynamic"
+ "loss scaling is used.",
+ )
+ group.add_argument(
+ "--initial-loss-scale",
+ type=float,
+ default=2 ** 32,
+ help="Initial loss-scale for dynamic loss scaling.",
+ )
+ group.add_argument(
+ "--min-loss-scale",
+ type=float,
+ default=1.0,
+ help="Minimum loss scale for dynamic loss scale.",
+ )
+ group.add_argument(
+ "--loss-scale-window",
+ type=float,
+ default=1000,
+ help="Window over which to raise/lower dynamic scale.",
+ )
+ group.add_argument(
+ "--hysteresis", type=int, default=2, help="hysteresis for dynamic loss scaling"
+ )
+ group.add_argument(
+ "--fp32-residual-connection",
+ action="store_true",
+ help="Move residual connections to fp32.",
+ )
+ group.add_argument('--apply-query-key-layer-scaling', action='store_true',
+ help='Scale Q * K^T by 1 / layer-number. If this flag '
+ 'is set, then it will automatically set '
+ 'attention-softmax-in-fp32 to true')
+ group.add_argument(
+ "--attention-softmax-in-fp32",
+ action="store_true",
+ help="Run attention masking and softmax in fp32. "
+ "This flag is ignored unless "
+ "--no-query-key-layer-scaling is specified.",
+ )
+ group.add_argument(
+ "--accumulate-allreduce-grads-in-fp32",
+ action="store_true",
+ help="Gradient accumulation and all-reduce in fp32.",
+ )
+ group.add_argument(
+ "--fp16-lm-cross-entropy",
+ action="store_true",
+ help="Move the cross entropy unreduced loss calculation" "for lm head to fp16.",
+ )
+
+ return parser
+
+
+def _add_distributed_args(parser):
+ group = parser.add_argument_group(title="distributed")
+
+ group.add_argument(
+ "--tensor-model-parallel-size",
+ type=int,
+ default=1,
+ help="Degree of tensor model parallelism.",
+ )
+ group.add_argument(
+ "--pipeline-model-parallel-size",
+ type=int,
+ default=1,
+ help="Degree of pipeline model parallelism.",
+ )
+ group.add_argument(
+ "--model-parallel-size",
+ type=int,
+ default=None,
+ help="Old model parallel argument, do not use. Use "
+ "--tensor-model-parallel-size instead.",
+ )
+ group.add_argument(
+ "--num-layers-per-virtual-pipeline-stage",
+ type=int,
+ default=None,
+ help="Number of layers per virtual pipeline stage",
+ )
+ group.add_argument(
+ "--distributed-backend",
+ default="nccl",
+ choices=["nccl", "gloo"],
+ help="Which backend to use for distributed training.",
+ )
+ group.add_argument(
+ "--DDP-impl",
+ default="local",
+ choices=["local", "torch"],
+ help="which DistributedDataParallel implementation " "to use.",
+ )
+ group.add_argument(
+ "--use-contiguous-buffers-in-ddp",
+ action="store_true",
+ help="If set, use contiguous buffer in DDP. Note that "
+ "this option only works woth local DDP.",
+ )
+ group.add_argument(
+ "--no-scatter-gather-tensors-in-pipeline",
+ action="store_false",
+ help="Use scatter/gather to optimize communication of tensors in pipeline",
+ dest="scatter_gather_tensors_in_pipeline",
+ )
+ group.add_argument(
+ "--local_rank",
+ type=int,
+ default=None,
+ help="local rank passed from distributed launcher.",
+ )
+ group.add_argument(
+ "--lazy-mpu-init",
+ type=bool,
+ required=False,
+ help="If set to True, initialize_megatron() "
+ "skips DDP initialization and returns function to "
+ "complete it instead.Also turns on "
+ "--use-cpu-initialization flag. This is for "
+ "external DDP manager.",
+ )
+ group.add_argument(
+ "--use-cpu-initialization",
+ action="store_true",
+ default=None,
+ help="If set, affine parallel weights " "initialization uses CPU",
+ )
+ group.add_argument(
+ "--force-device",
+ type=int,
+ default=None,
+ help="Force the model to run on a particular gpu",
+ )
+ group.add_argument(
+ "--force-default",
+ action="store_true",
+ help="Force setting default arguments for distributed training",
+ )
+ return parser
+
+
+def _add_validation_args(parser):
+ group = parser.add_argument_group(title="validation")
+
+ group.add_argument(
+ "--eval-iters",
+ type=int,
+ default=100,
+ help="Number of iterations to run for evaluation" "validation/test for.",
+ )
+ group.add_argument(
+ "--eval-interval",
+ type=int,
+ default=1000,
+ help="Interval between running evaluation on " "validation set.",
+ )
+ group.add_argument(
+ "--co-evaluation",
+ action="store_true",
+ help="If set, run evaluation on each part of the validation set"
+ )
+
+ return parser
+
+
+def _add_data_args(parser):
+ group = parser.add_argument_group(title="data and dataloader")
+
+ group.add_argument(
+ "--data-path",
+ nargs="*",
+ default=None,
+ help="Path to the training dataset. Accepted format:"
+ "1) a single data path, 2) multiple datasets in the"
+ "form: dataset1-weight dataset1-path dataset2-weight "
+ "dataset2-path ...",
+ )
+ group.add_argument(
+ "--valid-data-path",
+ nargs="*",
+ default=None,
+ help="Path to the validation dataset. Accepted format:"
+ "1) a single data path, 2) multiple datasets in the"
+ "form: dataset1-weight dataset1-path dataset2-weight "
+ "dataset2-path ...;"
+ "when co-evaluation is enabled, the form will be dataset1-tag dataset1-path ...",
+ )
+ group.add_argument("--index-cache-dir", type=str, default=None, help="Path to the index cache")
+ group.add_argument(
+ "--test-data-path",
+ nargs="*",
+ default=None,
+ help="Path to the test dataset. Accepted format:"
+ "1) a single data path, 2) multiple datasets in the"
+ "form: dataset1-tag dataset1-path dataset2-tag "
+ "dataset2-path ...",
+ )
+ group.add_argument(
+ "--split",
+ type=str,
+ default="969, 30, 1",
+ help="Comma-separated list of proportions for training,"
+ " validation, and test split. For example the split "
+ "`90,5,5` will use 90%% of data for training, 5%% for "
+ "validation and 5%% for test.",
+ )
+ group.add_argument(
+ "--vocab-file",
+ type=str,
+ default=None,
+ help="Path to the vocab file.",
+ )
+ group.add_argument(
+ "--merge-file",
+ type=str,
+ default=None,
+ help="Path to the BPE merge file.",
+ )
+ group.add_argument(
+ "--tokenizer-path",
+ type=str,
+ default=None,
+ help="Path to the tokenizer dir.",
+ )
+ group.add_argument(
+ "--vocab-extra-ids",
+ type=int,
+ default=0,
+ help="Number of additional vocabulary tokens. "
+ "They are used for span masking in the T5 model",
+ )
+ group.add_argument(
+ "--seq-length",
+ type=int,
+ default=None,
+ help="Maximum sequence length to process.",
+ )
+ group.add_argument(
+ "--encoder-seq-length",
+ type=int,
+ default=None,
+ help="Maximum encoder sequence length to process."
+ "This should be exclusive of --seq-length",
+ )
+ group.add_argument(
+ "--decoder-seq-length",
+ type=int,
+ default=None,
+ help="Maximum decoder sequence length to process.",
+ )
+ group.add_argument(
+ "--retriever-seq-length",
+ type=int,
+ default=256,
+ help="Maximum sequence length for the biencoder model " " for retriever",
+ )
+ group.add_argument(
+ "--sample-rate",
+ type=float,
+ default=1.0,
+ help="sample rate for training data. Supposed to be 0 " " < sample_rate < 1",
+ )
+ group.add_argument(
+ "--mask-prob",
+ type=float,
+ default=0.15,
+ help="Probability of replacing a token with mask.",
+ )
+ group.add_argument(
+ "--short-seq-prob",
+ type=float,
+ default=0.1,
+ help="Probability of producing a short sequence.",
+ )
+ group.add_argument("--mmap-warmup", action="store_true", help="Warm up mmap files.")
+ group.add_argument(
+ "--num-workers", type=int, default=2, help="Dataloader number of workers."
+ )
+ group.add_argument(
+ "--tokenizer-type",
+ type=str,
+ default=None,
+ choices=["BertWordPieceLowerCase", "BertWordPieceCase", "GPT2BPETokenizer"],
+ help="What type of tokenizer to use.",
+ )
+ group.add_argument(
+ "--data-impl",
+ type=str,
+ default="infer",
+ choices=["lazy", "cached", "mmap", "infer"],
+ help="Implementation of indexed datasets.",
+ )
+ group.add_argument(
+ "--reset-position-ids",
+ action="store_true",
+ help="Reset posistion ids after end-of-document token.",
+ )
+ group.add_argument(
+ "--reset-attention-mask",
+ action="store_true",
+ help="Reset self attention masks after " "end-of-document token.",
+ )
+ group.add_argument(
+ "--eod-mask-loss",
+ action="store_true",
+ help="Mask loss for the end of document tokens.",
+ )
+
+ return parser
+
+
+def _add_autoresume_args(parser):
+ group = parser.add_argument_group(title="autoresume")
+
+ group.add_argument(
+ "--adlr-autoresume",
+ action="store_true",
+ help="Enable autoresume on adlr cluster.",
+ )
+ group.add_argument(
+ "--adlr-autoresume-interval",
+ type=int,
+ default=1000,
+ help="Intervals over which check for autoresume" "termination signal",
+ )
+
+ return parser
+
+
+def _add_biencoder_args(parser):
+ group = parser.add_argument_group(title="biencoder")
+
+ # network size
+ group.add_argument(
+ "--ict-head-size",
+ type=int,
+ default=None,
+ help="Size of block embeddings to be used in ICT and "
+ "REALM (paper default: 128)",
+ )
+ group.add_argument(
+ "--biencoder-projection-dim",
+ type=int,
+ default=0,
+ help="Size of projection head used in biencoder (paper" " default: 128)",
+ )
+ group.add_argument(
+ "--biencoder-shared-query-context-model",
+ action="store_true",
+ help="Whether to share the parameters of the query "
+ "and context models or not",
+ )
+
+ # checkpointing
+ group.add_argument(
+ "--ict-load",
+ type=str,
+ default=None,
+ help="Directory containing an ICTBertModel checkpoint",
+ )
+ group.add_argument(
+ "--bert-load",
+ type=str,
+ default=None,
+ help="Directory containing an BertModel checkpoint "
+ "(needed to start ICT and REALM)",
+ )
+
+ # data
+ group.add_argument(
+ "--titles-data-path",
+ type=str,
+ default=None,
+ help="Path to titles dataset used for ICT",
+ )
+ group.add_argument(
+ "--query-in-block-prob",
+ type=float,
+ default=0.1,
+ help="Probability of keeping query in block for " "ICT dataset",
+ )
+ group.add_argument(
+ "--use-one-sent-docs",
+ action="store_true",
+ help="Whether to use one sentence documents in ICT",
+ )
+ group.add_argument(
+ "--evidence-data-path",
+ type=str,
+ default=None,
+ help="Path to Wikipedia Evidence frm DPR paper",
+ )
+
+ # training
+ group.add_argument(
+ "--retriever-report-topk-accuracies",
+ nargs="+",
+ type=int,
+ default=[],
+ help="Which top-k accuracies to report " "(e.g. '1 5 20')",
+ )
+ group.add_argument(
+ "--retriever-score-scaling",
+ action="store_true",
+ help="Whether to scale retriever scores by inverse "
+ "square root of hidden size",
+ )
+
+ # faiss index
+ group.add_argument(
+ "--block-data-path",
+ type=str,
+ default=None,
+ help="Where to save/load BlockData to/from",
+ )
+ group.add_argument(
+ "--embedding-path",
+ type=str,
+ default=None,
+ help="Where to save/load Open-Retrieval Embedding" " data to/from",
+ )
+
+ # indexer
+ group.add_argument(
+ "--indexer-batch-size",
+ type=int,
+ default=128,
+ help="How large of batches to use when doing indexing " "jobs",
+ )
+ group.add_argument(
+ "--indexer-log-interval",
+ type=int,
+ default=1000,
+ help="After how many batches should the indexer " "report progress",
+ )
+ return parser
+
+
+def _add_vit_args(parser):
+ group = parser.add_argument_group(title="vit")
+
+ group.add_argument(
+ "--num-classes",
+ type=int,
+ default=1000,
+ help="num of classes in vision classificaiton task",
+ )
+ group.add_argument(
+ "--img-dim",
+ type=int,
+ default=224,
+ help="Image size for vision classification task",
+ )
+ group.add_argument(
+ "--num-channels",
+ type=int,
+ default=3,
+ help="Number of channels in input image data",
+ )
+ group.add_argument(
+ "--patch-dim", type=int, default=16, help="patch dimension used in vit"
+ )
+
+ return parser
+
+
+def _add_zero_args(parser):
+ """Text generate arguments."""
+
+ group = parser.add_argument_group("ZeRO configurations", "configurations")
+ group.add_argument("--zero-stage", type=int, default=1.0)
+ group.add_argument(
+ "--zero-reduce-scatter",
+ action="store_true",
+ help="Use reduce scatter if specified",
+ )
+ group.add_argument(
+ "--zero-contigious-gradients",
+ action="store_true",
+ help="Use contigious memory optimizaiton if specified",
+ )
+ group.add_argument("--zero-reduce-bucket-size", type=int, default=0.0)
+ group.add_argument("--zero-allgather-bucket-size", type=int, default=0.0)
+ group.add_argument(
+ "--remote-device",
+ type=str,
+ default="none",
+ choices=["none", "cpu", "nvme"],
+ help="Remote device for ZeRO-3 initialized parameters.",
+ )
+ group.add_argument(
+ "--use-pin-memory",
+ action="store_true",
+ help="Use pinned CPU memory for ZeRO-3 initialized model parameters.",
+ )
+ return parser
+
+
+def _add_memoryopt_args(parser):
+ """Memory optimization arguments."""
+
+ group = parser.add_argument_group("Memory optimizations", "configurations")
+ group.add_argument(
+ "--scattered-embeddings",
+ action="store_true",
+ help="Save memory by scattering embedding activations. "
+ "Introduces dropout differences across MP configurations.",
+ )
+ group.add_argument(
+ "--split-transformers",
+ action="store_true",
+ help="Save memory by splitting transformer layers into two parts, "
+ "allowing for more frequent activation checkpoint savings.",
+ )
+ group.add_argument(
+ "--memory-centric-tiled-linear",
+ action="store_true",
+ help="Save memory by tiling with deepspeed.zero.TiledLinear.",
+ )
+ group.add_argument(
+ "--tile-factor",
+ type=int,
+ default=1,
+ help="Make all linear layers the same size of [hidden/tile_factor, hidden/tile_factor]. "
+ "Must be enabled with --memory-centric-tiled-linear. "
+ "Example A: if tile_factor=1, the qkv layer [hidden, 3* hidden] would be converted into [1,3] tiles of size [hidden,hidden]. "
+ "Example B: if tile_factor=2, the intermediate layer [4*hidden, hidden] will be converted into [8, 2] tiles of size [hidden/2, hidden/2]. "
+ "Default is 1.",
+ )
+
+ return parser
+
+
+def _add_activation_checkpoint_args(parser):
+ group = parser.add_argument_group(
+ "Activation Checkpointing", "Checkpointing Configurations"
+ )
+ group.add_argument(
+ "--deepspeed-activation-checkpointing",
+ action="store_true",
+ help="uses activation checkpointing from deepspeed",
+ )
+ group.add_argument(
+ "--partition-activations",
+ action="store_true",
+ help="partition Activations across GPUs before checkpointing.",
+ )
+ group.add_argument(
+ "--contigious-checkpointing",
+ action="store_true",
+ help="Contigious memory checkpointing for activatoins.",
+ )
+ group.add_argument(
+ "--checkpoint-in-cpu",
+ action="store_true",
+ help="Move the activation checkpoints to CPU.",
+ )
+ group.add_argument(
+ "--synchronize-each-layer",
+ action="store_true",
+ help="does a synchronize at the beginning and end of each checkpointed layer.",
+ )
+ group.add_argument(
+ "--profile-backward",
+ action="store_true",
+ help="Enables backward pass profiling for checkpointed layers.",
+ )
+ return parser
diff --git a/codegeex/megatron/checkpointing.py b/codegeex/megatron/checkpointing.py
new file mode 100644
index 0000000..8b417d8
--- /dev/null
+++ b/codegeex/megatron/checkpointing.py
@@ -0,0 +1,528 @@
+# coding=utf-8
+# Copyright (c) 2020, NVIDIA CORPORATION. All rights reserved.
+#
+# Licensed under the Apache License, Version 2.0 (the "License");
+# you may not use this file except in compliance with the License.
+# You may obtain a copy of the License at
+#
+# http://www.apache.org/licenses/LICENSE-2.0
+#
+# Unless required by applicable law or agreed to in writing, software
+# distributed under the License is distributed on an "AS IS" BASIS,
+# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
+# See the License for the specific language governing permissions and
+# limitations under the License.
+
+"""Input/output checkpointing."""
+
+import os
+import random
+import sys
+import numpy as np
+from glob import glob
+
+import torch
+
+from megatron import get_args, mpu, print_rank_0, update_num_microbatches, utils
+
+_CHECKPOINT_VERSION = None
+
+
+def set_checkpoint_version(value):
+ global _CHECKPOINT_VERSION
+ if _CHECKPOINT_VERSION is not None:
+ assert _CHECKPOINT_VERSION == value, "checkpoint versions do not match"
+ _CHECKPOINT_VERSION = value
+
+
+def get_checkpoint_version():
+ global _CHECKPOINT_VERSION
+ return _CHECKPOINT_VERSION
+
+
+def check_checkpoint_args(checkpoint_args):
+ """Ensure fixed arguments for a model are the same for the input
+ arguments and the one retrieved from checkpoint."""
+ args = get_args()
+
+ def _compare(arg_name, old_arg_name=None):
+ if old_arg_name is not None:
+ checkpoint_value = getattr(checkpoint_args, old_arg_name)
+ else:
+ checkpoint_value = getattr(checkpoint_args, arg_name)
+ args_value = getattr(args, arg_name)
+ error_message = (
+ "{} value from checkpoint ({}) is not equal to the "
+ "input argument value ({}).".format(arg_name, checkpoint_value, args_value)
+ )
+ assert checkpoint_value == args_value, error_message
+
+ _compare("num_layers")
+ _compare("hidden_size")
+ _compare("num_attention_heads")
+ _compare("max_position_embeddings")
+ if args.vocab_file:
+ _compare("make_vocab_size_divisible_by")
+ _compare("padded_vocab_size")
+ _compare("tokenizer_type")
+ if get_checkpoint_version() < 3.0:
+ _compare("tensor_model_parallel_size", old_arg_name="model_parallel_size")
+ if get_checkpoint_version() >= 3.0:
+ _compare("tensor_model_parallel_size")
+ _compare("pipeline_model_parallel_size")
+
+
+def ensure_directory_exists(filename):
+ """Build filename's path if it does not already exists."""
+ dirname = os.path.dirname(filename)
+ if not os.path.exists(dirname):
+ os.makedirs(dirname)
+
+
+def get_checkpoint_name(checkpoints_path, iteration, release=False):
+ """A unified checkpoint name."""
+ if release:
+ directory = "release"
+ else:
+ directory = "iter_{:07d}".format(iteration)
+ # Use both the tensor and pipeline MP rank.
+ if mpu.get_pipeline_model_parallel_world_size() == 1:
+ return os.path.join(
+ checkpoints_path,
+ directory,
+ "mp_rank_{:02d}".format(mpu.get_tensor_model_parallel_rank()),
+ "model_optim_rng.pt",
+ )
+ return os.path.join(
+ checkpoints_path,
+ directory,
+ "mp_rank_{:02d}_{:03d}".format(
+ mpu.get_tensor_model_parallel_rank(), mpu.get_pipeline_model_parallel_rank()
+ ),
+ "model_optim_rng.pt",
+ )
+
+
+def get_checkpoint_tracker_filename(checkpoints_path):
+ """Tracker file rescords the latest chckpoint during
+ training to restart from."""
+ return os.path.join(checkpoints_path, "latest_checkpointed_iteration.txt")
+
+
+def save_checkpoint(iteration, model, optimizer, lr_scheduler):
+ """Save a model checkpoint."""
+ args = get_args()
+
+ # Only rank zero of the data parallel writes to the disk.
+ if not args.deepspeed:
+ model = utils.unwrap_model(model)
+
+ print_rank_0(
+ "saving checkpoint at iteration {:7d} to {}".format(iteration, args.save)
+ )
+
+ if (
+ not torch.distributed.is_initialized()
+ or mpu.get_data_parallel_rank() == 0
+ or args.deepspeed
+ ):
+
+ # Arguments, iteration, and model.
+ state_dict = {}
+ state_dict["args"] = args
+ state_dict["checkpoint_version"] = 3.0
+ state_dict["iteration"] = iteration
+ state_dict["tokens"] = args.consumed_train_tokens
+
+ # DeepSpeed saves the model/optimizer/scheduler
+ if not args.deepspeed:
+ if len(model) == 1:
+ state_dict["model"] = model[0].state_dict_for_save_checkpoint()
+ else:
+ for i in range(len(model)):
+ mpu.set_virtual_pipeline_model_parallel_rank(i)
+ state_dict["model%d" % i] = model[
+ i
+ ].state_dict_for_save_checkpoint()
+
+ # Optimizer stuff.
+ if not args.no_save_optim:
+ if optimizer is not None:
+ state_dict["optimizer"] = optimizer.state_dict()
+ if lr_scheduler is not None:
+ state_dict["lr_scheduler"] = lr_scheduler.state_dict()
+
+ # RNG states.
+ if not args.no_save_rng:
+ state_dict["random_rng_state"] = random.getstate()
+ state_dict["np_rng_state"] = np.random.get_state()
+ state_dict["torch_rng_state"] = torch.get_rng_state()
+ state_dict["cuda_rng_state"] = torch.cuda.get_rng_state()
+ state_dict["rng_tracker_states"] = mpu.get_cuda_rng_tracker().get_states()
+
+ # Save.
+ checkpoint_name = get_checkpoint_name(args.save, iteration)
+ if not args.deepspeed:
+ ensure_directory_exists(checkpoint_name)
+ torch.save(state_dict, checkpoint_name)
+
+ if args.deepspeed:
+ # megatron model uses state_dict_for_save_checkpointing instead of the standard state_dict
+ # state_dict is used by deepspeed for module saving so it needs to point to the right function
+ if args.no_pipeline_parallel:
+ original_state_dict = model[0].module.state_dict
+ model[0].module.state_dict = model[0].module.state_dict_for_save_checkpoint
+
+ # Saving is a collective communication
+ checkpoint_name = get_checkpoint_name(args.save, iteration)
+ # Trim off the filename and mp_rank_* directory.
+ for _ in range(3):
+ checkpoint_name = os.path.dirname(checkpoint_name)
+ model[0].save_checkpoint(checkpoint_name, client_state=state_dict)
+
+ if args.no_pipeline_parallel:
+ model[0].module.state_dict = original_state_dict
+
+ # Wait so everyone is done (necessary)
+ if torch.distributed.is_initialized():
+ torch.distributed.barrier()
+
+ print_rank_0(
+ " successfully saved checkpoint at iteration {:7d} to {}".format(
+ iteration, args.save
+ )
+ )
+
+ # And update the latest iteration
+ if not torch.distributed.is_initialized() or torch.distributed.get_rank() == 0:
+ tracker_filename = get_checkpoint_tracker_filename(args.save)
+ with open(tracker_filename, "w") as f:
+ f.write(str(iteration))
+
+ # Wait so everyone is done (not necessary)
+ if torch.distributed.is_initialized():
+ torch.distributed.barrier()
+
+
+def _transpose_first_dim(t, num_splits, num_splits_first, model):
+ input_shape = t.size()
+ # We use a self_attention module but the values extracted aren't
+ # specific to self attention so should work for cross attention as well
+ while hasattr(model, "module"):
+ model = model.module
+ attention_module = model.language_model.encoder.layers[0].self_attention
+ hidden_size_per_attention_head = attention_module.hidden_size_per_attention_head
+ num_attention_heads_per_partition = (
+ attention_module.num_attention_heads_per_partition
+ )
+ if num_splits_first:
+ """[num_splits * np * hn, h]
+ -->(view) [num_splits, np, hn, h]
+ -->(tranpose) [np, num_splits, hn, h]
+ -->(view) [np * num_splits * hn, h]"""
+
+ intermediate_shape = (
+ num_splits,
+ num_attention_heads_per_partition,
+ hidden_size_per_attention_head,
+ ) + input_shape[1:]
+
+ t = t.view(*intermediate_shape)
+ t = t.transpose(0, 1).contiguous()
+ else:
+ """[np * hn * num_splits, h]
+ -->(view) [np, hn, num_splits, h]
+ -->(tranpose) [np, num_splits, hn, h]
+ -->(view) [np * num_splits * hn, h]"""
+
+ intermediate_shape = (
+ num_attention_heads_per_partition,
+ hidden_size_per_attention_head,
+ num_splits,
+ ) + input_shape[1:]
+
+ t = t.view(*intermediate_shape)
+ t = t.transpose(1, 2).contiguous()
+ t = t.view(*input_shape)
+
+ return t
+
+
+def fix_query_key_value_ordering(model, checkpoint_version):
+ """Fix up query/key/value matrix ordering if checkpoint
+ version is smaller than 2.0
+ """
+ if checkpoint_version < 2.0:
+ if isinstance(model, list):
+ assert len(model) == 1
+ model = model[0]
+ for name, param in model.named_parameters():
+ if name.endswith((".query_key_value.weight", ".query_key_value.bias")):
+ if checkpoint_version == 0:
+ fixed_param = _transpose_first_dim(param.data, 3, True, model)
+ elif checkpoint_version == 1.0:
+ fixed_param = _transpose_first_dim(param.data, 3, False, model)
+ else:
+ print_rank_0(f"Invalid checkpoint version {checkpoint_version}.")
+ sys.exit()
+ param.data.copy_(fixed_param)
+ if name.endswith((".key_value.weight", ".key_value.bias")):
+ if checkpoint_version == 0:
+ fixed_param = _transpose_first_dim(param.data, 2, True, model)
+ elif checkpoint_version == 1.0:
+ fixed_param = _transpose_first_dim(param.data, 2, False, model)
+ else:
+ print_rank_0(f"Invalid checkpoint version {checkpoint_version}.")
+ sys.exit()
+ param.data.copy_(fixed_param)
+ print_rank_0(
+ " succesfully fixed query-key-values ordering for"
+ " checkpoint version {}".format(checkpoint_version)
+ )
+
+
+def load_deepspeed_state(model):
+ model = utils.unwrap_model(model)
+ args = get_args()
+ load_dir = args.load
+ if os.path.isdir(load_dir):
+ model_state_paths = glob(os.path.join(load_dir, "*model_states.pt"))
+ assert len(model_state_paths) == 1, (
+ "only support loading deepspeed checkpoint of model parallel size 1"
+ ", but got {}".format(model_state_paths)
+ )
+ model_state_path = model_state_paths[0]
+ else:
+ model_state_path = load_dir
+ state_dict = torch.load(model_state_path, map_location="cpu")
+ state_dict = state_dict["module"]
+
+ model[0].load_state_dict(state_dict, strict=True)
+
+
+def load_checkpoint(model, optimizer, lr_scheduler, load_arg="load", strict=True):
+ """Load a model checkpoint and return the iteration.
+ strict (bool): whether to strictly enforce that the keys in
+ :attr:`state_dict` of the checkpoint match the names of
+ parameters and buffers in model.
+ """
+ args = get_args()
+ load_dir = getattr(args, load_arg)
+
+ if args.deepspeed:
+ loaded_dir, state_dict = model[0].load_checkpoint(load_dir)
+ if loaded_dir is None:
+ print_rank_0(
+ "WARNING: could not find the metadata file {} ".format(load_dir)
+ )
+ print_rank_0(
+ " will not load any checkpoints and will start from " "random"
+ )
+ return 0
+ release = False
+ else:
+ model = utils.unwrap_model(model)
+
+ # Read the tracker file and set the iteration.
+ tracker_filename = get_checkpoint_tracker_filename(load_dir)
+
+ # If no tracker file, return iretation zero.
+ if not os.path.isfile(tracker_filename):
+ print_rank_0(
+ "WARNING: could not find the metadata file {} ".format(tracker_filename)
+ )
+ print_rank_0(
+ " will not load any checkpoints and will start from " "random"
+ )
+ return 0
+
+ # Otherwise, read the tracker file and either set the iteration or
+ # mark it as a release checkpoint.
+ iteration = 0
+ release = False
+ with open(tracker_filename, "r") as f:
+ metastring = f.read().strip()
+ try:
+ iteration = int(metastring)
+ except ValueError:
+ release = metastring == "release"
+ if not release:
+ print_rank_0(
+ "ERROR: Invalid metadata file {}. Exiting".format(
+ tracker_filename
+ )
+ )
+ sys.exit()
+
+ assert iteration > 0 or release, "error parsing metadata file {}".format(
+ tracker_filename
+ )
+
+ # Checkpoint.
+ checkpoint_name = get_checkpoint_name(load_dir, iteration, release)
+ print_rank_0(f" loading checkpoint from {args.load} at iteration {iteration}")
+
+ # Load the checkpoint.
+ try:
+ state_dict = torch.load(checkpoint_name, map_location="cpu")
+ except ModuleNotFoundError:
+ from megatron.fp16_deprecated import loss_scaler
+
+ # For backward compatibility.
+ print_rank_0(" > deserializing using the old code structure ...")
+ sys.modules["fp16.loss_scaler"] = sys.modules[
+ "megatron.fp16_deprecated.loss_scaler"
+ ]
+ sys.modules["megatron.fp16.loss_scaler"] = sys.modules[
+ "megatron.fp16_deprecated.loss_scaler"
+ ]
+ state_dict = torch.load(checkpoint_name, map_location="cpu")
+ sys.modules.pop("fp16.loss_scaler", None)
+ sys.modules.pop("megatron.fp16.loss_scaler", None)
+ except BaseException as e:
+ print_rank_0("could not load the checkpoint")
+ print_rank_0(e)
+ sys.exit()
+
+ # set checkpoint version
+ set_checkpoint_version(state_dict.get("checkpoint_version", 0))
+
+ # Set iteration.
+ if args.finetune or release:
+ iteration = 0
+ else:
+ try:
+ iteration = state_dict["iteration"]
+ if "tokens" in state_dict:
+ args.consumed_train_tokens = state_dict["tokens"]
+ except KeyError:
+ try: # Backward compatible with older checkpoints
+ iteration = state_dict["total_iters"]
+ except KeyError:
+ print_rank_0(
+ "A metadata file exists but unable to load "
+ "iteration from checkpoint {}, exiting".format(checkpoint_name)
+ )
+ sys.exit()
+
+ # Check arguments.
+ assert args.consumed_train_samples == 0
+ assert args.consumed_valid_samples == 0
+ if "args" in state_dict:
+ checkpoint_args = state_dict["args"]
+ check_checkpoint_args(checkpoint_args)
+ args.consumed_train_samples = getattr(
+ checkpoint_args, "consumed_train_samples", 0
+ )
+ update_num_microbatches(consumed_samples=args.consumed_train_samples)
+ args.consumed_valid_samples = getattr(
+ checkpoint_args, "consumed_valid_samples", 0
+ )
+ else:
+ print_rank_0("could not find arguments in the checkpoint ...")
+
+ # Model.
+ if not args.deepspeed:
+ if len(model) == 1:
+ model[0].load_state_dict(state_dict["model"], strict=strict)
+ else:
+ for i in range(len(model)):
+ mpu.set_virtual_pipeline_model_parallel_rank(i)
+ model[i].load_state_dict(state_dict["model%d" % i], strict=strict)
+
+ # Fix up query/key/value matrix ordering if needed
+ checkpoint_version = get_checkpoint_version()
+ print_rank_0(f" checkpoint version {checkpoint_version}")
+ fix_query_key_value_ordering(model, checkpoint_version)
+
+ # Optimizer.
+ if not args.deepspeed:
+ if not release and not args.finetune and not args.no_load_optim:
+ try:
+ if optimizer is not None:
+ optimizer.load_state_dict(state_dict["optimizer"])
+ if lr_scheduler is not None:
+ lr_scheduler.load_state_dict(state_dict["lr_scheduler"])
+ except KeyError:
+ print_rank_0(
+ "Unable to load optimizer from checkpoint {}. "
+ "Specify --no-load-optim or --finetune to prevent "
+ "attempting to load the optimizer state, "
+ "exiting ...".format(checkpoint_name)
+ )
+ sys.exit()
+
+ # rng states.
+ if not release and not args.finetune and not args.no_load_rng:
+ try:
+ random.setstate(state_dict["random_rng_state"])
+ np.random.set_state(state_dict["np_rng_state"])
+ torch.set_rng_state(state_dict["torch_rng_state"])
+ torch.cuda.set_rng_state(state_dict["cuda_rng_state"])
+ # Check for empty states array
+ if not state_dict["rng_tracker_states"]:
+ raise KeyError
+ mpu.get_cuda_rng_tracker().set_states(state_dict["rng_tracker_states"])
+ except KeyError:
+ print_rank_0(
+ "Unable to load rng state from checkpoint {}. "
+ "Specify --no-load-rng or --finetune to prevent "
+ "attempting to load the rng state, "
+ "exiting ...".format(checkpoint_name)
+ )
+ sys.exit()
+
+ # Some utilities want to load a checkpoint without distributed being initialized
+ # if torch.distributed.is_initialized():
+ # torch.distributed.barrier()
+
+ print_rank_0(
+ f" successfully loaded checkpoint from {args.load} "
+ f"at iteration {iteration}"
+ )
+
+ return iteration
+
+
+def load_biencoder_checkpoint(
+ model, only_query_model=False, only_context_model=False, custom_load_path=None
+):
+ """
+ selectively load retrieval models for indexing/retrieving
+ from saved checkpoints
+ """
+
+ args = get_args()
+
+ model = utils.unwrap_model(model)
+
+ load_path = custom_load_path if custom_load_path is not None else args.load
+
+ tracker_filename = get_checkpoint_tracker_filename(load_path)
+ with open(tracker_filename, "r") as f:
+ iteration = int(f.read().strip())
+
+ checkpoint_name = get_checkpoint_name(load_path, iteration, False)
+ if mpu.get_data_parallel_rank() == 0:
+ print(
+ "global rank {} is loading checkpoint {}".format(
+ torch.distributed.get_rank(), checkpoint_name
+ )
+ )
+
+ state_dict = torch.load(checkpoint_name, map_location="cpu")
+ ret_state_dict = state_dict["model"]
+
+ if only_query_model:
+ ret_state_dict.pop("context_model")
+ if only_context_model:
+ ret_state_dict.pop("query_model")
+
+ assert len(model) == 1
+ model[0].load_state_dict(ret_state_dict)
+ torch.distributed.barrier()
+
+ if mpu.get_data_parallel_rank() == 0:
+ print(" successfully loaded {}".format(checkpoint_name))
+
+ return model
diff --git a/codegeex/megatron/code_generation_utils.py b/codegeex/megatron/code_generation_utils.py
new file mode 100644
index 0000000..31df158
--- /dev/null
+++ b/codegeex/megatron/code_generation_utils.py
@@ -0,0 +1,1240 @@
+# coding=utf-8
+# Copyright (c) 2020, NVIDIA CORPORATION. All rights reserved.
+#
+# Licensed under the Apache License, Version 2.0 (the "License");
+# you may not use this file except in compliance with the License.
+# You may obtain a copy of the License at
+#
+# http://www.apache.org/licenses/LICENSE-2.0
+#
+# Unless required by applicable law or agreed to in writing, software
+# distributed under the License is distributed on an "AS IS" BASIS,
+# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
+# See the License for the specific language governing permissions and
+# limitations under the License.
+
+"""Utilities for generating text."""
+
+import copy
+import json
+import os
+import time
+from typing import *
+
+import torch
+import torch.nn.functional as F
+from dataclasses import dataclass
+
+from codegeex.megatron import get_args
+from codegeex.megatron import get_tokenizer
+from codegeex.megatron import mpu
+from codegeex.megatron.utils import get_ltor_masks_and_position_ids
+
+
+def get_batch(context_tokens, micro_batch_size=None):
+ """Generate batch from context tokens."""
+ args = get_args()
+ tokenizer = get_tokenizer()
+
+ # Move to GPU.
+ if micro_batch_size is None:
+ micro_batch_size = args.micro_batch_size
+ tokens = context_tokens.view(micro_batch_size, -1).contiguous().cuda()
+ # Get the attention mask and postition ids.
+ attention_mask, _, position_ids = get_ltor_masks_and_position_ids(
+ tokens,
+ tokenizer.eod,
+ args.reset_position_ids,
+ args.reset_attention_mask,
+ args.eod_mask_loss,
+ )
+
+ return tokens, attention_mask, position_ids
+
+
+def get_batch_(context_tokens):
+ """Generate batch from context tokens."""
+ args = get_args()
+ tokenizer = get_tokenizer()
+
+ # Move to GPU.
+ tokens = context_tokens.contiguous().cuda()
+ # Get the attention mask and postition ids.
+ attention_mask, _, position_ids = get_ltor_masks_and_position_ids(
+ tokens,
+ tokenizer.eod,
+ args.reset_position_ids,
+ args.reset_attention_mask,
+ args.eod_mask_loss,
+ )
+
+ return tokens, attention_mask, position_ids
+
+
+def top_k_logits(logits, top_k=0, top_p=0.0, filter_value=-float("Inf")):
+ """This function has been mostly taken from huggingface conversational
+ ai code at
+ https://medium.com/huggingface/how-to-build-a-state-of-the-art-
+ conversational-ai-with-transfer-learning-2d818ac26313"""
+
+ if top_k > 0:
+ # Remove all tokens with a probability less than the
+ # last token of the top-k
+ indices_to_remove = logits < torch.topk(logits, top_k)[0][..., -1, None]
+ logits[indices_to_remove] = filter_value
+
+ if top_p > 0.0:
+ # Cconvert to 1D
+ sorted_logits, sorted_indices = torch.sort(logits, descending=True, dim=-1)
+ cumulative_probs = torch.cumsum(F.softmax(sorted_logits, dim=-1), dim=-1)
+
+ # Remove tokens with cumulative probability above the threshold
+ sorted_indices_to_remove = cumulative_probs > top_p
+ # Shift the indices to the right to keep also the first token
+ # above the threshold
+ sorted_indices_to_remove[..., 1:] = sorted_indices_to_remove[..., :-1].clone()
+ sorted_indices_to_remove[..., 0] = 0
+ for i in range(sorted_indices.size(0)):
+ indices_to_remove = sorted_indices[i][sorted_indices_to_remove[i]]
+ logits[i][indices_to_remove] = filter_value
+
+ return logits
+
+
+def generate_samples_input_from_file(model):
+ args = get_args()
+ tokenizer = get_tokenizer()
+
+ # Read the sample file and open the output file.
+ assert args.sample_input_file is not None, "sample input file is not provided."
+ if mpu.is_pipeline_first_stage() and mpu.get_tensor_model_parallel_rank() == 0:
+ fname = open(args.sample_input_file, "r")
+ all_raw_text = fname.readlines()
+ input_count = len(all_raw_text)
+ input_pos = 0
+ if args.sample_output_file is None:
+ sample_output_file = args.sample_input_file + ".out"
+ print(
+ "`sample-output-file` not specified, setting "
+ "it to {}".format(sample_output_file)
+ )
+ else:
+ sample_output_file = args.sample_output_file
+ fname_out = open(sample_output_file, "w+")
+
+ context_count = 0
+ model.eval()
+ with torch.no_grad():
+ while True:
+ terminate_runs = 0
+ raw_text_len = 0
+
+ if (
+ mpu.is_pipeline_first_stage()
+ and mpu.get_tensor_model_parallel_rank() == 0
+ ):
+ raw_text = all_raw_text[input_pos]
+ input_pos += 1
+ if input_pos == input_count:
+ raw_text = "stop"
+ raw_text_len = len(raw_text)
+
+ if "stop" in raw_text:
+ terminate_runs = 1
+ else:
+ context_tokens = tokenizer.tokenize(raw_text)
+ context_length = len(context_tokens)
+
+ if context_length >= (args.seq_length // 2):
+ print(
+ "\nContext length",
+ context_length,
+ "\nPlease give smaller context (half of the "
+ "sequence length)!",
+ flush=True,
+ )
+ continue
+ else:
+ context_tokens = tokenizer.tokenize("EMPTY TEXT")
+ context_length = 0
+
+ input_info = [terminate_runs, raw_text_len, context_length]
+ input_info_tensor = torch.cuda.LongTensor(input_info)
+ torch.distributed.all_reduce(
+ input_info_tensor, group=mpu.get_model_parallel_group()
+ )
+ terminate_runs = input_info_tensor[0].item()
+ raw_text_len = input_info_tensor[1].item()
+ context_length = input_info_tensor[2].item()
+
+ if terminate_runs == 1:
+ return
+
+ # For pipeline parallel we send context tokens to other stages
+ # so they get the lengths correct
+ if (
+ mpu.get_tensor_model_parallel_rank() == 0
+ and args.pipeline_model_parallel_size > 1
+ ):
+ if mpu.is_pipeline_first_stage():
+ src = mpu.get_pipeline_model_parallel_first_rank()
+ group = mpu.get_pipeline_model_parallel_group()
+ context_tokens_tensor = torch.cuda.LongTensor(context_tokens)
+ torch.distributed.broadcast(context_tokens_tensor, src, group)
+ else:
+ src = mpu.get_pipeline_model_parallel_first_rank()
+ group = mpu.get_pipeline_model_parallel_group()
+ context_tokens_tensor = torch.empty(
+ context_length, dtype=torch.int64, device=torch.device("cuda")
+ )
+ torch.distributed.broadcast(context_tokens_tensor, src, group)
+ context_tokens = context_tokens_tensor.cpu().numpy().tolist()
+
+ token_stream = get_token_stream(model, [context_tokens])
+ for _, decode_tokens in enumerate(token_stream):
+ pass
+
+ if mpu.get_tensor_model_parallel_rank() == 0:
+ if mpu.is_pipeline_first_stage():
+ os.system("clear")
+ print("\nContext:", raw_text, flush=True)
+
+ fname_out.write("\nContext:")
+ fname_out.write(raw_text)
+
+ decode_tokens, _ = decode_tokens
+ decode_tokens = decode_tokens[0].cpu().numpy().tolist()
+ trim_decode_tokens = tokenizer.detokenize(decode_tokens)[
+ raw_text_len:
+ ]
+ print("\nMegatron-LM:", trim_decode_tokens, flush=True)
+
+ fname_out.write("\n\nMegatron-LM:")
+ fname_out.write(trim_decode_tokens)
+ fname_out.write("\n")
+
+ raw_text = None
+ context_count += 1
+
+
+# We added this function to support the tasks evaluation such as squad
+# and drop in the https://github.com/EleutherAI/lm-evaluation-harness
+# codebase. The lm-evaluation-harness code can now call this function
+# similar to their current generate function call used for gpt style models.
+def generate_samples_eval(model, context, max_gen_length, eos_token_id):
+ # Generate samples for lm evaluation
+ # NEED TO THINK ABOUT eos token
+
+ args = get_args()
+ tokenizer = get_tokenizer()
+
+ raw_text_len = len(context)
+ model.eval()
+
+ context_tokens = tokenizer.tokenize(context)
+ args.out_seq_length = max_gen_length + len(context_tokens)
+ args.eos_id = eos_token_id
+
+ with torch.no_grad():
+ token_stream = get_token_stream(model, [context_tokens])
+ for counter, decode_tokens in enumerate(token_stream):
+ if counter == args.out_seq_length:
+ break
+
+ decode_tokens, _ = decode_tokens
+ decode_tokens = decode_tokens[0].cpu().numpy().tolist()
+ trim_decode_tokens = tokenizer.detokenize(decode_tokens)[raw_text_len:]
+
+ return trim_decode_tokens
+
+
+def generate_samples_interactive_code_contest(model, print_frequency=10):
+ args = get_args()
+ tokenizer = get_tokenizer()
+
+ context_count = 0
+ model.eval()
+ with torch.no_grad():
+ while True:
+ terminate_runs = 0
+ raw_text_len = 0
+
+ if (
+ mpu.is_pipeline_first_stage()
+ and mpu.get_tensor_model_parallel_rank() == 0
+ ):
+ # os.system("clear")
+ raw_text = []
+ input_line = input("\nContext prompt (EOF to exit) >>> ")
+
+ if input_line == ":recompute":
+ args.recompute = True
+ print(f"set recompute: {args.recompute}")
+ continue
+
+ if input_line == ":no-recompute":
+ args.recompute = False
+ print(f"set recompute: {args.recompute}")
+ continue
+
+ while input_line != "EOF":
+ raw_text.append(input_line)
+ input_line = input("\nContext prompt (EOF to exit) >>> ")
+ raw_text = "\n".join(raw_text)
+
+ raw_text_len = len(raw_text)
+
+ if "stop" in raw_text:
+ # terminate_runs = 1
+ pass
+ else:
+ context_tokens = tokenizer.tokenize(raw_text)
+ context_length = len(context_tokens)
+
+ if context_length >= (args.seq_length // 2):
+ print(
+ "\nContext length",
+ context_length,
+ "\nPlease give smaller context (half of the "
+ "sequence length)!",
+ flush=True,
+ )
+ continue
+ else:
+ context_tokens = tokenizer.tokenize("EMPTY TEXT")
+ context_length = 0
+
+ input_info = [terminate_runs, raw_text_len, context_length]
+ input_info_tensor = torch.cuda.LongTensor(input_info)
+ torch.distributed.all_reduce(
+ input_info_tensor, group=mpu.get_model_parallel_group()
+ )
+ terminate_runs = input_info_tensor[0].item()
+ raw_text_len = input_info_tensor[1].item()
+ context_length = input_info_tensor[2].item()
+
+ if terminate_runs == 1:
+ return
+
+ # For pipeline parallel we send context tokens to other stages
+ # so they get the lengths correct
+ if (
+ mpu.get_tensor_model_parallel_rank() == 0
+ and args.pipeline_model_parallel_size > 1
+ ):
+ if mpu.is_pipeline_first_stage():
+ src = mpu.get_pipeline_model_parallel_first_rank()
+ group = mpu.get_pipeline_model_parallel_group()
+ context_tokens_tensor = torch.cuda.LongTensor(context_tokens)
+ torch.distributed.broadcast(context_tokens_tensor, src, group)
+ else:
+ src = mpu.get_pipeline_model_parallel_first_rank()
+ group = mpu.get_pipeline_model_parallel_group()
+ context_tokens_tensor = torch.empty(
+ context_length, dtype=torch.int64, device=torch.device("cuda")
+ )
+ torch.distributed.broadcast(context_tokens_tensor, src, group)
+ context_tokens = context_tokens_tensor.cpu().numpy().tolist()
+
+ token_stream = get_token_stream(model, [context_tokens for _ in range(args.micro_batch_size)])
+
+ for counter, decode_tokens in enumerate(token_stream):
+ if (
+ counter % print_frequency != 0
+ or mpu.get_tensor_model_parallel_rank() != 0
+ or not mpu.is_pipeline_first_stage()
+ ):
+ continue
+
+ os.system("clear")
+ print("\nContext:", raw_text, flush=True)
+
+ decode_tokens, _ = decode_tokens
+ decode_tokens = decode_tokens[0].cpu().numpy().tolist()
+ trim_decode_tokens = tokenizer.detokenize(decode_tokens)[raw_text_len:]
+ print(f"\nMegatron-LM (gen len: {counter}):", trim_decode_tokens, flush=True)
+
+ if (
+ mpu.is_pipeline_first_stage()
+ and mpu.get_tensor_model_parallel_rank() == 0
+ ):
+ os.system("clear")
+ print("\nContext:", raw_text, flush=True)
+
+ if not isinstance(decode_tokens, list):
+ decode_tokens, _ = decode_tokens
+ decode_tokens = decode_tokens[0].cpu().numpy().tolist()
+ trim_decode_tokens = tokenizer.detokenize(decode_tokens)[raw_text_len:]
+ print("\nMegatron-LM:", trim_decode_tokens, flush=True)
+
+ input("\nPress Enter to continue >>>")
+
+ raw_text = None
+ context_count += 1
+
+
+def generate_samples_interactive(model, print_frequency=24):
+ args = get_args()
+ tokenizer = get_tokenizer()
+
+ context_count = 0
+ model.eval()
+ with torch.no_grad():
+ while True:
+ terminate_runs = 0
+ raw_text_len = 0
+
+ if (
+ mpu.is_pipeline_first_stage()
+ and mpu.get_tensor_model_parallel_rank() == 0
+ ):
+ os.system("clear")
+ raw_text = input("\nContext prompt (stop to exit) >>> ")
+ while not raw_text:
+ print("Prompt should not be empty!")
+ raw_text = input("\nContext prompt (stop to exit) >>> ")
+ raw_text_len = len(raw_text)
+
+ if "stop" in raw_text:
+ terminate_runs = 1
+ else:
+ context_tokens = tokenizer.tokenize(raw_text)
+ context_length = len(context_tokens)
+
+ if context_length >= (args.seq_length // 2):
+ print(
+ "\nContext length",
+ context_length,
+ "\nPlease give smaller context (half of the "
+ "sequence length)!",
+ flush=True,
+ )
+ continue
+ else:
+ context_tokens = tokenizer.tokenize("EMPTY TEXT")
+ context_length = 0
+
+ input_info = [terminate_runs, raw_text_len, context_length]
+ input_info_tensor = torch.cuda.LongTensor(input_info)
+ torch.distributed.all_reduce(
+ input_info_tensor, group=mpu.get_model_parallel_group()
+ )
+ terminate_runs = input_info_tensor[0].item()
+ raw_text_len = input_info_tensor[1].item()
+ context_length = input_info_tensor[2].item()
+
+ if terminate_runs == 1:
+ return
+
+ # For pipeline parallel we send context tokens to other stages
+ # so they get the lengths correct
+ if (
+ mpu.get_tensor_model_parallel_rank() == 0
+ and args.pipeline_model_parallel_size > 1
+ ):
+ if mpu.is_pipeline_first_stage():
+ src = mpu.get_pipeline_model_parallel_first_rank()
+ group = mpu.get_pipeline_model_parallel_group()
+ context_tokens_tensor = torch.cuda.LongTensor(context_tokens)
+ torch.distributed.broadcast(context_tokens_tensor, src, group)
+ else:
+ src = mpu.get_pipeline_model_parallel_first_rank()
+ group = mpu.get_pipeline_model_parallel_group()
+ context_tokens_tensor = torch.empty(
+ context_length, dtype=torch.int64, device=torch.device("cuda")
+ )
+ torch.distributed.broadcast(context_tokens_tensor, src, group)
+ context_tokens = context_tokens_tensor.cpu().numpy().tolist()
+
+ token_stream = get_token_stream(model, [context_tokens])
+
+ for counter, decode_tokens in enumerate(token_stream):
+ if (
+ counter % print_frequency != 0
+ or mpu.get_tensor_model_parallel_rank() != 0
+ or not mpu.is_pipeline_first_stage()
+ ):
+ continue
+
+ os.system("clear")
+ print("\nContext:", raw_text, flush=True)
+
+ decode_tokens, _ = decode_tokens
+ decode_tokens = decode_tokens[0].cpu().numpy().tolist()
+ trim_decode_tokens = tokenizer.detokenize(decode_tokens)[raw_text_len:]
+ print("\nMegatron-LM:", trim_decode_tokens, flush=True)
+
+ if (
+ mpu.is_pipeline_first_stage()
+ and mpu.get_tensor_model_parallel_rank() == 0
+ ):
+ os.system("clear")
+ print("\nContext:", raw_text, flush=True)
+
+ if not isinstance(decode_tokens, list):
+ decode_tokens, _ = decode_tokens
+ decode_tokens = decode_tokens[0].cpu().numpy().tolist()
+ trim_decode_tokens = tokenizer.detokenize(decode_tokens)[raw_text_len:]
+ print("\nMegatron-LM:", trim_decode_tokens, flush=True)
+
+ input("\nPress Enter to continue >>>")
+
+ raw_text = None
+ context_count += 1
+
+
+def generate_samples_unconditional(model):
+ args = get_args()
+ tokenizer = get_tokenizer()
+
+ num_samples = args.num_samples
+ context_tokens = [[tokenizer.eod] for _ in range(args.micro_batch_size)]
+ ctr = 0
+ while True:
+ start_time = time.time()
+ for token_stream in get_token_stream(model, copy.deepcopy(context_tokens)):
+ pass
+ if mpu.is_pipeline_last_stage() and mpu.get_tensor_model_parallel_rank() == 0:
+ if ctr % args.log_interval == 0:
+ print(
+ "Avg s/batch:",
+ (time.time() - start_time) / min(args.log_interval, ctr + 1),
+ )
+ start_time = time.time()
+ length = len(token_stream)
+ token_batch = token_stream[0].cpu().numpy().tolist()
+ length_batch = token_stream[1].cpu().numpy().tolist()
+ assert len(length_batch) == args.micro_batch_size
+ for tokens, length in zip(token_batch, length_batch):
+ tokens = tokens[1: length - 1]
+ text = tokenizer.detokenize(tokens)
+ is_finished = length < args.seq_length - 1
+ datum = {"text": text, "length": length - 1, "finished": is_finished}
+ yield datum
+ ctr += 1
+ if ctr >= num_samples:
+ break
+ else:
+ for _ in range(args.micro_batch_size):
+ yield None
+ ctr += 1
+ if ctr >= num_samples:
+ break
+ if ctr >= num_samples:
+ break
+
+
+def generate_and_write_samples_unconditional(model):
+ args = get_args()
+ assert args.genfile is not None
+ with open(args.genfile, "w") as f:
+ for datum in generate_samples_unconditional(model):
+ if (
+ mpu.is_pipeline_last_stage()
+ and mpu.get_tensor_model_parallel_rank() == 0
+ ):
+ f.write(json.dumps(datum) + "\n")
+
+
+def pad_batch(batch, pad_id, args):
+ context_lengths = []
+ for tokens in batch:
+ context_length = len(tokens)
+ if context_length < args.seq_length:
+ tokens.extend([pad_id] * (args.seq_length - context_length))
+ context_lengths.append(context_length)
+ return batch, context_lengths
+
+
+def topk_sampling(logits: torch.FloatTensor, num_samples: int):
+ """
+ Samples from a multinomial distribution using the top-k sampling strategy.
+
+ Args:
+ logits: A tensor of shape (batch_size, vocab_size) containing the logits.
+ num_samples: The number of samples to draw.
+ """
+ log_prob = F.log_softmax(logits, dim=-1)
+ topk = torch.topk(log_prob, num_samples, dim=-1)
+ topk_tokens = topk.indices
+ topk_log_prob = topk.values
+
+ return topk_tokens, topk_log_prob
+
+
+def nuclear_sampling(logits: torch.FloatTensor, temperature: float, top_p: float = None, top_k: int = None):
+ orig_log_probs = F.log_softmax(logits, dim=-1)
+ logits /= temperature
+ logits = top_k_logits(logits, top_k, top_p)
+ log_probs = F.softmax(logits, dim=-1)
+ tokens = torch.multinomial(log_probs, num_samples=1).view(-1)
+
+ indices = tokens.view(-1, 1)
+ new_scores = orig_log_probs.gather(1, indices).view(-1)
+
+ return tokens, new_scores
+
+
+def sample_topk_tokens(model,
+ input_tokens, attention_mask, position_ids,
+ context_length: int, num_samples: int):
+ assert context_length < input_tokens.shape[-1], "context_length must be smaller than seq_length"
+
+ model.eval()
+ with torch.no_grad():
+ output = forward_step(
+ model,
+ input_tokens,
+ position_ids,
+ attention_mask,
+ tokentype_ids=None,
+ forward_method_parallel_output=False,
+ )
+ assert output is not None
+ logits = output[:, context_length - 1, :]
+
+ return topk_sampling(logits, num_samples)
+
+
+def nuclear_sample_tokens(model,
+ input_tokens, attention_mask, position_ids,
+ context_length: int, temperature: float, top_p: float, top_k: int):
+ assert context_length < input_tokens.shape[-1], "context_length must be smaller than seq_length"
+
+ model.eval()
+ with torch.no_grad():
+ output = forward_step(
+ model,
+ input_tokens,
+ position_ids,
+ attention_mask,
+ tokentype_ids=None,
+ forward_method_parallel_output=False,
+ )
+ assert output is not None
+ logits = output[:, context_length - 1, :]
+ return nuclear_sampling(logits, temperature, top_p, top_k)
+
+
+@dataclass
+class Beam:
+ tokens: List[int]
+ score: float
+
+ def __repr__(self):
+ return f""
+
+ def get_code(self):
+ return get_tokenizer().detokenize(self.tokens)
+
+
+def expand_beams(beams: List[Beam], num_beams: int, model) -> List[Beam]:
+ args = get_args()
+ tokenizer = get_tokenizer()
+
+ context_tokens = [b.tokens.copy() for b in beams]
+ context_tokens, context_lengths = pad_batch(context_tokens, tokenizer.eod, args)
+
+ context_lengths = set(context_lengths)
+ assert len(context_lengths) == 1, "context_lengths must be the same"
+ context_length = list(context_lengths)[0]
+
+ context_tokens_tensor = torch.cuda.LongTensor(context_tokens)
+ tokens, attention_mask, position_ids = get_batch_(context_tokens_tensor)
+ tokens, scores = sample_topk_tokens(model, tokens, attention_mask, position_ids, context_length, num_beams)
+ tokens = tokens.detach().cpu().tolist()
+ scores = scores.detach().cpu().tolist()
+ assert len(tokens) == len(beams), "output tokens and input beams must have the same length"
+
+ all_beams = []
+ for i in range(len(beams)):
+ this_tokens = tokens[i]
+ this_scores = scores[i]
+
+ for token, score in zip(this_tokens, this_scores):
+ all_beams.append(Beam(beams[i].tokens + [token], beams[i].score + score))
+
+ return all_beams
+
+
+def beam_search(model, context_tokens, num_beams: int):
+ """Beam search.
+
+ Note that this function does not support model parallel!
+ """
+ args = get_args()
+ tokenizer = get_tokenizer()
+
+ assert not isinstance(context_tokens[0], list), "batched beam search not supported"
+
+ initial_beam = Beam(context_tokens, 0.0)
+ context_len = len(context_tokens)
+ org_context_len = context_len
+ finished_beams = []
+
+ # first expansion
+ beams = expand_beams([initial_beam], num_beams, model)
+ context_len += 1
+
+ # print(f"initial beam: {initial_beam}")
+
+ while len(beams) > 0 and context_len < args.seq_length:
+ expanded_beams = expand_beams(beams, num_beams, model)
+ next_beams = []
+ for beam in expanded_beams:
+ if args.beam_warmup_length > 0:
+ if len(beam.tokens) >= org_context_len + args.beam_warmup_length or beam.tokens[-1] == tokenizer.eod:
+ finished_beams.append(beam)
+ else:
+ next_beams.append(beam)
+ else:
+ if beam.tokens[-1] == tokenizer.eod:
+ finished_beams.append(beam)
+ else:
+ next_beams.append(beam)
+ # only keep top-k beams
+ next_beams.sort(key=lambda b: b.score, reverse=True)
+ beams = next_beams[:num_beams]
+ context_len += 1
+
+ if len(finished_beams) >= num_beams:
+ # first, only keep top-k beams
+ finished_beams.sort(key=lambda b: b.score, reverse=True)
+ finished_beams = finished_beams[:num_beams]
+ return finished_beams # return finished beams with highest scores
+ # stop if all currently expanding beams has a score lower than the minimal score of finished ones
+ min_score = min([b.score for b in finished_beams])
+ if min_score >= beams[0].score:
+ break
+ else:
+ print(f"we have got enough finished beams, but the minimal score is {min_score}")
+ print(f"and the maximum searching score is {beams[0].score}")
+
+ # return top-k finished and unfinished beams
+ all_beams = finished_beams + beams
+ all_beams.sort(key=lambda b: b.score, reverse=True)
+
+ return all_beams[:num_beams]
+
+
+@dataclass
+class Handle:
+ tokens: List[int]
+ score: float
+
+ def __repr__(self):
+ return f""
+
+ def is_finished(self):
+ return len(self.tokens) and self.tokens[-1] == get_tokenizer().eod
+
+ def derived(self, new_token: int, log_prob: float):
+ assert not self.is_finished(), "cannot derive from a finished handle"
+ return Handle(self.tokens + [new_token], self.score + log_prob)
+
+
+def expand_handles(handles: List[Handle], temperature: float, top_p: float, top_k: int, model):
+ args = get_args()
+ tokenizer = get_tokenizer()
+
+ context_tokens = [b.tokens.copy() for b in handles]
+ context_tokens, context_lengths = pad_batch(context_tokens, tokenizer.eod, args)
+
+ context_lengths = set(context_lengths)
+ assert len(context_lengths) == 1, "context_lengths must be the same"
+ context_length = list(context_lengths)[0]
+
+ context_tokens_tensor = torch.cuda.LongTensor(context_tokens)
+ tokens, attention_mask, position_ids = get_batch_(context_tokens_tensor)
+ tokens, scores = nuclear_sample_tokens(model, tokens, attention_mask, position_ids, context_length, temperature,
+ top_p, top_k)
+ tokens = tokens.detach().cpu().tolist()
+ scores = scores.detach().cpu().tolist()
+ assert len(tokens) == len(handles), "output tokens and input must have the same length"
+
+ all_beams = []
+ for i in range(len(handles)):
+ this_tokens = tokens[i]
+ this_scores = scores[i]
+
+ all_beams.append(handles[i].derived(this_tokens, this_scores))
+
+ return all_beams
+
+
+def generate_nuclear_sampling(model, context_tokens, num_samples: int, temperature: float, top_p: float, top_k: int):
+ """Beam search.
+
+ Note that this function does not support model parallel!
+ """
+ args = get_args()
+ tokenizer = get_tokenizer()
+
+ assert not isinstance(context_tokens[0], list), "batched beam search not supported"
+
+ handles = [Handle(tokens=context_tokens, score=0) for _ in range(num_samples)]
+ context_len = len(context_tokens)
+ finished_handles = []
+
+ while len(handles) > 0 and context_len < args.seq_length:
+ expanded_handles = expand_handles(handles, temperature, top_p, top_k, model)
+
+ new_handles = []
+ for h in expanded_handles:
+ if h.is_finished():
+ finished_handles.append(h)
+ else:
+ new_handles.append(h)
+
+ context_len += 1
+ handles = new_handles
+
+ return handles + finished_handles
+
+
+def forward_step(
+ model,
+ tokens,
+ position_ids,
+ attention_mask,
+ tokentype_ids,
+ layer_past=None,
+ get_key_value=None,
+ forward_method_parallel_output=None,
+ prompt_length=None,
+ context_length=None,
+):
+ # Hidden size changes when not using recompute, need to tell p2p_communicate
+ # functions the correct size
+ args = get_args()
+ orig_seq_length = args.seq_length
+ args.seq_length = tokens.shape[1]
+
+ # Forward pass through the model.
+ output_tensor = model(
+ tokens,
+ position_ids,
+ attention_mask,
+ tokentype_ids=tokentype_ids,
+ layer_past=layer_past,
+ get_key_value=get_key_value,
+ prompt_length=prompt_length,
+ context_length=context_length,
+ )
+
+ if get_key_value:
+ output_tensor, layer_past = output_tensor
+
+ args.seq_length = orig_seq_length
+ if get_key_value:
+ return output_tensor, layer_past
+
+ return output_tensor
+
+
+def get_token_stream(
+ model,
+ context_tokens,
+ return_scores: bool = False,
+ prompt_length: int = None,
+ micro_batch_size: int = None,
+ bad_ids: List = None,
+ temperature: float = None,
+ topp: float = None,
+ topk: int = None,
+ beam_warmup: bool = False,
+):
+ args = get_args()
+ tokenizer = get_tokenizer()
+
+ context_tokens, context_lengths = pad_batch(context_tokens, tokenizer.eod, args)
+
+ context_tokens_tensor = torch.cuda.LongTensor(context_tokens)
+ context_length_tensor = torch.cuda.LongTensor(context_lengths)
+
+ torch.distributed.broadcast(
+ context_length_tensor,
+ mpu.get_tensor_model_parallel_src_rank(),
+ group=mpu.get_tensor_model_parallel_group(),
+ )
+ torch.distributed.broadcast(
+ context_tokens_tensor,
+ mpu.get_tensor_model_parallel_src_rank(),
+ group=mpu.get_tensor_model_parallel_group(),
+ )
+
+ context_length = context_length_tensor.min().item()
+ tokens, attention_mask, position_ids = get_batch(context_tokens_tensor, micro_batch_size)
+
+ if beam_warmup:
+ batch_token_iterator = sample_sequence_batch_beam(
+ model,
+ context_tokens_tensor,
+ context_length_tensor,
+ attention_mask,
+ position_ids,
+ return_scores=return_scores,
+ prompt_length=prompt_length,
+ bad_ids=bad_ids,
+ temperature=temperature,
+ topp=topp,
+ topk=topk,
+ beam_warmup=True,
+ )
+ else:
+ batch_token_iterator = sample_sequence_batch(
+ model,
+ context_tokens_tensor,
+ context_length_tensor,
+ attention_mask,
+ position_ids,
+ return_scores=return_scores,
+ prompt_length=prompt_length,
+ bad_ids=bad_ids,
+ temperature=temperature,
+ topp=topp,
+ topk=topk,
+ )
+
+ for tokens, lengths in batch_token_iterator:
+ context_length += 1
+ if tokens is not None:
+ yield tokens[:, :context_length], lengths
+ else:
+ yield None, None
+
+
+def switch(val1, val2, boolean):
+ boolean = boolean.type_as(val1)
+ return (1 - boolean) * val1 + boolean * val2
+
+
+def sample_sequence_batch(
+ model,
+ context_tokens,
+ context_lengths,
+ attention_mask,
+ position_ids,
+ maxlen=None,
+ type_ids=None,
+ return_scores: bool = False,
+ prompt_length: int = None,
+ bad_ids: List = None,
+ temperature: float = None,
+ topp: float = None,
+ topk: int = None,
+):
+ args = get_args()
+ tokenizer = get_tokenizer()
+ temperature = temperature if temperature is not None else args.temperature
+ topp = topp if topp is not None else args.top_p
+ topk = topk if topk is not None else args.top_k
+
+ model.eval()
+ with torch.no_grad():
+ context_length = context_lengths.min().item()
+
+ # added eos_id to support the function generate_samples_eval that passes
+ # eos_id as an argument and needs termination when that id id found.
+ if hasattr(args, "eos_id"):
+ eos_id = args.eos_id
+ else:
+ eos_id = tokenizer.eod
+
+ counter = 0
+ org_context_length = context_length
+
+ layer_past = None
+ batch_size = context_tokens.size(0)
+ is_done = torch.zeros([batch_size]).byte().cuda()
+ tokens = context_tokens
+ if maxlen is None:
+ maxlen = args.seq_length - 1
+ if maxlen > (org_context_length + args.out_seq_length):
+ maxlen = org_context_length + args.out_seq_length
+
+ lengths = torch.ones([batch_size]).long().cuda() * maxlen
+ if return_scores:
+ scores = torch.zeros([batch_size]).float().cuda()
+
+ while context_length <= (maxlen):
+
+ if args.recompute:
+ logits = model(tokens,
+ position_ids,
+ attention_mask,
+ tokentype_ids=type_ids,
+ forward_method_parallel_output=False,
+ prompt_length=prompt_length,
+ context_length=context_length,
+ )
+ logits = logits[:, context_length - 1, :]
+ else:
+ types2use = None
+ if counter == 0:
+ tokens2use = tokens[:, :context_length]
+ positions2use = position_ids[:, :context_length]
+ if type_ids is not None:
+ types2use = type_ids[:, :context_length]
+ else:
+ tokens2use = tokens[:, context_length - 1].view(
+ batch_size, -1)
+ positions2use = position_ids[:, context_length - 1].view(
+ batch_size, -1)
+ if type_ids is not None:
+ types2use = type_ids[:, context_length - 1].view(
+ batch_size, -1)
+ logits, layer_past = model(tokens2use,
+ positions2use,
+ attention_mask,
+ layer_past=layer_past,
+ get_key_value=True,
+ tokentype_ids=types2use,
+ forward_method_parallel_output=False,
+ prompt_length=prompt_length,
+ context_length=context_length,
+ )
+ logits = logits[:, -1].view(batch_size, -1).contiguous()
+
+ if mpu.is_pipeline_last_stage():
+ if bad_ids is not None:
+ for bad_id in bad_ids:
+ logits[:, bad_id] = -10000
+ if args.greedy:
+ prev = torch.argmax(logits, dim=-1).view(-1)
+ else:
+ logits = logits.float()
+ if return_scores:
+ orig_log_probs = torch.log_softmax(logits, dim=-1)
+ logits /= temperature
+ logits = top_k_logits(logits, top_k=topk, top_p=topp)
+ log_probs = F.softmax(logits, dim=-1)
+ prev = torch.multinomial(log_probs, num_samples=1).view(-1)
+
+ started = context_lengths <= context_length
+
+ new_tokens = switch(tokens[:, context_length].view(-1), prev, started)
+
+ if not args.greedy and return_scores:
+ indices = prev.view(-1, 1)
+ new_scores = orig_log_probs.gather(1, indices).view(-1)
+ new_scores = new_scores * started
+ new_scores = new_scores * is_done.bool().logical_not()
+ scores += new_scores
+
+ tokens[:, context_length] = new_tokens
+ src = mpu.get_pipeline_model_parallel_last_rank()
+ group = mpu.get_embedding_group()
+ torch.distributed.broadcast(new_tokens, src, group)
+
+ done_token = (prev == eos_id).byte() & started.byte()
+ just_finished = (done_token & ~is_done).bool()
+ lengths[just_finished.view(-1)] = context_length
+ is_done = is_done | done_token
+
+ done = torch.all(is_done)
+ src = mpu.get_pipeline_model_parallel_last_rank()
+ group = mpu.get_pipeline_model_parallel_group()
+ torch.distributed.broadcast(done, src, group)
+
+ if return_scores:
+ yield tokens, (lengths, scores)
+ else:
+ yield tokens, lengths
+
+ else:
+ if mpu.is_pipeline_first_stage():
+ src = mpu.get_pipeline_model_parallel_last_rank()
+ group = mpu.get_embedding_group()
+ new_tokens = torch.empty_like(tokens[:, context_length])
+ torch.distributed.broadcast(new_tokens, src, group)
+ tokens[:, context_length] = new_tokens
+ yield tokens, None
+ else:
+ yield None, None
+
+ done = torch.cuda.ByteTensor([0])
+ src = mpu.get_pipeline_model_parallel_last_rank()
+ group = mpu.get_pipeline_model_parallel_group()
+ torch.distributed.broadcast(done, src, group)
+
+ context_length += 1
+ counter += 1
+ if done:
+ break
+
+
+def sample_sequence_batch_beam(
+ model,
+ context_tokens,
+ context_lengths,
+ attention_mask,
+ position_ids,
+ maxlen=None,
+ type_ids=None,
+ return_scores: bool = False,
+ prompt_length: int = None,
+ bad_ids: List = None,
+ temperature: float = None,
+ topp: float = None,
+ topk: int = None,
+ beam_warmup: bool = False,
+):
+ args = get_args()
+ tokenizer = get_tokenizer()
+ temperature = temperature if temperature is not None else args.temperature
+ topp = topp if topp is not None else args.top_p
+ topk = topk if topk is not None else args.top_k
+
+ model.eval()
+ with torch.no_grad():
+ context_length = context_lengths.min().item()
+
+ # added eos_id to support the function generate_samples_eval that passes
+ # eos_id as an argument and needs termination when that id id found.
+ if hasattr(args, "eos_id"):
+ eos_id = args.eos_id
+ else:
+ eos_id = tokenizer.eod
+
+ counter = 0
+ org_context_length = context_length
+
+ layer_past = None
+ batch_size = context_tokens.size(0)
+ is_done = torch.zeros([batch_size]).byte().cuda()
+ tokens = context_tokens
+ if maxlen is None:
+ maxlen = args.seq_length - 1
+ if maxlen > (org_context_length + args.out_seq_length):
+ maxlen = org_context_length + args.out_seq_length
+
+ lengths = torch.ones([batch_size]).long().cuda() * maxlen
+ if return_scores:
+ scores = torch.zeros([batch_size]).float().cuda()
+
+ if beam_warmup:
+ beams = beam_search(model, context_tokens=tokens.cpu().numpy().tolist()[0][:context_length],
+ num_beams=args.num_beams)
+ beam = beams[0]
+ tokens_ = beam.tokens
+ tokens_ = (tokens_ if tokens_[-1] != tokenizer.eod else tokens_[:-1])
+ tokens_warmup = []
+ for i in range(batch_size):
+ tokens_warmup.append(tokens_.copy())
+ tokens, context_lengths = pad_batch(tokens_warmup, tokenizer.eod, args)
+ tokens = torch.cuda.LongTensor(tokens)
+ context_lengths = torch.cuda.LongTensor(context_lengths)
+ context_length = len(tokens_)
+ org_context_length = context_length
+ if maxlen is None:
+ maxlen = args.seq_length - 1
+ if maxlen > (org_context_length + args.out_seq_length):
+ maxlen = org_context_length + args.out_seq_length
+ lengths = torch.ones([batch_size]).long().cuda() * maxlen
+ tokens, attention_mask, position_ids = get_batch(tokens, batch_size)
+
+ while context_length <= (maxlen):
+ if args.recompute:
+ logits = model(tokens,
+ position_ids,
+ attention_mask,
+ tokentype_ids=type_ids,
+ forward_method_parallel_output=False,
+ prompt_length=prompt_length,
+ context_length=context_length,
+ )
+ logits = logits[:, context_length - 1, :]
+ else:
+ types2use = None
+ if counter == 0:
+ tokens2use = tokens[:, :context_length]
+ positions2use = position_ids[:, :context_length]
+ if type_ids is not None:
+ types2use = type_ids[:, :context_length]
+ else:
+ tokens2use = tokens[:, context_length - 1].view(
+ batch_size, -1)
+ positions2use = position_ids[:, context_length - 1].view(
+ batch_size, -1)
+ if type_ids is not None:
+ types2use = type_ids[:, context_length - 1].view(
+ batch_size, -1)
+ logits, layer_past = model(tokens2use,
+ positions2use,
+ attention_mask,
+ layer_past=layer_past,
+ get_key_value=True,
+ tokentype_ids=types2use,
+ forward_method_parallel_output=False,
+ prompt_length=prompt_length,
+ context_length=context_length,
+ )
+ logits = logits[:, -1].view(batch_size, -1).contiguous()
+
+ if mpu.is_pipeline_last_stage():
+ if bad_ids is not None:
+ for bad_id in bad_ids:
+ logits[:, bad_id] = -10000
+ if args.greedy:
+ prev = torch.argmax(logits, dim=-1).view(-1)
+ else:
+ logits = logits.float()
+ if return_scores:
+ orig_log_probs = torch.log_softmax(logits, dim=-1)
+ logits /= temperature
+ logits = top_k_logits(logits, top_k=topk, top_p=topp)
+ log_probs = F.softmax(logits, dim=-1)
+ prev = torch.multinomial(log_probs, num_samples=1).view(-1)
+
+ started = context_lengths <= context_length
+
+ new_tokens = switch(tokens[:, context_length].view(-1), prev, started)
+
+ if not args.greedy and return_scores:
+ indices = prev.view(-1, 1)
+ new_scores = orig_log_probs.gather(1, indices).view(-1)
+ new_scores = new_scores * started
+ new_scores = new_scores * is_done.bool().logical_not()
+ scores += new_scores
+
+ tokens[:, context_length] = new_tokens
+ src = mpu.get_pipeline_model_parallel_last_rank()
+ group = mpu.get_embedding_group()
+ torch.distributed.broadcast(new_tokens, src, group)
+
+ done_token = (prev == eos_id).byte() & started.byte()
+ just_finished = (done_token & ~is_done).bool()
+ lengths[just_finished.view(-1)] = context_length
+ is_done = is_done | done_token
+
+ done = torch.all(is_done)
+ src = mpu.get_pipeline_model_parallel_last_rank()
+ group = mpu.get_pipeline_model_parallel_group()
+ torch.distributed.broadcast(done, src, group)
+
+ if return_scores:
+ yield tokens, (lengths, scores)
+ else:
+ yield tokens, lengths
+
+ else:
+ if mpu.is_pipeline_first_stage():
+ src = mpu.get_pipeline_model_parallel_last_rank()
+ group = mpu.get_embedding_group()
+ new_tokens = torch.empty_like(tokens[:, context_length])
+ torch.distributed.broadcast(new_tokens, src, group)
+ tokens[:, context_length] = new_tokens
+ yield tokens, None
+ else:
+ yield None, None
+
+ done = torch.cuda.ByteTensor([0])
+ src = mpu.get_pipeline_model_parallel_last_rank()
+ group = mpu.get_pipeline_model_parallel_group()
+ torch.distributed.broadcast(done, src, group)
+
+ context_length += 1
+ counter += 1
+ if done:
+ break
diff --git a/codegeex/megatron/global_vars.py b/codegeex/megatron/global_vars.py
new file mode 100644
index 0000000..7920bc3
--- /dev/null
+++ b/codegeex/megatron/global_vars.py
@@ -0,0 +1,256 @@
+# coding=utf-8
+# Copyright (c) 2020, NVIDIA CORPORATION. All rights reserved.
+#
+# Licensed under the Apache License, Version 2.0 (the "License");
+# you may not use this file except in compliance with the License.
+# You may obtain a copy of the License at
+#
+# http://www.apache.org/licenses/LICENSE-2.0
+#
+# Unless required by applicable law or agreed to in writing, software
+# distributed under the License is distributed on an "AS IS" BASIS,
+# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
+# See the License for the specific language governing permissions and
+# limitations under the License.
+
+"""Megatron global variables."""
+
+import os
+import sys
+import time
+import torch
+
+from codegeex.megatron.tokenizer import build_tokenizer
+from codegeex.megatron.arguments import parse_args
+
+_GLOBAL_ARGS = None
+_GLOBAL_NUM_MICROBATCHES_CALCULATOR = None
+_GLOBAL_TOKENIZER = None
+_GLOBAL_TENSORBOARD_WRITER = None
+_GLOBAL_ADLR_AUTORESUME = None
+_GLOBAL_TIMERS = None
+
+
+def get_args():
+ """Return arguments."""
+ _ensure_var_is_initialized(_GLOBAL_ARGS, "args")
+ return _GLOBAL_ARGS
+
+
+def get_num_microbatches():
+ return _GLOBAL_NUM_MICROBATCHES_CALCULATOR.get()
+
+
+def get_current_global_batch_size():
+ return _GLOBAL_NUM_MICROBATCHES_CALCULATOR.get_current_global_batch_size()
+
+
+def update_num_microbatches(consumed_samples, consistency_check=True):
+ _GLOBAL_NUM_MICROBATCHES_CALCULATOR.update(consumed_samples, consistency_check)
+
+
+def get_tokenizer():
+ """Return tokenizer."""
+ _ensure_var_is_initialized(_GLOBAL_TOKENIZER, "tokenizer")
+ return _GLOBAL_TOKENIZER
+
+
+def get_tensorboard_writer():
+ """Return tensorboard writer. It can be None so no need
+ to check if it is initialized."""
+ return _GLOBAL_TENSORBOARD_WRITER
+
+
+def get_adlr_autoresume():
+ """ADLR autoresume object. It can be None so no need
+ to check if it is initialized."""
+ return _GLOBAL_ADLR_AUTORESUME
+
+
+def get_timers():
+ """Return timers."""
+ _ensure_var_is_initialized(_GLOBAL_TIMERS, "timers")
+ return _GLOBAL_TIMERS
+
+
+def set_global_variables(
+ extra_args_provider=None, args_defaults={}, ignore_unknown_args=False
+):
+ """Set args, tokenizer, tensorboard-writer, adlr-autoresume, and timers."""
+ args = _parse_args(
+ extra_args_provider=extra_args_provider,
+ defaults=args_defaults,
+ ignore_unknown_args=ignore_unknown_args,
+ )
+ if args.vocab_file or args.tokenizer_path:
+ _ = _build_tokenizer(args)
+ _set_tensorboard_writer(args)
+ _set_adlr_autoresume(args)
+ _set_timers()
+
+
+def _parse_args(extra_args_provider=None, defaults={}, ignore_unknown_args=False):
+ """Parse entire arguments."""
+ global _GLOBAL_ARGS
+ _ensure_var_is_not_initialized(_GLOBAL_ARGS, "args")
+ _GLOBAL_ARGS = parse_args(
+ extra_args_provider=extra_args_provider,
+ defaults=defaults,
+ ignore_unknown_args=ignore_unknown_args,
+ )
+ return _GLOBAL_ARGS
+
+
+def _build_tokenizer(args):
+ """Initialize tokenizer."""
+ global _GLOBAL_TOKENIZER
+ _ensure_var_is_not_initialized(_GLOBAL_TOKENIZER, "tokenizer")
+ _GLOBAL_TOKENIZER = build_tokenizer(args)
+ return _GLOBAL_TOKENIZER
+
+
+def rebuild_tokenizer(args):
+ global _GLOBAL_TOKENIZER
+ _GLOBAL_TOKENIZER = None
+ return _build_tokenizer(args)
+
+
+def _set_tensorboard_writer(args):
+ """Set tensorboard writer."""
+ global _GLOBAL_TENSORBOARD_WRITER
+ _ensure_var_is_not_initialized(_GLOBAL_TENSORBOARD_WRITER, "tensorboard writer")
+
+ if (
+ hasattr(args, "tensorboard_dir")
+ and args.tensorboard_dir
+ and args.rank == (args.world_size - 1)
+ ):
+ try:
+ from torch.utils.tensorboard import SummaryWriter
+
+ print("> setting tensorboard ...")
+ _GLOBAL_TENSORBOARD_WRITER = SummaryWriter(
+ log_dir=args.tensorboard_dir, max_queue=args.tensorboard_queue_size
+ )
+ except ModuleNotFoundError:
+ print(
+ "WARNING: TensorBoard writing requested but is not "
+ "available (are you using PyTorch 1.1.0 or later?), "
+ "no TensorBoard logs will be written.",
+ flush=True,
+ )
+
+
+def _set_adlr_autoresume(args):
+ """Initialize ADLR autoresume."""
+ global _GLOBAL_ADLR_AUTORESUME
+ _ensure_var_is_not_initialized(_GLOBAL_ADLR_AUTORESUME, "adlr autoresume")
+
+ if args.adlr_autoresume:
+ if args.rank == 0:
+ print("enabling autoresume ...", flush=True)
+ sys.path.append(os.environ.get("SUBMIT_SCRIPTS", "."))
+ try:
+ from userlib.auto_resume import AutoResume
+ except BaseException:
+ print("ADLR autoresume is not available, exiting ...")
+ sys.exit()
+
+ _GLOBAL_ADLR_AUTORESUME = AutoResume
+
+
+def _set_timers():
+ """Initialize timers."""
+ global _GLOBAL_TIMERS
+ _ensure_var_is_not_initialized(_GLOBAL_TIMERS, "timers")
+ _GLOBAL_TIMERS = Timers()
+
+
+def _ensure_var_is_initialized(var, name):
+ """Make sure the input variable is not None."""
+ assert var is not None, "{} is not initialized.".format(name)
+
+
+def _ensure_var_is_not_initialized(var, name):
+ """Make sure the input variable is not None."""
+ assert var is None, "{} is already initialized.".format(name)
+
+
+class _Timer:
+ """Timer."""
+
+ def __init__(self, name):
+ self.name_ = name
+ self.elapsed_ = 0.0
+ self.started_ = False
+ self.start_time = time.time()
+
+ def start(self):
+ """Start the timer."""
+ assert not self.started_, "timer has already been started"
+ torch.cuda.synchronize()
+ self.start_time = time.time()
+ self.started_ = True
+
+ def stop(self):
+ """Stop the timer."""
+ assert self.started_, "timer is not started"
+ torch.cuda.synchronize()
+ self.elapsed_ += time.time() - self.start_time
+ self.started_ = False
+
+ def reset(self):
+ """Reset timer."""
+ self.elapsed_ = 0.0
+ self.started_ = False
+
+ def elapsed(self, reset=True):
+ """Calculate the elapsed time."""
+ started_ = self.started_
+ # If the timing in progress, end it first.
+ if self.started_:
+ self.stop()
+ # Get the elapsed time.
+ elapsed_ = self.elapsed_
+ # Reset the elapsed time
+ if reset:
+ self.reset()
+ # If timing was in progress, set it back.
+ if started_:
+ self.start()
+ return elapsed_
+
+
+class Timers:
+ """Group of timers."""
+
+ def __init__(self):
+ self.timers = {}
+
+ def __call__(self, name):
+ if name not in self.timers:
+ self.timers[name] = _Timer(name)
+ return self.timers[name]
+
+ def write(self, names, writer, iteration, normalizer=1.0, reset=False):
+ """Write timers to a tensorboard writer"""
+ # currently when using add_scalars,
+ # torch.utils.add_scalars makes each timer its own run, which
+ # polutes the runs list, so we just add each as a scalar
+ assert normalizer > 0.0
+ for name in names:
+ value = self.timers[name].elapsed(reset=reset) / normalizer
+ writer.add_scalar(name + "-time", value, iteration)
+
+ def log(self, names, normalizer=1.0, reset=True):
+ """Log a group of timers."""
+ assert normalizer > 0.0
+ string = "time (ms)"
+ for name in names:
+ elapsed_time = self.timers[name].elapsed(reset=reset) * 1000.0 / normalizer
+ string += " | {}: {:.2f}".format(name, elapsed_time)
+ if torch.distributed.is_initialized():
+ if torch.distributed.get_rank() == (torch.distributed.get_world_size() - 1):
+ print(string, flush=True)
+ else:
+ print(string, flush=True)
diff --git a/codegeex/megatron/inference.py b/codegeex/megatron/inference.py
new file mode 100644
index 0000000..53b018f
--- /dev/null
+++ b/codegeex/megatron/inference.py
@@ -0,0 +1,244 @@
+import copy
+import json
+import random
+import traceback
+from typing import *
+
+import numpy
+import torch
+import zmq
+
+from codegeex.benchmark.utils import is_code_generation_finished, cleanup_code
+from codegeex.megatron import get_args, get_tokenizer
+from codegeex.megatron import mpu
+from codegeex.megatron.code_generation_utils import get_token_stream
+from codegeex.megatron.model import CodeGeeXModel
+
+
+def model_provider():
+ """Build the model."""
+
+ model = CodeGeeXModel(num_tokentypes=0,
+ parallel_output=False)
+
+ return model
+
+
+def set_random_seed(seed):
+ """Set random seed for reproducability."""
+ random.seed(seed)
+ numpy.random.seed(seed)
+ torch.manual_seed(seed)
+ mpu.model_parallel_cuda_manual_seed(seed)
+
+
+def run_generation_distributed(model):
+ args = get_args()
+ if hasattr(args, "language_tgt_type"):
+ language_type = args.language_tgt_type
+ else:
+ language_type = args.language_type
+ print(f"Connecting to tcp://{args.channel_ip}:{args.channel_port}")
+ context = zmq.Context()
+ socket = context.socket(zmq.REQ)
+ socket.connect(f"tcp://{args.channel_ip}:{args.channel_port}")
+ output_file_path = args.output_prefix + f"_finished_rank{args.gen_rank}.jsonl"
+ unfinished_output_file_path = args.output_prefix + f"_unfinished_rank{args.gen_rank}.jsonl"
+ problems = {}
+ print("Building tokenizer...")
+ tokenizer = get_tokenizer()
+
+ with open(output_file_path, "w") as f:
+ with open(unfinished_output_file_path, "w") as unfinished_f:
+ while True:
+ socket.send_json({"rank": args.gen_rank, "action": "pull"})
+ resp = socket.recv_json()
+ try:
+ if "codecontest" in args.dataset.lower():
+ if resp["contest_name"] is None:
+ break
+ elif resp["task_id"] is None:
+ break
+
+ if "codecontest" in args.dataset.lower():
+ current_spec = problems[resp["contest_name"]]
+ prompt = current_spec.prompt
+ else:
+ current_spec = resp["task_id"]
+ prompt = current_spec["prompt"]
+
+ temperature = None if "temperature" not in resp else resp["temperature"]
+ topp = None if "topp" not in resp else resp["topp"]
+
+ f.flush()
+ unfinished_f.flush()
+ tokens = tokenizer.tokenize(prompt)
+ n_token_prompt = len(tokens)
+ if n_token_prompt >= args.seq_length:
+ continue
+ if "micro_batch_size" in resp:
+ micro_batch_size = resp["micro_batch_size"]
+ else:
+ micro_batch_size = args.micro_batch_size
+ if args.beam_search:
+ beams = get_token_stream(
+ model,
+ [
+ copy.deepcopy(tokens)
+ for _ in range(micro_batch_size)
+ ],
+ return_scores=args.return_scores,
+ prompt_length=n_token_prompt,
+ micro_batch_size=micro_batch_size,
+ bad_ids=args.bad_ids,
+ temperature=temperature,
+ topp=topp,
+ beam_warmup=args.beam_warmup,
+ )
+ for beam in beams:
+ generated_tokens_ = beam.tokens
+ generated_tokens_ = (
+ generated_tokens_
+ if generated_tokens_[-1] != tokenizer.eod
+ else generated_tokens_[:-1]
+ )
+ generated_code = tokenizer.detokenize(generated_tokens_[n_token_prompt:])
+ generated_code = cleanup_code(generated_code,
+ language_type=language_type,
+ dataset=args.dataset)
+ f.write(
+ json.dumps(
+ {
+ "task_id" : current_spec['task_id'],
+ "prompt" : prompt,
+ "generation": generated_code,
+ "scores" : beam.score,
+ "finish" : 2 if generated_tokens[i].cpu().numpy()[
+ -1] == tokenizer.eod else 1,
+ "output" : beam.tokens,
+ }
+ )
+ + "\n"
+ )
+ socket.send_json(
+ {
+ "rank" : args.gen_rank,
+ "action" : "success",
+ "task_id": current_spec['task_id']
+ }
+ )
+ socket.recv()
+ continue
+
+ token_stream = get_token_stream(
+ model,
+ [
+ copy.deepcopy(tokens)
+ for _ in range(micro_batch_size)
+ ],
+ return_scores=args.return_scores,
+ prompt_length=n_token_prompt,
+ micro_batch_size=micro_batch_size,
+ bad_ids=args.bad_ids,
+ temperature=temperature,
+ topp=topp,
+ beam_warmup=args.beam_warmup,
+ )
+ is_finished = [False for _ in range(micro_batch_size)]
+ for generated in token_stream:
+ generated_tokens = generated[0]
+ if args.return_scores:
+ scores = generated[1][1]
+ else:
+ scores = None
+
+ for i in range(micro_batch_size):
+ if is_finished[i]:
+ continue
+
+ generated_tokens_ = generated_tokens[i].cpu().numpy().tolist()
+ generated_tokens_ = (
+ generated_tokens_
+ if generated_tokens_[-1] != tokenizer.eod
+ else generated_tokens_[:-1]
+ )
+ generated_code = tokenizer.detokenize(generated_tokens_[n_token_prompt:])
+ if generated_tokens[i].cpu().numpy()[-1] == tokenizer.eod or \
+ is_code_generation_finished(
+ generated_code,
+ language_type=language_type,
+ dataset=args.dataset,
+ ):
+ is_finished[i] = True
+ generated_code = cleanup_code(generated_code,
+ language_type=language_type,
+ dataset=args.dataset)
+ f.write(
+ json.dumps(
+ {
+ "task_id" : current_spec['task_id'],
+ "prompt" : prompt,
+ "generation": generated_code,
+ "scores" : 0.0 if scores is None else scores[i].detach().cpu().item(),
+ "finish" : 2 if generated_tokens[i].cpu().numpy()[
+ -1] == tokenizer.eod else 1,
+ "output" : generated_tokens[i].cpu().numpy().tolist(),
+ }
+ )
+ + "\n"
+ )
+
+ if len(generated_tokens[i]) >= args.out_seq_length:
+ break
+
+ if all(is_finished):
+ break
+
+ for i in range(micro_batch_size):
+ if not is_finished[i]:
+ generated_tokens_ = generated_tokens[i].cpu().numpy().tolist()
+ generated_code = tokenizer.detokenize(generated_tokens_[n_token_prompt:])
+ unfinished_f.write(
+ json.dumps(
+ {
+ "task_id" : current_spec['task_id'],
+ "prompt" : prompt,
+ "generation": generated_code,
+ "scores" : 0.0 if scores is None else scores[i].detach().cpu().item(),
+ "finish" : 0,
+ "output" : generated_tokens_,
+ }
+ )
+ + "\n"
+ )
+
+ socket.send_json(
+ {
+ "rank" : args.gen_rank,
+ "action" : "success",
+ "task_id": current_spec['task_id']
+ }
+ )
+ socket.recv()
+
+ except Exception as e:
+ print(f"*** (rank={args.gen_rank}) crashed.")
+ print(f" error: {repr(e)}")
+ traceback.print_exc()
+ if args.dataset.lower() == "codecontest":
+ socket.send_json({
+ "rank" : args.gen_rank,
+ "action" : "fail",
+ "contest_name" : current_spec.name,
+ "micro_batch_size": micro_batch_size
+ })
+ else:
+ socket.send_json(
+ {
+ "rank" : args.gen_rank,
+ "action" : "fail",
+ "task_id": current_spec['task_id']
+ }
+ )
+ socket.recv()
+ continue
diff --git a/codegeex/megatron/initialize.py b/codegeex/megatron/initialize.py
new file mode 100644
index 0000000..f8ab529
--- /dev/null
+++ b/codegeex/megatron/initialize.py
@@ -0,0 +1,337 @@
+# coding=utf-8
+# Copyright (c) 2020, NVIDIA CORPORATION. All rights reserved.
+#
+# Licensed under the Apache License, Version 2.0 (the "License");
+# you may not use this file except in compliance with the License.
+# You may obtain a copy of the License at
+#
+# http://www.apache.org/licenses/LICENSE-2.0
+#
+# Unless required by applicable law or agreed to in writing, software
+# distributed under the License is distributed on an "AS IS" BASIS,
+# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
+# See the License for the specific language governing permissions and
+# limitations under the License.
+
+"""Megatron initialization."""
+
+import random
+import os
+import time
+import datetime
+
+import numpy as np
+import torch
+
+from codegeex.megatron import get_adlr_autoresume
+from codegeex.megatron import get_args
+from codegeex.megatron import get_tensorboard_writer
+from codegeex.megatron import mpu
+from codegeex.megatron.global_vars import set_global_variables
+from codegeex.megatron.mpu import (
+ set_tensor_model_parallel_rank,
+ set_tensor_model_parallel_world_size,
+)
+
+try:
+ import wandb
+except ImportError:
+ wandb = None
+
+import deepspeed
+
+
+def initialize_megatron(
+ extra_args_provider=None,
+ args_defaults={},
+ ignore_unknown_args=False,
+ allow_no_cuda=False,
+):
+ """Set global variables, initialize distributed, and
+ set autoresume and random seeds.
+ `allow_no_cuda` should not be set unless using megatron for cpu only
+ data processing. In general this arg should not be set unless you know
+ what you are doing.
+ Returns a function to finalize distributed env initialization
+ (optionally, only when args.lazy_mpu_init == True)
+ """
+ if not allow_no_cuda:
+ # Make sure cuda is available.
+ assert torch.cuda.is_available(), "Megatron requires CUDA."
+
+ # Parse args, build tokenizer, and set adlr-autoresume,
+ # tensorboard-writer, and timers.
+ set_global_variables(
+ extra_args_provider=extra_args_provider,
+ args_defaults=args_defaults,
+ ignore_unknown_args=ignore_unknown_args,
+ )
+
+ # torch.distributed initialization
+ def finish_mpu_init():
+ args = get_args()
+ # Pytorch distributed.
+ _initialize_distributed()
+
+ # Random seeds for reproducibility.
+ if args.rank == 0:
+ print("> setting random seeds to {} ...".format(args.seed))
+ _set_random_seed(args.seed)
+
+ args = get_args()
+ if args.lazy_mpu_init:
+ args.use_cpu_initialization = True
+ # delayed initialization of DDP-related stuff
+ # We only set basic DDP globals
+ set_tensor_model_parallel_world_size(args.tensor_model_parallel_size)
+ # and return function for external DDP manager
+ # to call when it has DDP initialized
+ set_tensor_model_parallel_rank(args.rank)
+ return finish_mpu_init
+ else:
+ # Megatron's MPU is the master. Complete initialization right away.
+ finish_mpu_init()
+
+ # Initialize memory buffers.
+ _initialize_mem_buffs()
+
+ # Autoresume.
+ _init_autoresume()
+
+ # No continuation function
+ return None
+
+
+def _compile_dependencies():
+
+ args = get_args()
+
+ # =========================
+ # Compile dataset C++ code.
+ # =========================
+ # TODO: move this to ninja
+ if torch.distributed.get_rank() == 0:
+ start_time = time.time()
+ print("> compiling dataset index builder ...")
+ # from megatron.data.dataset_utils import compile_helper
+ # compile_helper()
+ print(
+ ">>> done with dataset index builder. Compilation time: {:.3f} "
+ "seconds".format(time.time() - start_time),
+ flush=True,
+ )
+
+ # Custom kernel constraints check.
+ seq_len = args.seq_length
+ attn_batch_size = (
+ args.num_attention_heads / args.tensor_model_parallel_size
+ ) * args.micro_batch_size
+ # Constraints on sequence length and attn_batch_size to enable warp based
+ # optimization and upper triangular optimization (for causal mask)
+ custom_kernel_constraint = (
+ seq_len > 16
+ and seq_len <= 2048
+ and seq_len % 4 == 0
+ and attn_batch_size % 4 == 0
+ )
+ # Print a warning.
+ if not (
+ (args.fp16 or args.bf16)
+ and custom_kernel_constraint
+ and args.masked_softmax_fusion
+ ):
+ if args.rank == 0:
+ print(
+ "WARNING: constraints for invoking optimized"
+ " fused softmax kernel are not met. We default"
+ " back to unfused kernel invocations.",
+ flush=True,
+ )
+
+ # Always build on rank zero first.
+ if torch.distributed.get_rank() == 0:
+ start_time = time.time()
+ print("> compiling and loading fused kernels ...", flush=True)
+ torch.distributed.barrier()
+ else:
+ torch.distributed.barrier()
+ # Simple barrier to make sure all ranks have passed the
+ # compilation phase successfully before moving on to the
+ # rest of the program. We think this might ensure that
+ # the lock is released.
+ torch.distributed.barrier()
+ if torch.distributed.get_rank() == 0:
+ print(
+ ">>> done with compiling and loading fused kernels. "
+ "Compilation time: {:.3f} seconds".format(time.time() - start_time),
+ flush=True,
+ )
+
+
+def setup_deepspeed_random_and_activation_checkpointing(args):
+ """Optional DeepSpeed Activation Checkpointing features.
+ Gives access to partition activations, contiguous memory optimizations
+ and cpu checkpointing.
+ Activation checkpoint requires keep track of the random states
+ and setting the random seed for each MP process. Megatron uses
+ mpu.get_cuda_rng_tracker and mpu.model_parallel_cuda_manual_seed
+ for keeping track of the random states and setting the random seeds.
+ Since they are used in places outside of activation checkpointing,
+ we overwrite them to maintain consistency.
+ This must be called before all the calls to mpu.model_parallel_cuda_manual_seed
+ """
+ num_layers = args.num_layers // args.checkpoint_num_layers
+ num_layers = (
+ num_layers
+ if args.num_layers % args.checkpoint_num_layers == 0
+ else num_layers + 1
+ )
+ if args.split_transformers:
+ num_layers *= 2
+
+ deepspeed.checkpointing.configure(
+ mpu,
+ partition_activations=args.partition_activations,
+ contiguous_checkpointing=args.contigious_checkpointing,
+ num_checkpoints=num_layers,
+ checkpoint_in_cpu=args.checkpoint_in_cpu,
+ synchronize=args.synchronize_each_layer,
+ profile=args.profile_backward,
+ )
+
+ mpu.checkpoint = deepspeed.checkpointing.checkpoint
+ mpu.get_cuda_rng_tracker = deepspeed.checkpointing.get_cuda_rng_tracker
+ mpu.model_parallel_cuda_manual_seed = (
+ deepspeed.checkpointing.model_parallel_cuda_manual_seed
+ )
+
+
+def _initialize_distributed():
+ """Initialize torch.distributed and mpu."""
+ args = get_args()
+
+ device_count = torch.cuda.device_count()
+ if torch.distributed.is_initialized():
+
+ if args.rank == 0:
+ print(
+ "torch distributed is already initialized, "
+ "skipping initialization ...",
+ flush=True,
+ )
+ args.rank = torch.distributed.get_rank()
+ args.world_size = torch.distributed.get_world_size()
+
+ else:
+
+ if args.rank == 0:
+ print("> initializing torch distributed ...", flush=True)
+ # Manually set the device ids.
+ if device_count > 0:
+ device = args.rank % device_count
+ if args.local_rank is not None:
+ assert (
+ args.local_rank == device
+ ), "expected local-rank to be the same as rank % device-count."
+ else:
+ args.local_rank = device
+ if args.force_device is not None:
+ print(
+ f" > forcefully set the device to {args.force_device}, originally {device}"
+ )
+ device = args.force_device
+ torch.cuda.set_device(device)
+ # Call the init process
+ init_method = "tcp://"
+ master_ip = os.getenv("MASTER_ADDR", "localhost")
+ master_port = os.getenv("MASTER_PORT", "6000")
+ init_method += master_ip + ":" + master_port
+ print(
+ f" > (rank={args.rank}) initializing process group: "
+ f"world_size={args.world_size} "
+ f"backend={args.distributed_backend} "
+ f"init_method={init_method}",
+ flush=True,
+ )
+ timeout = datetime.timedelta(minutes=args.dist_timeout)
+ torch.distributed.init_process_group(
+ backend=args.distributed_backend,
+ world_size=args.world_size,
+ rank=args.rank,
+ init_method=init_method,
+ timeout=timeout
+ )
+ print(f" > (rank={args.rank}) process group initialized")
+
+ # Set the tensor model-parallel, pipeline model-parallel, and
+ # data-parallel communicators.
+ if device_count > 0:
+ if mpu.model_parallel_is_initialized():
+ print("model parallel is already initialized")
+ else:
+ mpu.initialize_model_parallel(
+ args.tensor_model_parallel_size,
+ args.pipeline_model_parallel_size,
+ args.virtual_pipeline_model_parallel_size,
+ )
+
+ if args.deepspeed and args.deepspeed_activation_checkpointing:
+ setup_deepspeed_random_and_activation_checkpointing(args)
+
+
+def _init_autoresume():
+ """Set autoresume start time."""
+ autoresume = get_adlr_autoresume()
+ if autoresume:
+ torch.distributed.barrier()
+ autoresume.init()
+ torch.distributed.barrier()
+
+
+def _set_random_seed(seed_):
+ """Set random seed for reproducability."""
+ if seed_ is not None and seed_ > 0:
+ # Ensure that different pipeline MP stages get different seeds.
+ seed = seed_ + (100 * mpu.get_pipeline_model_parallel_rank())
+ random.seed(seed)
+ np.random.seed(seed)
+ torch.manual_seed(seed)
+ if torch.cuda.device_count() > 0:
+ mpu.model_parallel_cuda_manual_seed(seed)
+ else:
+ raise ValueError("Seed ({}) should be a positive integer.".format(seed))
+
+
+def write_args_to_tensorboard():
+ """Write arguments to tensorboard."""
+ args = get_args()
+ writer = get_tensorboard_writer()
+ if writer:
+ for arg in vars(args):
+ writer.add_text(arg, str(getattr(args, arg)), global_step=args.iteration)
+
+
+def initialize_wandb_experiment():
+ """Initialize wandb experiment."""
+ assert wandb is not None, "Fail to import wandb"
+
+ args = get_args()
+ config = args.__dict__
+
+ wandb_id_path = os.path.join(args.save, "wandb_id.txt")
+ if os.path.exists(wandb_id_path):
+ wandb_id = open(wandb_id_path, "r").read().strip()
+ else:
+ wandb_id = wandb.util.generate_id()
+ open(wandb_id_path, "w").write(wandb_id)
+
+ wandb.init(id=wandb_id, project="megatron", config=config, resume="allow")
+
+
+def _initialize_mem_buffs():
+ """Initialize manually allocated static memory."""
+ args = get_args()
+
+ # Initialize memory for checkpointed activations.
+ if args.distribute_checkpointed_activations:
+ mpu.init_checkpointed_activations_memory_buffer()
diff --git a/codegeex/megatron/memory.py b/codegeex/megatron/memory.py
new file mode 100644
index 0000000..1b70daa
--- /dev/null
+++ b/codegeex/megatron/memory.py
@@ -0,0 +1,150 @@
+# coding=utf-8
+# Copyright (c) 2020, NVIDIA CORPORATION. All rights reserved.
+#
+# Licensed under the Apache License, Version 2.0 (the "License");
+# you may not use this file except in compliance with the License.
+# You may obtain a copy of the License at
+#
+# http://www.apache.org/licenses/LICENSE-2.0
+#
+# Unless required by applicable law or agreed to in writing, software
+# distributed under the License is distributed on an "AS IS" BASIS,
+# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
+# See the License for the specific language governing permissions and
+# limitations under the License.
+
+
+import torch
+
+
+# A dictionary of all the memory buffers allocated.
+_MEM_BUFFS = dict()
+
+
+def allocate_mem_buff(name, numel, dtype, track_usage):
+ """Allocate a memory buffer."""
+ assert name not in _MEM_BUFFS, "memory buffer {} already allocated.".format(name)
+ _MEM_BUFFS[name] = MemoryBuffer(name, numel, dtype, track_usage)
+ return _MEM_BUFFS[name]
+
+
+def get_mem_buff(name):
+ """Get the memory buffer."""
+ return _MEM_BUFFS[name]
+
+
+class MemoryBuffer:
+ """Contiguous memory buffer.
+ Allocate a contiguous memory of type `dtype` and size `numel`. It is
+ used to reduce memory fragmentation.
+
+ Usage: After the allocation, the `_start` index is set tot the first
+ index of the memory. A memory chunk starting from `_start` index
+ can be `allocated` for an input tensor, with the elements of the
+ tensor being coppied. The buffer can be reused by resetting the
+ `_start` index.
+
+ """
+
+ def __init__(self, name, numel, dtype, track_usage):
+ if torch.distributed.get_rank() == 0:
+ element_size = torch.tensor([], dtype=dtype).element_size()
+ print(
+ "> building the {} memory buffer with {} num elements "
+ "and {} dtype ({:.1f} MB)...".format(
+ name, numel, dtype, numel * element_size / 1024 / 1024
+ ),
+ flush=True,
+ )
+ self.name = name
+ self.numel = numel
+ self.dtype = dtype
+ self.data = torch.empty(
+ self.numel,
+ dtype=self.dtype,
+ device=torch.cuda.current_device(),
+ requires_grad=False,
+ )
+
+ # Index tracking the start of the free memory.
+ self._start = 0
+
+ # Values used for tracking usage.
+ self.track_usage = track_usage
+ if self.track_usage:
+ self.in_use_value = 0.0
+ self.total_value = 0.0
+
+ def reset(self):
+ """Reset the buffer start index to the beginning of the buffer."""
+ self._start = 0
+
+ def is_in_use(self):
+ """Whether the current buffer hold on to any memory."""
+ return self._start > 0
+
+ def numel_in_use(self):
+ """Return number of elements in use."""
+ return self._start
+
+ def add(self, tensor):
+ """Allocate a chunk of memory from the buffer to tensor and copy
+ the values."""
+ assert (
+ tensor.dtype == self.dtype
+ ), "Input tensor type {} different from buffer type {}".format(
+ tensor.dtype, self.dtype
+ )
+ # Number of elements of the input tensor.
+ tensor_numel = torch.numel(tensor)
+ new_start = self._start + tensor_numel
+ assert (
+ new_start <= self.numel
+ ), "Not enough memory left in the buffer ({} > {})".format(
+ tensor_numel, self.numel - self._start
+ )
+ # New tensor is a view into the memory.
+ new_tensor = self.data[self._start : new_start]
+ self._start = new_start
+ new_tensor = new_tensor.view(tensor.shape)
+ new_tensor.copy_(tensor)
+ # Return a pointer to the new tensor.
+ return new_tensor
+
+ def get_data(self):
+ """Return the data currently in use."""
+ if self.track_usage:
+ self.in_use_value += float(self._start)
+ self.total_value += float(self.numel)
+ return self.data[: self._start]
+
+ def print_average_usage(self):
+ """Print memory usage average over time. We would like this value
+ to be as high as possible."""
+ assert self.track_usage, "You need to enable track usage."
+ if torch.distributed.get_rank() == 0:
+ print(
+ " > usage of {} memory buffer: {:.2f} %".format(
+ self.name, self.in_use_value * 100.0 / self.total_value
+ ),
+ flush=True,
+ )
+
+
+class RingMemBuffer:
+ """A ring of memory buffers."""
+
+ def __init__(self, name, num_buffers, numel, dtype, track_usage):
+ self.num_buffers = num_buffers
+ self.buffers = [
+ allocate_mem_buff(name + " {}".format(i), numel, dtype, track_usage)
+ for i in range(num_buffers)
+ ]
+ self._index = -1
+
+ def get_next_buffer(self):
+ self._index += 1
+ self._index = self._index % self.num_buffers
+ buff = self.buffers[self._index]
+ assert not buff.is_in_use(), "buffer is already in use."
+ return buff
diff --git a/codegeex/megatron/mindspore_to_megatron.py b/codegeex/megatron/mindspore_to_megatron.py
new file mode 100644
index 0000000..da66c79
--- /dev/null
+++ b/codegeex/megatron/mindspore_to_megatron.py
@@ -0,0 +1,319 @@
+# coding=utf-8
+# Copyright (c) 2020, NVIDIA CORPORATION. All rights reserved.
+#
+# Licensed under the Apache License, Version 2.0 (the "License");
+# you may not use this file except in compliance with the License.
+# You may obtain a copy of the License at
+#
+# http://www.apache.org/licenses/LICENSE-2.0
+#
+# Unless required by applicable law or agreed to in writing, software
+# distributed under the License is distributed on an "AS IS" BASIS,
+# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
+# See the License for the specific language governing permissions and
+# limitations under the License.
+
+"""Merge model parallel partitions."""
+
+import os
+import random
+import sys
+
+import numpy as np
+import torch
+
+sys.path.append(os.path.abspath(os.path.join(os.path.dirname(__file__),
+ os.path.pardir)))
+
+from codegeex.megatron import get_args
+from codegeex.megatron.model import CodeGeeXModel
+from codegeex.megatron.initialize import initialize_megatron
+from codegeex.megatron.checkpointing import ensure_directory_exists
+
+
+def get_change_ckpt_args(parser):
+ """Provide extra arguments required for merging."""
+ group = parser.add_argument_group(title='Mindspore to megatron')
+ group.add_argument(
+ '--npy-ckpt-path',
+ type=str,
+ required=True,
+ help='path of npy checkpoint.',
+ )
+ group.add_argument(
+ '--save-ckpt-path',
+ type=str,
+ required=True,
+ help='path to save checkpoint.',
+ )
+
+ return parser
+
+
+def loadModelFromNp(sd, args):
+ num_layers = args.num_layers
+ npCkptPath = args.npy_ckpt_path
+ languageModel = sd['module']['language_model']
+ loadEmbeddingFromNp(npCkptPath, languageModel)
+ transformer = sd['module']['language_model']['transformer']
+ for layerID in range(num_layers):
+ loadAttentionLayerFromNp(npCkptPath, transformer, layerID)
+ loadQueryLayerFromNp(npCkptPath, transformer)
+
+ transformer['final_layernorm.weight'][:] = \
+ torch.tensor(
+ np.load(npCkptPath + f'backbone.layernorm.gamma.npy')
+ ).float()
+ transformer['final_layernorm.bias'][:] = \
+ torch.tensor(
+ np.load(npCkptPath + f'backbone.layernorm.beta.npy')
+ ).float()
+
+
+def loadEmbeddingFromNp(npCkptPath, languageModel, vocabSize=52224):
+ word_embedding_np = \
+ np.load(npCkptPath + 'backbone.embedding.word_embedding.embedding_table.npy')
+ languageModel['embedding']['word_embeddings']['weight'][:vocabSize, :] = \
+ torch.tensor(word_embedding_np).float()
+
+ position_embeddings_np = \
+ np.load(npCkptPath + 'backbone.embedding.position_embedding.embedding_table.npy')
+ languageModel['embedding']['position_embeddings']['weight'][:, :] = \
+ torch.tensor(position_embeddings_np).float()
+
+ topQueryEmbedding_np = \
+ np.load(npCkptPath + 'backbone.top_query_embedding.embedding_table.npy')
+ languageModel['topQueryEmbedding']['top_query_embeddings']['weight'][:, :] = \
+ torch.tensor(topQueryEmbedding_np).float()
+
+
+def loadAttentionLayerFromNp(npCkptPath, transformer, layerID):
+ attention_dense1_weight_np = \
+ np.load(npCkptPath + f'backbone.blocks.{layerID}.attention.dense1.weight.npy')
+ attention_dense2_weight_np = \
+ np.load(npCkptPath + f'backbone.blocks.{layerID}.attention.dense2.weight.npy')
+ attention_dense3_weight_np = \
+ np.load(npCkptPath + f'backbone.blocks.{layerID}.attention.dense3.weight.npy')
+
+ attention_dense1_bias_np = \
+ np.load(npCkptPath + f'backbone.blocks.{layerID}.attention.dense1.bias.npy')
+ attention_dense2_bias_np = \
+ np.load(npCkptPath + f'backbone.blocks.{layerID}.attention.dense2.bias.npy')
+ attention_dense3_bias_np = \
+ np.load(npCkptPath + f'backbone.blocks.{layerID}.attention.dense3.bias.npy')
+
+ query_weight = transformer[f'layers.{layerID}.attention.query.weight']
+ key_weight = transformer[f'layers.{layerID}.attention.key.weight']
+ value_weight = transformer[f'layers.{layerID}.attention.value.weight']
+
+ query_weight[:] = torch.tensor(attention_dense1_weight_np).float()
+ key_weight[:] = torch.tensor(attention_dense2_weight_np).float()
+ value_weight[:] = torch.tensor(attention_dense3_weight_np).float()
+
+ query_bias = transformer[f'layers.{layerID}.attention.query.bias']
+ key_bias = transformer[f'layers.{layerID}.attention.key.bias']
+ value_bias = transformer[f'layers.{layerID}.attention.value.bias']
+
+ query_bias[:] = torch.tensor(attention_dense1_bias_np).float()
+ key_bias[:] = torch.tensor(attention_dense2_bias_np).float()
+ value_bias[:] = torch.tensor(attention_dense3_bias_np).float()
+
+ att_dense_weight = transformer[f'layers.{layerID}.attention.dense.weight']
+ att_dense_weight[:, :] = \
+ torch.tensor(
+ np.load(npCkptPath + f'backbone.blocks.{layerID}.attention.projection.weight.npy').transpose()
+ ).float()
+ att_dense_bias = transformer[f'layers.{layerID}.attention.dense.bias']
+ att_dense_bias[:] = \
+ torch.tensor(
+ np.load(npCkptPath + f'backbone.blocks.{layerID}.attention.projection.bias.npy')
+ ).float()
+
+ mlp_dense_h_to_4h_weight = transformer[f'layers.{layerID}.mlp.dense_h_to_4h.weight']
+ mlp_dense_h_to_4h_weight[:, :] = \
+ torch.tensor(
+ np.load(npCkptPath + f'backbone.blocks.{layerID}.output.mapping.weight.npy').transpose()
+ ).float()
+ mlp_dense_h_to_4h_bias = transformer[f'layers.{layerID}.mlp.dense_h_to_4h.bias']
+ mlp_dense_h_to_4h_bias[:] = \
+ torch.tensor(
+ np.load(npCkptPath + f'backbone.blocks.{layerID}.output.mapping.bias.npy')
+ ).float()
+
+ mlp_dense_4h_to_h_weight = transformer[f'layers.{layerID}.mlp.dense_4h_to_h.weight']
+ mlp_dense_4h_to_h_weight[:, :] = \
+ torch.tensor(
+ np.load(npCkptPath + f'backbone.blocks.{layerID}.output.projection.weight.npy').transpose()
+ ).float()
+ mlp_dense_4h_to_h_bias = transformer[f'layers.{layerID}.mlp.dense_4h_to_h.bias']
+ mlp_dense_4h_to_h_bias[:] = \
+ torch.tensor(
+ np.load(npCkptPath + f'backbone.blocks.{layerID}.output.projection.bias.npy')
+ ).float()
+
+ input_layernorm_weight = transformer[f'layers.{layerID}.input_layernorm.weight']
+ input_layernorm_weight[:] = \
+ torch.tensor(
+ np.load(npCkptPath + f'backbone.blocks.{layerID}.layernorm1.gamma.npy')
+ ).float()
+ input_layernorm_bias = transformer[f'layers.{layerID}.input_layernorm.bias']
+ input_layernorm_bias[:] = \
+ torch.tensor(
+ np.load(npCkptPath + f'backbone.blocks.{layerID}.layernorm1.beta.npy')
+ ).float()
+
+ post_attention_layernorm_weight = transformer[f'layers.{layerID}.post_attention_layernorm.weight']
+ post_attention_layernorm_weight[:] = \
+ torch.tensor(
+ np.load(npCkptPath + f'backbone.blocks.{layerID}.layernorm2.gamma.npy')
+ ).float()
+ post_attention_layernorm_bias = transformer[f'layers.{layerID}.post_attention_layernorm.bias']
+ post_attention_layernorm_bias[:] = \
+ torch.tensor(
+ np.load(npCkptPath + f'backbone.blocks.{layerID}.layernorm2.beta.npy')
+ ).float()
+
+ input_layernorm_weight = transformer[f'layers.{layerID}.input_layernorm.weight']
+ input_layernorm_weight[:] = \
+ torch.tensor(
+ np.load(npCkptPath + f'backbone.blocks.{layerID}.layernorm1.gamma.npy')
+ ).float()
+ input_layernorm_bias = transformer[f'layers.{layerID}.input_layernorm.bias']
+ input_layernorm_bias[:] = \
+ torch.tensor(
+ np.load(npCkptPath + f'backbone.blocks.{layerID}.layernorm1.beta.npy')
+ ).float()
+
+ post_attention_layernorm_weight = transformer[f'layers.{layerID}.post_attention_layernorm.weight']
+ post_attention_layernorm_weight[:] = \
+ torch.tensor(
+ np.load(npCkptPath + f'backbone.blocks.{layerID}.layernorm2.gamma.npy')
+ ).float()
+ post_attention_layernorm_bias = transformer[f'layers.{layerID}.post_attention_layernorm.bias']
+ post_attention_layernorm_bias[:] = \
+ torch.tensor(
+ np.load(npCkptPath + f'backbone.blocks.{layerID}.layernorm2.beta.npy')
+ ).float()
+
+
+def loadQueryLayerFromNp(npCkptPath, transformer):
+ attention_dense1_weight_np = \
+ np.load(npCkptPath + f'backbone.top_query_layer.attention.dense1.weight.npy')
+ attention_dense1_bias_np = \
+ np.load(npCkptPath + f'backbone.top_query_layer.attention.dense1.bias.npy')
+ attention_dense2_weight_np = \
+ np.load(npCkptPath + f'backbone.top_query_layer.attention.dense2.weight.npy')
+ attention_dense2_bias_np = \
+ np.load(npCkptPath + f'backbone.top_query_layer.attention.dense2.bias.npy')
+ attention_dense3_weight_np = \
+ np.load(npCkptPath + f'backbone.top_query_layer.attention.dense3.weight.npy')
+ attention_dense3_bias_np = \
+ np.load(npCkptPath + f'backbone.top_query_layer.attention.dense3.bias.npy')
+
+ query_weight = transformer[f'topQueryLayer.attention.query.weight']
+ query_weight[:, :] = \
+ torch.tensor(attention_dense1_weight_np).float()
+ query_bias = transformer[f'topQueryLayer.attention.query.bias']
+ query_bias[:] = torch.tensor(attention_dense1_bias_np).float()
+
+ key_weight = transformer[f'topQueryLayer.attention.key.weight']
+ key_weight[:, :] = \
+ torch.tensor(attention_dense2_weight_np).float()
+ key_bias = transformer[f'topQueryLayer.attention.key.bias']
+ key_bias[:] = torch.tensor(attention_dense2_bias_np).float()
+
+ value_weight = transformer[f'topQueryLayer.attention.value.weight']
+ value_weight[:, :] = \
+ torch.tensor(attention_dense3_weight_np).float()
+ value_bias = transformer[f'topQueryLayer.attention.value.bias']
+ value_bias[:] = torch.tensor(attention_dense3_bias_np).float()
+
+ att_dense_weight = transformer[f'topQueryLayer.attention.dense.weight']
+ att_dense_weight[:, :] = \
+ torch.tensor(
+ np.load(npCkptPath + f'backbone.top_query_layer.attention.projection.weight.npy')
+ .transpose()
+ ).float()
+ att_dense_bias = transformer[f'topQueryLayer.attention.dense.bias']
+ att_dense_bias[:] = \
+ torch.tensor(
+ np.load(npCkptPath + f'backbone.top_query_layer.attention.projection.bias.npy')
+ ).float()
+
+ mlp_dense_h_to_4h_weight = transformer[f'topQueryLayer.mlp.dense_h_to_4h.weight']
+ mlp_dense_h_to_4h_weight[:, :] = \
+ torch.tensor(
+ np.load(npCkptPath + f'backbone.top_query_layer.output.mapping.weight.npy')
+ .transpose()
+ ).float()
+ mlp_dense_h_to_4h_bias = transformer[f'topQueryLayer.mlp.dense_h_to_4h.bias']
+ mlp_dense_h_to_4h_bias[:] = \
+ torch.tensor(
+ np.load(npCkptPath + f'backbone.top_query_layer.output.mapping.bias.npy')
+ ).float()
+
+ mlp_dense_4h_to_h_weight = transformer[f'topQueryLayer.mlp.dense_4h_to_h.weight']
+ mlp_dense_4h_to_h_weight[:, :] = \
+ torch.tensor(
+ np.load(npCkptPath + f'backbone.top_query_layer.output.projection.weight.npy')
+ .transpose()
+ ).float()
+ mlp_dense_4h_to_h_bias = transformer[f'topQueryLayer.mlp.dense_4h_to_h.bias']
+ mlp_dense_4h_to_h_bias[:] = \
+ torch.tensor(
+ np.load(npCkptPath + f'backbone.top_query_layer.output.projection.bias.npy')
+ ).float()
+
+ input_layernorm_weight = transformer[f'topQueryLayer.input_layernorm.weight']
+ input_layernorm_weight[:] = \
+ torch.tensor(
+ np.load(npCkptPath + f'backbone.top_query_layer.layernorm1.gamma.npy')
+ ).float()
+ input_layernorm_bias = transformer[f'topQueryLayer.input_layernorm.bias']
+ input_layernorm_bias[:] = \
+ torch.tensor(
+ np.load(npCkptPath + f'backbone.top_query_layer.layernorm1.beta.npy')
+ ).float()
+
+ post_attention_layernorm_weight = transformer[f'topQueryLayer.post_attention_layernorm.weight']
+ post_attention_layernorm_weight[:] = \
+ torch.tensor(
+ np.load(npCkptPath + f'backbone.top_query_layer.layernorm2.gamma.npy')
+ ).float()
+ post_attention_layernorm_bias = transformer[f'topQueryLayer.post_attention_layernorm.bias']
+ post_attention_layernorm_bias[:] = \
+ torch.tensor(
+ np.load(npCkptPath + f'backbone.top_query_layer.layernorm2.beta.npy')
+ ).float()
+
+
+def main():
+ os.environ["MASTER_ADDR"] = "localhost"
+ os.environ["MASTER_PORT"] = str(random.randint(10000, 20000))
+
+ initialize_megatron(
+ extra_args_provider=get_change_ckpt_args,
+ args_defaults={
+ "tokenizer_type": "GPT2BPETokenizer",
+ "no_load_rng" : True,
+ "no_load_optim" : True,
+ },
+ )
+
+ args = get_args()
+ model = CodeGeeXModel()
+ # print(dir(model))
+ print(model.state_dict)
+
+ # Save the model.
+ sd = {}
+ sd['module'] = model.state_dict_for_save_checkpoint()
+ ensure_directory_exists(args.save_ckpt_path)
+ loadModelFromNp(sd, args)
+ print('> saving merged model to {}'.format(args.save_ckpt_path))
+ torch.save(sd, args.save_ckpt_path)
+ print(f"Converted checkpoint saved in {args.save_ckpt_path}.")
+
+
+if __name__ == '__main__':
+ main()
diff --git a/codegeex/megatron/model/__init__.py b/codegeex/megatron/model/__init__.py
new file mode 100644
index 0000000..ef046df
--- /dev/null
+++ b/codegeex/megatron/model/__init__.py
@@ -0,0 +1,19 @@
+# coding=utf-8
+# Copyright (c) 2020, NVIDIA CORPORATION. All rights reserved.
+#
+# Licensed under the Apache License, Version 2.0 (the "License");
+# you may not use this file except in compliance with the License.
+# You may obtain a copy of the License at
+#
+# http://www.apache.org/licenses/LICENSE-2.0
+#
+# Unless required by applicable law or agreed to in writing, software
+# distributed under the License is distributed on an "AS IS" BASIS,
+# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
+# See the License for the specific language governing permissions and
+# limitations under the License.
+
+from .distributed import DistributedDataParallel
+from .codegeex_model import CodeGeeXModel
+from .language_model import get_language_model
+from .module import Float16Module
\ No newline at end of file
diff --git a/codegeex/megatron/model/codegeex_model.py b/codegeex/megatron/model/codegeex_model.py
new file mode 100644
index 0000000..c8fc78a
--- /dev/null
+++ b/codegeex/megatron/model/codegeex_model.py
@@ -0,0 +1,109 @@
+# coding=utf-8
+# Copyright (c) 2020, NVIDIA CORPORATION. All rights reserved.
+#
+# Licensed under the Apache License, Version 2.0 (the "License");
+# you may not use this file except in compliance with the License.
+# You may obtain a copy of the License at
+#
+# http://www.apache.org/licenses/LICENSE-2.0
+#
+# Unless required by applicable law or agreed to in writing, software
+# distributed under the License is distributed on an "AS IS" BASIS,
+# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
+# See the License for the specific language governing permissions and
+# limitations under the License.
+
+import torch
+from codegeex.megatron import get_args
+from codegeex.megatron import mpu
+from .module import MegatronModule
+
+from .language_model import parallel_lm_logits
+from .language_model import get_language_model
+from .utils import init_method_normal
+from .utils import scaled_init_method_normal
+
+
+class CodeGeeXModel(MegatronModule):
+ """Code Generative Model for Multilingual Program Synthesis."""
+
+ def __init__(self, num_tokentypes=0, parallel_output=False):
+ super(CodeGeeXModel, self).__init__()
+ args = get_args()
+
+ self.parallel_output = parallel_output
+ self.fp16_lm_cross_entropy = args.fp16_lm_cross_entropy
+
+ self.language_model, self._language_model_key = get_language_model(
+ num_tokentypes=num_tokentypes,
+ add_pooler=False,
+ init_method=init_method_normal(args.init_method_std),
+ scaled_init_method=scaled_init_method_normal(args.init_method_std,
+ args.num_layers))
+
+ def forward(
+ self,
+ input_ids,
+ position_ids,
+ attention_mask,
+ labels=None,
+ tokentype_ids=None,
+ layer_past=None,
+ get_key_value=False,
+ forward_method_parallel_output=None,
+ prompt_length=None,
+ context_length=None,
+ ):
+
+ # Language model.
+ lm_output = self.language_model(input_ids,
+ position_ids,
+ attention_mask,
+ tokentype_ids=tokentype_ids,
+ layer_past=layer_past,
+ get_key_value=get_key_value,
+ prompt_length=prompt_length,
+ context_length=context_length)
+
+ if get_key_value:
+ lm_output, presents = lm_output
+
+ lm_output = torch.add(lm_output, 0)
+ # Output.
+ parallel_output = self.parallel_output
+ if forward_method_parallel_output is not None:
+ parallel_output = forward_method_parallel_output
+ output = parallel_lm_logits(
+ lm_output,
+ self.language_model.embedding.word_embeddings.weight,
+ parallel_output)
+
+ if get_key_value:
+ output = [output, presents]
+
+ if labels is None:
+ return output
+ else:
+ if self.fp16_lm_cross_entropy:
+ assert output.dtype == torch.half
+ loss = mpu.vocab_parallel_cross_entropy(output, labels)
+ else:
+ loss = mpu.vocab_parallel_cross_entropy(output.float(), labels)
+
+ return loss
+
+ def state_dict_for_save_checkpoint(self, destination=None, prefix='',
+ keep_vars=False):
+
+ state_dict_ = {}
+ state_dict_[self._language_model_key] \
+ = self.language_model.state_dict_for_save_checkpoint(
+ destination, prefix, keep_vars)
+ return state_dict_
+
+ def load_state_dict(self, state_dict, strict=True):
+ """Customized load."""
+
+ if self._language_model_key in state_dict:
+ state_dict = state_dict[self._language_model_key]
+ self.language_model.load_state_dict(state_dict, strict=strict)
diff --git a/codegeex/megatron/model/distributed.py b/codegeex/megatron/model/distributed.py
new file mode 100644
index 0000000..3c6ec56
--- /dev/null
+++ b/codegeex/megatron/model/distributed.py
@@ -0,0 +1,215 @@
+# coding=utf-8
+# Copyright (c) 2020, NVIDIA CORPORATION. All rights reserved.
+#
+# Licensed under the Apache License, Version 2.0 (the "License");
+# you may not use this file except in compliance with the License.
+# You may obtain a copy of the License at
+#
+# http://www.apache.org/licenses/LICENSE-2.0
+#
+# Unless required by applicable law or agreed to in writing, software
+# distributed under the License is distributed on an "AS IS" BASIS,
+# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
+# See the License for the specific language governing permissions and
+# limitations under the License.
+
+from abc import ABC
+from abc import abstractmethod
+
+import torch
+from torch._utils import _flatten_dense_tensors, _unflatten_dense_tensors
+
+from codegeex.megatron import mpu
+from .module import MegatronModule
+
+
+class MemoryBuffer:
+ def __init__(self, numel, dtype):
+ self.numel = numel
+ self.dtype = dtype
+ self.data = torch.zeros(
+ self.numel,
+ dtype=self.dtype,
+ device=torch.cuda.current_device(),
+ requires_grad=False,
+ )
+
+ def zero(self):
+ """Reset the buffer to zero."""
+ self.data.zero_()
+
+ def get(self, shape, start_index):
+ """Return a tensor with the input `shape` as a view into the
+ 1-D data starting at `start_index`."""
+ end_index = start_index + shape.numel()
+ assert end_index <= self.numel, "requested tensor is out of the buffer range."
+ buffer_tensor = self.data[start_index:end_index]
+ buffer_tensor = buffer_tensor.view(shape)
+ return buffer_tensor
+
+
+class DistributedDataParallelBase(MegatronModule, ABC):
+ """Abstract class for DDP."""
+
+ def __init__(self, module):
+ super(DistributedDataParallelBase, self).__init__()
+ # Keep a pointer to the model.
+ self.module = module
+
+ @abstractmethod
+ def allreduce_gradients(self):
+ pass
+
+ def forward(self, *inputs, **kwargs):
+ return self.module(*inputs, **kwargs)
+
+ def state_dict(self, destination=None, prefix="", keep_vars=False):
+ return self.module.state_dict(destination, prefix, keep_vars)
+
+ def state_dict_for_save_checkpoint(
+ self, destination=None, prefix="", keep_vars=False
+ ):
+ return self.module.state_dict_for_save_checkpoint(
+ destination, prefix, keep_vars
+ )
+
+ def load_state_dict(self, state_dict, strict=True):
+ self.module.load_state_dict(state_dict, strict=strict)
+
+
+class DistributedDataParallel(DistributedDataParallelBase):
+ """DDP with contiguous buffers options to storre and accumulate gradients.
+ This class:
+ - has the potential to reduce memory fragmentation.
+ - provides the option to do the gradient accumulation
+ in a type other than the params type (for example fp32)
+
+ Arguments:
+ module: input model.
+ accumulate_allreduce_grads_in_fp32: if true do the gradient accumulation
+ and the gradient all-reduce all in in float32. If this option is
+ true, we require `use_contiguous_buffers` to be true too.
+ use_contiguous_buffers: if true, use a contiguous buffer to store the
+ gradients.
+ """
+
+ def __init__(
+ self, module, accumulate_allreduce_grads_in_fp32, use_contiguous_buffers
+ ):
+
+ super(DistributedDataParallel, self).__init__(module)
+
+ self.accumulate_allreduce_grads_in_fp32 = accumulate_allreduce_grads_in_fp32
+ self.use_contiguous_buffers = use_contiguous_buffers
+ # If we are using fp32-accumulate-allreduce explicitly
+ # this means we need main grads in a continous buffer.
+ if self.accumulate_allreduce_grads_in_fp32:
+ assert self.use_contiguous_buffers
+
+ # ===================================
+ # Rest of this part applies only to
+ # the case we use continuous buffers.
+ # ===================================
+ self._grad_buffers = None
+ if self.use_contiguous_buffers:
+ self._grad_buffers = {}
+
+ # Simple function to define buffer type.
+ def _get_buffer_type(param):
+ return (
+ torch.float
+ if self.accumulate_allreduce_grads_in_fp32
+ else param.dtype
+ )
+
+ # First calculate total number of elements per type.
+ type_num_elements = {}
+ for param in self.module.parameters():
+ if param.requires_grad:
+ dtype = _get_buffer_type(param)
+ type_num_elements[dtype] = (
+ type_num_elements.get(dtype, 0) + param.data.nelement()
+ )
+
+ # Allocate the buffer.
+ for dtype, num_elements in type_num_elements.items():
+ self._grad_buffers[dtype] = MemoryBuffer(num_elements, dtype)
+
+ # Assume the back prop order is reverse the params order,
+ # store the start index for the gradients.
+ for param in self.module.parameters():
+ if param.requires_grad:
+ dtype = _get_buffer_type(param)
+ type_num_elements[dtype] -= param.data.nelement()
+ param.main_grad = self._grad_buffers[dtype].get(
+ param.data.shape, type_num_elements[dtype]
+ )
+
+ # Backward hook.
+ # Accumalation function for the gradients. We need
+ # to store them so they don't go out of scope.
+ self.grad_accs = []
+ # Loop over all the parameters in the model.
+ for param in self.module.parameters():
+ if param.requires_grad:
+ # Expand so we get access to grad_fn.
+ param_tmp = param.expand_as(param)
+ # Get the gradient accumulator functtion.
+ grad_acc = param_tmp.grad_fn.next_functions[0][0]
+ grad_acc.register_hook(self._make_param_hook(param))
+ self.grad_accs.append(grad_acc)
+
+ def _make_param_hook(self, param):
+ """Create the all-reduce hook for backprop."""
+
+ # Hook used for back-prop.
+ def param_hook(*unused):
+ # Add the gradient to the buffer.
+ if param.grad.data is not None:
+ param.main_grad.add_(param.grad.data)
+ # Now we can deallocate grad memory.
+ param.grad = None
+
+ return param_hook
+
+ def zero_grad_buffer(self):
+ """Set the grad buffer data to zero. Needs to be called at the
+ begining of each iteration."""
+ assert self._grad_buffers is not None, "buffers are not initialized."
+ for _, buffer_ in self._grad_buffers.items():
+ buffer_.zero()
+
+ def allreduce_gradients(self):
+ """Reduce gradients across data parallel ranks."""
+ # If we have buffers, simply reduce the data in the buffer.
+ if self._grad_buffers is not None:
+ for _, buffer_ in self._grad_buffers.items():
+ buffer_.data /= mpu.get_data_parallel_world_size()
+ torch.distributed.all_reduce(
+ buffer_.data, group=mpu.get_data_parallel_group()
+ )
+ else:
+ # Otherwise, bucketize and all-reduce
+ buckets = {}
+ # Pack the buckets.
+ for param in self.module.parameters():
+ if param.requires_grad and param.grad is not None:
+ tp = param.data.type()
+ if tp not in buckets:
+ buckets[tp] = []
+ buckets[tp].append(param)
+ param.main_grad = param.grad
+
+ # For each bucket, all-reduce and copy all-reduced grads.
+ for tp in buckets:
+ bucket = buckets[tp]
+ grads = [param.grad.data for param in bucket]
+ coalesced = _flatten_dense_tensors(grads)
+ coalesced /= mpu.get_data_parallel_world_size()
+ torch.distributed.all_reduce(
+ coalesced, group=mpu.get_data_parallel_group()
+ )
+ for buf, synced in zip(
+ grads, _unflatten_dense_tensors(coalesced, grads)
+ ):
+ buf.copy_(synced)
diff --git a/codegeex/megatron/model/language_model.py b/codegeex/megatron/model/language_model.py
new file mode 100644
index 0000000..41561bf
--- /dev/null
+++ b/codegeex/megatron/model/language_model.py
@@ -0,0 +1,503 @@
+# coding=utf-8
+# Copyright (c) 2020, NVIDIA CORPORATION. All rights reserved.
+#
+# Licensed under the Apache License, Version 2.0 (the "License");
+# you may not use this file except in compliance with the License.
+# You may obtain a copy of the License at
+#
+# http://www.apache.org/licenses/LICENSE-2.0
+#
+# Unless required by applicable law or agreed to in writing, software
+# distributed under the License is distributed on an "AS IS" BASIS,
+# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
+# See the License for the specific language governing permissions and
+# limitations under the License.
+
+"""Transformer based language model."""
+
+import torch
+import torch.nn.functional as F
+
+from codegeex.megatron import get_args
+from codegeex.megatron import mpu
+from codegeex.megatron.model.module import MegatronModule
+from codegeex.megatron.model.transformer import ParallelTransformer
+from codegeex.megatron.model.utils import init_method_normal, scaled_init_method_normal
+
+
+def parallel_lm_logits(input_, word_embeddings_weight, parallel_output, bias=None):
+ """LM logits using word embedding weights."""
+ # Parallel logits.
+ input_parallel = mpu.copy_to_tensor_model_parallel_region(input_)
+ # Matrix multiply.
+ if bias is None:
+ logits_parallel = F.linear(input_parallel, word_embeddings_weight.half())
+ else:
+ logits_parallel = F.linear(input_parallel, word_embeddings_weight.half(), bias)
+ # Gather if needed.
+ if parallel_output:
+ return logits_parallel
+
+ return mpu.gather_from_tensor_model_parallel_region(logits_parallel)
+
+
+def get_language_model(
+ num_tokentypes,
+ add_pooler,
+ init_method=None,
+ scaled_init_method=None,
+):
+ """Build language model and return along with the key to save."""
+ args = get_args()
+
+ if init_method is None:
+ init_method = init_method_normal(args.init_method_std)
+
+ if scaled_init_method is None:
+ scaled_init_method = scaled_init_method_normal(args.init_method_std, args.num_layers)
+
+ # Language model.
+ language_model = TransformerLanguageModel(
+ init_method=init_method,
+ output_layer_init_method=scaled_init_method,
+ num_tokentypes=num_tokentypes,
+ add_pooler=add_pooler)
+ # key used for checkpoints.
+ language_model_key = 'language_model'
+
+ return language_model, language_model_key
+
+
+class Embedding(MegatronModule):
+ """Language model embeddings.
+
+ Arguments:
+ hidden_size: hidden size
+ vocab_size: vocabulary size
+ max_sequence_length: maximum size of sequence. This
+ is used for positional embedding
+ embedding_dropout_prob: dropout probability for embeddings
+ init_method: weight initialization method
+ num_tokentypes: size of the token-type embeddings. 0 value
+ will ignore this embedding
+ """
+
+ def __init__(self,
+ hidden_size,
+ vocab_size,
+ max_sequence_length,
+ embedding_dropout_prob,
+ init_method,
+ num_tokentypes=0):
+ super(Embedding, self).__init__()
+
+ self.hidden_size = hidden_size
+ self.init_method = init_method
+ self.num_tokentypes = num_tokentypes
+
+ # Word embeddings (parallel).
+ self.word_embeddings = mpu.VocabParallelEmbedding(
+ vocab_size, self.hidden_size, init_method=self.init_method)
+ self._word_embeddings_key = 'word_embeddings'
+ self.vocab_size = vocab_size
+
+ # Position embedding (serial).
+ self.position_embeddings = torch.nn.Embedding(
+ max_sequence_length, self.hidden_size)
+ self.position_embeddings = self.position_embeddings.half()
+ self._position_embeddings_key = 'position_embeddings'
+ # Initialize the position embeddings.
+ self.init_method(self.position_embeddings.weight)
+
+ # Token type embedding.
+ # Add this as an optional field that can be added through
+ # method call so we can load a pretrain model without
+ # token types and add them as needed.
+ self._tokentype_embeddings_key = 'tokentype_embeddings'
+ if self.num_tokentypes > 0:
+ self.tokentype_embeddings = torch.nn.Embedding(self.num_tokentypes,
+ self.hidden_size)
+ # Initialize the token-type embeddings.
+ self.init_method(self.tokentype_embeddings.weight)
+ else:
+ self.tokentype_embeddings = None
+
+ # Embeddings dropout
+ self.embedding_dropout = torch.nn.Dropout(embedding_dropout_prob)
+
+ def add_tokentype_embeddings(self, num_tokentypes):
+ """Add token-type embedding. This function is provided so we can add
+ token-type embeddings in case the pretrained model does not have it.
+ This allows us to load the model normally and then add this embedding.
+ """
+ if self.tokentype_embeddings is not None:
+ raise Exception('tokentype embeddings is already initialized')
+ if torch.distributed.get_rank() == 0:
+ print('adding embedding for {} tokentypes'.format(num_tokentypes),
+ flush=True)
+ self.num_tokentypes = num_tokentypes
+ self.tokentype_embeddings = torch.nn.Embedding(num_tokentypes,
+ self.hidden_size)
+ # Initialize the token-type embeddings.
+ self.init_method(self.tokentype_embeddings.weight)
+
+ def forward(self, input_ids, position_ids, tokentype_ids=None):
+ # Embeddings.
+ words_embeddings = self.word_embeddings(input_ids)
+ position_embeddings = self.position_embeddings(position_ids)
+ embeddings = words_embeddings + position_embeddings
+ if tokentype_ids is not None:
+ assert self.tokentype_embeddings is not None
+ embeddings = embeddings + self.tokentype_embeddings(tokentype_ids)
+ else:
+ assert self.tokentype_embeddings is None
+
+ # Dropout.
+ embeddings = self.embedding_dropout(embeddings)
+
+ return embeddings
+
+ def state_dict_for_save_checkpoint(self, destination=None, prefix='',
+ keep_vars=False):
+ """For easy load."""
+
+ state_dict_ = {}
+ state_dict_[self._word_embeddings_key] \
+ = self.word_embeddings.state_dict(destination, prefix, keep_vars)
+ state_dict_[self._position_embeddings_key] \
+ = self.position_embeddings.state_dict(
+ destination, prefix, keep_vars)
+ if self.num_tokentypes > 0:
+ state_dict_[self._tokentype_embeddings_key] \
+ = self.tokentype_embeddings.state_dict(
+ destination, prefix, keep_vars)
+
+ return state_dict_
+
+ def load_state_dict(self, state_dict, strict=True):
+ """Customized load."""
+
+ # Word embedding.
+ if self._word_embeddings_key in state_dict:
+ state_dict_ = state_dict[self._word_embeddings_key]
+ else:
+ # for backward compatibility.
+ state_dict_ = {}
+ for key in state_dict.keys():
+ if 'word_embeddings' in key:
+ state_dict_[key.split('word_embeddings.')[1]] \
+ = state_dict[key]
+ state_dict_["weight"] = state_dict_["weight"][:self.vocab_size]
+ self.word_embeddings.load_state_dict(state_dict_, strict=strict)
+
+ # Position embedding.
+ if self._position_embeddings_key in state_dict:
+ state_dict_ = state_dict[self._position_embeddings_key]
+ else:
+ # for backward compatibility.
+ state_dict_ = {}
+ for key in state_dict.keys():
+ if 'position_embeddings' in key:
+ state_dict_[key.split('position_embeddings.')[1]] \
+ = state_dict[key]
+ self.position_embeddings.load_state_dict(state_dict_, strict=strict)
+
+ # Tokentype embedding.
+ if self.num_tokentypes > 0:
+ state_dict_ = {}
+ if self._tokentype_embeddings_key in state_dict:
+ state_dict_ = state_dict[self._tokentype_embeddings_key]
+ else:
+ # for backward compatibility.
+ for key in state_dict.keys():
+ if 'tokentype_embeddings' in key:
+ state_dict_[key.split('tokentype_embeddings.')[1]] \
+ = state_dict[key]
+ if len(state_dict_.keys()) > 0:
+ self.tokentype_embeddings.load_state_dict(state_dict_,
+ strict=strict)
+ else:
+ print('***WARNING*** expected tokentype embeddings in the '
+ 'checkpoint but could not find it', flush=True)
+
+
+class QueryEmbedding(MegatronModule):
+ """Language model embeddings.
+
+ Arguments:
+ hidden_size: hidden size
+ vocab_size: vocabulary size
+ max_sequence_length: maximum size of sequence. This
+ is used for positional embedding
+ embedding_dropout_prob: dropout probability for embeddings
+ init_method: weight initialization method
+ num_tokentypes: size of the token-type embeddings. 0 value
+ will ignore this embedding
+ """
+
+ def __init__(self,
+ hidden_size,
+ vocab_size,
+ max_sequence_length,
+ embedding_dropout_prob,
+ init_method,
+ num_tokentypes=0):
+ super(QueryEmbedding, self).__init__()
+
+ self.hidden_size = hidden_size
+ self.init_method = init_method
+ self.num_tokentypes = num_tokentypes
+
+ # Top query position embedding (serial).
+ self.top_query_embeddings = torch.nn.Embedding(
+ max_sequence_length, self.hidden_size)
+ self.top_query_embeddings = self.top_query_embeddings.half()
+ self._top_query_embeddings_key = 'top_query_embeddings'
+ # Initialize the top query position embeddings.
+ self.init_method(self.top_query_embeddings.weight)
+
+ # Token type embedding.
+ # Add this as an optional field that can be added through
+ # method call so we can load a pretrain model without
+ # token types and add them as needed.
+ self._tokentype_embeddings_key = 'tokentype_embeddings'
+ if self.num_tokentypes > 0:
+ self.tokentype_embeddings = torch.nn.Embedding(self.num_tokentypes,
+ self.hidden_size)
+ # Initialize the token-type embeddings.
+ self.init_method(self.tokentype_embeddings.weight)
+ else:
+ self.tokentype_embeddings = None
+
+ # Embeddings dropout
+ self.embedding_dropout = torch.nn.Dropout(embedding_dropout_prob)
+
+ def add_tokentype_embeddings(self, num_tokentypes):
+ """Add token-type embedding. This function is provided so we can add
+ token-type embeddings in case the pretrained model does not have it.
+ This allows us to load the model normally and then add this embedding.
+ """
+ if self.tokentype_embeddings is not None:
+ raise Exception('tokentype embeddings is already initialized')
+ if torch.distributed.get_rank() == 0:
+ print('adding embedding for {} tokentypes'.format(num_tokentypes),
+ flush=True)
+ self.num_tokentypes = num_tokentypes
+ self.tokentype_embeddings = torch.nn.Embedding(num_tokentypes,
+ self.hidden_size)
+ # Initialize the token-type embeddings.
+ self.init_method(self.tokentype_embeddings.weight)
+
+ def forward(self, position_ids, tokentype_ids=None):
+ # Embeddings.
+
+ embeddings = self.top_query_embeddings(position_ids)
+ if tokentype_ids is not None:
+ assert self.tokentype_embeddings is not None
+ embeddings = embeddings + self.tokentype_embeddings(tokentype_ids)
+ else:
+ assert self.tokentype_embeddings is None
+
+ # Dropout.
+ embeddings = self.embedding_dropout(embeddings)
+
+ return embeddings
+
+ def state_dict_for_save_checkpoint(self, destination=None, prefix='',
+ keep_vars=False):
+ """For easy load."""
+
+ state_dict_ = {}
+ state_dict_[self._top_query_embeddings_key] \
+ = self.top_query_embeddings.state_dict(
+ destination, prefix, keep_vars)
+ if self.num_tokentypes > 0:
+ state_dict_[self._tokentype_embeddings_key] \
+ = self.tokentype_embeddings.state_dict(
+ destination, prefix, keep_vars)
+
+ return state_dict_
+
+ def load_state_dict(self, state_dict, strict=True):
+ """Customized load."""
+
+ # Position embedding.
+ if self._top_query_embeddings_key in state_dict:
+ state_dict_ = state_dict[self._top_query_embeddings_key]
+ else:
+ # for backward compatibility.
+ state_dict_ = {}
+ for key in state_dict.keys():
+ if 'top_query_embeddings' in key:
+ state_dict_[key.split('top_query_embeddings.')[1]] \
+ = state_dict[key]
+ self.top_query_embeddings.load_state_dict(state_dict_, strict=strict)
+
+ # Tokentype embedding.
+ if self.num_tokentypes > 0:
+ state_dict_ = {}
+ if self._tokentype_embeddings_key in state_dict:
+ state_dict_ = state_dict[self._tokentype_embeddings_key]
+ else:
+ # for backward compatibility.
+ for key in state_dict.keys():
+ if 'tokentype_embeddings' in key:
+ state_dict_[key.split('tokentype_embeddings.')[1]] \
+ = state_dict[key]
+ if len(state_dict_.keys()) > 0:
+ self.tokentype_embeddings.load_state_dict(state_dict_,
+ strict=strict)
+ else:
+ print('***WARNING*** expected tokentype embeddings in the '
+ 'checkpoint but could not find it', flush=True)
+
+
+class TransformerLanguageModel(MegatronModule):
+ """Transformer language model.
+
+ Arguments:
+ transformer_hparams: transformer hyperparameters
+ attention_mask_func: a function that takes `unmaksed-attention-scores`
+ with size [b, np, s, s] and an `attention-mask` and will apply
+ the masking. The function should return a masked score of the
+ same size [b, np, s, s].
+ masked-attention-scores = attention_mask_func(
+ unmaksed-attention-scores, attention-mask)
+ vocab_size: vocabulary size
+ max_sequence_length: maximum size of sequence. This
+ is used for positional embedding
+ embedding_dropout_prob: dropout probability for embeddings
+ num_tokentypes: size of the token-type embeddings. 0 value
+ will ignore this embedding
+ """
+
+ def __init__(self,
+ init_method,
+ output_layer_init_method,
+ num_tokentypes=0,
+ add_pooler=False):
+ super(TransformerLanguageModel, self).__init__()
+ args = get_args()
+
+ self.hidden_size = args.hidden_size
+ self.num_tokentypes = num_tokentypes
+ self.init_method = init_method
+ self.add_pooler = add_pooler
+
+ # Embeddings
+ self.embedding = Embedding(self.hidden_size,
+ args.padded_vocab_size,
+ args.max_position_embeddings,
+ args.hidden_dropout,
+ self.init_method,
+ self.num_tokentypes)
+ self._embedding_key = 'embedding'
+
+ # Query embeddings
+ self.topQueryEmbedding = QueryEmbedding(self.hidden_size,
+ args.padded_vocab_size,
+ args.max_position_embeddings,
+ args.hidden_dropout,
+ self.init_method,
+ self.num_tokentypes)
+ self._topQueryEmbedding_key = 'topQueryEmbedding'
+
+ # Transformer
+ self.transformer = ParallelTransformer(
+ self.init_method,
+ output_layer_init_method)
+ self._transformer_key = 'transformer'
+
+ def forward(
+ self,
+ input_ids,
+ position_ids,
+ attention_mask,
+ tokentype_ids=None,
+ layer_past=None,
+ get_key_value=False,
+ pooling_sequence_index=0,
+ prompt_length=None,
+ context_length=None,
+ ):
+
+ # Embeddings.
+ embedding_output = self.embedding(input_ids, position_ids,
+ tokentype_ids=tokentype_ids)
+ query_position_ids = position_ids
+ queryEmbedding_out = self.topQueryEmbedding(query_position_ids,
+ tokentype_ids=tokentype_ids)
+
+ # Transformer.
+ transformer_output = self.transformer(embedding_output,
+ queryEmbedding_out,
+ attention_mask,
+ layer_past=layer_past,
+ get_key_value=get_key_value,
+ prompt_length=prompt_length,
+ context_length=context_length, )
+
+ return transformer_output
+
+ def state_dict_for_save_checkpoint(self, destination=None, prefix='',
+ keep_vars=False):
+ """For easy load."""
+
+ state_dict_ = {}
+ state_dict_[self._embedding_key] \
+ = self.embedding.state_dict_for_save_checkpoint(
+ destination, prefix, keep_vars)
+ state_dict_[self._topQueryEmbedding_key] \
+ = self.topQueryEmbedding.state_dict_for_save_checkpoint(
+ destination, prefix, keep_vars)
+ state_dict_[self._transformer_key] \
+ = self.transformer.state_dict_for_save_checkpoint(
+ destination, prefix, keep_vars)
+ if self.add_pooler:
+ state_dict_[self._pooler_key] \
+ = self.pooler.state_dict_for_save_checkpoint(
+ destination, prefix, keep_vars)
+
+ return state_dict_
+
+ def load_state_dict(self, state_dict, strict=True):
+ """Customized load."""
+
+ # Embedding.
+ if self._embedding_key in state_dict:
+ state_dict_ = state_dict[self._embedding_key]
+ else:
+ # for backward compatibility.
+ state_dict_ = {}
+ for key in state_dict.keys():
+ if '_embeddings' in key:
+ state_dict_[key] = state_dict[key]
+ self.embedding.load_state_dict(state_dict_, strict=strict)
+
+ if self._topQueryEmbedding_key in state_dict:
+ state_dict_ = state_dict[self._topQueryEmbedding_key]
+ else:
+ # for backward compatibility.
+ state_dict_ = {}
+ for key in state_dict.keys():
+ if '_embeddings' in key:
+ state_dict_[key] = state_dict[key]
+ self.topQueryEmbedding.load_state_dict(state_dict_, strict=strict)
+
+ # Transformer.
+ if self._transformer_key in state_dict:
+ state_dict_ = state_dict[self._transformer_key]
+ else:
+ # for backward compatibility.
+ state_dict_ = {}
+ for key in state_dict.keys():
+ if 'transformer.' in key:
+ state_dict_[key.split('transformer.')[1]] = state_dict[key]
+ self.transformer.load_state_dict(state_dict_, strict=strict)
+
+ # Pooler.
+ if self.add_pooler:
+ assert 'pooler' in state_dict, \
+ 'could not find data for pooler in the checkpoint'
+ self.pooler.load_state_dict(state_dict[self._pooler_key],
+ strict=strict)
diff --git a/codegeex/megatron/model/module.py b/codegeex/megatron/model/module.py
new file mode 100644
index 0000000..f9203fd
--- /dev/null
+++ b/codegeex/megatron/model/module.py
@@ -0,0 +1,199 @@
+# coding=utf-8
+# Copyright (c) 2020, NVIDIA CORPORATION. All rights reserved.
+#
+# Licensed under the Apache License, Version 2.0 (the "License");
+# you may not use this file except in compliance with the License.
+# You may obtain a copy of the License at
+#
+# http://www.apache.org/licenses/LICENSE-2.0
+#
+# Unless required by applicable law or agreed to in writing, software
+# distributed under the License is distributed on an "AS IS" BASIS,
+# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
+# See the License for the specific language governing permissions and
+# limitations under the License.
+
+"""Megatron Module"""
+
+import torch
+from torch.autograd import Variable
+from torch.nn.parameter import Parameter
+
+from codegeex.megatron import get_args
+from codegeex.megatron import mpu
+
+
+_FLOAT_TYPES = (torch.FloatTensor, torch.cuda.FloatTensor)
+_HALF_TYPES = (torch.HalfTensor, torch.cuda.HalfTensor)
+_BF16_TYPES = (torch.BFloat16Tensor, torch.cuda.BFloat16Tensor)
+
+
+def param_is_not_shared(param):
+ return not hasattr(param, "shared") or not param.shared
+
+
+class MegatronModule(torch.nn.Module):
+ """Megatron specific extensions of torch Module with support
+ for pipelining."""
+
+ def __init__(self, share_word_embeddings=True):
+ super(MegatronModule, self).__init__()
+ self.share_word_embeddings = share_word_embeddings
+
+ def state_dict_for_save_checkpoint(
+ self, destination=None, prefix="", keep_vars=False
+ ):
+ """Use this function to override the state dict for
+ saving checkpoints."""
+ return self.state_dict(destination, prefix, keep_vars)
+
+ def word_embeddings_weight(self):
+ if mpu.is_pipeline_first_stage(ignore_virtual=True):
+ return self.language_model.embedding.word_embeddings.weight
+ if mpu.is_pipeline_last_stage(ignore_virtual=True):
+ if not self.share_word_embeddings:
+ raise Exception(
+ "word_embeddings_weight() called for last "
+ "stage, but share_word_embeddings is false"
+ )
+ return self.word_embeddings.weight
+ raise Exception(
+ "word_embeddings_weight() should be " "called for first and last stage only"
+ )
+
+ def initialize_word_embeddings(self, init_method_normal):
+ args = get_args()
+ if not self.share_word_embeddings:
+ raise Exception(
+ "initialize_word_embeddings() was called but "
+ "share_word_embeddings is false"
+ )
+
+ # This function just initializes the word embeddings in the final stage
+ # when we are using pipeline parallelism. If we aren't using pipeline
+ # parallelism there is nothing to do.
+ if args.pipeline_model_parallel_size == 1:
+ return
+
+ # Parameters are shared between the word embeddings layer, and the
+ # heads at the end of the model. In a pipelined setup with more than
+ # one stage, the initial embedding layer and the head are on different
+ # workers, so we do the following:
+ # 1. Create a second copy of word_embeddings on the last stage, with
+ # initial parameters of 0.0.
+ # 2. Do an all-reduce between the first and last stage to ensure that
+ # the two copies of word_embeddings start off with the same
+ # parameter values.
+ # 3. In the training loop, before an all-reduce between the grads of
+ # the two word_embeddings layers to ensure that every applied weight
+ # update is the same on both stages.
+ if mpu.is_pipeline_last_stage():
+ assert not mpu.is_pipeline_first_stage()
+ self._word_embeddings_for_head_key = "word_embeddings_for_head"
+ # set word_embeddings weights to 0 here, then copy first
+ # stage's weights using all_reduce below.
+ self.word_embeddings = mpu.VocabParallelEmbedding(
+ args.padded_vocab_size,
+ args.hidden_size,
+ init_method=init_method_normal(args.init_method_std),
+ )
+ self.word_embeddings.weight.data.fill_(0)
+ self.word_embeddings.weight.shared = True
+
+ # Ensure that first and last stages have the same initial parameter
+ # values.
+ if torch.distributed.is_initialized():
+ if mpu.is_pipeline_first_stage() or mpu.is_pipeline_last_stage():
+ torch.distributed.all_reduce(
+ self.word_embeddings_weight().data, group=mpu.get_embedding_group()
+ )
+ else:
+ print(
+ "WARNING! Distributed processes aren't initialized, so "
+ "word embeddings in the last layer are not initialized. "
+ "If you are just manipulating a model this is fine, but "
+ "this needs to be handled manually. If you are training "
+ "something is definitely wrong."
+ )
+
+
+def conversion_helper(val, conversion):
+ """Apply conversion to val. Recursively apply conversion if `val`
+ #is a nested tuple/list structure."""
+ if not isinstance(val, (tuple, list)):
+ return conversion(val)
+ rtn = [conversion_helper(v, conversion) for v in val]
+ if isinstance(val, tuple):
+ rtn = tuple(rtn)
+ return rtn
+
+
+def fp32_to_float16(val, float16_convertor):
+ """Convert fp32 `val` to fp16/bf16"""
+
+ def half_conversion(val):
+ val_typecheck = val
+ if isinstance(val_typecheck, (Parameter, Variable)):
+ val_typecheck = val.data
+ if isinstance(val_typecheck, _FLOAT_TYPES):
+ val = float16_convertor(val)
+ return val
+
+ return conversion_helper(val, half_conversion)
+
+
+def float16_to_fp32(val):
+ """Convert fp16/bf16 `val` to fp32"""
+
+ def float_conversion(val):
+ val_typecheck = val
+ if isinstance(val_typecheck, (Parameter, Variable)):
+ val_typecheck = val.data
+ if isinstance(val_typecheck, (_BF16_TYPES, _HALF_TYPES)):
+ val = val.float()
+ return val
+
+ return conversion_helper(val, float_conversion)
+
+
+class Float16Module(MegatronModule):
+ def __init__(self, module, args):
+ super(Float16Module, self).__init__()
+
+ if args.fp16:
+ self.add_module("module", module.half())
+
+ def float16_convertor(val):
+ return val.half()
+
+ elif args.bf16:
+ self.add_module("module", module.bfloat16())
+
+ def float16_convertor(val):
+ return val.bfloat16()
+
+ else:
+ raise Exception("should not be here")
+
+ self.float16_convertor = float16_convertor
+
+ def forward(self, *inputs, **kwargs):
+ if mpu.is_pipeline_first_stage():
+ inputs = fp32_to_float16(inputs, self.float16_convertor)
+ outputs = self.module(*inputs, **kwargs)
+ if mpu.is_pipeline_last_stage():
+ outputs = float16_to_fp32(outputs)
+ return outputs
+
+ def state_dict(self, destination=None, prefix="", keep_vars=False):
+ return self.module.state_dict(destination, prefix, keep_vars)
+
+ def state_dict_for_save_checkpoint(
+ self, destination=None, prefix="", keep_vars=False
+ ):
+ return self.module.state_dict_for_save_checkpoint(
+ destination, prefix, keep_vars
+ )
+
+ def load_state_dict(self, state_dict, strict=True):
+ self.module.load_state_dict(state_dict, strict=strict)
diff --git a/codegeex/megatron/model/transformer.py b/codegeex/megatron/model/transformer.py
new file mode 100644
index 0000000..a5cee5e
--- /dev/null
+++ b/codegeex/megatron/model/transformer.py
@@ -0,0 +1,970 @@
+# coding=utf-8
+# Copyright (c) 2020, NVIDIA CORPORATION. All rights reserved.
+#
+# Licensed under the Apache License, Version 2.0 (the "License");
+# you may not use this file except in compliance with the License.
+# You may obtain a copy of the License at
+#
+# http://www.apache.org/licenses/LICENSE-2.0
+#
+# Unless required by applicable law or agreed to in writing, software
+# distributed under the License is distributed on an "AS IS" BASIS,
+# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
+# See the License for the specific language governing permissions and
+# limitations under the License.
+
+"""Transformer."""
+
+import math
+import torch
+from torch.nn import LayerNorm
+
+from codegeex.megatron import get_args
+from codegeex.megatron import mpu
+from codegeex.megatron.model.module import MegatronModule
+from codegeex.megatron.model.utils import fast_gelu
+
+# flags required to enable jit fusion kernels
+torch._C._jit_set_profiling_mode(False)
+torch._C._jit_set_profiling_executor(False)
+torch._C._jit_override_can_fuse_on_cpu(True)
+torch._C._jit_override_can_fuse_on_gpu(True)
+
+""" We use the following notation throughout this file:
+ h: hidden size
+ n: number of attention heads
+ p: number of model parallel partitions
+ np: n/p
+ hp: h/p
+ hn: h/n
+ b: batch size
+ s: sequence length
+ l: number of layers
+ Transformer takes input of size [s, b, h] and returns a
+ tensor of the same size. We use the following arguments:
+ hyperparameters: transformer hyperparameters
+ attention_mask_func: a function that takes `unmaksed-attention-scores`
+ with size [b, np, s, s] and an `attention-mask` and will apply
+ the masking. The function should return a masked score of the
+ same size [b, np, s, s].
+ masked-attention-scores = attention_mask_func(
+ unmaksed-attention-scores, attention-mask)
+"""
+
+
+class ParallelMLP(MegatronModule):
+ """MLP.
+
+ MLP will take the input with h hidden state, project it to 4*h
+ hidden dimension, perform nonlinear transformation, and project the
+ state back into h hidden dimension. At the end, dropout is also
+ applied.
+ """
+
+ def __init__(self, init_method, output_layer_init_method):
+ super(ParallelMLP, self).__init__()
+ args = get_args()
+
+ # Project to 4h.
+ self.dense_h_to_4h = mpu.ColumnParallelLinear(
+ args.hidden_size,
+ 4 * args.hidden_size,
+ gather_output=False,
+ init_method=init_method,
+ # skip_bias_add=True,
+ )
+
+ self.activation_func = fast_gelu
+
+ # Project back to h.
+ self.dense_4h_to_h = mpu.RowParallelLinear(
+ 4 * args.hidden_size,
+ args.hidden_size,
+ input_is_parallel=False,
+ init_method=output_layer_init_method,
+ # skip_bias_add=True,
+ )
+
+ def forward(self, hidden_states):
+ # [s, b, 4hp]
+ intermediate_parallel, _ = self.dense_h_to_4h(hidden_states)
+ intermediate_parallel = self.activation_func(intermediate_parallel)
+ # [s, b, h]
+ output, output_bias = self.dense_4h_to_h(intermediate_parallel)
+
+ return output, output_bias
+
+
+class ParallelSelfAttention(MegatronModule):
+ """Parallel self-attention layer abstract class.
+
+ Self-attention layer takes input with size [b, s, h]
+ and returns output of the same size.
+ """
+
+ def __init__(self, init_method,
+ output_layer_init_method, layer_number):
+ super(ParallelSelfAttention, self).__init__()
+ args = get_args()
+ self.fp16 = args.fp16
+ self.attention_softmax_in_fp32 = args.attention_softmax_in_fp32
+ self.layer_number = max(1, layer_number)
+
+ # Per attention head and per partition values.
+ world_size = mpu.get_model_parallel_world_size()
+ self.hidden_size_per_partition = mpu.divide(args.hidden_size,
+ world_size)
+ self.hidden_size_per_attention_head = mpu.divide(
+ args.hidden_size, args.num_attention_heads)
+ self.num_attention_heads_per_partition = mpu.divide(
+ args.num_attention_heads, world_size)
+ if hasattr(args, 'attention_upweight'):
+ self.attention_upweight = args.attention_upweight
+ else:
+ self.attention_upweight = None
+ # Strided linear layer.
+ self.query = mpu.ColumnParallelLinear(
+ args.hidden_size,
+ args.hidden_size,
+ gather_output=False,
+ init_method=init_method)
+ self.key = mpu.ColumnParallelLinear(
+ args.hidden_size,
+ args.hidden_size,
+ gather_output=False,
+ init_method=init_method)
+ self.value = mpu.ColumnParallelLinear(
+ args.hidden_size,
+ args.hidden_size,
+ gather_output=False,
+ init_method=init_method)
+
+ self.norm_factor = math.sqrt(self.hidden_size_per_attention_head)
+ self.softmax = torch.nn.Softmax(dim=-1)
+
+ # Dropout. Note that for a single iteration, this layer will generate
+ # different outputs on different number of parallel partitions but
+ # on average it should not be partition dependent.
+ self.attention_dropout = torch.nn.Dropout(args.attention_dropout)
+
+ # Output.
+ self.dense = mpu.RowParallelLinear(
+ args.hidden_size,
+ args.hidden_size,
+ input_is_parallel=False,
+ init_method=output_layer_init_method,
+ skip_bias_add=True)
+
+ def forward(
+ self,
+ hidden_states,
+ attention_mask,
+ layer_past=None,
+ get_key_value=False,
+ prompt_length=None,
+ context_length=None,
+ ):
+ # hidden_states: [sq, b, h]
+
+ # =====================
+ # Query, Key, and Value
+ # =====================
+
+ query_layer, _ = self.query(hidden_states)
+ key_layer, _ = self.key(hidden_states)
+ value_layer, _ = self.value(hidden_states)
+
+ new_query_layer_shape = query_layer.size()[:-1] + \
+ (self.num_attention_heads_per_partition,
+ self.hidden_size_per_attention_head)
+ query_layer = query_layer.view(*new_query_layer_shape)
+
+ new_query_layer_shape = key_layer.size()[:-1] + \
+ (self.num_attention_heads_per_partition,
+ self.hidden_size_per_attention_head)
+ key_layer = key_layer.view(*new_query_layer_shape)
+
+ new_query_layer_shape = value_layer.size()[:-1] + \
+ (self.num_attention_heads_per_partition,
+ self.hidden_size_per_attention_head)
+ value_layer = value_layer.view(*new_query_layer_shape)
+
+ # ==================================
+ # Adjust key and value for inference
+ # ==================================
+
+ if layer_past is not None:
+ past_key, past_value = layer_past
+ key_layer = torch.cat((past_key.type_as(key_layer),
+ key_layer), dim=0)
+ value_layer = torch.cat((past_value.type_as(value_layer),
+ value_layer), dim=0)
+ if get_key_value:
+ present = (key_layer, value_layer)
+
+ # ===================================
+ # Raw attention scores. [b, np, sq, sk]
+ # ===================================
+
+ # [b, np, sq, sk]
+ output_size = (query_layer.size(1),
+ query_layer.size(2),
+ query_layer.size(0),
+ key_layer.size(0))
+
+ # [sq, b, np, hn] -> [sq, b * np, hn]
+ query_layer = query_layer.contiguous().view(output_size[2], output_size[0] * output_size[1], -1)
+ key_layer = key_layer.contiguous().view(output_size[3], output_size[0] * output_size[1], -1)
+
+ # Raw attention scores. [b * np, sq, sk]
+ matmul_result = torch.matmul(query_layer.transpose(0, 1),
+ key_layer.transpose(0, 1).transpose(1, 2)) / self.norm_factor
+
+ # change view to [b, np, sq, sk]
+ attention_scores = matmul_result.view(*output_size)
+
+ if self.attention_upweight is not None and layer_past is None:
+ log_attention_weights = torch.zeros(attention_scores.size(3), attention_scores.size(3),
+ device=torch.cuda.current_device(),
+ dtype=torch.half if self.fp16 else torch.float32)
+ if prompt_length is None:
+ log_attention_weights = self.attention_upweight
+ else:
+ log_attention_weights[:prompt_length, :prompt_length] = self.attention_upweight
+ attention_scores += log_attention_weights
+
+ # ==================================================
+ # Update attention mask for inference. [b, np, sq, sk]
+ # ==================================================
+
+ if get_key_value:
+ with torch.no_grad():
+ if layer_past is not None:
+ attention_mask = attention_mask[
+ ...,
+ attention_scores.size(3) - 1,
+ :attention_scores.size(3)].unsqueeze(2)
+ else:
+ attention_mask = attention_mask[
+ ...,
+ :attention_scores.size(3),
+ :attention_scores.size(3)]
+
+ # ===========================
+ # Attention probs and dropout
+ # ===========================
+
+ if context_length is not None:
+ attention_mask = torch.clone(attention_mask)
+ attention_mask[:, :, context_length:, :] = True
+
+ # attention scores and attention mask [b, np, sq, sk]
+ # attention_scores = attention_mask_func(attention_scores, attention_mask)
+ attention_scores = attention_scores - attention_mask * 10000.0
+ if self.attention_softmax_in_fp32:
+ attention_probs = self.softmax(attention_scores.float()).half()
+ else:
+ attention_probs = self.softmax(attention_scores)
+
+ # This is actually dropping out entire tokens to attend to, which might
+ # seem a bit unusual, but is taken from the original Transformer paper.
+ with mpu.get_cuda_rng_tracker().fork():
+ attention_probs = self.attention_dropout(attention_probs)
+
+ # =========================
+ # Context layer. [sq, b, hp]
+ # =========================
+
+ # value_layer -> context layer.
+ # [sq, b, np, hn] --> [b, np, sq, hn]
+
+ # context layer shape: [b, np, sq, hn]
+ output_size = (value_layer.size(1),
+ value_layer.size(2),
+ query_layer.size(0),
+ value_layer.size(3))
+
+ # change view [sq, b * np, hn]
+ value_layer = value_layer.view(value_layer.size(0), output_size[0] * output_size[1], -1)
+
+ # change view [b * np, sq, sk]
+ attention_probs = attention_probs.view(output_size[0] * output_size[1],
+ output_size[2], -1)
+
+ context_layer = torch.bmm(attention_probs, value_layer.unsqueeze(0).transpose(1, 2).squeeze(0))
+
+ # change view [b, np, sq, hn]
+ context_layer = context_layer.view(*output_size)
+
+ # # [b, np, sq, hn] --> [sq, b, np, hn]
+ context_layer = context_layer.permute(2, 0, 1, 3).contiguous()
+
+ # # [sq, b, np, hn] --> [sq, b, hp]
+ new_context_layer_shape = context_layer.size()[:-2] + \
+ (self.hidden_size_per_partition,)
+ context_layer = context_layer.view(*new_context_layer_shape)
+
+ # =================
+ # Output. [sq, b, h]
+ # =================
+
+ output, bias = self.dense(context_layer)
+
+ if get_key_value:
+ output = [output, present]
+
+ return output, bias
+
+
+class ParallelTopQuerySelfAttention(MegatronModule):
+ """Parallel top query self-attention layer abstract class.
+
+ Self-attention layer takes input with size [b, s, h]
+ and returns output of the same size.
+ """
+
+ def __init__(self, init_method,
+ output_layer_init_method, layer_number):
+ super(ParallelTopQuerySelfAttention, self).__init__()
+ args = get_args()
+ self.fp16 = args.fp16
+ self.attention_softmax_in_fp32 = args.attention_softmax_in_fp32
+ self.layer_number = max(1, layer_number)
+
+ if hasattr(args, 'attention_upweight_top'):
+ self.attention_upweight = args.attention_upweight_top
+ else:
+ self.attention_upweight = None
+ # Per attention head and per partition values.
+ world_size = mpu.get_model_parallel_world_size()
+ self.hidden_size_per_partition = mpu.divide(args.hidden_size,
+ world_size)
+ self.hidden_size_per_attention_head = mpu.divide(
+ args.hidden_size, args.num_attention_heads)
+ self.num_attention_heads_per_partition = mpu.divide(
+ args.num_attention_heads, world_size)
+
+ self.query = mpu.ColumnParallelLinear(
+ args.hidden_size,
+ args.hidden_size,
+ gather_output=False,
+ init_method=init_method)
+
+ self.key = mpu.ColumnParallelLinear(
+ args.hidden_size,
+ args.hidden_size,
+ gather_output=False,
+ init_method=init_method)
+
+ self.value = mpu.ColumnParallelLinear(
+ args.hidden_size,
+ args.hidden_size,
+ gather_output=False,
+ init_method=init_method)
+
+ self.norm_factor = math.sqrt(self.hidden_size_per_attention_head)
+ self.softmax = torch.nn.Softmax(dim=-1)
+
+ # Dropout. Note that for a single iteration, this layer will generate
+ # different outputs on different number of parallel partitions but
+ # on average it should not be partition dependent.
+ self.attention_dropout = torch.nn.Dropout(args.attention_dropout)
+
+ # Output.
+ self.dense = mpu.RowParallelLinear(
+ args.hidden_size,
+ args.hidden_size,
+ input_is_parallel=False,
+ init_method=output_layer_init_method,
+ skip_bias_add=True)
+
+ def forward(
+ self,
+ hidden_states,
+ query_hidden_state,
+ attention_mask,
+ layer_past=None,
+ get_key_value=False,
+ prompt_length=None,
+ context_length=None,
+ ):
+
+ # hidden_states: [sq, b, h]
+
+ query_layer, _ = self.query(query_hidden_state)
+ key_layer, _ = self.key(hidden_states)
+ value_layer, _ = self.value(hidden_states)
+
+ new_query_layer_shape = query_layer.size()[:-1] + \
+ (self.num_attention_heads_per_partition,
+ self.hidden_size_per_attention_head)
+ query_layer = query_layer.view(*new_query_layer_shape)
+
+ new_query_layer_shape = key_layer.size()[:-1] + \
+ (self.num_attention_heads_per_partition,
+ self.hidden_size_per_attention_head)
+ key_layer = key_layer.view(*new_query_layer_shape)
+
+ new_query_layer_shape = value_layer.size()[:-1] + \
+ (self.num_attention_heads_per_partition,
+ self.hidden_size_per_attention_head)
+ value_layer = value_layer.view(*new_query_layer_shape)
+
+ # ==================================
+ # Adjust key and value for inference
+ # ==================================
+
+ if layer_past is not None:
+ past_key, past_value = layer_past
+ key_layer = torch.cat((past_key.type_as(key_layer),
+ key_layer), dim=0)
+ value_layer = torch.cat((past_value.type_as(value_layer),
+ value_layer), dim=0)
+ if get_key_value:
+ present = (key_layer, value_layer)
+
+ # ===================================
+ # Raw attention scores. [b, np, sq, sk]
+ # ===================================
+
+ # [b, np, sq, sk]
+ output_size = (query_layer.size(1),
+ query_layer.size(2),
+ query_layer.size(0),
+ key_layer.size(0))
+
+ # [s, b, np, hn] -> [s, b * np, hn]
+ query_layer = query_layer.contiguous().view(output_size[2], output_size[0] * output_size[1], -1)
+ key_layer = key_layer.contiguous().view(output_size[3], output_size[0] * output_size[1], -1)
+
+ # Raw attention scores. [b * np, sq, sk]
+ matmul_result = torch.matmul(query_layer.transpose(0, 1),
+ key_layer.transpose(0, 1).transpose(1, 2)) / self.norm_factor
+
+ # change view to [b, np, s, s]
+ attention_scores = matmul_result.view(*output_size)
+
+ if self.attention_upweight is not None and layer_past is None:
+ log_attention_weights = torch.zeros(attention_scores.size(3), attention_scores.size(3),
+ device=torch.cuda.current_device(),
+ dtype=torch.half if self.fp16 else torch.float32)
+ if prompt_length is None:
+ log_attention_weights = self.attention_upweight
+ else:
+ log_attention_weights[:prompt_length, :prompt_length] = self.attention_upweight
+ attention_scores += log_attention_weights
+
+ # ==================================================
+ # Update attention mask for inference. [b, np, sq, sk]
+ # ==================================================
+
+ if get_key_value:
+ with torch.no_grad():
+ if layer_past is not None:
+ attention_mask = attention_mask[
+ ...,
+ attention_scores.size(3) - 1,
+ :attention_scores.size(3)].unsqueeze(2)
+ else:
+ attention_mask = attention_mask[
+ ...,
+ :attention_scores.size(3),
+ :attention_scores.size(3)]
+
+ # ===========================
+ # Attention probs and dropout
+ # ===========================
+
+ if context_length is not None:
+ attention_mask = torch.clone(attention_mask)
+ attention_mask[:, :, context_length:, :] = True
+
+ # attention scores and attention mask [b, np, sq, sk]
+ # attention_scores = attention_mask_func(attention_scores, attention_mask)
+ attention_scores = attention_scores - attention_mask * 10000.0
+ if self.attention_softmax_in_fp32:
+ attention_probs = self.softmax(attention_scores.float()).half()
+ else:
+ attention_probs = self.softmax(attention_scores)
+
+ # This is actually dropping out entire tokens to attend to, which might
+ # seem a bit unusual, but is taken from the original Transformer paper.
+ with mpu.get_cuda_rng_tracker().fork():
+ attention_probs = self.attention_dropout(attention_probs)
+
+ # =========================
+ # Context layer. [sq, b, hp]
+ # =========================
+
+ # value_layer -> context layer.
+ # [sq, b, np, hn] --> [b, np, sq, hn]
+
+ # context layer shape: [b, np, sq, hn]
+ output_size = (value_layer.size(1),
+ value_layer.size(2),
+ query_layer.size(0),
+ value_layer.size(3))
+
+ # change view [sq, b * np, hn]
+ value_layer = value_layer.view(value_layer.size(0), output_size[0] * output_size[1], -1)
+
+ # change view [b * np, sq, sk]
+ attention_probs = attention_probs.view(output_size[0] * output_size[1],
+ output_size[2], -1)
+
+ # matmul: [b * np, sq, hn]
+ context_layer = torch.bmm(attention_probs, value_layer.unsqueeze(0).transpose(1, 2).squeeze(0))
+
+ # change view [b, np, sq, hn]
+ context_layer = context_layer.view(*output_size)
+
+ # [b, np, sq, hn] --> [sq, b, np, hn]
+ context_layer = context_layer.permute(2, 0, 1, 3).contiguous()
+
+ # [sq, b, np, hn] --> [sq, b, hp]
+ new_context_layer_shape = context_layer.size()[:-2] + \
+ (self.hidden_size_per_partition,)
+ context_layer = context_layer.view(*new_context_layer_shape)
+
+ # =================
+ # Output. [sq, b, h]
+ # =================
+
+ output, bias = self.dense(context_layer)
+
+ if get_key_value:
+ output = [output, present]
+
+ return output, bias
+
+
+def bias_dropout_add(x, bias, residual, prob, training):
+ # type: (Tensor, Tensor, Tensor, float, bool) -> Tensor
+ out = torch.nn.functional.dropout(x + bias, p=prob, training=training)
+ out = residual + out
+ return out
+
+
+def get_bias_dropout_add(training):
+ def _bias_dropout_add(x, bias, residual, prob):
+ return bias_dropout_add(x, bias, residual, prob, training)
+
+ return _bias_dropout_add
+
+
+@torch.jit.script
+def bias_dropout_add_fused_train(x, bias, residual, prob):
+ # type: (Tensor, Tensor, Tensor, float) -> Tensor
+ return bias_dropout_add(x, bias, residual, prob, True)
+
+
+@torch.jit.script
+def bias_dropout_add_fused_inference(x, bias, residual, prob):
+ # type: (Tensor, Tensor, Tensor, float) -> Tensor
+ return bias_dropout_add(x, bias, residual, prob, False)
+
+
+class ParallelTransformerLayer(MegatronModule):
+ """A single transformer layer.
+
+ Transformore layer takes input with size [b, s, h] and returns an
+ output of the same size.
+ """
+
+ def __init__(self, init_method,
+ output_layer_init_method, layer_number):
+ args = get_args()
+
+ super(ParallelTransformerLayer, self).__init__()
+ self.layer_number = layer_number
+
+ self.apply_residual_connection_post_layernorm \
+ = args.apply_residual_connection_post_layernorm
+
+ # Layernorm on the input data.
+ self.input_layernorm = LayerNorm(
+ args.hidden_size,
+ eps=args.layernorm_epsilon)
+
+ # Self attention.
+ self.attention = ParallelSelfAttention(init_method,
+ output_layer_init_method,
+ layer_number)
+ self.hidden_dropout = args.hidden_dropout
+ self.bias_dropout_fusion = args.bias_dropout_fusion
+
+ # Layernorm on the input data.
+ self.post_attention_layernorm = LayerNorm(
+ args.hidden_size,
+ eps=args.layernorm_epsilon)
+ if hasattr(args, 'attention_upweight'):
+ self.attention_upweight = args.attention_upweight
+ else:
+ self.attention_upweight = None
+ if hasattr(args, 'ln_fp16'):
+ self.ln_fp16 = args.ln_fp16
+ else:
+ self.ln_fp16 = False
+ # MLP
+ self.mlp = ParallelMLP(init_method,
+ output_layer_init_method)
+
+ def forward(
+ self,
+ hidden_states,
+ attention_mask,
+ layer_past=None,
+ get_key_value=False,
+ prompt_length=None,
+ context_length=None,
+ ):
+ # hidden_states: [b, s, h]
+ if self.ln_fp16:
+ layernorm_output = self.input_layernorm(hidden_states)
+ else:
+ layernorm_output = self.input_layernorm(hidden_states.float()).half()
+
+ # Self attention.
+ attention_output, attention_bias = \
+ self.attention(layernorm_output,
+ attention_mask,
+ layer_past=layer_past,
+ get_key_value=get_key_value,
+ prompt_length=prompt_length,
+ context_length=context_length)
+
+ if get_key_value:
+ attention_output, presents = attention_output
+
+ # Residual connection.
+ if self.apply_residual_connection_post_layernorm:
+ residual = layernorm_output
+ else:
+ residual = hidden_states
+
+ # jit scripting for a nn.module (with dropout) is not
+ # trigerring the fusion kernel. For now, we use two
+ # different nn.functional routines to account for varying
+ # dropout semantics during training and inference phases.
+ if self.bias_dropout_fusion:
+ if self.training:
+ bias_dropout_add_func = bias_dropout_add_fused_train
+ else:
+ bias_dropout_add_func = bias_dropout_add_fused_inference
+ else:
+ bias_dropout_add_func = get_bias_dropout_add(self.training)
+
+ # re-enable torch grad to enable fused optimization.
+ with torch.enable_grad():
+ layernorm_input = bias_dropout_add_func(
+ attention_output,
+ attention_bias.expand_as(residual),
+ residual,
+ self.hidden_dropout)
+
+ # Layer norm post the self attention.
+ if self.ln_fp16:
+ layernorm_output = self.post_attention_layernorm(layernorm_input)
+ else:
+ layernorm_output = self.post_attention_layernorm(layernorm_input.float()).half()
+
+ mlp_output, _ = self.mlp(layernorm_output)
+
+ # MLP.
+ # Second residual connection.
+ if self.apply_residual_connection_post_layernorm:
+ residual = layernorm_output
+ else:
+ residual = layernorm_input
+
+ output = mlp_output + residual
+
+ if get_key_value:
+ output = [output, presents]
+
+ return output
+
+
+class ParallelTopQueryLayer(MegatronModule):
+ """A single top query layer.
+
+ Top query layer takes input with size [b, s, h] and returns an
+ output of the same size.
+ """
+
+ def __init__(self, init_method,
+ output_layer_init_method, layer_number):
+ args = get_args()
+
+ super(ParallelTopQueryLayer, self).__init__()
+ self.layer_number = layer_number
+
+ self.apply_residual_connection_post_layernorm \
+ = args.apply_residual_connection_post_layernorm
+
+ # Layernorm on the input data.
+ self.input_layernorm = LayerNorm(
+ args.hidden_size,
+ eps=args.layernorm_epsilon)
+
+ # Self attention.
+ self.attention = ParallelTopQuerySelfAttention(init_method,
+ output_layer_init_method,
+ layer_number)
+
+ self.hidden_dropout = args.hidden_dropout
+ self.bias_dropout_fusion = args.bias_dropout_fusion
+
+ # Layernorm on the input data.
+ self.post_attention_layernorm = LayerNorm(
+ args.hidden_size,
+ eps=args.layernorm_epsilon)
+
+ if hasattr(args, 'ln_fp16'):
+ self.ln_fp16 = args.ln_fp16
+ else:
+ self.ln_fp16 = False
+
+ # MLP
+ self.mlp = ParallelMLP(init_method,
+ output_layer_init_method)
+
+ def forward(
+ self,
+ hidden_states,
+ query_hidden_state,
+ attention_mask,
+ layer_past=None,
+ get_key_value=False,
+ prompt_length=None,
+ context_length=None,
+ ):
+ # hidden_states: [b, s, h]
+ assert query_hidden_state != None
+
+ # Layer norm at the begining of the transformer layer.
+ if self.ln_fp16:
+ layernorm_output = self.input_layernorm(hidden_states)
+ else:
+ layernorm_output = self.input_layernorm(hidden_states.float()).half()
+
+ # Self attention.
+ attention_output, attention_bias = \
+ self.attention(layernorm_output,
+ query_hidden_state,
+ attention_mask,
+ layer_past=layer_past,
+ get_key_value=get_key_value,
+ prompt_length=prompt_length,
+ context_length=context_length)
+
+ if get_key_value:
+ attention_output, presents = attention_output
+
+ # Residual connection.
+ if self.apply_residual_connection_post_layernorm:
+ residual = layernorm_output
+ else:
+ residual = hidden_states
+
+ # jit scripting for a nn.module (with dropout) is not
+ # trigerring the fusion kernel. For now, we use two
+ # different nn.functional routines to account for varying
+ # dropout semantics during training and inference phases.
+ if self.bias_dropout_fusion:
+ if self.training:
+ bias_dropout_add_func = bias_dropout_add_fused_train
+ else:
+ bias_dropout_add_func = bias_dropout_add_fused_inference
+ else:
+ bias_dropout_add_func = get_bias_dropout_add(self.training)
+
+ # re-enable torch grad to enable fused optimization.
+ with torch.enable_grad():
+ layernorm_input = bias_dropout_add_func(
+ attention_output,
+ attention_bias.expand_as(residual),
+ residual,
+ self.hidden_dropout)
+
+ # Layer norm post the self attention.
+ if self.ln_fp16:
+ layernorm_output = self.post_attention_layernorm(layernorm_input)
+ else:
+ layernorm_output = self.post_attention_layernorm(layernorm_input.float()).half()
+
+ # MLP.
+ mlp_output, _ = self.mlp(layernorm_output)
+
+ # Second residual connection.
+ if self.apply_residual_connection_post_layernorm:
+ residual = layernorm_output
+ else:
+ residual = layernorm_input
+
+ output = mlp_output + residual
+
+ if get_key_value:
+ output = [output, presents]
+
+ return output
+
+
+class ParallelTransformer(MegatronModule):
+ """Transformer class."""
+
+ def __init__(self, init_method, output_layer_init_method):
+ super(ParallelTransformer, self).__init__()
+ args = get_args()
+
+ # Store activation checkpoiting flag.
+ self.checkpoint_activations = args.checkpoint_activations
+ self.checkpoint_num_layers = args.checkpoint_num_layers
+
+ # Number of layers:
+ self.num_layers = args.num_layers
+ self.num_unique_layers = None
+
+ #################
+ assert self.num_unique_layers is None
+ #################
+
+ if self.num_unique_layers is None:
+ self.num_unique_layers = self.num_layers
+ assert self.num_layers % self.num_unique_layers == 0, \
+ 'number of layers should be divisible by number of unique layers'
+ self.param_sharing_style = 'grouped'
+
+ # Transformer layers.
+ def build_layer(layer_number):
+ return ParallelTransformerLayer(
+ init_method,
+ output_layer_init_method, layer_number)
+
+ self.layers = torch.nn.ModuleList(
+ [build_layer(i + 1) for i in range(self.num_unique_layers)])
+
+ self.topQueryLayer = ParallelTopQueryLayer(
+ init_method,
+ output_layer_init_method, self.num_unique_layers)
+
+ # Final layer norm before output.
+ if hasattr(args, 'ln_fp16'):
+ self.ln_fp16 = args.ln_fp16
+ else:
+ self.ln_fp16 = False
+
+ self.final_layernorm = LayerNorm(
+ args.hidden_size,
+ eps=args.layernorm_epsilon)
+
+ def _get_layer_index(self, layer_number):
+ if self.param_sharing_style == 'grouped':
+ return layer_number % self.num_unique_layers
+ if self.param_sharing_style == 'spaced':
+ return layer_number // (self.num_layers // self.num_unique_layers)
+ assert False, 'should not be here'
+
+ def _get_layer(self, layer_number):
+ return self.layers[self._get_layer_index(layer_number)]
+
+ def _checkpointed_forward(self, hidden_states, attention_mask):
+ """Forward method with activation checkpointing."""
+
+ def custom(start, end):
+ def custom_forward(*inputs):
+ x_ = inputs[0]
+ for index in range(start, end):
+ layer = self._get_layer(index)
+ x_ = layer(x_, inputs[1])
+ return x_
+
+ return custom_forward
+
+ # Make sure memory is freed.
+ mpu.reset_checkpointed_activations_memory_buffer()
+ l = 0
+ while l < self.num_layers:
+ hidden_states = mpu.checkpoint(
+ custom(l, l + self.checkpoint_num_layers),
+ hidden_states, attention_mask)
+ l += self.checkpoint_num_layers
+
+ return hidden_states
+
+ def forward(
+ self,
+ hidden_states,
+ query_hidden_state,
+ attention_mask,
+ layer_past=None,
+ get_key_value=False,
+ prompt_length=None,
+ context_length=None,
+ ):
+
+ # Checks
+ if layer_past is not None:
+ assert get_key_value, \
+ 'for not None values in layer_past, ' \
+ 'expected get_key_value to be set'
+ if get_key_value:
+ assert not self.checkpoint_activations, \
+ 'get_key_value does not work with ' \
+ 'activation checkpointing'
+
+ # data format change to avoid explicit tranposes : [b s h] --> [s b h]
+ hidden_states = hidden_states.transpose(0, 1).contiguous()
+ query_hidden_state = query_hidden_state.transpose(0, 1).contiguous()
+
+ if self.checkpoint_activations:
+ hidden_states = self._checkpointed_forward(hidden_states,
+ attention_mask)
+ else:
+ if get_key_value:
+ presents = []
+ for index in range(self.num_layers):
+ layer = self._get_layer(index)
+ past = None
+ if layer_past is not None:
+ past = layer_past[index]
+ hidden_states = layer(hidden_states,
+ attention_mask,
+ layer_past=past,
+ get_key_value=get_key_value,
+ prompt_length=prompt_length,
+ context_length=context_length)
+ if get_key_value:
+ hidden_states, present = hidden_states
+ presents.append(present)
+
+ if self.ln_fp16:
+ hidden_states_ = self.final_layernorm(hidden_states)
+ else:
+ hidden_states_ = self.final_layernorm(hidden_states.float()).half()
+
+ #################################
+ # top query layer
+ #################################
+ past = None
+ if layer_past is not None:
+ past = layer_past[self.num_layers]
+ hidden_states = self.topQueryLayer(hidden_states_,
+ query_hidden_state,
+ attention_mask,
+ layer_past=past,
+ get_key_value=get_key_value,
+ prompt_length=prompt_length,
+ context_length=context_length)
+
+ if get_key_value:
+ hidden_states, present = hidden_states
+ presents.append(present)
+
+ # reverting data format change [s b h] --> [b s h]
+ output = hidden_states.transpose(0, 1).contiguous()
+
+ if get_key_value:
+ output = [output, presents]
+
+ return output
diff --git a/codegeex/megatron/model/utils.py b/codegeex/megatron/model/utils.py
new file mode 100644
index 0000000..6543f7b
--- /dev/null
+++ b/codegeex/megatron/model/utils.py
@@ -0,0 +1,83 @@
+# coding=utf-8
+# Copyright (c) 2020, NVIDIA CORPORATION. All rights reserved.
+#
+# Licensed under the Apache License, Version 2.0 (the "License");
+# you may not use this file except in compliance with the License.
+# You may obtain a copy of the License at
+#
+# http://www.apache.org/licenses/LICENSE-2.0
+#
+# Unless required by applicable law or agreed to in writing, software
+# distributed under the License is distributed on an "AS IS" BASIS,
+# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
+# See the License for the specific language governing permissions and
+# limitations under the License.
+
+"""Utilities for models."""
+
+import math
+import torch
+
+
+def init_method_normal(sigma):
+ """Init method based on N(0, sigma)."""
+
+ def init_(tensor):
+ return torch.nn.init.normal_(tensor, mean=0.0, std=sigma)
+
+ return init_
+
+
+def scaled_init_method_normal(sigma, num_layers):
+ """Init method based on N(0, sigma/sqrt(2*num_layers)."""
+ std = sigma / math.sqrt(2.0 * num_layers)
+
+ def init_(tensor):
+ return torch.nn.init.normal_(tensor, mean=0.0, std=std)
+
+ return init_
+
+
+def attention_mask_func(attention_scores, attention_mask):
+ attention_scores.masked_fill_(attention_mask, -10000.0)
+
+ return attention_scores
+
+
+def get_linear_layer(rows, columns, init_method):
+ """Simple linear layer with weight initialization."""
+ layer = torch.nn.Linear(rows, columns)
+ init_method(layer.weight)
+ with torch.no_grad():
+ layer.bias.zero_()
+ return layer
+
+
+def fast_gelu(x):
+ """Mindspore's fast gelu implementation."""
+ return x / (1 + torch.exp(-1.702 * torch.abs(x))) * torch.exp(0.851 * (x - torch.abs(x)))
+
+
+@torch.jit.script
+def gelu_impl(x):
+ """OpenAI's gelu implementation."""
+ return (
+ 0.5 * x * (1.0 + torch.tanh(0.7978845608028654 * x * (1.0 + 0.044715 * x * x)))
+ )
+
+
+def openai_gelu(x):
+ return gelu_impl(x)
+
+
+# This is actually Python equivalent of torch.nn.functional.gelu(), also with type hints for ONNX exporter
+@torch.jit.script
+def erf_gelu(x):
+ return (
+ x
+ * 0.5
+ * (
+ torch.erf(x / 1.41421).to(dtype=x.dtype)
+ + torch.ones_like(x).to(dtype=x.dtype)
+ )
+ )
diff --git a/codegeex/megatron/mpu/__init__.py b/codegeex/megatron/mpu/__init__.py
new file mode 100644
index 0000000..cf73a10
--- /dev/null
+++ b/codegeex/megatron/mpu/__init__.py
@@ -0,0 +1,81 @@
+# coding=utf-8
+# Copyright (c) 2020, NVIDIA CORPORATION. All rights reserved.
+#
+# Licensed under the Apache License, Version 2.0 (the "License");
+# you may not use this file except in compliance with the License.
+# You may obtain a copy of the License at
+#
+# http://www.apache.org/licenses/LICENSE-2.0
+#
+# Unless required by applicable law or agreed to in writing, software
+# distributed under the License is distributed on an "AS IS" BASIS,
+# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
+# See the License for the specific language governing permissions and
+# limitations under the License.
+
+"""Model parallel utility interface."""
+
+from .cross_entropy import vocab_parallel_cross_entropy
+
+from .data import broadcast_data
+
+from .initialize import is_unitialized
+from .initialize import destroy_model_parallel
+from .initialize import get_data_parallel_group
+from .initialize import get_data_parallel_rank
+from .initialize import get_data_parallel_world_size
+from .initialize import get_embedding_group
+from .initialize import get_model_parallel_group
+from .initialize import get_tensor_model_parallel_group
+from .initialize import get_pipeline_model_parallel_group
+from .initialize import get_tensor_model_parallel_rank, set_tensor_model_parallel_rank
+from .initialize import (
+ get_pipeline_model_parallel_rank,
+ set_pipeline_model_parallel_rank,
+)
+from .initialize import is_pipeline_first_stage, is_pipeline_last_stage
+from .initialize import get_tensor_model_parallel_src_rank
+from .initialize import get_pipeline_model_parallel_first_rank
+from .initialize import get_pipeline_model_parallel_last_rank
+from .initialize import get_pipeline_model_parallel_next_rank
+from .initialize import get_pipeline_model_parallel_prev_rank
+from .initialize import (
+ get_tensor_model_parallel_world_size,
+ set_tensor_model_parallel_world_size,
+)
+from .initialize import (
+ get_pipeline_model_parallel_world_size,
+ set_pipeline_model_parallel_world_size,
+)
+from .initialize import (
+ get_virtual_pipeline_model_parallel_rank,
+ set_virtual_pipeline_model_parallel_rank,
+)
+from .initialize import initialize_model_parallel
+from .initialize import model_parallel_is_initialized
+from .initialize import get_model_parallel_world_size, get_model_parallel_rank
+
+from .layers import ColumnParallelLinear
+from .layers import RowParallelLinear
+from .layers import VocabParallelEmbedding
+from .layers import (
+ set_tensor_model_parallel_attributes,
+ set_defaults_if_not_set_tensor_model_parallel_attributes,
+ copy_tensor_model_parallel_attributes,
+)
+
+from .mappings import copy_to_tensor_model_parallel_region
+from .mappings import gather_from_tensor_model_parallel_region
+from .mappings import reduce_from_tensor_model_parallel_region
+from .mappings import scatter_to_tensor_model_parallel_region
+
+from .random import checkpoint
+from .random import get_cuda_rng_tracker
+from .random import init_checkpointed_activations_memory_buffer
+from .random import model_parallel_cuda_manual_seed
+from .random import reset_checkpointed_activations_memory_buffer
+from .random import gather_split_1d_tensor
+from .random import split_tensor_into_1d_equal_chunks
+
+from .utils import divide
+from .utils import split_tensor_along_last_dim
diff --git a/codegeex/megatron/mpu/cross_entropy.py b/codegeex/megatron/mpu/cross_entropy.py
new file mode 100644
index 0000000..cb3c6aa
--- /dev/null
+++ b/codegeex/megatron/mpu/cross_entropy.py
@@ -0,0 +1,115 @@
+# coding=utf-8
+# Copyright (c) 2020, NVIDIA CORPORATION. All rights reserved.
+#
+# Licensed under the Apache License, Version 2.0 (the "License");
+# you may not use this file except in compliance with the License.
+# You may obtain a copy of the License at
+#
+# http://www.apache.org/licenses/LICENSE-2.0
+#
+# Unless required by applicable law or agreed to in writing, software
+# distributed under the License is distributed on an "AS IS" BASIS,
+# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
+# See the License for the specific language governing permissions and
+# limitations under the License.
+
+
+import torch
+
+from .initialize import get_tensor_model_parallel_group
+from .initialize import get_tensor_model_parallel_rank
+from .initialize import get_tensor_model_parallel_world_size
+from .utils import VocabUtility
+
+
+class _VocabParallelCrossEntropy(torch.autograd.Function):
+ @staticmethod
+ def forward(ctx, vocab_parallel_logits, target):
+
+ # Maximum value along vocab dimension across all GPUs.
+ logits_max = torch.max(vocab_parallel_logits, dim=-1)[0]
+ torch.distributed.all_reduce(
+ logits_max,
+ op=torch.distributed.ReduceOp.MAX,
+ group=get_tensor_model_parallel_group(),
+ )
+ # Subtract the maximum value.
+ vocab_parallel_logits.sub_(logits_max.unsqueeze(dim=-1))
+
+ # Get the partition's vocab indecies
+ get_vocab_range = VocabUtility.vocab_range_from_per_partition_vocab_size
+ partition_vocab_size = vocab_parallel_logits.size()[-1]
+ rank = get_tensor_model_parallel_rank()
+ world_size = get_tensor_model_parallel_world_size()
+ vocab_start_index, vocab_end_index = get_vocab_range(
+ partition_vocab_size, rank, world_size
+ )
+
+ # Create a mask of valid vocab ids (1 means it needs to be masked).
+ target_mask = (target < vocab_start_index) | (target >= vocab_end_index)
+ masked_target = target.clone() - vocab_start_index
+ masked_target[target_mask] = 0
+
+ # Get predicted-logits = logits[target].
+ # For Simplicity, we convert logits to a 2-D tensor with size
+ # [*, partition-vocab-size] and target to a 1-D tensor of size [*].
+ logits_2d = vocab_parallel_logits.view(-1, partition_vocab_size)
+ masked_target_1d = masked_target.view(-1)
+ arange_1d = torch.arange(
+ start=0, end=logits_2d.size()[0], device=logits_2d.device
+ )
+ predicted_logits_1d = logits_2d[arange_1d, masked_target_1d]
+ predicted_logits_1d = predicted_logits_1d.clone().contiguous()
+ predicted_logits = predicted_logits_1d.view_as(target)
+ predicted_logits[target_mask] = 0.0
+ # All reduce is needed to get the chunks from other GPUs.
+ torch.distributed.all_reduce(
+ predicted_logits,
+ op=torch.distributed.ReduceOp.SUM,
+ group=get_tensor_model_parallel_group(),
+ )
+
+ # Sum of exponential of logits along vocab dimension across all GPUs.
+ exp_logits = vocab_parallel_logits
+ torch.exp(vocab_parallel_logits, out=exp_logits)
+ sum_exp_logits = exp_logits.sum(dim=-1)
+ torch.distributed.all_reduce(
+ sum_exp_logits,
+ op=torch.distributed.ReduceOp.SUM,
+ group=get_tensor_model_parallel_group(),
+ )
+
+ # Loss = log(sum(exp(logits))) - predicted-logit.
+ loss = torch.log(sum_exp_logits) - predicted_logits
+
+ # Store softmax, target-mask and masked-target for backward pass.
+ exp_logits.div_(sum_exp_logits.unsqueeze(dim=-1))
+ ctx.save_for_backward(exp_logits, target_mask, masked_target_1d)
+
+ return loss
+
+ @staticmethod
+ def backward(ctx, grad_output):
+
+ # Retreive tensors from the forward path.
+ softmax, target_mask, masked_target_1d = ctx.saved_tensors
+
+ # All the inputs have softmax as thier gradient.
+ grad_input = softmax
+ # For simplicity, work with the 2D gradient.
+ partition_vocab_size = softmax.size()[-1]
+ grad_2d = grad_input.view(-1, partition_vocab_size)
+
+ # Add the gradient from matching classes.
+ arange_1d = torch.arange(start=0, end=grad_2d.size()[0], device=grad_2d.device)
+ grad_2d[arange_1d, masked_target_1d] -= 1.0 - target_mask.view(-1).float()
+
+ # Finally elementwise multiplication with the output gradients.
+ grad_input.mul_(grad_output.unsqueeze(dim=-1))
+
+ return grad_input, None
+
+
+def vocab_parallel_cross_entropy(vocab_parallel_logits, target):
+ """Helper function for the cross entropy."""
+ return _VocabParallelCrossEntropy.apply(vocab_parallel_logits, target)
diff --git a/codegeex/megatron/mpu/data.py b/codegeex/megatron/mpu/data.py
new file mode 100644
index 0000000..7357250
--- /dev/null
+++ b/codegeex/megatron/mpu/data.py
@@ -0,0 +1,125 @@
+# coding=utf-8
+# Copyright (c) 2020, NVIDIA CORPORATION. All rights reserved.
+#
+# Licensed under the Apache License, Version 2.0 (the "License");
+# you may not use this file except in compliance with the License.
+# You may obtain a copy of the License at
+#
+# http://www.apache.org/licenses/LICENSE-2.0
+#
+# Unless required by applicable law or agreed to in writing, software
+# distributed under the License is distributed on an "AS IS" BASIS,
+# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
+# See the License for the specific language governing permissions and
+# limitations under the License.
+
+import torch
+
+from .initialize import get_tensor_model_parallel_group
+from .initialize import get_tensor_model_parallel_rank
+from .initialize import get_tensor_model_parallel_src_rank
+
+
+_MAX_DATA_DIM = 5
+
+
+def _check_data_types(keys, data, target_dtype):
+ """Check that all the keys have the same target data type."""
+ for key in keys:
+ assert (
+ data[key].dtype == target_dtype
+ ), "{} has data type {} which " "is different than {}".format(
+ key, data[key].dtype, target_dtype
+ )
+
+
+def _build_key_size_numel_dictionaries(keys, data):
+ """Build the size on rank 0 and broadcast."""
+ max_dim = _MAX_DATA_DIM
+ sizes = [0 for _ in range(max_dim) for _ in keys]
+
+ # Pack the sizes on rank zero.
+ if get_tensor_model_parallel_rank() == 0:
+ offset = 0
+ for key in keys:
+ assert data[key].dim() < max_dim, "you should increase MAX_DATA_DIM"
+ size = data[key].size()
+ for i, s in enumerate(size):
+ sizes[i + offset] = s
+ offset += max_dim
+
+ # Move to GPU and broadcast.
+ sizes_cuda = torch.cuda.LongTensor(sizes)
+ torch.distributed.broadcast(
+ sizes_cuda,
+ get_tensor_model_parallel_src_rank(),
+ group=get_tensor_model_parallel_group(),
+ )
+
+ # Move back to cpu and unpack.
+ sizes_cpu = sizes_cuda.cpu()
+ key_size = {}
+ key_numel = {}
+ total_numel = 0
+ offset = 0
+ for key in keys:
+ i = 0
+ size = []
+ numel = 1
+ while sizes_cpu[offset + i] > 0:
+ this_size = sizes_cpu[offset + i]
+ size.append(this_size)
+ numel *= this_size
+ i += 1
+ key_size[key] = size
+ key_numel[key] = numel
+ total_numel += numel
+ offset += max_dim
+
+ return key_size, key_numel, total_numel
+
+
+def broadcast_data(keys, data, datatype):
+ """Broadcast data from rank zero of each model parallel group to the
+ members of the same model parallel group.
+
+ Arguments:
+ keys: list of keys in the data disctionary to be broadcasted
+ data: data dictionary of string keys and cpu tensor values.
+ datatype: torch data type of all tensors in data associated
+ with keys.
+ """
+ # Build (key, size) and (key, number of elements) dictionaries along
+ # with the total number of elements on all ranks.
+ key_size, key_numel, total_numel = _build_key_size_numel_dictionaries(keys, data)
+
+ # Pack on rank zero.
+ if get_tensor_model_parallel_rank() == 0:
+ # Check that all keys have the same data type.
+ _check_data_types(keys, data, datatype)
+ # Flatten the data associated with the keys
+ flatten_data = torch.cat(
+ [data[key].contiguous().view(-1) for key in keys], dim=0
+ ).cuda()
+ else:
+ flatten_data = torch.empty(
+ total_numel, device=torch.cuda.current_device(), dtype=datatype
+ )
+
+ # Broadcast
+ torch.distributed.broadcast(
+ flatten_data,
+ get_tensor_model_parallel_src_rank(),
+ group=get_tensor_model_parallel_group(),
+ )
+
+ # Unpack
+ output = {}
+ offset = 0
+ for key in keys:
+ size = key_size[key]
+ numel = key_numel[key]
+ output[key] = flatten_data.narrow(0, offset, numel).view(size)
+ offset += numel
+
+ return output
diff --git a/codegeex/megatron/mpu/initialize.py b/codegeex/megatron/mpu/initialize.py
new file mode 100644
index 0000000..0e5523d
--- /dev/null
+++ b/codegeex/megatron/mpu/initialize.py
@@ -0,0 +1,403 @@
+# coding=utf-8
+# Copyright (c) 2020, NVIDIA CORPORATION. All rights reserved.
+#
+# Licensed under the Apache License, Version 2.0 (the "License");
+# you may not use this file except in compliance with the License.
+# You may obtain a copy of the License at
+#
+# http://www.apache.org/licenses/LICENSE-2.0
+#
+# Unless required by applicable law or agreed to in writing, software
+# distributed under the License is distributed on an "AS IS" BASIS,
+# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
+# See the License for the specific language governing permissions and
+# limitations under the License.
+
+
+"""Model and data parallel groups."""
+
+import torch
+
+from .utils import ensure_divisibility
+
+
+# Intra-layer model parallel group that the current rank belongs to.
+_TENSOR_MODEL_PARALLEL_GROUP = None
+# Inter-layer model parallel group that the current rank belongs to.
+_PIPELINE_MODEL_PARALLEL_GROUP = None
+# Model parallel group (both intra- and pipeline) that the current rank belongs to.
+_MODEL_PARALLEL_GROUP = None
+# Embedding group.
+_EMBEDDING_GROUP = None
+# Data parallel group that the current rank belongs to.
+_DATA_PARALLEL_GROUP = None
+
+_VIRTUAL_PIPELINE_MODEL_PARALLEL_RANK = None
+_VIRTUAL_PIPELINE_MODEL_PARALLEL_WORLD_SIZE = None
+
+# These values enable us to change the mpu sizes on the fly.
+_MPU_TENSOR_MODEL_PARALLEL_WORLD_SIZE = None
+_MPU_PIPELINE_MODEL_PARALLEL_WORLD_SIZE = None
+_MPU_TENSOR_MODEL_PARALLEL_RANK = None
+_MPU_PIPELINE_MODEL_PARALLEL_RANK = None
+
+# A list of global ranks for each pipeline group to ease calculation of the source
+# rank when broadcasting from the first or last pipeline stage
+_PIPELINE_GLOBAL_RANKS = None
+
+
+def is_unitialized():
+ """Useful for code segments that may be accessed with or without mpu initialization"""
+ return _DATA_PARALLEL_GROUP is None
+
+
+def initialize_model_parallel(
+ tensor_model_parallel_size_=1,
+ pipeline_model_parallel_size_=1,
+ virtual_pipeline_model_parallel_size_=None,
+):
+ """
+ Initialize model data parallel groups.
+
+ Arguments:
+ tensor_model_parallel_size: number of GPUs used to parallelize model tensor.
+ pipeline_model_parallel_size: number of GPUs used to parallelize model pipeline.
+
+ Let's say we have a total of 16 GPUs denoted by g0 ... g15 and we
+ use 2 GPUs to parallelize the model tensor, and 4 GPUs to parallelize
+ the model pipeline. The present function will
+ create 8 tensor model-parallel groups, 4 pipeline model-parallel groups
+ and 8 data-parallel groups as:
+ 8 data_parallel groups:
+ [g0, g2], [g1, g3], [g4, g6], [g5, g7], [g8, g10], [g9, g11], [g12, g14], [g13, g15]
+ 8 tensor model-parallel groups:
+ [g0, g1], [g2, g3], [g4, g5], [g6, g7], [g8, g9], [g10, g11], [g12, g13], [g14, g15]
+ 4 pipeline model-parallel groups:
+ [g0, g4, g8, g12], [g1, g5, g9, g13], [g2, g6, g10, g14], [g3, g7, g11, g15]
+ Note that for efficiency, the caller should make sure adjacent ranks
+ are on the same DGX box. For example if we are using 2 DGX-1 boxes
+ with a total of 16 GPUs, rank 0 to 7 belong to the first box and
+ ranks 8 to 15 belong to the second box.
+ """
+ if torch.distributed.get_rank() == 0:
+ print(
+ "> initializing tensor model parallel with size {}".format(
+ tensor_model_parallel_size_
+ )
+ )
+ print(
+ "> initializing pipeline model parallel with size {}".format(
+ pipeline_model_parallel_size_
+ )
+ )
+ # Get world size and rank. Ensure some consistencies.
+ assert torch.distributed.is_initialized()
+ world_size = torch.distributed.get_world_size()
+ tensor_model_parallel_size = min(tensor_model_parallel_size_, world_size)
+ pipeline_model_parallel_size = min(pipeline_model_parallel_size_, world_size)
+ ensure_divisibility(
+ world_size, tensor_model_parallel_size * pipeline_model_parallel_size
+ )
+ data_parallel_size = world_size // (
+ tensor_model_parallel_size * pipeline_model_parallel_size
+ )
+
+ num_tensor_model_parallel_groups = world_size // tensor_model_parallel_size
+ num_pipeline_model_parallel_groups = world_size // pipeline_model_parallel_size
+ num_data_parallel_groups = world_size // data_parallel_size
+
+ if virtual_pipeline_model_parallel_size_ is not None:
+ global _VIRTUAL_PIPELINE_MODEL_PARALLEL_RANK
+ global _VIRTUAL_PIPELINE_MODEL_PARALLEL_WORLD_SIZE
+ _VIRTUAL_PIPELINE_MODEL_PARALLEL_RANK = 0
+ _VIRTUAL_PIPELINE_MODEL_PARALLEL_WORLD_SIZE = (
+ virtual_pipeline_model_parallel_size_
+ )
+
+ rank = torch.distributed.get_rank()
+
+ # Build the data-parallel groups.
+ global _DATA_PARALLEL_GROUP
+ assert _DATA_PARALLEL_GROUP is None, "data parallel group is already initialized"
+ all_data_parallel_group_ranks = []
+ for i in range(pipeline_model_parallel_size):
+ start_rank = i * num_pipeline_model_parallel_groups
+ end_rank = (i + 1) * num_pipeline_model_parallel_groups
+ for j in range(tensor_model_parallel_size):
+ ranks = range(start_rank + j, end_rank, tensor_model_parallel_size)
+ all_data_parallel_group_ranks.append(list(ranks))
+ group = torch.distributed.new_group(ranks)
+ if rank in ranks:
+ _DATA_PARALLEL_GROUP = group
+
+ # Build the model-parallel groups.
+ global _MODEL_PARALLEL_GROUP
+ assert _MODEL_PARALLEL_GROUP is None, "model parallel group is already initialized"
+ for i in range(data_parallel_size):
+ ranks = [
+ data_parallel_group_ranks[i]
+ for data_parallel_group_ranks in all_data_parallel_group_ranks
+ ]
+ group = torch.distributed.new_group(ranks)
+ if rank in ranks:
+ _MODEL_PARALLEL_GROUP = group
+
+ # Build the tensor model-parallel groups.
+ global _TENSOR_MODEL_PARALLEL_GROUP
+ assert (
+ _TENSOR_MODEL_PARALLEL_GROUP is None
+ ), "tensor model parallel group is already initialized"
+ for i in range(num_tensor_model_parallel_groups):
+ ranks = range(
+ i * tensor_model_parallel_size, (i + 1) * tensor_model_parallel_size
+ )
+ group = torch.distributed.new_group(ranks)
+ if rank in ranks:
+ _TENSOR_MODEL_PARALLEL_GROUP = group
+
+ # Build the pipeline model-parallel groups and embedding groups
+ # (first and last rank in each pipeline model-parallel group).
+ global _PIPELINE_MODEL_PARALLEL_GROUP
+ global _PIPELINE_GLOBAL_RANKS
+ assert (
+ _PIPELINE_MODEL_PARALLEL_GROUP is None
+ ), "pipeline model parallel group is already initialized"
+ global _EMBEDDING_GROUP
+ assert _EMBEDDING_GROUP is None, "embedding group is already initialized"
+ for i in range(num_pipeline_model_parallel_groups):
+ ranks = range(i, world_size, num_pipeline_model_parallel_groups)
+ group = torch.distributed.new_group(ranks)
+ if rank in ranks:
+ _PIPELINE_MODEL_PARALLEL_GROUP = group
+ _PIPELINE_GLOBAL_RANKS = ranks
+ # Setup embedding group (to exchange gradients between
+ # first and last stages).
+ if len(ranks) > 1:
+ embedding_ranks = [ranks[0], ranks[-1]]
+ else:
+ embedding_ranks = ranks
+ group = torch.distributed.new_group(embedding_ranks)
+ if rank in embedding_ranks:
+ _EMBEDDING_GROUP = group
+
+
+def model_parallel_is_initialized():
+ """Check if model and data parallel groups are initialized."""
+ if (
+ _TENSOR_MODEL_PARALLEL_GROUP is None
+ or _PIPELINE_MODEL_PARALLEL_GROUP is None
+ or _DATA_PARALLEL_GROUP is None
+ ):
+ return False
+ return True
+
+
+def get_model_parallel_group():
+ """Get the model parallel group the caller rank belongs to."""
+ assert _MODEL_PARALLEL_GROUP is not None, "model parallel group is not initialized"
+ return _MODEL_PARALLEL_GROUP
+
+
+def get_tensor_model_parallel_group():
+ """Get the tensor model parallel group the caller rank belongs to."""
+ assert (
+ _TENSOR_MODEL_PARALLEL_GROUP is not None
+ ), "intra_layer_model parallel group is not initialized"
+ return _TENSOR_MODEL_PARALLEL_GROUP
+
+
+def get_pipeline_model_parallel_group():
+ """Get the pipeline model parallel group the caller rank belongs to."""
+ assert (
+ _PIPELINE_MODEL_PARALLEL_GROUP is not None
+ ), "pipeline_model parallel group is not initialized"
+ return _PIPELINE_MODEL_PARALLEL_GROUP
+
+
+def get_data_parallel_group():
+ """Get the data parallel group the caller rank belongs to."""
+ assert _DATA_PARALLEL_GROUP is not None, "data parallel group is not initialized"
+ return _DATA_PARALLEL_GROUP
+
+
+def get_embedding_group():
+ """Get the embedding group the caller rank belongs to."""
+ assert _EMBEDDING_GROUP is not None, "embedding group is not initialized"
+ return _EMBEDDING_GROUP
+
+
+def set_tensor_model_parallel_world_size(world_size):
+ """Set the tensor model parallel size"""
+ global _MPU_TENSOR_MODEL_PARALLEL_WORLD_SIZE
+ _MPU_TENSOR_MODEL_PARALLEL_WORLD_SIZE = world_size
+
+
+def set_pipeline_model_parallel_world_size(world_size):
+ """Set the pipeline model parallel size"""
+ global _MPU_PIPELINE_MODEL_PARALLEL_WORLD_SIZE
+ _MPU_PIPELINE_MODEL_PARALLEL_WORLD_SIZE = world_size
+
+
+def get_tensor_model_parallel_world_size():
+ """Return world size for the tensor model parallel group."""
+ global _MPU_TENSOR_MODEL_PARALLEL_WORLD_SIZE
+ if _MPU_TENSOR_MODEL_PARALLEL_WORLD_SIZE is not None:
+ return _MPU_TENSOR_MODEL_PARALLEL_WORLD_SIZE
+ return torch.distributed.get_world_size(group=get_tensor_model_parallel_group())
+
+
+def get_model_parallel_world_size():
+ assert (
+ get_pipeline_model_parallel_world_size() == 1
+ ), "legacy get_model_parallel_world_size is only supported if PP is disabled"
+ return get_tensor_model_parallel_world_size()
+
+
+def get_pipeline_model_parallel_world_size():
+ """Return world size for the pipeline model parallel group."""
+ global _MPU_PIPELINE_MODEL_PARALLEL_WORLD_SIZE
+ if _MPU_PIPELINE_MODEL_PARALLEL_WORLD_SIZE is not None:
+ return _MPU_PIPELINE_MODEL_PARALLEL_WORLD_SIZE
+ return torch.distributed.get_world_size(group=get_pipeline_model_parallel_group())
+
+
+def set_tensor_model_parallel_rank(rank):
+ """Set tensor model parallel rank."""
+ global _MPU_TENSOR_MODEL_PARALLEL_RANK
+ _MPU_TENSOR_MODEL_PARALLEL_RANK = rank
+
+
+def set_pipeline_model_parallel_rank(rank):
+ """Set pipeline model parallel rank."""
+ global _MPU_PIPELINE_MODEL_PARALLEL_RANK
+ _MPU_PIPELINE_MODEL_PARALLEL_RANK = rank
+
+
+def get_tensor_model_parallel_rank():
+ """Return my rank for the tensor model parallel group."""
+ global _MPU_TENSOR_MODEL_PARALLEL_RANK
+ if _MPU_TENSOR_MODEL_PARALLEL_RANK is not None:
+ return _MPU_TENSOR_MODEL_PARALLEL_RANK
+ return torch.distributed.get_rank(group=get_tensor_model_parallel_group())
+
+
+def get_model_parallel_rank():
+ assert (
+ get_pipeline_model_parallel_world_size() == 1
+ ), "legacy get_model_parallel_rank is only supported if PP is disabled"
+ return get_tensor_model_parallel_rank()
+
+
+def get_pipeline_model_parallel_rank():
+ """Return my rank for the pipeline model parallel group."""
+ global _MPU_PIPELINE_MODEL_PARALLEL_RANK
+ if _MPU_PIPELINE_MODEL_PARALLEL_RANK is not None:
+ return _MPU_PIPELINE_MODEL_PARALLEL_RANK
+ return torch.distributed.get_rank(group=get_pipeline_model_parallel_group())
+
+
+def is_pipeline_first_stage(ignore_virtual=False):
+ """Return True if in the first pipeline model-parallel stage, False otherwise."""
+ if not ignore_virtual:
+ if (
+ get_virtual_pipeline_model_parallel_world_size() is not None
+ and get_virtual_pipeline_model_parallel_rank() != 0
+ ):
+ return False
+ return get_pipeline_model_parallel_rank() == 0
+
+
+def is_pipeline_last_stage(ignore_virtual=False):
+ """Return True if in the last pipeline model-parallel stage, False otherwise."""
+ if not ignore_virtual:
+ virtual_pipeline_model_parallel_world_size = (
+ get_virtual_pipeline_model_parallel_world_size()
+ )
+ if (
+ virtual_pipeline_model_parallel_world_size is not None
+ and get_virtual_pipeline_model_parallel_rank()
+ != (virtual_pipeline_model_parallel_world_size - 1)
+ ):
+ return False
+ return get_pipeline_model_parallel_rank() == (
+ get_pipeline_model_parallel_world_size() - 1
+ )
+
+
+def get_virtual_pipeline_model_parallel_rank():
+ """Return the virtual pipeline-parallel rank."""
+ global _VIRTUAL_PIPELINE_MODEL_PARALLEL_RANK
+ return _VIRTUAL_PIPELINE_MODEL_PARALLEL_RANK
+
+
+def set_virtual_pipeline_model_parallel_rank(rank):
+ """Set the virtual pipeline-parallel rank."""
+ global _VIRTUAL_PIPELINE_MODEL_PARALLEL_RANK
+ _VIRTUAL_PIPELINE_MODEL_PARALLEL_RANK = rank
+
+
+def get_virtual_pipeline_model_parallel_world_size():
+ """Return the virtual pipeline-parallel world size."""
+ global _VIRTUAL_PIPELINE_MODEL_PARALLEL_WORLD_SIZE
+ return _VIRTUAL_PIPELINE_MODEL_PARALLEL_WORLD_SIZE
+
+
+def get_tensor_model_parallel_src_rank():
+ """Calculate the global rank corresponding to the first local rank
+ in the tensor model parallel group."""
+ global_rank = torch.distributed.get_rank()
+ local_world_size = get_tensor_model_parallel_world_size()
+ return (global_rank // local_world_size) * local_world_size
+
+
+def get_pipeline_model_parallel_first_rank():
+ assert (
+ _PIPELINE_GLOBAL_RANKS is not None
+ ), "Pipeline parallel group is not initialized"
+ return _PIPELINE_GLOBAL_RANKS[0]
+
+
+def get_pipeline_model_parallel_last_rank():
+ assert (
+ _PIPELINE_GLOBAL_RANKS is not None
+ ), "Pipeline parallel group is not initialized"
+ last_rank_local = get_pipeline_model_parallel_world_size() - 1
+ return _PIPELINE_GLOBAL_RANKS[last_rank_local]
+
+
+def get_pipeline_model_parallel_next_rank():
+ assert (
+ _PIPELINE_GLOBAL_RANKS is not None
+ ), "Pipeline parallel group is not initialized"
+ rank_in_pipeline = get_pipeline_model_parallel_rank()
+ world_size = get_pipeline_model_parallel_world_size()
+ return _PIPELINE_GLOBAL_RANKS[(rank_in_pipeline + 1) % world_size]
+
+
+def get_pipeline_model_parallel_prev_rank():
+ assert (
+ _PIPELINE_GLOBAL_RANKS is not None
+ ), "Pipeline parallel group is not initialized"
+ rank_in_pipeline = get_pipeline_model_parallel_rank()
+ world_size = get_pipeline_model_parallel_world_size()
+ return _PIPELINE_GLOBAL_RANKS[(rank_in_pipeline - 1) % world_size]
+
+
+def get_data_parallel_world_size():
+ """Return world size for the data parallel group."""
+ return torch.distributed.get_world_size(group=get_data_parallel_group())
+
+
+def get_data_parallel_rank():
+ """Return my rank for the data parallel group."""
+ return torch.distributed.get_rank(group=get_data_parallel_group())
+
+
+def destroy_model_parallel():
+ """Set the groups to none."""
+ global _TENSOR_MODEL_PARALLEL_GROUP
+ _TENSOR_MODEL_PARALLEL_GROUP = None
+ global _PIPELINE_MODEL_PARALLEL_GROUP
+ _PIPELINE_MODEL_PARALLEL_GROUP = None
+ global _DATA_PARALLEL_GROUP
+ _DATA_PARALLEL_GROUP = None
diff --git a/codegeex/megatron/mpu/layers.py b/codegeex/megatron/mpu/layers.py
new file mode 100644
index 0000000..331df98
--- /dev/null
+++ b/codegeex/megatron/mpu/layers.py
@@ -0,0 +1,480 @@
+# coding=utf-8
+# Copyright (c) 2020, NVIDIA CORPORATION. All rights reserved.
+#
+# Licensed under the Apache License, Version 2.0 (the "License");
+# you may not use this file except in compliance with the License.
+# You may obtain a copy of the License at
+#
+# http://www.apache.org/licenses/LICENSE-2.0
+#
+# Unless required by applicable law or agreed to in writing, software
+# distributed under the License is distributed on an "AS IS" BASIS,
+# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
+# See the License for the specific language governing permissions and
+# limitations under the License.
+
+
+# Parts of the code here are adapted from PyTorch
+# repo: https://github.com/pytorch/pytorch
+
+
+import math
+
+import torch
+import torch.nn.functional as F
+import torch.nn.init as init
+from torch.nn.parameter import Parameter
+
+from .initialize import get_tensor_model_parallel_rank
+from .initialize import get_tensor_model_parallel_world_size
+from .mappings import copy_to_tensor_model_parallel_region
+from .mappings import gather_from_tensor_model_parallel_region
+from .mappings import reduce_from_tensor_model_parallel_region
+from .mappings import scatter_to_tensor_model_parallel_region
+from .random import get_cuda_rng_tracker
+from .utils import divide
+from .utils import split_tensor_along_last_dim
+from .utils import VocabUtility
+from codegeex.megatron import get_args
+import deepspeed.runtime.activation_checkpointing.checkpointing as ds_checkpointing
+
+
+_MODEL_PARALLEL_ATTRIBUTE_DEFAULTS = {
+ "tensor_model_parallel": False,
+ "partition_dim": -1,
+ "partition_stride": 1,
+}
+
+
+def param_is_not_tensor_parallel_duplicate(param):
+ return (
+ hasattr(param, "tensor_model_parallel") and param.tensor_model_parallel
+ ) or (get_tensor_model_parallel_rank() == 0)
+
+
+def set_tensor_model_parallel_attributes(tensor, is_parallel, dim, stride):
+ # Make sure the attributes are not set.
+ for attribute in _MODEL_PARALLEL_ATTRIBUTE_DEFAULTS:
+ assert not hasattr(tensor, attribute)
+ # Set the attributes.
+ setattr(tensor, "tensor_model_parallel", is_parallel)
+ setattr(tensor, "partition_dim", dim)
+ setattr(tensor, "partition_stride", stride)
+
+
+def set_defaults_if_not_set_tensor_model_parallel_attributes(tensor):
+ def maybe_set(attribute, value):
+ if not hasattr(tensor, attribute):
+ setattr(tensor, attribute, value)
+
+ for attribute in _MODEL_PARALLEL_ATTRIBUTE_DEFAULTS:
+ maybe_set(attribute, _MODEL_PARALLEL_ATTRIBUTE_DEFAULTS[attribute])
+
+
+def copy_tensor_model_parallel_attributes(destination_tensor, source_tensor):
+ def maybe_copy(attribute):
+ if hasattr(source_tensor, attribute):
+ setattr(destination_tensor, attribute, getattr(source_tensor, attribute))
+
+ for attribute in _MODEL_PARALLEL_ATTRIBUTE_DEFAULTS:
+ maybe_copy(attribute)
+
+
+def _initialize_affine_weight_gpu(weight, init_method, partition_dim, stride=1):
+ """Initialize affine weight for model parallel on GPU."""
+
+ set_tensor_model_parallel_attributes(
+ tensor=weight, is_parallel=True, dim=partition_dim, stride=stride
+ )
+
+ if ds_checkpointing.is_configured():
+ global get_cuda_rng_tracker
+ get_cuda_rng_tracker = ds_checkpointing.get_cuda_rng_tracker
+
+ with get_cuda_rng_tracker().fork():
+ init_method(weight)
+
+
+def _initialize_affine_weight_cpu(
+ weight,
+ output_size,
+ input_size,
+ per_partition_size,
+ partition_dim,
+ init_method,
+ stride=1,
+ return_master_weight=False,
+):
+ """Initialize affine weight for model parallel.
+
+ Build the master weight on all processes and scatter
+ the relevant chunk."""
+
+ set_tensor_model_parallel_attributes(
+ tensor=weight, is_parallel=True, dim=partition_dim, stride=stride
+ )
+
+ # Initialize master weight
+ master_weight = torch.empty(
+ output_size, input_size, dtype=torch.float, requires_grad=False
+ )
+ init_method(master_weight)
+ args = get_args()
+ master_weight = master_weight.to(dtype=args.params_dtype)
+
+ # Split and copy
+ per_partition_per_stride_size = divide(per_partition_size, stride)
+ weight_list = torch.split(
+ master_weight, per_partition_per_stride_size, dim=partition_dim
+ )
+ rank = get_tensor_model_parallel_rank()
+ world_size = get_tensor_model_parallel_world_size()
+ my_weight_list = weight_list[rank::world_size]
+
+ with torch.no_grad():
+ torch.cat(my_weight_list, dim=partition_dim, out=weight)
+ if return_master_weight:
+ return master_weight
+ return None
+
+
+class VocabParallelEmbedding(torch.nn.Module):
+ """Embedding parallelized in the vocabulary dimension.
+
+ This is mainly adapted from torch.nn.Embedding and all the default
+ values are kept.
+ Arguments:
+ num_embeddings: vocabulary size.
+ embedding_dim: size of hidden state.
+ init_method: method to initialize weights.
+ """
+
+ def __init__(self, num_embeddings, embedding_dim, init_method=init.xavier_normal_):
+ super(VocabParallelEmbedding, self).__init__()
+ # Keep the input dimensions.
+ self.num_embeddings = num_embeddings
+ self.embedding_dim = embedding_dim
+ # Set the detauls for compatibility.
+ self.padding_idx = None
+ self.max_norm = None
+ self.norm_type = 2.0
+ self.scale_grad_by_freq = False
+ self.sparse = False
+ self._weight = None
+ self.tensor_model_parallel_size = get_tensor_model_parallel_world_size()
+ # Divide the weight matrix along the vocaburaly dimension.
+ (
+ self.vocab_start_index,
+ self.vocab_end_index,
+ ) = VocabUtility.vocab_range_from_global_vocab_size(
+ self.num_embeddings,
+ get_tensor_model_parallel_rank(),
+ self.tensor_model_parallel_size,
+ )
+ self.num_embeddings_per_partition = (
+ self.vocab_end_index - self.vocab_start_index
+ )
+
+ # Allocate weights and initialize.
+ args = get_args()
+ if args.use_cpu_initialization:
+ self.weight = Parameter(
+ torch.empty(
+ self.num_embeddings_per_partition,
+ self.embedding_dim,
+ dtype=args.params_dtype,
+ # dtype=torch.float32,
+ )
+ )
+ _initialize_affine_weight_cpu(
+ self.weight,
+ self.num_embeddings,
+ self.embedding_dim,
+ self.num_embeddings_per_partition,
+ 0,
+ init_method,
+ )
+ else:
+ self.weight = Parameter(
+ torch.empty(
+ self.num_embeddings_per_partition,
+ self.embedding_dim,
+ device=torch.cuda.current_device(),
+ dtype=args.params_dtype,
+ # dtype=torch.float32,
+ )
+ )
+ _initialize_affine_weight_gpu(
+ self.weight, init_method, partition_dim=0, stride=1
+ )
+
+ def forward(self, input_):
+ if self.tensor_model_parallel_size > 1:
+ # Build the mask.
+ input_mask = (input_ < self.vocab_start_index) | (
+ input_ >= self.vocab_end_index
+ )
+ # Mask the input.
+ masked_input = input_.clone() - self.vocab_start_index
+ masked_input[input_mask] = 0
+ else:
+ masked_input = input_
+ # Get the embeddings.
+ output_parallel = F.embedding(
+ masked_input,
+ self.weight,
+ self.padding_idx,
+ self.max_norm,
+ self.norm_type,
+ self.scale_grad_by_freq,
+ self.sparse,
+ )
+ # Mask the output embedding.
+ if self.tensor_model_parallel_size > 1:
+ output_parallel[input_mask, :] = 0.0
+ # Reduce across all the model parallel GPUs.
+ output = reduce_from_tensor_model_parallel_region(output_parallel)
+ return output
+
+
+class ColumnParallelLinear(torch.nn.Module):
+ """Linear layer with column parallelism.
+
+ The linear layer is defined as Y = XA + b. A is parallelized along
+ its second dimension as A = [A_1, ..., A_p].
+
+ Arguments:
+ input_size: first dimension of matrix A.
+ output_size: second dimension of matrix A.
+ bias: If true, add bias
+ gather_output: If true, call all-gether on output and make Y avaiable
+ to all GPUs, otherwise, every GPU will have its output
+ which is Y_i = XA_i
+ init_method: method to initialize weights. Note that bias is always set
+ to zero.
+ stride: For the strided linear layers.
+ keep_master_weight_for_test: This was added for testing and should be
+ set to False. It returns the master weights
+ used for initialization.
+ skip_bias_add: This was added to enable performance optimations where bias
+ can be fused with other elementwise operations. we skip
+ adding bias but instead return it.
+ """
+
+ def __init__(
+ self,
+ input_size,
+ output_size,
+ bias=True,
+ gather_output=True,
+ init_method=init.xavier_normal_,
+ stride=1,
+ keep_master_weight_for_test=False,
+ skip_bias_add=False,
+ ):
+ super(ColumnParallelLinear, self).__init__()
+
+ # Keep input parameters
+ self.input_size = input_size
+ self.output_size = output_size
+ self.gather_output = gather_output
+ # Divide the weight matrix along the last dimension.
+ world_size = get_tensor_model_parallel_world_size()
+ self.output_size_per_partition = divide(output_size, world_size)
+ self.skip_bias_add = skip_bias_add
+
+ # Parameters.
+ # Note: torch.nn.functional.linear performs XA^T + b and as a result
+ # we allocate the transpose.
+ # Initialize weight.
+ args = get_args()
+ if args.use_cpu_initialization:
+ self.weight = Parameter(
+ torch.empty(
+ self.output_size_per_partition,
+ self.input_size,
+ dtype=args.params_dtype,
+ )
+ )
+ self.master_weight = _initialize_affine_weight_cpu(
+ self.weight,
+ self.output_size,
+ self.input_size,
+ self.output_size_per_partition,
+ 0,
+ init_method,
+ stride=stride,
+ return_master_weight=keep_master_weight_for_test,
+ )
+ else:
+ self.weight = Parameter(
+ torch.empty(
+ self.output_size_per_partition,
+ self.input_size,
+ device=torch.cuda.current_device(),
+ dtype=args.params_dtype,
+ )
+ )
+ _initialize_affine_weight_gpu(
+ self.weight, init_method, partition_dim=0, stride=stride
+ )
+
+ if bias:
+ if args.use_cpu_initialization:
+ self.bias = Parameter(
+ torch.empty(self.output_size_per_partition, dtype=args.params_dtype)
+ )
+ else:
+ self.bias = Parameter(
+ torch.empty(
+ self.output_size_per_partition,
+ device=torch.cuda.current_device(),
+ dtype=args.params_dtype,
+ )
+ )
+ set_tensor_model_parallel_attributes(self.bias, True, 0, stride)
+ # Always initialize bias to zero.
+ with torch.no_grad():
+ self.bias.zero_()
+ else:
+ self.register_parameter("bias", None)
+
+ def forward(self, input_):
+ # Set up backprop all-reduce.
+ input_parallel = copy_to_tensor_model_parallel_region(input_)
+ # Matrix multiply.
+
+ bias = self.bias if not self.skip_bias_add else None
+ output_parallel = F.linear(input_parallel, self.weight, bias)
+ if self.gather_output:
+ # All-gather across the partitions.
+ output = gather_from_tensor_model_parallel_region(output_parallel)
+ else:
+ output = output_parallel
+ output_bias = self.bias if self.skip_bias_add else None
+ return output, output_bias
+
+
+class RowParallelLinear(torch.nn.Module):
+ """Linear layer with row parallelism.
+
+ The linear layer is defined as Y = XA + b. A is parallelized along
+ its first dimension and X along its second dimension as:
+ - -
+ | A_1 |
+ | . |
+ A = | . | X = [X_1, ..., X_p]
+ | . |
+ | A_p |
+ - -
+ Arguments:
+ input_size: first dimension of matrix A.
+ output_size: second dimension of matrix A.
+ bias: If true, add bias. Note that bias is not parallelized.
+ input_is_parallel: If true, we assume that the input is already
+ split across the GPUs and we do not split
+ again.
+ init_method: method to initialize weights. Note that bias is always set
+ to zero.
+ stride: For the strided linear layers.
+ keep_master_weight_for_test: This was added for testing and should be
+ set to False. It returns the master weights
+ used for initialization.
+ skip_bias_add: This was added to enable performance optimations where bias
+ can be fused with other elementwise operations. we skip
+ adding bias but instead return it.
+ """
+
+ def __init__(
+ self,
+ input_size,
+ output_size,
+ bias=True,
+ input_is_parallel=False,
+ init_method=init.xavier_normal_,
+ stride=1,
+ keep_master_weight_for_test=False,
+ skip_bias_add=False,
+ ):
+ super(RowParallelLinear, self).__init__()
+
+ # Keep input parameters
+ self.input_size = input_size
+ self.output_size = output_size
+ self.input_is_parallel = input_is_parallel
+ # Divide the weight matrix along the last dimension.
+ world_size = get_tensor_model_parallel_world_size()
+ self.input_size_per_partition = divide(input_size, world_size)
+ self.skip_bias_add = skip_bias_add
+
+ # Parameters.
+ # Note: torch.nn.functional.linear performs XA^T + b and as a result
+ # we allocate the transpose.
+ # Initialize weight.
+ args = get_args()
+ if args.use_cpu_initialization:
+ self.weight = Parameter(
+ torch.empty(
+ self.output_size,
+ self.input_size_per_partition,
+ dtype=args.params_dtype,
+ )
+ )
+ self.master_weight = _initialize_affine_weight_cpu(
+ self.weight,
+ self.output_size,
+ self.input_size,
+ self.input_size_per_partition,
+ 1,
+ init_method,
+ stride=stride,
+ return_master_weight=keep_master_weight_for_test,
+ )
+ else:
+ self.weight = Parameter(
+ torch.empty(
+ self.output_size,
+ self.input_size_per_partition,
+ device=torch.cuda.current_device(),
+ dtype=args.params_dtype,
+ )
+ )
+ _initialize_affine_weight_gpu(
+ self.weight, init_method, partition_dim=1, stride=stride
+ )
+ if bias:
+ if args.use_cpu_initialization:
+ self.bias = Parameter(
+ torch.empty(self.output_size, dtype=args.params_dtype)
+ )
+ else:
+ self.bias = Parameter(
+ torch.empty(
+ self.output_size,
+ device=torch.cuda.current_device(),
+ dtype=args.params_dtype,
+ )
+ )
+ # Always initialize bias to zero.
+ with torch.no_grad():
+ self.bias.zero_()
+ else:
+ self.register_parameter("bias", None)
+
+ def forward(self, input_):
+ # Set up backprop all-reduce.
+ if self.input_is_parallel:
+ input_parallel = input_
+ else:
+ input_parallel = scatter_to_tensor_model_parallel_region(input_)
+ # Matrix multiply.
+ output_parallel = F.linear(input_parallel, self.weight)
+ # All-reduce across all the partitions.
+ output_ = reduce_from_tensor_model_parallel_region(output_parallel)
+ if not self.skip_bias_add:
+ output = output_ + self.bias if self.bias is not None else output_
+ output_bias = None
+ else:
+ output = output_
+ output_bias = self.bias
+ return output, output_bias
diff --git a/codegeex/megatron/mpu/mappings.py b/codegeex/megatron/mpu/mappings.py
new file mode 100644
index 0000000..259fe9b
--- /dev/null
+++ b/codegeex/megatron/mpu/mappings.py
@@ -0,0 +1,164 @@
+# coding=utf-8
+# Copyright (c) 2020, NVIDIA CORPORATION. All rights reserved.
+#
+# Licensed under the Apache License, Version 2.0 (the "License");
+# you may not use this file except in compliance with the License.
+# You may obtain a copy of the License at
+#
+# http://www.apache.org/licenses/LICENSE-2.0
+#
+# Unless required by applicable law or agreed to in writing, software
+# distributed under the License is distributed on an "AS IS" BASIS,
+# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
+# See the License for the specific language governing permissions and
+# limitations under the License.
+
+import torch
+
+from .initialize import (
+ get_tensor_model_parallel_group,
+ get_tensor_model_parallel_world_size,
+ get_tensor_model_parallel_rank,
+)
+from .utils import split_tensor_along_last_dim
+
+
+def _reduce(input_):
+ """All-reduce the the input tensor across model parallel group."""
+
+ # Bypass the function if we are using only 1 GPU.
+ if get_tensor_model_parallel_world_size() == 1:
+ return input_
+
+ # All-reduce.
+ torch.distributed.all_reduce(input_, group=get_tensor_model_parallel_group())
+
+ return input_
+
+
+def _split(input_):
+ """Split the tensor along its last dimension and keep the
+ corresponding slice."""
+
+ world_size = get_tensor_model_parallel_world_size()
+ # Bypass the function if we are using only 1 GPU.
+ if world_size == 1:
+ return input_
+
+ # Split along last dimension.
+ input_list = split_tensor_along_last_dim(input_, world_size)
+
+ # Note: torch.split does not create contiguous tensors by default.
+ rank = get_tensor_model_parallel_rank()
+ output = input_list[rank].contiguous()
+
+ return output
+
+
+def _gather(input_):
+ """Gather tensors and concatinate along the last dimension."""
+
+ world_size = get_tensor_model_parallel_world_size()
+ # Bypass the function if we are using only 1 GPU.
+ if world_size == 1:
+ return input_
+
+ # Size and dimension.
+ last_dim = input_.dim() - 1
+ rank = get_tensor_model_parallel_rank()
+
+ tensor_list = [torch.empty_like(input_) for _ in range(world_size)]
+ tensor_list[rank] = input_
+ torch.distributed.all_gather(
+ tensor_list, input_, group=get_tensor_model_parallel_group()
+ )
+
+ # Note: torch.cat already creates a contiguous tensor.
+ output = torch.cat(tensor_list, dim=last_dim).contiguous()
+
+ return output
+
+
+class _CopyToModelParallelRegion(torch.autograd.Function):
+ """Pass the input to the model parallel region."""
+
+ @staticmethod
+ def symbolic(graph, input_):
+ return input_
+
+ @staticmethod
+ def forward(ctx, input_):
+ return input_
+
+ @staticmethod
+ def backward(ctx, grad_output):
+ return _reduce(grad_output)
+
+
+class _ReduceFromModelParallelRegion(torch.autograd.Function):
+ """All-reduce the input from the model parallel region."""
+
+ @staticmethod
+ def symbolic(graph, input_):
+ return _reduce(input_)
+
+ @staticmethod
+ def forward(ctx, input_):
+ return _reduce(input_)
+
+ @staticmethod
+ def backward(ctx, grad_output):
+ return grad_output
+
+
+class _ScatterToModelParallelRegion(torch.autograd.Function):
+ """Split the input and keep only the corresponding chuck to the rank."""
+
+ @staticmethod
+ def symbolic(graph, input_):
+ return _split(input_)
+
+ @staticmethod
+ def forward(ctx, input_):
+ return _split(input_)
+
+ @staticmethod
+ def backward(ctx, grad_output):
+ return _gather(grad_output)
+
+
+class _GatherFromModelParallelRegion(torch.autograd.Function):
+ """Gather the input from model parallel region and concatinate."""
+
+ @staticmethod
+ def symbolic(graph, input_):
+ return _gather(input_)
+
+ @staticmethod
+ def forward(ctx, input_):
+ return _gather(input_)
+
+ @staticmethod
+ def backward(ctx, grad_output):
+ return _split(grad_output)
+
+
+# -----------------
+# Helper functions.
+# -----------------
+
+
+def copy_to_tensor_model_parallel_region(input_):
+ return _CopyToModelParallelRegion.apply(input_)
+
+
+def reduce_from_tensor_model_parallel_region(input_):
+ return _ReduceFromModelParallelRegion.apply(input_)
+
+
+def scatter_to_tensor_model_parallel_region(input_):
+ return _ScatterToModelParallelRegion.apply(input_)
+
+
+def gather_from_tensor_model_parallel_region(input_):
+ return _GatherFromModelParallelRegion.apply(input_)
diff --git a/codegeex/megatron/mpu/random.py b/codegeex/megatron/mpu/random.py
new file mode 100644
index 0000000..3837f89
--- /dev/null
+++ b/codegeex/megatron/mpu/random.py
@@ -0,0 +1,342 @@
+# coding=utf-8
+# Copyright (c) 2020, NVIDIA CORPORATION. All rights reserved.
+#
+# Licensed under the Apache License, Version 2.0 (the "License");
+# you may not use this file except in compliance with the License.
+# You may obtain a copy of the License at
+#
+# http://www.apache.org/licenses/LICENSE-2.0
+#
+# Unless required by applicable law or agreed to in writing, software
+# distributed under the License is distributed on an "AS IS" BASIS,
+# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
+# See the License for the specific language governing permissions and
+# limitations under the License.
+
+
+# Parts of the code here are adapted from PyTorch
+# repo: https://github.com/pytorch/pytorch
+
+import contextlib
+
+import torch
+from torch import _C
+from torch.cuda import _lazy_call, device as device_ctx_manager
+from torch.utils.checkpoint import detach_variable
+
+from codegeex.megatron import get_args
+from codegeex.megatron.memory import allocate_mem_buff
+
+from .initialize import get_data_parallel_rank
+from .initialize import get_tensor_model_parallel_group
+from .initialize import get_tensor_model_parallel_rank
+from .initialize import get_tensor_model_parallel_world_size
+
+
+# Default name for the model parallel rng tracker.
+_MODEL_PARALLEL_RNG_TRACKER_NAME = "model-parallel-rng"
+
+
+# Whether apply model parallelsim to checkpointed hidden states.
+_CHECKPOINTED_ACTIVATIONS_MEMORY_BUFFER = None
+
+
+def init_checkpointed_activations_memory_buffer():
+ """Initializ the memory buffer for the checkpointed activations."""
+ args = get_args()
+
+ per_layer = (
+ args.micro_batch_size
+ * args.max_position_embeddings
+ * args.hidden_size
+ // args.tensor_model_parallel_size
+ )
+ assert (
+ args.num_layers % args.checkpoint_num_layers == 0
+ ), "number of layers is not divisible by checkpoint-num-layers"
+ num_checkpointer_layers = args.num_layers // args.checkpoint_num_layers
+ numel = per_layer * num_checkpointer_layers
+ dtype = torch.half
+ if not args.fp16:
+ dtype = torch.float
+
+ global _CHECKPOINTED_ACTIVATIONS_MEMORY_BUFFER
+ assert (
+ _CHECKPOINTED_ACTIVATIONS_MEMORY_BUFFER is None
+ ), "checkpointed activations memory buffer is already allocated."
+ _CHECKPOINTED_ACTIVATIONS_MEMORY_BUFFER = allocate_mem_buff(
+ "checkpointed activations", numel, dtype, track_usage=False
+ )
+
+
+def reset_checkpointed_activations_memory_buffer():
+ """Reset the memory used for checkpointing."""
+ if _CHECKPOINTED_ACTIVATIONS_MEMORY_BUFFER is not None:
+ _CHECKPOINTED_ACTIVATIONS_MEMORY_BUFFER.reset()
+
+
+def _set_cuda_rng_state(new_state, device=-1):
+ """Sets the random number generator state of the current GPU.
+
+ Argumentss:
+ new_state (torch.ByteTensor): The desired state
+ This function is adapted from PyTorch repo (torch.cuda.set_rng_state)
+ with a single change: the input state is not cloned. Cloning caused
+ major performance issues for +4 GPU cases.
+ """
+ if hasattr(_C, "_cuda_setRNGState") and callable(_C._cuda_setRNGState):
+ # older PyTorch
+ def cb():
+ with device_ctx_manager(device):
+ _C._cuda_setRNGState(new_state)
+
+ else:
+ # newer PyTorch
+ if device == -1:
+ device = torch.device("cuda")
+ elif isinstance(device, str):
+ device = torch.device(device)
+ elif isinstance(device, int):
+ device = torch.device("cuda", device)
+
+ def cb():
+ idx = device.index
+ if idx is None:
+ idx = torch.cuda.current_device()
+ default_generator = torch.cuda.default_generators[idx]
+ default_generator.set_state(new_state)
+
+ _lazy_call(cb)
+
+
+def split_tensor_into_1d_equal_chunks(tensor):
+ """Break a tensor into equal 1D chunks."""
+ data = tensor.view(-1)
+ partition_size = torch.numel(data) // get_tensor_model_parallel_world_size()
+ start_index = partition_size * get_tensor_model_parallel_rank()
+ end_index = start_index + partition_size
+ return data[start_index:end_index]
+
+
+def gather_split_1d_tensor(tensor):
+ """Opposite of above function, gather values from model parallel ranks."""
+ world_size = get_tensor_model_parallel_world_size()
+ numel = torch.numel(tensor)
+ numel_gathered = world_size * numel
+ gathered = torch.empty(
+ numel_gathered,
+ dtype=tensor.dtype,
+ device=torch.cuda.current_device(),
+ requires_grad=False,
+ )
+ chunks = [gathered[i * numel : (i + 1) * numel] for i in range(world_size)]
+ torch.distributed.all_gather(
+ chunks, tensor, group=get_tensor_model_parallel_group()
+ )
+ return gathered
+
+
+class CudaRNGStatesTracker:
+ """Tracker for the cuda RNG states.
+
+ Using the `add` method, a cuda rng state is initialized based on
+ the input `seed` and is assigned to `name`. Later, by forking the
+ rng state, we can perform operations and return to our starting
+ cuda state.
+ """
+
+ def __init__(self):
+ # Map from a string name to the cuda rng state.
+ self.states_ = {}
+ # Seeds are just for book keeping and ensure no seed is set twice.
+ self.seeds_ = set()
+
+ def reset(self):
+ """Set to the initial state (no tracker)."""
+ self.states_ = {}
+ self.seeds_ = set()
+
+ def get_states(self):
+ """Get rng states. Copy the dictionary so we have direct
+ pointers to the states, not just a pointer to the dictionary."""
+ states = {}
+ for name in self.states_:
+ states[name] = self.states_[name]
+ return states
+
+ def set_states(self, states):
+ """Set the rng states. For efficiency purposes, we do not check
+ the size of seed for compatibility."""
+ self.states_ = states
+
+ def add(self, name, seed):
+ """Track the rng state."""
+ # Check seed is not already used.
+ if seed in self.seeds_:
+ raise Exception("seed {} already exists".format(seed))
+ self.seeds_.add(seed)
+ # Check that state is not already defined.
+ if name in self.states_:
+ raise Exception("cuda rng state {} already exists".format(name))
+ # Get the current rng state.
+ orig_rng_state = torch.cuda.get_rng_state()
+ # Set the new state and store it.
+ torch.cuda.manual_seed(seed)
+ self.states_[name] = torch.cuda.get_rng_state()
+ # Reset rng state to what it was.
+ _set_cuda_rng_state(orig_rng_state)
+
+ @contextlib.contextmanager
+ def fork(self, name=_MODEL_PARALLEL_RNG_TRACKER_NAME):
+ """Fork the cuda rng state, perform operations, and exit with
+ the original state."""
+ # Check if we have added the state
+ if name not in self.states_:
+ print(name, self.states_)
+ raise Exception("cuda rng state {} is not added".format(name))
+ # Store current rng state.
+ orig_cuda_rng_state = torch.cuda.get_rng_state()
+ # Set rng state to the desired one
+ _set_cuda_rng_state(self.states_[name])
+ # Do the stuff we wanted to do.
+ try:
+ yield
+ finally:
+ # Update the current rng state for later use.
+ self.states_[name] = torch.cuda.get_rng_state()
+ # And set the state to the original state we started with.
+ _set_cuda_rng_state(orig_cuda_rng_state)
+
+
+# RNG tracker object.
+_CUDA_RNG_STATE_TRACKER = CudaRNGStatesTracker()
+
+
+def get_cuda_rng_tracker():
+ """Get cuda rng tracker."""
+ return _CUDA_RNG_STATE_TRACKER
+
+
+def model_parallel_cuda_manual_seed(seed):
+ """Initialize model parallel cuda seed.
+
+ This function should be called after the model parallel is
+ initialized. Also, no torch.cuda.manual_seed should be called
+ after this function. Basically, this is replacement for that
+ function.
+ Two set of RNG states are tracked:
+ default state: This is for data parallelism and is the same among a
+ set of model parallel GPUs but different across
+ different model paralle groups. This is used for
+ example for dropout in the non-tensor-model-parallel regions.
+ tensor-model-parallel state: This state is different among a set of model
+ parallel GPUs, but the same across data parallel
+ groups. This is used for example for dropout in
+ model parallel regions.
+ """
+ # 2718 is just for fun and any POSITIVE value will work.
+ offset = seed + 2718
+ tensor_model_parallel_seed = offset + get_tensor_model_parallel_rank()
+ # Data parallel gets the original seed.
+ data_parallel_seed = seed
+
+ if torch.distributed.get_rank() == 0:
+ print(
+ "> initializing model parallel cuda seeds on global rank {}, "
+ "model parallel rank {}, and data parallel rank {} with "
+ "model parallel seed: {} and data parallel seed: {}".format(
+ torch.distributed.get_rank(),
+ get_tensor_model_parallel_rank(),
+ get_data_parallel_rank(),
+ tensor_model_parallel_seed,
+ data_parallel_seed,
+ ),
+ flush=True,
+ )
+ _CUDA_RNG_STATE_TRACKER.reset()
+ # Set the default state.
+ torch.cuda.manual_seed(data_parallel_seed)
+ # and model parallel state.
+ _CUDA_RNG_STATE_TRACKER.add(
+ _MODEL_PARALLEL_RNG_TRACKER_NAME, tensor_model_parallel_seed
+ )
+
+
+class CheckpointFunction(torch.autograd.Function):
+ """This function is adapted from torch.utils.checkpoint with
+ two main changes:
+ 1) torch.cuda.set_rng_state is replaced with `_set_cuda_rng_state`
+ 2) the states in the model parallel tracker are also properly
+ tracked/set/reset.
+ """
+
+ @staticmethod
+ def forward(ctx, run_function, *args):
+ ctx.run_function = run_function
+
+ # Copy the rng states.
+ ctx.fwd_cpu_rng_state = torch.get_rng_state()
+ ctx.fwd_cuda_rng_state = torch.cuda.get_rng_state()
+ ctx.fwd_cuda_rng_state_tracker = get_cuda_rng_tracker().get_states()
+
+ with torch.no_grad():
+ outputs = run_function(*args)
+
+ # Divide hidden states across model parallel group and only keep
+ # the chunk corresponding to the current rank.
+ if _CHECKPOINTED_ACTIVATIONS_MEMORY_BUFFER is not None:
+ ctx.input_0_shape = args[0].data.shape
+ args[0].data = split_tensor_into_1d_equal_chunks(args[0].data)
+ args[0].data = _CHECKPOINTED_ACTIVATIONS_MEMORY_BUFFER.add(args[0].data)
+
+ # Store everything.
+ ctx.save_for_backward(*args)
+
+ return outputs
+
+ @staticmethod
+ def backward(ctx, *args):
+ if not torch.autograd._is_checkpoint_valid():
+ raise RuntimeError(
+ "Checkpointing is not compatible with .grad(), "
+ "please use .backward() if possible"
+ )
+ inputs = ctx.saved_tensors
+ if _CHECKPOINTED_ACTIVATIONS_MEMORY_BUFFER is not None:
+ inputs[0].data = gather_split_1d_tensor(inputs[0].data)
+ inputs[0].data = inputs[0].data.view(ctx.input_0_shape)
+
+ # Store the current states.
+ bwd_cpu_rng_state = torch.get_rng_state()
+ bwd_cuda_rng_state = torch.cuda.get_rng_state()
+ bwd_cuda_rng_state_tracker = get_cuda_rng_tracker().get_states()
+
+ # Set the states to what it used to be before the forward pass.
+ torch.set_rng_state(ctx.fwd_cpu_rng_state)
+ _set_cuda_rng_state(ctx.fwd_cuda_rng_state)
+ get_cuda_rng_tracker().set_states(ctx.fwd_cuda_rng_state_tracker)
+
+ # Compute the forward pass.
+ detached_inputs = detach_variable(inputs)
+ with torch.enable_grad():
+ outputs = ctx.run_function(*detached_inputs)
+
+ # Set the states back to what it was at the start of this function.
+ torch.set_rng_state(bwd_cpu_rng_state)
+ _set_cuda_rng_state(bwd_cuda_rng_state)
+ get_cuda_rng_tracker().set_states(bwd_cuda_rng_state_tracker)
+
+ if isinstance(outputs, torch.Tensor):
+ outputs = (outputs,)
+ torch.autograd.backward(outputs, args)
+ grads = tuple(
+ inp.grad if isinstance(inp, torch.Tensor) else inp
+ for inp in detached_inputs
+ )
+ return (None,) + grads
+
+
+def checkpoint(function, *args):
+ """Checkpoint a model or part of the model.
+ This has been directly copied from torch.utils.checkpoint."""
+ return CheckpointFunction.apply(function, *args)
diff --git a/codegeex/megatron/mpu/utils.py b/codegeex/megatron/mpu/utils.py
new file mode 100644
index 0000000..e8f52dd
--- /dev/null
+++ b/codegeex/megatron/mpu/utils.py
@@ -0,0 +1,72 @@
+# coding=utf-8
+# Copyright (c) 2020, NVIDIA CORPORATION. All rights reserved.
+#
+# Licensed under the Apache License, Version 2.0 (the "License");
+# you may not use this file except in compliance with the License.
+# You may obtain a copy of the License at
+#
+# http://www.apache.org/licenses/LICENSE-2.0
+#
+# Unless required by applicable law or agreed to in writing, software
+# distributed under the License is distributed on an "AS IS" BASIS,
+# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
+# See the License for the specific language governing permissions and
+# limitations under the License.
+
+
+import torch
+
+
+def ensure_divisibility(numerator, denominator):
+ """Ensure that numerator is divisible by the denominator."""
+ assert numerator % denominator == 0, "{} is not divisible by {}".format(
+ numerator, denominator
+ )
+
+
+def divide(numerator, denominator):
+ """Ensure that numerator is divisible by the denominator and return
+ the division value."""
+ ensure_divisibility(numerator, denominator)
+ return numerator // denominator
+
+
+def split_tensor_along_last_dim(tensor, num_partitions, contiguous_split_chunks=False):
+ """Split a tensor along its last dimension.
+ Arguments:
+ tensor: input tensor.
+ num_partitions: number of partitions to split the tensor
+ contiguous_split_chunks: If True, make each chunk contiguous
+ in memory.
+ """
+ # Get the size and dimension.
+ last_dim = tensor.dim() - 1
+ last_dim_size = divide(tensor.size()[last_dim], num_partitions)
+ # Split.
+ tensor_list = torch.split(tensor, last_dim_size, dim=last_dim)
+ # Note: torch.split does not create contiguous tensors by default.
+ if contiguous_split_chunks:
+ return tuple(chunk.contiguous() for chunk in tensor_list)
+
+ return tensor_list
+
+
+class VocabUtility:
+ """Split the vocabulary into `world_size` chunks amd return the
+ first and last index of the vocabulary belonging to the `rank`
+ partition: Note that indecies in [fist, last)"""
+
+ @staticmethod
+ def vocab_range_from_per_partition_vocab_size(
+ per_partition_vocab_size, rank, world_size
+ ):
+ index_f = rank * per_partition_vocab_size
+ index_l = index_f + per_partition_vocab_size
+ return index_f, index_l
+
+ @staticmethod
+ def vocab_range_from_global_vocab_size(global_vocab_size, rank, world_size):
+ per_partition_vocab_size = divide(global_vocab_size, world_size)
+ return VocabUtility.vocab_range_from_per_partition_vocab_size(
+ per_partition_vocab_size, rank, world_size
+ )
diff --git a/codegeex/megatron/p2p_communication.py b/codegeex/megatron/p2p_communication.py
new file mode 100644
index 0000000..6d22add
--- /dev/null
+++ b/codegeex/megatron/p2p_communication.py
@@ -0,0 +1,296 @@
+# coding=utf-8
+# Copyright (c) 2020, NVIDIA CORPORATION. All rights reserved.
+#
+# Licensed under the Apache License, Version 2.0 (the "License");
+# you may not use this file except in compliance with the License.
+# You may obtain a copy of the License at
+#
+# http://www.apache.org/licenses/LICENSE-2.0
+#
+# Unless required by applicable law or agreed to in writing, software
+# distributed under the License is distributed on an "AS IS" BASIS,
+# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
+# See the License for the specific language governing permissions and
+# limitations under the License.
+
+from functools import reduce
+import operator
+import torch
+
+from codegeex.megatron import get_args
+from codegeex.megatron import mpu
+
+
+def _communicate(
+ tensor_send_next, tensor_send_prev, recv_prev, recv_next, use_ring_exchange=False
+):
+ """Communicate tensors between stages. Used as helper method in other
+ communication methods that are used in megatron/schedules.py.
+
+ Takes the following arguments:
+ tensor_send_next: tensor to send to next rank (no tensor sent if
+ set to None).
+ tensor_send_prev: tensor to send to prev rank (no tensor sent if
+ set to None).
+ recv_prev: boolean for whether tensor should be received from
+ previous rank.
+ recv_next: boolean for whether tensor should be received from
+ next rank.
+ use_ring_exchange: boolean for whether torch.distributed.ring_exchange()
+ API should be used.
+
+ Returns:
+ (tensor_recv_prev, tensor_recv_next)
+ """
+ args = get_args()
+
+ # Create placeholder tensors for receive in forward and backward directions
+ # if needed.
+ tensor_recv_prev = None
+ tensor_recv_next = None
+ tensor_shape = (args.seq_length, args.micro_batch_size, args.hidden_size)
+ if args.scatter_gather_tensors_in_pipeline:
+ tensor_chunk_shape = (
+ reduce(operator.mul, tensor_shape, 1)
+ // mpu.get_tensor_model_parallel_world_size()
+ )
+ else:
+ tensor_chunk_shape = tensor_shape
+ dtype = args.params_dtype
+ if args.fp32_residual_connection:
+ dtype = torch.float
+ if recv_prev:
+ tensor_recv_prev = torch.empty(
+ tensor_chunk_shape,
+ requires_grad=True,
+ device=torch.cuda.current_device(),
+ dtype=dtype,
+ )
+ if recv_next:
+ tensor_recv_next = torch.empty(
+ tensor_chunk_shape,
+ requires_grad=True,
+ device=torch.cuda.current_device(),
+ dtype=dtype,
+ )
+
+ # Split tensor into smaller chunks if using scatter-gather optimization.
+ if args.scatter_gather_tensors_in_pipeline:
+ if tensor_send_next is not None:
+ tensor_send_next = mpu.split_tensor_into_1d_equal_chunks(tensor_send_next)
+
+ if tensor_send_prev is not None:
+ tensor_send_prev = mpu.split_tensor_into_1d_equal_chunks(tensor_send_prev)
+
+ # Send tensors in both the forward and backward directions as appropriate.
+ if use_ring_exchange:
+ torch.distributed.ring_exchange(
+ tensor_send_prev=tensor_send_prev,
+ tensor_recv_prev=tensor_recv_prev,
+ tensor_send_next=tensor_send_next,
+ tensor_recv_next=tensor_recv_next,
+ group=mpu.get_pipeline_model_parallel_group(),
+ )
+ else:
+ ops = []
+ if tensor_send_prev is not None:
+ send_prev_op = torch.distributed.P2POp(
+ torch.distributed.isend,
+ tensor_send_prev,
+ mpu.get_pipeline_model_parallel_prev_rank(),
+ )
+ ops.append(send_prev_op)
+ if tensor_recv_prev is not None:
+ recv_prev_op = torch.distributed.P2POp(
+ torch.distributed.irecv,
+ tensor_recv_prev,
+ mpu.get_pipeline_model_parallel_prev_rank(),
+ )
+ ops.append(recv_prev_op)
+ if tensor_send_next is not None:
+ send_next_op = torch.distributed.P2POp(
+ torch.distributed.isend,
+ tensor_send_next,
+ mpu.get_pipeline_model_parallel_next_rank(),
+ )
+ ops.append(send_next_op)
+ if tensor_recv_next is not None:
+ recv_next_op = torch.distributed.P2POp(
+ torch.distributed.irecv,
+ tensor_recv_next,
+ mpu.get_pipeline_model_parallel_next_rank(),
+ )
+ ops.append(recv_next_op)
+ if len(ops) > 0:
+ reqs = torch.distributed.batch_isend_irecv(ops)
+ for req in reqs:
+ req.wait()
+ # To protect against race condition when using batch_isend_irecv().
+ torch.cuda.synchronize()
+
+ # If using scatter-gather optimization, gather smaller chunks.
+ if args.scatter_gather_tensors_in_pipeline:
+ if recv_prev:
+ tensor_recv_prev = (
+ mpu.gather_split_1d_tensor(tensor_recv_prev)
+ .view(tensor_shape)
+ .requires_grad_()
+ )
+
+ if recv_next:
+ tensor_recv_next = (
+ mpu.gather_split_1d_tensor(tensor_recv_next)
+ .view(tensor_shape)
+ .requires_grad_()
+ )
+
+ return tensor_recv_prev, tensor_recv_next
+
+
+def recv_forward(timers=None):
+ """Receive tensor from previous rank in pipeline (forward receive)."""
+ if mpu.is_pipeline_first_stage():
+ input_tensor = None
+ else:
+ if timers is not None:
+ timers("forward-recv").start()
+ input_tensor, _ = _communicate(
+ tensor_send_next=None,
+ tensor_send_prev=None,
+ recv_prev=True,
+ recv_next=False,
+ )
+ if timers is not None:
+ timers("forward-recv").stop()
+ return input_tensor
+
+
+def recv_backward(timers=None):
+ """Receive tensor from next rank in pipeline (backward receive)."""
+ if mpu.is_pipeline_last_stage():
+ output_tensor_grad = None
+ else:
+ if timers is not None:
+ timers("backward-recv").start()
+ _, output_tensor_grad = _communicate(
+ tensor_send_next=None,
+ tensor_send_prev=None,
+ recv_prev=False,
+ recv_next=True,
+ )
+ if timers is not None:
+ timers("backward-recv").stop()
+ return output_tensor_grad
+
+
+def send_forward(output_tensor, timers=None):
+ """Send tensor to next rank in pipeline (forward send)."""
+ if not mpu.is_pipeline_last_stage():
+ if timers is not None:
+ timers("forward-send").start()
+ _communicate(
+ tensor_send_next=output_tensor,
+ tensor_send_prev=None,
+ recv_prev=False,
+ recv_next=False,
+ )
+ if timers is not None:
+ timers("forward-send").stop()
+
+
+def send_backward(input_tensor_grad, timers=None):
+ """Send tensor to previous rank in pipeline (backward send)."""
+ if not mpu.is_pipeline_first_stage():
+ if timers is not None:
+ timers("backward-send").start()
+ _communicate(
+ tensor_send_next=None,
+ tensor_send_prev=input_tensor_grad,
+ recv_prev=False,
+ recv_next=False,
+ )
+ if timers is not None:
+ timers("backward-send").stop()
+
+
+def send_forward_recv_backward(output_tensor, timers=None):
+ """Batched send and recv with next rank in pipeline."""
+ if mpu.is_pipeline_last_stage():
+ output_tensor_grad = None
+ else:
+ if timers is not None:
+ timers("forward-send-backward-recv").start()
+ _, output_tensor_grad = _communicate(
+ tensor_send_next=output_tensor,
+ tensor_send_prev=None,
+ recv_prev=False,
+ recv_next=True,
+ )
+ if timers is not None:
+ timers("forward-send-backward-recv").stop()
+ return output_tensor_grad
+
+
+def send_backward_recv_forward(input_tensor_grad, timers=None):
+ """Batched send and recv with previous rank in pipeline."""
+ if mpu.is_pipeline_first_stage():
+ input_tensor = None
+ else:
+ if timers is not None:
+ timers("backward-send-forward-recv").start()
+ input_tensor, _ = _communicate(
+ tensor_send_next=None,
+ tensor_send_prev=input_tensor_grad,
+ recv_prev=True,
+ recv_next=False,
+ )
+ if timers is not None:
+ timers("backward-send-forward-recv").stop()
+ return input_tensor
+
+
+def send_forward_recv_forward(output_tensor, recv_prev, timers=None):
+ """Batched recv from previous rank and send to next rank in pipeline."""
+ if timers is not None:
+ timers("forward-send-forward-recv").start()
+ input_tensor, _ = _communicate(
+ tensor_send_next=output_tensor,
+ tensor_send_prev=None,
+ recv_prev=recv_prev,
+ recv_next=False,
+ )
+ if timers is not None:
+ timers("forward-send-forward-recv").stop()
+ return input_tensor
+
+
+def send_backward_recv_backward(input_tensor_grad, recv_next, timers=None):
+ """Batched recv from next rank and send to previous rank in pipeline."""
+ if timers is not None:
+ timers("backward-send-backward-recv").start()
+ _, output_tensor_grad = _communicate(
+ tensor_send_next=None,
+ tensor_send_prev=input_tensor_grad,
+ recv_prev=False,
+ recv_next=recv_next,
+ )
+ if timers is not None:
+ timers("backward-send-backward-recv").stop()
+ return output_tensor_grad
+
+
+def send_forward_backward_recv_forward_backward(
+ output_tensor, input_tensor_grad, recv_prev, recv_next, timers=None
+):
+ """Batched send and recv with previous and next ranks in pipeline."""
+ if timers is not None:
+ timers("forward-backward-send-forward-backward-recv").start()
+ input_tensor, output_tensor_grad = _communicate(
+ tensor_send_next=output_tensor,
+ tensor_send_prev=input_tensor_grad,
+ recv_prev=recv_prev,
+ recv_next=recv_next,
+ )
+ if timers is not None:
+ timers("forward-backward-send-forward-backward-recv").stop()
+ return input_tensor, output_tensor_grad
diff --git a/codegeex/megatron/tokenizer/__init__.py b/codegeex/megatron/tokenizer/__init__.py
new file mode 100644
index 0000000..311f2fd
--- /dev/null
+++ b/codegeex/megatron/tokenizer/__init__.py
@@ -0,0 +1,17 @@
+# coding=utf-8
+# Copyright (c) 2020, NVIDIA CORPORATION. All rights reserved.
+#
+# Licensed under the Apache License, Version 2.0 (the "License");
+# you may not use this file except in compliance with the License.
+# You may obtain a copy of the License at
+#
+# http://www.apache.org/licenses/LICENSE-2.0
+#
+# Unless required by applicable law or agreed to in writing, software
+# distributed under the License is distributed on an "AS IS" BASIS,
+# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
+# See the License for the specific language governing permissions and
+# limitations under the License.
+
+
+from .tokenizer import build_tokenizer
diff --git a/codegeex/megatron/tokenizer/gpt2_tokenization.py b/codegeex/megatron/tokenizer/gpt2_tokenization.py
new file mode 100644
index 0000000..bcf8cf6
--- /dev/null
+++ b/codegeex/megatron/tokenizer/gpt2_tokenization.py
@@ -0,0 +1,375 @@
+# coding=utf-8
+# Copyright 2018 The Open AI Team Authors and The HuggingFace Inc. team.
+#
+# Licensed under the Apache License, Version 2.0 (the "License");
+# you may not use this file except in compliance with the License.
+# You may obtain a copy of the License at
+#
+# http://www.apache.org/licenses/LICENSE-2.0
+#
+# Unless required by applicable law or agreed to in writing, software
+# distributed under the License is distributed on an "AS IS" BASIS,
+# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
+# See the License for the specific language governing permissions and
+# limitations under the License.
+
+"""Tokenization classes for OpenAI GPT."""
+
+from __future__ import absolute_import, division, print_function, unicode_literals
+
+import sys
+import json
+import logging
+import os
+import regex as re
+from io import open
+
+try:
+ from functools import lru_cache
+except ImportError:
+ # Just a dummy decorator to get the checks to run on python2
+ # because honestly I don't want to support a byte-level unicode BPE
+ # tokenizer on python 2 right now.
+ def lru_cache():
+ return lambda func: func
+
+
+logger = logging.getLogger(__name__)
+
+PRETRAINED_VOCAB_ARCHIVE_MAP = {
+ "gpt2": "https://s3.amazonaws.com/models.huggingface.co/bert/gpt2-vocab.json",
+}
+PRETRAINED_MERGES_ARCHIVE_MAP = {
+ "gpt2": "https://s3.amazonaws.com/models.huggingface.co/bert/gpt2-merges.txt",
+}
+PRETRAINED_VOCAB_POSITIONAL_EMBEDDINGS_SIZE_MAP = {
+ "gpt2": 1024,
+}
+VOCAB_NAME = "vocab.json"
+MERGES_NAME = "merges.txt"
+SPECIAL_TOKENS_NAME = "special_tokens.txt"
+
+
+@lru_cache()
+def bytes_to_unicode():
+ """
+ Returns list of utf-8 byte and a corresponding list of unicode strings.
+ The reversible bpe codes work on unicode strings.
+ This means you need a large # of unicode characters in your vocab if you want to avoid UNKs.
+ When you're at something like a 10B token dataset you end up needing around 5K for decent coverage.
+ This is a signficant percentage of your normal, say, 32K bpe vocab.
+ To avoid that, we want lookup tables between utf-8 bytes and unicode strings.
+ And avoids mapping to whitespace/control characters the bpe code barfs on.
+ """
+ _chr = unichr if sys.version_info[0] == 2 else chr
+ bs = (
+ list(range(ord("!"), ord("~") + 1))
+ + list(range(ord("¡"), ord("¬") + 1))
+ + list(range(ord("®"), ord("ÿ") + 1))
+ )
+ cs = bs[:]
+ n = 0
+ for b in range(2 ** 8):
+ if b not in bs:
+ bs.append(b)
+ cs.append(2 ** 8 + n)
+ n += 1
+ cs = [_chr(n) for n in cs]
+ return dict(zip(bs, cs))
+
+
+def get_pairs(word):
+ """Return set of symbol pairs in a word.
+
+ Word is represented as tuple of symbols (symbols being variable-length strings).
+ """
+ pairs = set()
+ prev_char = word[0]
+ for char in word[1:]:
+ pairs.add((prev_char, char))
+ prev_char = char
+ return pairs
+
+
+class GPT2Tokenizer(object):
+ """
+ GPT-2 BPE tokenizer. Peculiarities:
+ - Byte-level BPE
+ """
+
+ @classmethod
+ def from_pretrained(
+ cls, pretrained_model_name_or_path, cache_dir=None, *inputs, **kwargs
+ ):
+ """
+ Instantiate a PreTrainedBertModel from a pre-trained model file.
+ Download and cache the pre-trained model file if needed.
+ """
+ if pretrained_model_name_or_path in PRETRAINED_VOCAB_ARCHIVE_MAP:
+ vocab_file = PRETRAINED_VOCAB_ARCHIVE_MAP[pretrained_model_name_or_path]
+ merges_file = PRETRAINED_MERGES_ARCHIVE_MAP[pretrained_model_name_or_path]
+ special_tokens_file = None
+ else:
+ vocab_file = os.path.join(pretrained_model_name_or_path, VOCAB_NAME)
+ merges_file = os.path.join(pretrained_model_name_or_path, MERGES_NAME)
+ special_tokens_file = os.path.join(
+ pretrained_model_name_or_path, SPECIAL_TOKENS_NAME
+ )
+ if not os.path.exists(special_tokens_file):
+ special_tokens_file = None
+ else:
+ logger.info(
+ "loading special tokens file {}".format(special_tokens_file)
+ )
+ # redirect to the cache, if necessary
+ try:
+ from .file_utils import cached_path
+
+ resolved_vocab_file = cached_path(vocab_file, cache_dir=cache_dir)
+ resolved_merges_file = cached_path(merges_file, cache_dir=cache_dir)
+ except EnvironmentError:
+ logger.error(
+ "Model name '{}' was not found in model name list ({}). "
+ "We assumed '{}' was a path or url but couldn't find files {} and {} "
+ "at this path or url.".format(
+ pretrained_model_name_or_path,
+ ", ".join(PRETRAINED_VOCAB_ARCHIVE_MAP.keys()),
+ pretrained_model_name_or_path,
+ vocab_file,
+ merges_file,
+ )
+ )
+ return None
+ if resolved_vocab_file == vocab_file and resolved_merges_file == merges_file:
+ logger.info("loading vocabulary file {}".format(vocab_file))
+ logger.info("loading merges file {}".format(merges_file))
+ else:
+ logger.info(
+ "loading vocabulary file {} from cache at {}".format(
+ vocab_file, resolved_vocab_file
+ )
+ )
+ logger.info(
+ "loading merges file {} from cache at {}".format(
+ merges_file, resolved_merges_file
+ )
+ )
+ if (
+ pretrained_model_name_or_path
+ in PRETRAINED_VOCAB_POSITIONAL_EMBEDDINGS_SIZE_MAP
+ ):
+ # if we're using a pretrained model, ensure the tokenizer wont index sequences longer
+ # than the number of positional embeddings
+ max_len = PRETRAINED_VOCAB_POSITIONAL_EMBEDDINGS_SIZE_MAP[
+ pretrained_model_name_or_path
+ ]
+ kwargs["max_len"] = min(kwargs.get("max_len", int(1e12)), max_len)
+ # Instantiate tokenizer.
+ if special_tokens_file and "special_tokens" not in kwargs:
+ special_tokens = (
+ open(special_tokens_file, encoding="utf-8").read().split("\n")[:-1]
+ )
+ else:
+ special_tokens = kwargs.pop("special_tokens", [])
+ tokenizer = cls(
+ resolved_vocab_file,
+ resolved_merges_file,
+ special_tokens=special_tokens,
+ *inputs,
+ **kwargs
+ )
+ return tokenizer
+
+ def __init__(
+ self,
+ vocab_file,
+ merges_file,
+ errors="replace",
+ special_tokens=None,
+ max_len=None,
+ ):
+ self.max_len = max_len if max_len is not None else int(1e12)
+ self.encoder = json.load(open(vocab_file))
+ self.decoder = {v: k for k, v in self.encoder.items()}
+ self.errors = errors # how to handle errors in decoding
+ self.byte_encoder = bytes_to_unicode()
+ self.byte_decoder = {v: k for k, v in self.byte_encoder.items()}
+ bpe_data = open(merges_file, encoding="utf-8").read().split("\n")[1:-1]
+ bpe_merges = [tuple(merge.split()) for merge in bpe_data]
+ self.bpe_ranks = dict(zip(bpe_merges, range(len(bpe_merges))))
+ self.cache = {}
+
+ # Should haved added re.IGNORECASE so BPE merges can happen for
+ # capitalized versions of contractions
+ self.pat = re.compile(
+ r"""'s|'t|'re|'ve|'m|'ll|'d| ?\p{L}+| ?\p{N}+| ?[^\s\p{L}\p{N}]+|\s+(?!\S)|\s+"""
+ )
+
+ self.special_tokens = {}
+ self.special_tokens_decoder = {}
+ self.set_special_tokens(special_tokens)
+
+ def __len__(self):
+ return len(self.encoder) + len(self.special_tokens)
+
+ def set_special_tokens(self, special_tokens):
+ """Add a list of additional tokens to the encoder.
+ The additional tokens are indexed starting from the last index of the
+ current vocabulary in the order of the `special_tokens` list.
+ """
+ if not special_tokens:
+ self.special_tokens = {}
+ self.special_tokens_decoder = {}
+ return
+ self.special_tokens = dict(
+ (tok, len(self.encoder) + i) for i, tok in enumerate(special_tokens)
+ )
+ self.special_tokens_decoder = {v: k for k, v in self.special_tokens.items()}
+ logger.info("Special tokens {}".format(self.special_tokens))
+
+ def bpe(self, token):
+ if token in self.cache:
+ return self.cache[token]
+ word = tuple(token)
+ pairs = get_pairs(word)
+
+ if not pairs:
+ return token
+
+ while True:
+ bigram = min(pairs, key=lambda pair: self.bpe_ranks.get(pair, float("inf")))
+ if bigram not in self.bpe_ranks:
+ break
+ first, second = bigram
+ new_word = []
+ i = 0
+ while i < len(word):
+ try:
+ j = word.index(first, i)
+ new_word.extend(word[i:j])
+ i = j
+ except BaseException:
+ new_word.extend(word[i:])
+ break
+
+ if word[i] == first and i < len(word) - 1 and word[i + 1] == second:
+ new_word.append(first + second)
+ i += 2
+ else:
+ new_word.append(word[i])
+ i += 1
+ new_word = tuple(new_word)
+ word = new_word
+ if len(word) == 1:
+ break
+ else:
+ pairs = get_pairs(word)
+ word = " ".join(word)
+ self.cache[token] = word
+ return word
+
+ def tokenize(self, text):
+ """Tokenize a string."""
+ bpe_tokens = []
+ for token in re.findall(self.pat, text):
+ if sys.version_info[0] == 2:
+ token = "".join(self.byte_encoder[ord(b)] for b in token)
+ else:
+ token = "".join(self.byte_encoder[b] for b in token.encode("utf-8"))
+ bpe_tokens.extend(bpe_token for bpe_token in self.bpe(token).split(" "))
+ return bpe_tokens
+
+ def convert_tokens_to_ids(self, tokens):
+ """Converts a sequence of tokens into ids using the vocab."""
+ ids = []
+ if isinstance(tokens, str) or (
+ sys.version_info[0] == 2 and isinstance(tokens, unicode)
+ ):
+ if tokens in self.special_tokens:
+ return self.special_tokens[tokens]
+ else:
+ return self.encoder.get(tokens, 0)
+ for token in tokens:
+ if token in self.special_tokens:
+ ids.append(self.special_tokens[token])
+ else:
+ ids.append(self.encoder.get(token, 0))
+ if len(ids) > self.max_len:
+ logger.warning(
+ "Token indices sequence length is longer than the specified maximum "
+ " sequence length for this OpenAI GPT model ({} > {}). Running this"
+ " sequence through the model will result in indexing errors".format(
+ len(ids), self.max_len
+ )
+ )
+ return ids
+
+ def convert_ids_to_tokens(self, ids, skip_special_tokens=False):
+ """Converts a sequence of ids in BPE tokens using the vocab."""
+ tokens = []
+ for i in ids:
+ if i in self.special_tokens_decoder:
+ if not skip_special_tokens:
+ tokens.append(self.special_tokens_decoder[i])
+ else:
+ tokens.append(self.decoder[i])
+ return tokens
+
+ def encode(self, text):
+ return self.convert_tokens_to_ids(self.tokenize(text))
+
+ def decode(self, tokens):
+ text = "".join([self.decoder[token] for token in tokens])
+ text = bytearray([self.byte_decoder[c] for c in text]).decode(
+ "utf-8", errors=self.errors
+ )
+ return text
+
+ def save_vocabulary(self, vocab_path):
+ """Save the tokenizer vocabulary and merge files to a directory."""
+ if not os.path.isdir(vocab_path):
+ logger.error(
+ "Vocabulary path ({}) should be a directory".format(vocab_path)
+ )
+ return
+ vocab_file = os.path.join(vocab_path, VOCAB_NAME)
+ merge_file = os.path.join(vocab_path, MERGES_NAME)
+ special_tokens_file = os.path.join(vocab_path, SPECIAL_TOKENS_NAME)
+
+ with open(vocab_file, "w", encoding="utf-8") as f:
+ f.write(json.dumps(self.encoder, ensure_ascii=False))
+
+ index = 0
+ with open(merge_file, "w", encoding="utf-8") as writer:
+ writer.write("#version: 0.2\n")
+ for bpe_tokens, token_index in sorted(
+ self.bpe_ranks.items(), key=lambda kv: kv[1]
+ ):
+ if index != token_index:
+ logger.warning(
+ "Saving vocabulary to {}: BPE merge indices are not consecutive."
+ " Please check that the tokenizer is not corrupted!".format(
+ merge_file
+ )
+ )
+ index = token_index
+ writer.write(" ".join(bpe_tokens) + "\n")
+ index += 1
+
+ index = len(self.encoder)
+ with open(special_tokens_file, "w", encoding="utf-8") as writer:
+ for token, token_index in sorted(
+ self.special_tokens.items(), key=lambda kv: kv[1]
+ ):
+ if index != token_index:
+ logger.warning(
+ "Saving special tokens vocabulary to {}: BPE indices are not consecutive."
+ " Please check that the tokenizer is not corrupted!".format(
+ special_tokens_file
+ )
+ )
+ index = token_index
+ writer.write(token + "\n")
+ index += 1
+
+ return vocab_file, merge_file, special_tokens_file
diff --git a/codegeex/megatron/tokenizer/tokenizer.py b/codegeex/megatron/tokenizer/tokenizer.py
new file mode 100644
index 0000000..92ef816
--- /dev/null
+++ b/codegeex/megatron/tokenizer/tokenizer.py
@@ -0,0 +1,280 @@
+# coding=utf-8
+# Copyright (c) 2020, NVIDIA CORPORATION. All rights reserved.
+#
+# Licensed under the Apache License, Version 2.0 (the "License");
+# you may not use this file except in compliance with the License.
+# You may obtain a copy of the License at
+#
+# http://www.apache.org/licenses/LICENSE-2.0
+#
+# Unless required by applicable law or agreed to in writing, software
+# distributed under the License is distributed on an "AS IS" BASIS,
+# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
+# See the License for the specific language governing permissions and
+# limitations under the License.
+
+"""CodeGeeX tokenizers."""
+
+from abc import ABC
+from abc import abstractmethod
+
+from .gpt2_tokenization import GPT2Tokenizer
+from transformers import AutoTokenizer
+
+
+def encode_whitespaces(text, start_extra_id: int, max_len: int):
+ """Encode whitespaces to extra tokens in GPT-J.
+
+ >>> encode_whitespaces('a\\n b\\n c', 10, 10)
+ 'a\\n<|extratoken_10|>b\\n<|extratoken_11|>c'
+ """
+
+ def push_acc_space(acc_len: int, text: str):
+ if acc_len == 0:
+ return text
+ if acc_len == 1:
+ return text + " "
+ assert (
+ acc_len <= max_len
+ ), f"Max whitespace run length {max_len}, but found {acc_len}"
+ extra_id = start_extra_id - 2 + acc_len
+ extra_token = f"<|extratoken_{extra_id}|>"
+ return text + extra_token
+
+ acc_len = 0
+ res = ""
+ for ch in text:
+ if ch == " ":
+ acc_len += 1
+ if acc_len == max_len:
+ res = push_acc_space(acc_len, res)
+ acc_len = 0
+ else:
+ res = push_acc_space(acc_len, res)
+ acc_len = 0
+ res = res + ch
+
+ res = push_acc_space(acc_len, res)
+
+ return res
+
+
+def decode_whitespaces(text: str, start_extra_id: int, max_len: int):
+ """Decode the whitespace-encoded strings produced by encode_whitespace.
+
+ >>> text = 'a\\n b\\n c'
+ >>> s, l = 10, 10
+ >>> text == decode_whitespaces(encode_whitespaces(text, s, l), s, l)
+ True
+ """
+ for l in range(2, max_len + 1):
+ token_id = start_extra_id - 2 + l
+ token = f"<|extratoken_{token_id}|>"
+ text = text.replace(token, " " * l)
+ return text
+
+
+def build_hgf_tokenizer(args):
+ """Initialize tokenizer."""
+ tokenizer_path = args.tokenizer_path
+ if args.rank == 0:
+ print(f"> building huggingface tokenizer from {tokenizer_path} ...", flush=True)
+ assert tokenizer_path is not None, "Tokenizer path must be provided."
+
+ tokenizer = AutoTokenizer.from_pretrained(tokenizer_path)
+ if args.rank == 0:
+ print(f" > eos_token = {tokenizer.eos_token}", flush=True)
+
+ ws_start_id = args.ws_encoding_start_id if "ws_encoding_start_id" in args else None
+ ws_len = args.ws_encoding_length if "ws_encoding_length" in args else None
+
+ return HgfTokenizerWrapper(
+ tokenizer, ws_start=ws_start_id, ws_len=ws_len
+ )
+
+
+def build_tokenizer(args):
+ """Initialize tokenizer."""
+ if "tokenizer_path" in args and args.tokenizer_path is not None:
+ # build huggingface tokenizer
+ tokenizer = build_hgf_tokenizer(args)
+ args.padded_vocab_size = _vocab_size_with_padding(tokenizer.vocab_size, args)
+ return tokenizer
+
+ if args.rank == 0:
+ print("> building {} tokenizer ...".format(args.tokenizer_type), flush=True)
+
+ # Select and instantiate the tokenizer.
+ assert args.vocab_file is not None
+ if args.tokenizer_type == "GPT2BPETokenizer":
+ assert args.merge_file is not None
+ tokenizer = _GPT2BPETokenizer(args.vocab_file, args.merge_file)
+ else:
+ raise NotImplementedError(
+ "{} tokenizer is not " "implemented.".format(args.tokenizer_type)
+ )
+
+ # Add vocab size.
+ args.padded_vocab_size = _vocab_size_with_padding(tokenizer.vocab_size, args)
+
+ return tokenizer
+
+
+def _vocab_size_with_padding(orig_vocab_size, args):
+ """Pad vocab size so it is divisible by model parallel size and
+ still having GPU friendly size."""
+
+ after = orig_vocab_size
+ multiple = args.make_vocab_size_divisible_by * args.tensor_model_parallel_size
+ while (after % multiple) != 0:
+ after += 1
+ if args.rank == 0:
+ print(
+ " > padded vocab (size: {}) with {} dummy tokens "
+ "(new size: {})".format(orig_vocab_size, after - orig_vocab_size, after),
+ flush=True,
+ )
+ return after
+
+
+class AbstractTokenizer(ABC):
+ """Abstract class for tokenizer."""
+
+ def __init__(self, name):
+ self.name = name
+ super().__init__()
+
+ @property
+ @abstractmethod
+ def vocab_size(self):
+ pass
+
+ @property
+ @abstractmethod
+ def vocab(self):
+ """Dictionary from vocab text token to id token."""
+ pass
+
+ @property
+ @abstractmethod
+ def inv_vocab(self):
+ """Dictionary from vocab id token to text token."""
+ pass
+
+ @abstractmethod
+ def tokenize(self, text):
+ pass
+
+ def detokenize(self, token_ids):
+ raise NotImplementedError(
+ "detokenizer is not implemented for {} " "tokenizer".format(self.name)
+ )
+
+ @property
+ def cls(self):
+ raise NotImplementedError(
+ "CLS is not provided for {} " "tokenizer".format(self.name)
+ )
+
+ @property
+ def sep(self):
+ raise NotImplementedError(
+ "SEP is not provided for {} " "tokenizer".format(self.name)
+ )
+
+ @property
+ def pad(self):
+ raise NotImplementedError(
+ "PAD is not provided for {} " "tokenizer".format(self.name)
+ )
+
+ @property
+ def eod(self):
+ raise NotImplementedError(
+ "EOD is not provided for {} " "tokenizer".format(self.name)
+ )
+
+ @property
+ def mask(self):
+ raise NotImplementedError(
+ "MASK is not provided for {} " "tokenizer".format(self.name)
+ )
+
+
+class _GPT2BPETokenizer(AbstractTokenizer):
+ """Original GPT2 BPE tokenizer."""
+
+ def __init__(self, vocab_file, merge_file):
+ name = "GPT2 BPE"
+ super().__init__(name)
+
+ self.tokenizer = GPT2Tokenizer(
+ vocab_file, merge_file, errors="replace", special_tokens=[], max_len=None
+ )
+ self.eod_id = self.tokenizer.encoder["<|endoftext|>"]
+
+ @property
+ def vocab_size(self):
+ return len(self.tokenizer.encoder)
+
+ @property
+ def vocab(self):
+ return self.tokenizer.encoder
+
+ @property
+ def inv_vocab(self):
+ return self.tokenizer.decoder
+
+ def tokenize(self, text):
+ return self.tokenizer.encode(text)
+
+ def detokenize(self, token_ids):
+ return self.tokenizer.decode(token_ids)
+
+ @property
+ def eod(self):
+ return self.eod_id
+
+
+class HgfTokenizerWrapper(AbstractTokenizer):
+ """Wrapper for Hugging Face tokenizer."""
+
+ def __init__(
+ self,
+ tokenizer,
+ ws_start: int = None,
+ ws_len: int = None,
+ ):
+ super(HgfTokenizerWrapper, self).__init__(tokenizer.__class__.__name__)
+ self.tokenizer = tokenizer
+ self.ws_start = ws_start
+ self.ws_len = ws_len
+
+ def tokenize(self, text):
+ if self.ws_start:
+ text = encode_whitespaces(text, self.ws_start, self.ws_len)
+ input_ids = self.tokenizer(text, is_split_into_words=False).input_ids
+
+ return input_ids
+
+ def detokenize(self, token_ids):
+ text = self.tokenizer.decode(token_ids, skip_special_tokens=False)
+ if self.ws_start:
+ text = decode_whitespaces(text, self.ws_start, self.ws_len)
+ return text
+
+ @property
+ def eod(self):
+ return self.tokenizer.eos_token_id
+
+ @property
+ def inv_vocab(self):
+ return len(self.tokenizer.decoder)
+
+ @property
+ def vocab(self):
+ return self.tokenizer.vocab
+
+ @property
+ def vocab_size(self):
+ return len(self.vocab)
diff --git a/codegeex/megatron/utils.py b/codegeex/megatron/utils.py
new file mode 100644
index 0000000..2b5ba0b
--- /dev/null
+++ b/codegeex/megatron/utils.py
@@ -0,0 +1,219 @@
+# coding=utf-8
+# Copyright (c) 2020, NVIDIA CORPORATION. All rights reserved.
+#
+# Licensed under the Apache License, Version 2.0 (the "License");
+# you may not use this file except in compliance with the License.
+# You may obtain a copy of the License at
+#
+# http://www.apache.org/licenses/LICENSE-2.0
+#
+# Unless required by applicable law or agreed to in writing, software
+# distributed under the License is distributed on an "AS IS" BASIS,
+# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
+# See the License for the specific language governing permissions and
+# limitations under the License.
+
+"""General utilities."""
+
+import sys
+
+import torch
+from torch.nn.parallel import DistributedDataParallel as torchDDP
+
+from apex.multi_tensor_apply import multi_tensor_applier
+import amp_C
+
+from codegeex.megatron import get_args
+from codegeex.megatron import print_rank_0
+from codegeex.megatron import get_adlr_autoresume
+from codegeex.megatron import mpu
+from codegeex.megatron.model.module import param_is_not_shared
+from codegeex.megatron.mpu.layers import param_is_not_tensor_parallel_duplicate
+
+
+def unwrap_model(model, module_instances=(torchDDP)):
+ return_list = True
+ if not isinstance(model, list):
+ model = [model]
+ return_list = False
+ unwrapped_model = []
+ for model_module in model:
+ while isinstance(model_module, module_instances):
+ model_module = model_module.module
+ unwrapped_model.append(model_module)
+ if not return_list:
+ return unwrapped_model[0]
+ return unwrapped_model
+
+
+def calc_params_l2_norm(model):
+ """Calculate l2 norm of parameters"""
+ args = get_args()
+ if not isinstance(model, list):
+ model = [model]
+ # Remove duplicate params.
+ params_data = []
+ for model_ in model:
+ for param in model_.parameters():
+ is_not_shared = param_is_not_shared(param)
+ is_not_tp_duplicate = param_is_not_tensor_parallel_duplicate(param)
+ if is_not_shared and is_not_tp_duplicate:
+ if args.bf16:
+ params_data.append(param.data.float())
+ else:
+ params_data.append(param.data)
+ # Calculate norm
+ dummy_overflow_buf = torch.cuda.IntTensor([0])
+ norm, _ = multi_tensor_applier(
+ amp_C.multi_tensor_l2norm,
+ dummy_overflow_buf,
+ [params_data],
+ False, # no per-parameter norm
+ )
+ norm_2 = norm * norm
+ # Sum across all model-parallel GPUs.
+ torch.distributed.all_reduce(
+ norm_2, op=torch.distributed.ReduceOp.SUM, group=mpu.get_model_parallel_group()
+ )
+ return norm_2.item() ** 0.5
+
+
+def average_losses_across_data_parallel_group(losses):
+ """Reduce a tensor of losses across all GPUs."""
+ averaged_losses = torch.cat([loss.clone().detach().view(1) for loss in losses])
+ torch.distributed.all_reduce(averaged_losses, group=mpu.get_data_parallel_group())
+ averaged_losses = averaged_losses / torch.distributed.get_world_size(
+ group=mpu.get_data_parallel_group()
+ )
+
+ return averaged_losses
+
+
+def report_memory(name):
+ """Simple GPU memory report."""
+ mega_bytes = 1024.0 * 1024.0
+ string = name + " memory (MB)"
+ string += " | allocated: {}".format(torch.cuda.memory_allocated() / mega_bytes)
+ string += " | max allocated: {}".format(
+ torch.cuda.max_memory_allocated() / mega_bytes
+ )
+ string += " | reserved: {}".format(torch.cuda.memory_reserved() / mega_bytes)
+ string += " | max reserved: {}".format(
+ torch.cuda.max_memory_reserved() / mega_bytes
+ )
+ if mpu.get_data_parallel_rank() == 0:
+ print("[Rank {}] {}".format(torch.distributed.get_rank(), string), flush=True)
+
+
+def print_params_min_max_norm(optimizer, iteration):
+ """Print min, max, and norm of all parameters."""
+ index = 0
+ rank = torch.distributed.get_rank()
+ string = "iteration, rank, index, tensor-model-parallel, min, max, norm\n"
+ optimizer_ = optimizer.optimizer
+ for param_group in optimizer_.param_groups:
+ for param in param_group["params"]:
+ index += 1
+ min_ = param.data.min()
+ max_ = param.data.max()
+ norm = torch.linalg.norm(param.data)
+ string += "{:7d}, {:4d}, {:4d}, {:2d}, ".format(
+ iteration, rank, index, int(param.tensor_model_parallel)
+ )
+ string += "{:.6E}, {:.6E}, {:.6E}\n".format(min_, max_, norm)
+ print(string, flush=True)
+
+
+def check_adlr_autoresume_termination(iteration, model, optimizer, lr_scheduler):
+ """Check for autoresume signal and exit if it is received."""
+ from codegeex.megatron.checkpointing import save_checkpoint
+
+ args = get_args()
+ autoresume = get_adlr_autoresume()
+ # Add barrier to ensure consistnecy.
+ torch.distributed.barrier()
+ if autoresume.termination_requested():
+ if args.save:
+ save_checkpoint(iteration, model, optimizer, lr_scheduler)
+ print_rank_0(">>> autoresume termination request found!")
+ if torch.distributed.get_rank() == 0:
+ autoresume.request_resume()
+ print_rank_0(">>> training terminated. Returning")
+ sys.exit(0)
+
+
+def get_ltor_masks_and_position_ids(
+ data, eod_token, reset_position_ids, reset_attention_mask, eod_mask_loss
+):
+ """Build masks and position id for left to right model."""
+
+ # Extract batch size and sequence length.
+ micro_batch_size, seq_length = data.size()
+
+ # Attention mask (lower triangular).
+ if reset_attention_mask:
+ att_mask_batch = micro_batch_size
+ else:
+ att_mask_batch = 1
+ attention_mask = torch.tril(
+ torch.ones((att_mask_batch, seq_length, seq_length), device=data.device)
+ ).view(att_mask_batch, 1, seq_length, seq_length)
+
+ # Loss mask.
+ loss_mask = torch.ones(data.size(), dtype=torch.float, device=data.device)
+ if eod_mask_loss:
+ loss_mask[data == eod_token] = 0.0
+
+ # Position ids.
+ position_ids = torch.arange(seq_length, dtype=torch.long, device=data.device)
+ position_ids = position_ids.unsqueeze(0).expand_as(data)
+ # We need to clone as the ids will be modifed based on batch index.
+ if reset_position_ids:
+ position_ids = position_ids.clone()
+
+ if reset_position_ids or reset_attention_mask:
+ # Loop through the batches:
+ for b in range(micro_batch_size):
+
+ # Find indecies where EOD token is.
+ eod_index = position_ids[b, data[b] == eod_token]
+ # Detach indecies from positions if going to modify positions.
+ if reset_position_ids:
+ eod_index = eod_index.clone()
+
+ # Loop through EOD indecies:
+ prev_index = 0
+ for j in range(eod_index.size()[0]):
+ i = eod_index[j]
+ # Mask attention loss.
+ if reset_attention_mask:
+ attention_mask[b, 0, (i + 1) :, : (i + 1)] = 0
+ # Reset positions.
+ if reset_position_ids:
+ position_ids[b, (i + 1) :] -= i + 1 - prev_index
+ prev_index = i + 1
+
+ # Convert attention mask to binary:
+ attention_mask = attention_mask < 0.5
+
+ return attention_mask, loss_mask, position_ids
+
+
+def get_parameters_in_billions(model):
+ gpus_per_model = torch.distributed.get_world_size(
+ group=mpu.get_model_parallel_group()
+ )
+
+ approx_parameters_in_billions = sum(
+ [
+ sum(
+ [
+ p.ds_numel if hasattr(p, "ds_id") else p.nelement()
+ for p in model_module.parameters()
+ ]
+ )
+ for model_module in model
+ ]
+ )
+
+ return approx_parameters_in_billions * gpus_per_model / (1e9)
diff --git a/codegeex/mindspore/generation_humaneval.py b/codegeex/mindspore/generation_humaneval.py
index 05ce020..030ff47 100644
--- a/codegeex/mindspore/generation_humaneval.py
+++ b/codegeex/mindspore/generation_humaneval.py
@@ -209,7 +209,7 @@ def run_predict(model_predict, config, args_opt, rank):
part = int(args_opt.part)
gen_times = 12 # TODO: set as generation times of current task
print(f"gen times: {gen_times}, part: {part}")
- save_path = f'/home/work/sfs/xx/pangu_alpha_code/generation_humanevalx/cpp/epoch_6_7375_temp_{args_opt.temperature}/samples_{args_opt.load_ckpt_epoch}_part_{part}.jsonl' # TODO: set as current save path
+ save_path = f'/home/work/sfs/xx/pangu_alpha_code/generation_humanevalx/cpp/temp_{args_opt.temperature}/samples_{args_opt.load_ckpt_epoch}_part_{part}.jsonl' # TODO: set as current save path
if rank == 0 and not os.path.exists(save_path):
os.makedirs(os.path.split(save_path)[0], exist_ok=True)
f = open(save_path, 'w')
diff --git a/configs/codegeex_13b.sh b/configs/codegeex_13b.sh
new file mode 100644
index 0000000..6332c88
--- /dev/null
+++ b/configs/codegeex_13b.sh
@@ -0,0 +1,16 @@
+# CodeGeeX-13B configuration
+
+CHECKPOINT_PATH=""
+
+MODEL_ARGS="--num-layers 39 \
+ --hidden-size 5120 \
+ --num-attention-heads 40 \
+ --max-position-embeddings 2048 \
+ --attention-softmax-in-fp32 \
+ --load "$CHECKPOINT_PATH" \
+ --layernorm-epsilon 1e-5 \
+ --fp16 \
+ --ws-encoding-start-id 10 \
+ --ws-encoding-length 10 \
+ --make-vocab-size-divisible-by 52224 \
+ --seq-length 2048"
\ No newline at end of file
diff --git a/resources/api/api_step_1.png b/resources/api/api_step_1.png
new file mode 100644
index 0000000..2ecba88
Binary files /dev/null and b/resources/api/api_step_1.png differ
diff --git a/resources/api/api_step_2.png b/resources/api/api_step_2.png
new file mode 100644
index 0000000..ea7b0e8
Binary files /dev/null and b/resources/api/api_step_2.png differ
diff --git a/resources/api/api_step_3.png b/resources/api/api_step_3.png
new file mode 100644
index 0000000..098cadf
Binary files /dev/null and b/resources/api/api_step_3.png differ
diff --git a/resources/api/api_step_4.png b/resources/api/api_step_4.png
new file mode 100644
index 0000000..7a4f4db
Binary files /dev/null and b/resources/api/api_step_4.png differ
diff --git a/resources/api/api_step_5.png b/resources/api/api_step_5.png
new file mode 100644
index 0000000..85705d2
Binary files /dev/null and b/resources/api/api_step_5.png differ
diff --git a/resources/en/hx_budget_assignment.png b/resources/en/hx_budget_assignment.png
deleted file mode 100644
index 98e7e9e..0000000
Binary files a/resources/en/hx_budget_assignment.png and /dev/null differ
diff --git a/resources/en/hx_generattion_radar_horizon.png b/resources/en/hx_generattion_radar_horizon.png
index 180405b..e048ce4 100644
Binary files a/resources/en/hx_generattion_radar_horizon.png and b/resources/en/hx_generattion_radar_horizon.png differ
diff --git a/resources/logo/codegeex_logo.png b/resources/logo/codegeex_logo.png
index 0259298..c909a84 100644
Binary files a/resources/logo/codegeex_logo.png and b/resources/logo/codegeex_logo.png differ
diff --git a/resources/zh/hx_budget_assignment_zh.png b/resources/zh/hx_budget_assignment_zh.png
deleted file mode 100644
index e3a3422..0000000
Binary files a/resources/zh/hx_budget_assignment_zh.png and /dev/null differ
diff --git a/resources/zh/hx_generattion_radar_horizon_zh.png b/resources/zh/hx_generattion_radar_horizon_zh.png
index 7c03e4d..5ea68b1 100644
Binary files a/resources/zh/hx_generattion_radar_horizon_zh.png and b/resources/zh/hx_generattion_radar_horizon_zh.png differ
diff --git a/scripts/convert_mindspore_to_megatron.sh b/scripts/convert_mindspore_to_megatron.sh
new file mode 100644
index 0000000..d221d82
--- /dev/null
+++ b/scripts/convert_mindspore_to_megatron.sh
@@ -0,0 +1,28 @@
+# This script is used to convert mindspore checkpoint to the megatron format.
+
+NPY_CKPT_PATH=$1 # Path to Mindspore exported weights in .npy format.
+SAVE_CKPT_PATH=$2 # Path to save the output .pt checkpoint.
+GPU=$3
+
+SCRIPT_PATH=$(realpath "$0")
+SCRIPT_DIR=$(dirname "$SCRIPT_PATH")
+MAIN_DIR=$(dirname "$SCRIPT_DIR")
+TOKENIZER_PATH="$MAIN_DIR/codegeex/tokenizer/"
+
+# export CUDA settings
+if [ -z "$GPU" ]; then
+ GPU=0
+fi
+
+export CUDA_HOME=/usr/local/cuda-11.1/
+export CUDA_VISIBLE_DEVICES=$GPU
+
+
+CMD="python $MAIN_DIR/codegeex/megatron/mindspore_to_megatron.py \
+ --npy-ckpt-path $NPY_CKPT_PATH \
+ --save-ckpt-path $SAVE_CKPT_PATH \
+ --tokenizer-path $TOKENIZER_PATH \
+ $MODEL_ARGS"
+
+echo "$CMD"
+eval "$CMD"
\ No newline at end of file
diff --git a/scripts/generate_humaneval_x.sh b/scripts/generate_humaneval_x.sh
new file mode 100644
index 0000000..8e91bf7
--- /dev/null
+++ b/scripts/generate_humaneval_x.sh
@@ -0,0 +1,95 @@
+# This script is used to generate solutions of HumanEval-X.
+
+LANGUAGE=$1 # Target programming language, currently support one of ["python", "java", "cpp", "js", "go"]
+OUTPUT_PATH=$2 # Output path of the generated programs.
+HOSTLIST=$3 # Provide hostfile if generating distributedly
+
+SCRIPT_PATH=$(realpath "$0")
+SCRIPT_DIR=$(dirname "$SCRIPT_PATH")
+MAIN_DIR=$(dirname "$SCRIPT_DIR")
+TOKENIZER_PATH="$MAIN_DIR/codegeex/tokenizer/"
+
+# export CUDA settings
+export CUDA_HOME=/usr/local/cuda-11.1/
+
+# import model configuration
+source "$MAIN_DIR/configs/codegeex_13b.sh"
+
+# nccl options
+OPTIONS_NCCL="export NCCL_DEBUG=warn; export NCCL_IB_DISABLE=0; export NCCL_IB_GID_INDEX=3"
+OPTIONS_PATH="export PATH=$PATH; export LD_LIBRARY_PATH=$LD_LIBRARY_PATH"
+CWD=$(pwd)
+
+# set master ip for zmq server
+if [ -z "$HOSTLIST" ]; then
+ ZMQ_ADDR=$(hostname -i)
+ echo "$ZMQ_ADDR" > "./hostfile"
+ HOSTLIST="./hostfile"
+else
+ ZMQ_ADDR=$(cat $HOSTLIST | head -n 1)
+fi
+echo "master_ip: $ZMQ_ADDR"
+
+NUM_SAMPLES=1
+MICRO_BSZ=1
+WORLD_SIZE=1
+TEMP=0.8
+TOPP=0.95
+SEED=42
+DATASET=humaneval
+TODAY=$(date +%y%m%d)
+CHANNEL_PORT=$(expr $RANDOM + 5000)
+MASTER_PORT=$(expr $RANDOM + 8000)
+
+# save log file
+LOG_DIR=$MAIN_DIR/log
+mkdir -p "$LOG_DIR"
+LOG_PATH="$LOG_DIR/$TODAY-generation.log"
+
+if [ -z "$LANGUAGE" ]; then
+ LANGUAGE=python
+fi
+
+if [ -z "$INPUT_PATH" ]; then
+ INPUT_PATH=$MAIN_DIR/codegeex/benchmark/humaneval-x/$LANGUAGE/data/humaneval_$LANGUAGE.jsonl.gz
+fi
+
+if [ -z "$OUTPUT_PATH" ]; then
+ OUTPUT_PATH=$MAIN_DIR/codegeex/benchmark/output/humaneval-x/codegeex/
+ mkdir -p "$OUTPUT_PATH"
+fi
+
+JOB_ID=codegeex-ns$NUM_SAMPLES-t$TEMP-topp$TOPP-seed$SEED-$LANGUAGE
+
+RUN_CMD="python \
+ $MAIN_DIR/codegeex/benchmark/humaneval-x/generate_humaneval_x.py \
+ --hostfile $HOSTLIST \
+ --channel-ip $ZMQ_ADDR \
+ --channel-port $CHANNEL_PORT \
+ --master-port $MASTER_PORT \
+ --tokenizer-path $TOKENIZER_PATH \
+ --load-deepspeed \
+ --temperature $TEMP \
+ --top-p $TOPP \
+ --out-seq-length 1024 \
+ --micro-batch-size $MICRO_BSZ \
+ --samples-per-problem $NUM_SAMPLES \
+ --language-type $LANGUAGE \
+ --dataset $DATASET \
+ --input-path $INPUT_PATH \
+ --output-prefix $OUTPUT_PATH/$JOB_ID \
+ --gen-node-world-size $WORLD_SIZE \
+ --seed $SEED \
+ $MODEL_ARGS"
+
+RUN_CMD="$OPTIONS_NCCL; $OPTIONS_PATH; $RUN_CMD"
+RUN_CMD="cd $CWD; $RUN_CMD"
+
+if (( WORLD_SIZE != 1 )); then
+ RUN_CMD="pdsh -R ssh -w ^$HOSTLIST \"$RUN_CMD\""
+fi
+
+echo "$RUN_CMD"
+echo "Writing log to $LOG_PATH"
+eval "$RUN_CMD" > "$LOG_PATH"
+bash $MAIN_DIR/scripts/gather_output.sh $OUTPUT_PATH $JOB_ID 1
diff --git a/scripts/test_inference.sh b/scripts/test_inference.sh
new file mode 100644
index 0000000..ef46051
--- /dev/null
+++ b/scripts/test_inference.sh
@@ -0,0 +1,39 @@
+# This script is used to test the inference of CodeGeeX.
+
+GPU=$1
+PROMPT_FILE=$2
+
+SCRIPT_PATH=$(realpath "$0")
+SCRIPT_DIR=$(dirname "$SCRIPT_PATH")
+MAIN_DIR=$(dirname "$SCRIPT_DIR")
+TOKENIZER_PATH="$MAIN_DIR/codegeex/tokenizer/"
+
+# import model configuration
+source "$MAIN_DIR/configs/codegeex_13b.sh"
+
+# export CUDA settings
+if [ -z "$GPU" ]; then
+ GPU=0
+fi
+
+export CUDA_HOME=/usr/local/cuda-11.1/
+export CUDA_VISIBLE_DEVICES=$GPU
+
+if [ -z "$PROMPT_FILE" ]; then
+ PROMPT_FILE=$MAIN_DIR/tests/test_prompt.txt
+fi
+
+# remove --greedy if using sampling
+CMD="python $MAIN_DIR/tests/test_inference.py \
+ --prompt-file $PROMPT_FILE \
+ --tokenizer-path $TOKENIZER_PATH \
+ --micro-batch-size 1 \
+ --out-seq-length 1024 \
+ --temperature 0.8 \
+ --top-p 0.95 \
+ --top-k 100 \
+ --greedy \
+ $MODEL_ARGS"
+
+echo "$CMD"
+eval "$CMD"
diff --git a/scripts/translate_humaneval_x.sh b/scripts/translate_humaneval_x.sh
new file mode 100644
index 0000000..1f5281f
--- /dev/null
+++ b/scripts/translate_humaneval_x.sh
@@ -0,0 +1,110 @@
+# This script is used to translate solutions of HumanEval-X.
+
+LANG_SRC_TYPE=$1 # Source programming language, currently support one of ["python", "java", "cpp", "js", "go"]
+LANG_TGT_TYPE=$2 # Target programming language, currently support one of ["python", "java", "cpp", "js", "go"]
+OUTPUT_PATH=$3 # Output path of the generated programs.
+HOSTLIST=$4 # Provide hostfile if generating distributedly
+
+SCRIPT_PATH=$(realpath "$0")
+SCRIPT_DIR=$(dirname "$SCRIPT_PATH")
+MAIN_DIR=$(dirname "$SCRIPT_DIR")
+TOKENIZER_PATH="$MAIN_DIR/codegeex/tokenizer/"
+
+# export CUDA settings
+export CUDA_HOME=/usr/local/cuda-11.1/
+
+# import model configuration
+source "$MAIN_DIR/configs/codegeex_13b.sh"
+
+# nccl options
+OPTIONS_NCCL="export NCCL_DEBUG=warn; export NCCL_IB_DISABLE=0; export NCCL_IB_GID_INDEX=3"
+OPTIONS_PATH="export PATH=$PATH; export LD_LIBRARY_PATH=$LD_LIBRARY_PATH"
+CWD=$(pwd)
+
+# set master ip for zmq server
+if [ -z "$HOSTLIST" ]; then
+ ZMQ_ADDR=$(hostname -i)
+ echo "$ZMQ_ADDR" > "./hostfile"
+ HOSTLIST="./hostfile"
+else
+ ZMQ_ADDR=$(cat $HOSTLIST | head -n 1)
+fi
+echo "master_ip: $ZMQ_ADDR"
+
+NUM_SAMPLES=1
+MICRO_BSZ=1
+WORLD_SIZE=1
+TEMP=0.8
+TOPP=0.95
+SEED=42
+DATASET=humaneval
+TODAY=$(date +%y%m%d)
+CHANNEL_PORT=$(expr $RANDOM + 5000)
+MASTER_PORT=$(expr $RANDOM + 8000)
+
+# save log file
+LOG_DIR=$MAIN_DIR/log
+mkdir -p "$LOG_DIR"
+LOG_PATH="$LOG_DIR/$TODAY-translation.log"
+
+if [ -z "$LANG_SRC_TYPE" ]
+then
+ LANG_SRC_TYPE=python
+fi
+
+if [ -z "$LANG_TGT_TYPE" ]
+then
+ LANG_TGT_TYPE=java
+fi
+
+if [ -z "$INPUT_SRC_PATH" ]
+then
+ INPUT_SRC_PATH=$MAIN_DIR/codegeex/benchmark/humaneval-x/$LANG_SRC_TYPE/data/humaneval_$LANG_SRC_TYPE.jsonl.gz
+fi
+
+if [ -z "$INPUT_TGT_PATH" ]
+then
+ INPUT_TGT_PATH=$MAIN_DIR/codegeex/benchmark/humaneval-x/$LANG_TGT_TYPE/data/humaneval_$LANG_TGT_TYPE.jsonl.gz
+fi
+
+if [ -z "$OUTPUT_PATH" ]; then
+ OUTPUT_PATH=$MAIN_DIR/codegeex/benchmark/output/humaneval-x/codegeex/
+ mkdir -p "$OUTPUT_PATH"
+fi
+
+JOB_ID=codegeex-ns$NUM_SAMPLES-t$TEMP-topp$TOPP-seed$SEED-$LANGUAGE
+
+RUN_CMD="python \
+ $MAIN_DIR/codegeex/benchmark/humaneval-x/translate_humaneval_x.py \
+ --hostfile $HOSTLIST \
+ --channel-ip $ZMQ_ADDR \
+ --channel-port $CHANNEL_PORT \
+ --master-port $MASTER_PORT \
+ --tokenizer-path $TOKENIZER_PATH \
+ --load-deepspeed \
+ --temperature $TEMP \
+ --top-p $TOPP \
+ --out-seq-length 1024 \
+ --micro-batch-size $MICRO_BSZ \
+ --samples-per-problem $NUM_SAMPLES \
+ --language-src-type $LANG_SRC_TYPE \
+ --language-tgt-type $LANG_TGT_TYPE \
+ --src-path $INPUT_SRC_PATH \
+ --tgt-path $INPUT_TGT_PATH \
+ --dataset $DATASET \
+ --output-prefix $OUTPUT_PATH/$JOB_ID \
+ --gen-node-world-size $WORLD_SIZE \
+ --seed $SEED \
+ $MODEL_ARGS"
+
+RUN_CMD="$OPTIONS_NCCL; $OPTIONS_PATH; $RUN_CMD"
+RUN_CMD="cd $CWD; $RUN_CMD"
+
+if (( WORLD_SIZE != 1 )); then
+ RUN_CMD="pdsh -R ssh -w ^$HOSTLIST \"$RUN_CMD\""
+fi
+
+echo "$RUN_CMD"
+echo "Writing log to $LOG_PATH"
+eval "$RUN_CMD" > "$LOG_PATH"
+bash $MAIN_DIR/scripts/gather_output.sh $OUTPUT_PATH $JOB_ID 1
diff --git a/tests/test_inference.py b/tests/test_inference.py
new file mode 100644
index 0000000..021da4c
--- /dev/null
+++ b/tests/test_inference.py
@@ -0,0 +1,199 @@
+import os
+import copy
+import time
+import torch
+import random
+import numpy as np
+
+from codegeex.megatron import get_tokenizer, get_args
+from codegeex.megatron.initialize import initialize_megatron
+from codegeex.megatron.model import CodeGeeXModel
+from codegeex.megatron.code_generation_utils import get_token_stream
+
+torch.set_printoptions(precision=8)
+
+
+def set_random_seed(seed):
+ np.random.seed(seed)
+ torch.manual_seed(seed)
+ torch.cuda.manual_seed_all(seed)
+
+ torch.backends.cudnn.deterministic = True
+ torch.backends.cudnn.benchmark = False
+
+
+def model_provider():
+ """Build the model."""
+
+ model = CodeGeeXModel(num_tokentypes=0,
+ parallel_output=False)
+
+ return model
+
+
+def add_code_generation_args(parser):
+ """Code generation arguments."""
+ group = parser.add_argument_group(title="code generation")
+
+ group.add_argument(
+ "--temperature",
+ type=float,
+ default=1.0,
+ help="Sampling temperature.",
+ )
+ group.add_argument(
+ "--greedy",
+ action="store_true",
+ default=False,
+ help="Use greedy sampling.",
+ )
+ group.add_argument(
+ "--top-p",
+ type=float,
+ default=0.0,
+ help="Top p sampling.",
+ )
+ group.add_argument(
+ "--top-k",
+ type=int,
+ default=0,
+ help="Top k sampling.",
+ )
+ group.add_argument(
+ "--out-seq-length",
+ type=int,
+ default=2048,
+ help="Size of the output generated text.",
+ )
+ group.add_argument(
+ "--recompute",
+ action="store_true",
+ help="During generation recompute all attention "
+ "instead of using previously computed keys/values.",
+ )
+ group.add_argument(
+ "--ws-encoding-start-id",
+ type=int,
+ default=10,
+ help="Start id for whitespace encoding",
+ )
+ group.add_argument(
+ "--ws-encoding-length",
+ type=int,
+ default=80,
+ help="Length of whitespace encoding",
+ )
+ group.add_argument(
+ "--n-generation",
+ type=int,
+ default=10,
+ )
+ group.add_argument(
+ "--eos-id",
+ type=int,
+ default=50256,
+ )
+ group.add_argument(
+ "--prompt-file",
+ type=str,
+ default="./test_prompt.txt",
+ )
+ group.add_argument(
+ "--perf-file",
+ type=str,
+ default="./perf_out.txt",
+ )
+ group.add_argument(
+ "--perf-trace",
+ type=str,
+ default="./perf_out.txt",
+ )
+ group.add_argument(
+ "--use-torch-profile",
+ action="store_true",
+ )
+ group.add_argument(
+ "--ln-fp32",
+ action="store_true",
+ )
+ group.add_argument(
+ '--bad-ids',
+ nargs="*",
+ type=int,
+ default=None,
+ help='Identify the type of programming language to generate',
+ )
+
+ return parser
+
+
+def main():
+ os.environ["MASTER_ADDR"] = "localhost"
+ os.environ["MASTER_PORT"] = str(random.randint(10000, 20000))
+
+ initialize_megatron(
+ extra_args_provider=add_code_generation_args,
+ )
+
+ args = get_args()
+ set_random_seed(args.seed)
+
+ print("Loading tokenizer ...")
+ tokenizer = get_tokenizer()
+
+ print("Loading state dict ...")
+ state_dict = torch.load(args.load, map_location="cpu")
+ state_dict = state_dict["module"]
+
+ print("Building CodeGeeX model ...")
+ model = model_provider()
+ model.load_state_dict(state_dict)
+ model.eval()
+ if args.fp16 and args.ln_fp16:
+ model.half()
+ model.cuda()
+
+ with open(args.prompt_file, "r") as f:
+ prompt = f.readlines()
+ prompt = "".join(prompt)
+
+ print("Generating ...")
+ t0 = time.perf_counter()
+ for prompt in [prompt]:
+ tokens = tokenizer.tokenize(prompt)
+ print(tokens)
+ print("Current prompt:")
+ print(prompt)
+ n_token_prompt = len(tokens)
+ print("N_token_prompt:", n_token_prompt)
+ token_stream = get_token_stream(
+ model,
+ [copy.deepcopy(tokens) for _ in range(args.micro_batch_size)],
+ micro_batch_size=args.micro_batch_size,
+ bad_ids=args.bad_ids,
+ )
+ is_finished = [False for _ in range(args.micro_batch_size)]
+ for i, generated in enumerate(token_stream):
+ generated_tokens = generated[0]
+ for j in range(args.micro_batch_size):
+ if is_finished[j]:
+ continue
+ if generated_tokens[j].cpu().numpy()[-1] == tokenizer.eod or len(
+ generated_tokens[j]) >= args.out_seq_length:
+ is_finished[j] = True
+ generated_tokens_ = generated_tokens[j].cpu().numpy().tolist()
+ generated_code = tokenizer.detokenize(generated_tokens_[n_token_prompt:])
+ t1 = time.perf_counter()
+ print("Total generation time:", t1 - t0, "# Tokens:", len(generated_tokens_) - n_token_prompt)
+ print(f"{(t1 - t0) / (len(generated_tokens_) - n_token_prompt)}s/token")
+ print("================================= Generated code:")
+ print(generated_code)
+ t0 = time.perf_counter()
+ if all(is_finished):
+ break
+
+ print("Generation finished.")
+
+
+if __name__ == "__main__":
+ main()
diff --git a/tests/test_prompt.txt b/tests/test_prompt.txt
new file mode 100644
index 0000000..e8cc681
--- /dev/null
+++ b/tests/test_prompt.txt
@@ -0,0 +1,15 @@
+code translation
+Java:
+public class Solution {
+ public static boolean hasCloseElements(int[] nums, int threshold) {
+ for (int i = 0; i < nums.length - 1; i++) {
+ for (int j = i + 1; j < nums.length; j++) {
+ if (Math.abs(nums[i] - nums[j]) < threshold) {
+ return true;
+ }
+ }
+ }
+ return false;
+ }
+}
+Python:
diff --git a/vscode-extension/README.md b/vscode-extension/README.md
index ce76487..51b9acd 100644
--- a/vscode-extension/README.md
+++ b/vscode-extension/README.md
@@ -1,8 +1,14 @@
-
+
🌐 中文
-We introduce CodeGeeX, a large-scale multilingual code generative model with 13 billion parameters, pretrained on a large code corpus of more than 20 programming languages. With CodeGeeX, we can generate codes by only providing natural language descriptions, complete any code snippet, or translate codes to other programming languages, etc. CodeGeeX also provides customizable features (**Prompt Mode**) to help you configure your own programming assistant. Happy coding!
+
+
+
+
+
+
+We introduce CodeGeeX, a large-scale multilingual code generation model with 13 billion parameters, pretrained on a large code corpus of more than 20 programming languages. With CodeGeeX, we can generate codes by only providing natural language descriptions, complete any code snippet, or translate codes to other programming languages, etc. CodeGeeX also provides customizable features (**Prompt Mode**) to help you configure your own programming assistant. Happy coding!
For more information, please check out our [Homepage](https://models.aminer.cn/codegeex/) and [GitHub repo](https://github.com/THUDM/CodeGeeX).
diff --git a/vscode-extension/README_zh.md b/vscode-extension/README_zh.md
index d097958..c78f2dc 100644
--- a/vscode-extension/README_zh.md
+++ b/vscode-extension/README_zh.md
@@ -1,7 +1,13 @@
-
+
🌐 English
+
+
+
+
+
+
CodeGeeX是一个具有130亿参数的多编程语言代码生成预训练模型,使用超过二十种编程语言训练得到。基于CodeGeeX开发的插件可以实现通过描述生成代码、补全代码、代码翻译等一系列功能。CodeGeeX同样提供可以定制的**提示模式(Prompt Mode)**,构建专属的编程助手。Happy Coding!
VS Code插件市场搜索"codegeex"即可免费使用,更多关于CodeGeeX信息请见我们的[主页](https://models.aminer.cn/codegeex/) and [GitHub仓库](https://github.com/THUDM/CodeGeeX)。