mirror of https://github.com/THUDM/CodeGeeX.git
You cannot select more than 25 topics
Topics must start with a letter or number, can include dashes ('-') and can be up to 35 characters long.
131 lines
4.7 KiB
Python
131 lines
4.7 KiB
Python
import os
|
|
import gzip
|
|
import json
|
|
|
|
from typing import *
|
|
|
|
LANGUAGE_TAG = {
|
|
"c" : "// language: C",
|
|
"c++" : "// language: C++",
|
|
"cpp" : "// language: C++",
|
|
"c#" : "// language: C#",
|
|
"csharp" : "// language: C#",
|
|
"css" : "/* language: CSS */",
|
|
"cuda" : "// language: Cuda",
|
|
"dart" : "// language: Dart",
|
|
"lua" : "// language: Lua",
|
|
"objectivec" : "// language: Objective-C",
|
|
"objective-c" : "// language: Objective-C",
|
|
"objective-c++": "// language: Objective-C++",
|
|
"python" : "# language: Python",
|
|
"perl" : "# language: Perl",
|
|
"prolog" : f"% language: Prolog",
|
|
"swift" : "// language: swift",
|
|
"lisp" : "; language: Lisp",
|
|
"java" : "// language: Java",
|
|
"scala" : "// language: Scala",
|
|
"tex" : f"% language: TeX",
|
|
"vue" : "<!--language: Vue-->",
|
|
"markdown" : "<!--language: Markdown-->",
|
|
"html" : "<!--language: HTML-->",
|
|
"php" : "// language: PHP",
|
|
"js" : "// language: JavaScript",
|
|
"javascript" : "// language: JavaScript",
|
|
"typescript" : "// language: TypeScript",
|
|
"go" : "// language: Go",
|
|
"shell" : "# language: Shell",
|
|
"rust" : "// language: Rust",
|
|
"sql" : "-- language: SQL",
|
|
"kotlin" : "// language: Kotlin",
|
|
"vb" : "' language: Visual Basic",
|
|
"ruby" : "# language: Ruby",
|
|
"pascal" : "// language: Pascal",
|
|
"r" : "# language: R",
|
|
"fortran" : "!language: Fortran",
|
|
"lean" : "-- language: Lean",
|
|
"matlab" : f"% language: Matlab",
|
|
"delphi" : "{language: Delphi}",
|
|
"scheme" : "; language: Scheme",
|
|
"basic" : "' language: Basic",
|
|
"assembly" : "; language: Assembly",
|
|
"groovy" : "// language: Groovy",
|
|
"abap" : "* language: Abap",
|
|
"gdscript" : "# language: GDScript",
|
|
"haskell" : "-- language: Haskell",
|
|
"julia" : "# language: Julia",
|
|
"elixir" : "# language: Elixir",
|
|
"excel" : "' language: Excel",
|
|
"clojure" : "; language: Clojure",
|
|
"actionscript" : "// language: ActionScript",
|
|
"solidity" : "// language: Solidity",
|
|
"powershell" : "# language: PowerShell",
|
|
"erlang" : f"% language: Erlang",
|
|
"cobol" : "// language: Cobol",
|
|
}
|
|
|
|
|
|
def stream_jsonl(filename: str) -> Iterable[Dict]:
|
|
"""
|
|
Parses each jsonl line and yields it as a dictionary
|
|
"""
|
|
if filename.endswith(".gz"):
|
|
with open(filename, "rb") as gzfp:
|
|
with gzip.open(gzfp, "rt") as fp:
|
|
for line in fp:
|
|
if any(not x.isspace() for x in line):
|
|
yield json.loads(line)
|
|
else:
|
|
with open(filename, "r") as fp:
|
|
for line in fp:
|
|
if any(not x.isspace() for x in line):
|
|
yield json.loads(line)
|
|
|
|
|
|
def write_jsonl(filename: str, data: Iterable[Dict], append: bool = False):
|
|
"""
|
|
Writes an iterable of dictionaries to jsonl
|
|
"""
|
|
if append:
|
|
mode = "ab"
|
|
else:
|
|
mode = "wb"
|
|
filename = os.path.expanduser(filename)
|
|
if filename.endswith(".gz"):
|
|
with open(filename, mode) as fp:
|
|
with gzip.GzipFile(fileobj=fp, mode="wb") as gzfp:
|
|
for x in data:
|
|
gzfp.write((json.dumps(x) + "\n").encode("utf-8"))
|
|
else:
|
|
with open(filename, mode) as fp:
|
|
for x in data:
|
|
fp.write((json.dumps(x) + "\n").encode("utf-8"))
|
|
|
|
|
|
def sliding_window(
|
|
prompt_tokens: list,
|
|
code_tokens: list,
|
|
seq_len: int,
|
|
sliding_stride: int,
|
|
minimum_code_len: int = 1,
|
|
) -> Iterable[Tuple[list, list]]:
|
|
"""
|
|
Generate a series of (prompt, code) pairs by sliding the window over the code.
|
|
"""
|
|
prompt_len = len(prompt_tokens)
|
|
code_len = len(code_tokens)
|
|
total_len = prompt_len + code_len
|
|
|
|
start_idx = max(0, prompt_len - seq_len + minimum_code_len) # at least `minimum_code_len` code token should be in the window
|
|
end_idx = max(0, total_len - seq_len)
|
|
start_idx = min(start_idx, end_idx)
|
|
|
|
for i in range(start_idx, end_idx + 1, sliding_stride):
|
|
current_prompt = prompt_tokens[i:i + seq_len]
|
|
current_code = code_tokens[max(i - prompt_len, 0):i - prompt_len + seq_len]
|
|
yield current_prompt, current_code
|
|
|
|
if (end_idx - start_idx) % sliding_stride != 0:
|
|
current_prompt = prompt_tokens[end_idx:end_idx + seq_len]
|
|
current_code = code_tokens[max(end_idx - prompt_len, 0):end_idx - prompt_len + seq_len]
|
|
yield current_prompt, current_code
|