diff --git a/codegeex/benchmark/evaluate_humaneval_x.py b/codegeex/benchmark/evaluate_humaneval_x.py new file mode 100644 index 0000000..076e0e8 --- /dev/null +++ b/codegeex/benchmark/evaluate_humaneval_x.py @@ -0,0 +1,239 @@ +import os +import sys +import fire +import json +import gzip +import regex +import numpy as np + +from typing import * +from tqdm.auto import tqdm +from collections import defaultdict +from concurrent.futures import ThreadPoolExecutor, as_completed + +from codegeex.benchmark.utils import read_dataset, IMPORT_HELPER +from codegeex.benchmark.metric import estimate_pass_at_k +from codegeex.benchmark.execution import check_correctness + +LANGUAGE_NAME = { + "cpp" : "CPP", + "go" : "Go", + "java" : "Java", + "js" : "JavaScript", + "python": "Python", +} + + +def process_humaneval_test(sample, problems, example_test=False): + task_id = sample["task_id"] + language = task_id.split("/")[0].lower() + + prompt = sample["prompt"] + if example_test and "example_test" in problems[task_id] and problems[task_id]["example_test"] != "": + test = problems[task_id]["example_test"] + else: + test = problems[task_id]["test"] + code = sample["generation"] + + # Pre-process for different languages + if language == "python": + code_ = [] + for line in code.split("\n"): + if (len(line.strip()) > 0 and line[0] != ' ' and line[0] != '\t'): + break + code_.append(line) + code = "\n".join(code_) + test_setup = "\n".join(IMPORT_HELPER["python"]) + "\n" + test_string = test_setup + prompt + code + "\n" + test + "\n" + elif language == "cpp": + test_set_up = "" + for s in IMPORT_HELPER["cpp"]: + if s not in prompt: + test_set_up += s + "\n" + test_string = test_set_up + "\n" + prompt + code + "\n" + test + elif language == "java": + test_string = prompt + code + "\n" + test + elif language == "js" or language == "javascript": + test_string = prompt + code + "\n" + test + elif language == "go": + import_string = problems[task_id]["import"] + prompt = prompt.replace(import_string, "") + if example_test and "example_test" in problems[task_id]: + test = problems[task_id]["example_test"] + else: + test = problems[task_id]["test"] + test_setup = problems[task_id]["test_setup"] + other_pkgs = [] + for pkg in IMPORT_HELPER["go"]: + if pkg not in test_setup: + p = pkg.split("/")[-1] + if p + "." in code: + other_pkgs.append(f"\"{pkg}\"") + if other_pkgs: + import_other_pkgs = "import (\n" + " ".join([p + "\n" for p in other_pkgs]) + ")" + test_string = test_setup + "\n" + import_other_pkgs + "\n" + prompt + code + "\n" + test + else: + test_string = test_setup + "\n" + prompt + code + "\n" + test + elif language == "rust": + main = "\nfn main(){ \n } \n" + declaration = problems[task_id]["declaration"] + test_string = main + declaration + prompt + code + test + + return test_string + + +def stream_jsonl_all(filename: str) -> Iterable[Dict]: + results = [] + if filename.endswith(".gz"): + fp = gzip.open(open(filename, "rb"), "rt") + else: + fp = open(filename, "r") + for line in fp: + if any(not x.isspace() for x in line): + results.append(json.loads(line)) + fp.close() + + return results + + +def evaluate_functional_correctness( + input_file: str = None, + tmp_dir: str = "./", + n_workers: int = 32, + timeout: float = 500.0, + problem_file: str = "../data/humaneval_python.jsonl.gz", + out_dir: str = None, + k: List[int] = [1, 10, 100], + test_groundtruth: bool = False, + example_test: bool = False, +): + if example_test: + print("Example test...") + + problems = read_dataset(problem_file, + dataset_type="humaneval") + sample_jsonl = stream_jsonl_all(input_file) + + if example_test: + suffix = "_example_test.jsonl" + else: + suffix = "_results.jsonl" + if out_dir is not None: + if not os.path.exists(out_dir): + os.makedirs(out_dir) + out_file = os.path.join(out_dir, input_file.split('/')[-1].replace(".jsonl", suffix)) + else: + out_file = os.path.join(input_file.replace(".jsonl", suffix)) + + if "/codegeex/benchmark/humaneval-x/" in input_file: + test_groundtruth = True + + if "-to-" in input_file: + translation_mode = True + else: + translation_mode = False + + with ThreadPoolExecutor(max_workers=n_workers) as executor: + + futures = [] + completion_id = Counter() + n_samples = 0 + results = defaultdict(list) + + if test_groundtruth: + print("Testing ground truth...") + for sample in tqdm(problems.values()): + task_id = sample["task_id"] + lang = task_id.split("/")[0].lower() + if lang == "javascript": + lang = "js" + tmp_dir_ = os.path.join(tmp_dir, lang, "evaluation") + sample["generation"] = sample["canonical_solution"] + sample["test_code"] = process_humaneval_test(sample, problems, example_test) + if sample["test_code"] is None: + continue + args = (task_id, sample, lang, timeout, tmp_dir_, completion_id[task_id]) + future = executor.submit(check_correctness, *args) + futures.append(future) + completion_id[task_id] += 1 + n_samples += 1 + else: + print("Reading samples...") + for sample in tqdm(sample_jsonl): + task_id = sample["task_id"] + lang = task_id.split("/")[0].lower() + if translation_mode: + task_id = sample["task_id"].split("/")[-1] + lang = regex.findall("-to-.*-", input_file)[0].split("-to-")[-1].rstrip("-") + for l in LANGUAGE_NAME: + if l in lang: + lang = l + break + task_id = f"{LANGUAGE_NAME[lang]}/{task_id}" + if lang == "javascript": + lang = "js" + tmp_dir_ = os.path.join(tmp_dir, lang, "evaluation") + sample["task_id"] = task_id + sample["test_code"] = process_humaneval_test(sample, problems, example_test) + if sample["test_code"] is None: + continue + if "completion_id" in sample: + completion_id_ = sample["completion_id"] + else: + completion_id_ = completion_id[task_id] + args = (task_id, sample, lang, timeout, tmp_dir_, completion_id_) + future = executor.submit(check_correctness, *args) + futures.append(future) + completion_id[task_id] += 1 + n_samples += 1 + + print(completion_id) + if len(completion_id) == len(problems): + evaluate_pass_at_k = True + else: + evaluate_pass_at_k = False + + print("Running test suites...") + for future in tqdm(as_completed(futures), total=len(futures)): + result = future.result() + results[result["task_id"]].append((result["completion_id"], result)) + + # Calculate pass@k. + total, correct = [], [] + for result in results.values(): + passed = [r[1]["passed"] for r in result] + total.append(len(passed)) + correct.append(sum(passed)) + total = np.array(total) + correct = np.array(correct) + if evaluate_pass_at_k: + ks = k + pass_at_k = {f"pass@{k}": estimate_pass_at_k(total, correct, k).mean() + for k in ks if (total >= k).all()} + print(pass_at_k) + else: + print("Total:", np.sum(total)) + print("Correct:", np.sum(correct)) + + print("Writing to: ", out_file) + if out_file.endswith(".gz"): + fp = gzip.GzipFile(fileobj=open(out_file, "wb"), mode="wb") + for res in results.values(): + for r in res: + fp.write((json.dumps(r[1]) + "\n").encode("utf-8")) + else: + fp = open(out_file, 'w') + for res in results.values(): + for r in res: + fp.write(json.dumps(r[1]) + "\n") + fp.close() + + print("Evaluation finished.") + + +def main(): + fire.Fire(evaluate_functional_correctness) + + +if __name__ == "__main__": + sys.exit(main()) diff --git a/codegeex/benchmark/execution.py b/codegeex/benchmark/execution.py index 9de9ec8..cbdf14f 100644 --- a/codegeex/benchmark/execution.py +++ b/codegeex/benchmark/execution.py @@ -8,8 +8,38 @@ import signal import random import subprocess import tempfile +import gzip +import json from typing import * +def dicts_to_jsonl(data_list: list, filename: str, compress: bool = True) -> None: + """ + Method saves list of dicts into jsonl file. + :param data: (list) list of dicts to be stored, + :param filename: (str) path to the output file. If suffix .jsonl is not given then methods appends + .jsonl suffix into the file. + :param compress: (bool) should file be compressed into a gzip archive? + """ + sjsonl = '.jsonl' + sgz = '.gz' + # Check filename + if not filename.endswith(sjsonl): + filename = filename + sjsonl + # Save data + + if compress: + filename = filename + sgz + with gzip.open(filename, 'w') as compressed: + for ddict in data_list: + jout = json.dumps(ddict) + '\n' + jout = jout.encode('utf-8') + compressed.write(jout) + else: + with open(filename, 'w') as out: + for ddict in data_list: + jout = json.dumps(ddict) + '\n' + out.write(jout) + def check_correctness( task_id: str, @@ -52,8 +82,8 @@ def check_correctness( # does not perform destructive actions on their host or network. # Once you have read this disclaimer and taken appropriate precautions, # uncomment the following line and proceed at your own risk: - # exec(sample["test_code"], exec_globals) - result.append("passed") + exec(sample["test_code"], exec_globals) + result.append("passed") except TimeoutException: result.append("timed out") except AssertionError as e: @@ -92,7 +122,7 @@ def check_correctness( # does not perform destructive actions on their host or network. # Once you have read this disclaimer and taken appropriate precautions, # uncomment the following line and proceed at your own risk: - # exec_result = subprocess.run(["go", "test", f"-timeout={timeout}s", "main_test.go"], timeout=timeout, capture_output=True) + exec_result = subprocess.run(["go", "test", f"-timeout={timeout}s", "main_test.go"], timeout=timeout, capture_output=True) if exec_result.returncode == 0: result.append("passed") @@ -137,7 +167,7 @@ def check_correctness( # does not perform destructive actions on their host or network. # Once you have read this disclaimer and taken appropriate precautions, # uncomment the following line and proceed at your own risk: - # exec_result = subprocess.run(["node", "test.js"], timeout=timeout, capture_output=True) + exec_result = subprocess.run(["node", "test.js"], timeout=timeout, capture_output=True) if exec_result.stderr.decode(): err = exec_result.stderr.decode() @@ -190,7 +220,7 @@ def check_correctness( # does not perform destructive actions on their host or network. # Once you have read this disclaimer and taken appropriate precautions, # uncomment the following line and proceed at your own risk: - # exec_result = subprocess.run(["./a.out"], timeout=timeout, capture_output=True) + exec_result = subprocess.run(["./a.out"], timeout=timeout, capture_output=True) if exec_result.returncode == 0: result.append("passed") @@ -210,6 +240,71 @@ def check_correctness( result.append("timed out") shutil.rmtree(tmp_dir) + elif "rust" in language_type.lower(): + import os + + WD: str = os.path.dirname(os.path.abspath(__file__)) + RUST_DIR: str = os.path.join(WD, "rust") + RUST_SRC: str = os.path.join(RUST_DIR, "src") + RUST_BIN: str = os.path.join(RUST_SRC, "bin") + RUST_TMP_DIR: str = os.path.join(RUST_DIR, "tmp") + RUST_LOGS: str = os.path.join(RUST_TMP_DIR, "logs") + RUST_EXT: str = ".rs" + + # Create mandatory tmp directories + os.makedirs(RUST_TMP_DIR, exist_ok=True) + os.makedirs(RUST_LOGS, exist_ok=True) + os.makedirs(RUST_SRC, exist_ok=True) + os.makedirs(RUST_BIN, exist_ok=True) + + with tempfile.NamedTemporaryFile(dir = RUST_BIN, delete=False) as f: + #temporal file name + file_prefix = sample["task_id"].lower().replace("/", "_") + file_name:str = file_prefix +RUST_EXT + + os.rename(f.name, os.path.join(RUST_BIN, file_name)) + + # Sample to pure Rust function + rust_code: str = sample["test_code"] + + # dump the rust source code in the target temporal file + f.write(rust_code.encode('utf-8')) + + # Proceed towards Rust binaries compilation. Therefore move to Rust module root dir. + os.chdir(RUST_DIR) + + # Two possible outcomes + # Pass OR Fail compilation + log_filename: str = file_prefix + ".jsonl" + log_path: str = os.path.join(RUST_LOGS, log_filename) + cargo_check: str = "cargo check --bin " + file_prefix + " --message-format json >> " + log_path + # Compilation build status + returned_val_compilation: int + + # Overwrite file content + if os.path.exists(log_path): + if(file_size := os.path.getsize(log_path)) >= 0: + os.remove(log_path) + returned_val_compilation = os.system(cargo_check) + + else: + returned_val_compilation = os.system(cargo_check) + + # 0 means success + if returned_val_compilation == 0: + + #Execution pipeline + cargo_test: str = "cargo test --bin " +file_prefix+ " --message-format json >> " + log_path + returned_val_execution = os.system(cargo_test) + + if returned_val_execution == 0: + result.append("passed") + else: + result.append(f"failed: execution error") + + else: + result.append(f"failed: compilation error") + elif "java" in language_type.lower(): assert tmp_dir is not None, "Java should be evaluated in a temporary dir." @@ -264,7 +359,7 @@ def check_correctness( result.append(res) shutil.rmtree(tmp_dir) - + manager = multiprocessing.Manager() result = manager.list() diff --git a/codegeex/benchmark/humaneval-x/evaluate_humaneval_x.py b/codegeex/benchmark/humaneval-x/evaluate_humaneval_x.py index 54be678..076e0e8 100644 --- a/codegeex/benchmark/humaneval-x/evaluate_humaneval_x.py +++ b/codegeex/benchmark/humaneval-x/evaluate_humaneval_x.py @@ -74,6 +74,10 @@ def process_humaneval_test(sample, problems, example_test=False): test_string = test_setup + "\n" + import_other_pkgs + "\n" + prompt + code + "\n" + test else: test_string = test_setup + "\n" + prompt + code + "\n" + test + elif language == "rust": + main = "\nfn main(){ \n } \n" + declaration = problems[task_id]["declaration"] + test_string = main + declaration + prompt + code + test return test_string @@ -96,7 +100,7 @@ def evaluate_functional_correctness( input_file: str = None, tmp_dir: str = "./", n_workers: int = 32, - timeout: float = 5.0, + timeout: float = 500.0, problem_file: str = "../data/humaneval_python.jsonl.gz", out_dir: str = None, k: List[int] = [1, 10, 100], diff --git a/codegeex/benchmark/humaneval-x/rust/data/humaneval_rust.jsonl.gz b/codegeex/benchmark/humaneval-x/rust/data/humaneval_rust.jsonl.gz new file mode 100644 index 0000000..35a8186 Binary files /dev/null and b/codegeex/benchmark/humaneval-x/rust/data/humaneval_rust.jsonl.gz differ diff --git a/codegeex/benchmark/rust/Cargo.lock b/codegeex/benchmark/rust/Cargo.lock new file mode 100644 index 0000000..e2e0655 --- /dev/null +++ b/codegeex/benchmark/rust/Cargo.lock @@ -0,0 +1,121 @@ +# This file is automatically @generated by Cargo. +# It is not intended for manual editing. +version = 3 + +[[package]] +name = "aho-corasick" +version = "0.7.20" +source = "registry+https://github.com/rust-lang/crates.io-index" +checksum = "cc936419f96fa211c1b9166887b38e5e40b19958e5b895be7c1f93adec7071ac" +dependencies = [ + "memchr", +] + +[[package]] +name = "fuchsia-cprng" +version = "0.1.1" +source = "registry+https://github.com/rust-lang/crates.io-index" +checksum = "a06f77d526c1a601b7c4cdd98f54b5eaabffc14d5f2f0296febdc7f357c6d3ba" + +[[package]] +name = "libc" +version = "0.2.139" +source = "registry+https://github.com/rust-lang/crates.io-index" +checksum = "201de327520df007757c1f0adce6e827fe8562fbc28bfd9c15571c66ca1f5f79" + +[[package]] +name = "md5" +version = "0.7.0" +source = "registry+https://github.com/rust-lang/crates.io-index" +checksum = "490cc448043f947bae3cbee9c203358d62dbee0db12107a74be5c30ccfd09771" + +[[package]] +name = "memchr" +version = "2.5.0" +source = "registry+https://github.com/rust-lang/crates.io-index" +checksum = "2dffe52ecf27772e601905b7522cb4ef790d2cc203488bbd0e2fe85fcb74566d" + +[[package]] +name = "rand" +version = "0.4.6" +source = "registry+https://github.com/rust-lang/crates.io-index" +checksum = "552840b97013b1a26992c11eac34bdd778e464601a4c2054b5f0bff7c6761293" +dependencies = [ + "fuchsia-cprng", + "libc", + "rand_core 0.3.1", + "rdrand", + "winapi", +] + +[[package]] +name = "rand_core" +version = "0.3.1" +source = "registry+https://github.com/rust-lang/crates.io-index" +checksum = "7a6fdeb83b075e8266dcc8762c22776f6877a63111121f5f8c7411e5be7eed4b" +dependencies = [ + "rand_core 0.4.2", +] + +[[package]] +name = "rand_core" +version = "0.4.2" +source = "registry+https://github.com/rust-lang/crates.io-index" +checksum = "9c33a3c44ca05fa6f1807d8e6743f3824e8509beca625669633be0acbdf509dc" + +[[package]] +name = "rdrand" +version = "0.4.0" +source = "registry+https://github.com/rust-lang/crates.io-index" +checksum = "678054eb77286b51581ba43620cc911abf02758c91f93f479767aed0f90458b2" +dependencies = [ + "rand_core 0.3.1", +] + +[[package]] +name = "regex" +version = "1.7.1" +source = "registry+https://github.com/rust-lang/crates.io-index" +checksum = "48aaa5748ba571fb95cd2c85c09f629215d3a6ece942baa100950af03a34f733" +dependencies = [ + "aho-corasick", + "memchr", + "regex-syntax", +] + +[[package]] +name = "regex-syntax" +version = "0.6.28" +source = "registry+https://github.com/rust-lang/crates.io-index" +checksum = "456c603be3e8d448b072f410900c09faf164fbce2d480456f50eea6e25f9c848" + +[[package]] +name = "rust" +version = "0.1.0" +dependencies = [ + "md5", + "rand", + "regex", +] + +[[package]] +name = "winapi" +version = "0.3.9" +source = "registry+https://github.com/rust-lang/crates.io-index" +checksum = "5c839a674fcd7a98952e593242ea400abe93992746761e38641405d28b00f419" +dependencies = [ + "winapi-i686-pc-windows-gnu", + "winapi-x86_64-pc-windows-gnu", +] + +[[package]] +name = "winapi-i686-pc-windows-gnu" +version = "0.4.0" +source = "registry+https://github.com/rust-lang/crates.io-index" +checksum = "ac3b87c63620426dd9b991e5ce0329eff545bccbbb34f3be09ff6fb6ab51b7b6" + +[[package]] +name = "winapi-x86_64-pc-windows-gnu" +version = "0.4.0" +source = "registry+https://github.com/rust-lang/crates.io-index" +checksum = "712e227841d057c1ee1cd2fb22fa7e5a5461ae8e48fa2ca79ec42cfc1931183f" diff --git a/codegeex/benchmark/rust/Cargo.toml b/codegeex/benchmark/rust/Cargo.toml new file mode 100644 index 0000000..1ed0c3a --- /dev/null +++ b/codegeex/benchmark/rust/Cargo.toml @@ -0,0 +1,12 @@ +[package] +name = "rust" +version = "0.1.0" +edition = "2021" + +# See more keys and their definitions at https://doc.rust-lang.org/cargo/reference/manifest.html + +[dependencies] +rand = "0.4" +regex = "1" +md5 = "0.7.0" + diff --git a/generations/humaneval_python_generations.jsonl.gz b/generations/humaneval_python_generations.jsonl.gz new file mode 100644 index 0000000..22f1d4e Binary files /dev/null and b/generations/humaneval_python_generations.jsonl.gz differ diff --git a/generations/humaneval_rust_generations.jsonl.gz b/generations/humaneval_rust_generations.jsonl.gz new file mode 100644 index 0000000..b0d19e4 Binary files /dev/null and b/generations/humaneval_rust_generations.jsonl.gz differ diff --git a/scripts/evaluate_humaneval_x.py b/scripts/evaluate_humaneval_x.py new file mode 100644 index 0000000..eb91b9e --- /dev/null +++ b/scripts/evaluate_humaneval_x.py @@ -0,0 +1,71 @@ +import argparse +import os +from pathlib import Path +from codegeex.benchmark.evaluate_humaneval_x import evaluate_functional_correctness +#GLOBALS +INPUT_FILE: str +LANGUAGE: str +N_WORKERS: int +TIMEOUT: int + + +parser = argparse.ArgumentParser("Debugging evaluate humaneval_x") +# Path to the .jsonl file that contains the generated codes. +parser.add_argument("-s","--samples", type=str) + +# Target programming language, currently support one of ["python", "java", "cpp", "js", "go"] +parser.add_argument("-l","--language", default="python", type=str) + +# Number of parallel workers. +parser.add_argument("-w","--workers", default=64, type=int) + +# Timeout in seconds. +parser.add_argument("-t","--timeout", default=5, type=int) + +args = parser.parse_args() + +INPUT_FILE = args.samples +LANGUAGE = args.language +N_WORKERS = args.workers +TIMEOUT= args.timeout + + + +SCRIPT_PATH: str = Path(os.path.abspath(__file__)) +print(SCRIPT_PATH) +SCRIPT_DIR: str = os.path.dirname(SCRIPT_PATH) +print(SCRIPT_DIR) +MAIN_DIR: str = os.path.dirname(SCRIPT_DIR) +print(MAIN_DIR) + +DATA_DIR=os.path.join(MAIN_DIR,"codegeex/benchmark/humaneval-x/" + LANGUAGE + "/data/humaneval_" + LANGUAGE + ".jsonl.gz") +print(DATA_DIR) + +TMP_DIR=os.path.join(MAIN_DIR, "/codegeex/benchmark/humaneval-x/") + + +#Debugging +INPUT_FILE='/home/rog0d/Escritorio/CodeGeeX/generations/humaneval_rust_generations.jsonl.gz' +LANGUAGE='rust' +DATA_DIR=os.path.join(MAIN_DIR,"codegeex/benchmark/humaneval-x/" + LANGUAGE + "/data/humaneval_" + LANGUAGE + ".jsonl.gz") + +""" +input_file: str = None, + tmp_dir: str = "./", + n_workers: int = 32, + timeout: float = 5.0, + problem_file: str = "../data/humaneval_python.jsonl.gz", + out_dir: str = None, + k: List[int] = [1, 10, 100], + test_groundtruth: bool = False, + example_test: bool = False, + +""" + +evaluate_functional_correctness(input_file=INPUT_FILE, + n_workers=N_WORKERS, + tmp_dir=TMP_DIR, + problem_file=DATA_DIR, + timeout=300.0) + +