mirror of https://github.com/THUDM/CodeGeeX.git
You cannot select more than 25 topics
Topics must start with a letter or number, can include dashes ('-') and can be up to 35 characters long.
287 lines
12 KiB
Python
287 lines
12 KiB
Python
import os
|
|
import sys
|
|
import fire
|
|
import json
|
|
import glob
|
|
import numpy as np
|
|
import pandas as pd
|
|
|
|
from collections import defaultdict
|
|
from codegeex.benchmark.metric import estimate_pass_at_k
|
|
|
|
error_types = {
|
|
"python" : [
|
|
"accepted",
|
|
"assertion error",
|
|
"undefined error",
|
|
"runtime error",
|
|
"syntax error",
|
|
"timeout error",
|
|
"type error",
|
|
],
|
|
"java" : [
|
|
"accepted",
|
|
"compilation error",
|
|
"assertion error",
|
|
"timeout error",
|
|
"index error",
|
|
"class cast error",
|
|
"stack overflow",
|
|
"null pointer error",
|
|
"unsupported operation",
|
|
"number format error",
|
|
"no such element",
|
|
"illegal argument",
|
|
"out of memory",
|
|
"arithmetic error",
|
|
"others",
|
|
],
|
|
"cpp" : [
|
|
"accepted",
|
|
"compilation error",
|
|
"assertion error",
|
|
"range error",
|
|
"invalid argument",
|
|
"pointer error",
|
|
"out of memory",
|
|
"package error",
|
|
"others",
|
|
],
|
|
"javascript": [
|
|
"accepted",
|
|
"assertion error",
|
|
"undefined error",
|
|
"runtime error",
|
|
"syntax error",
|
|
"timeout error",
|
|
"range error",
|
|
"type error",
|
|
],
|
|
"go" : [
|
|
"accepted",
|
|
"assertion error",
|
|
"undefined error",
|
|
"runtime error",
|
|
"syntax error",
|
|
"timeout error",
|
|
"type error",
|
|
"notused error",
|
|
],
|
|
}
|
|
|
|
|
|
def inspect_result(
|
|
input_dir: str = None,
|
|
input_file: str = None,
|
|
output_dir: str = None,
|
|
pass_at_k_outpath: str = None,
|
|
):
|
|
if input_dir is not None:
|
|
input_files = glob.glob(input_dir + "/*_results.jsonl")
|
|
else:
|
|
input_files = [input_file]
|
|
|
|
if output_dir is not None:
|
|
if not os.path.exists(output_dir):
|
|
os.makedirs(output_dir)
|
|
else:
|
|
output_dir = "/"
|
|
|
|
pass_at_k_outs = []
|
|
for input_file in input_files:
|
|
result_stats = defaultdict(dict)
|
|
df = None
|
|
incompleted = False
|
|
with open(input_file, "r") as f:
|
|
for i, line in enumerate(f):
|
|
obj = json.loads(line)
|
|
task_id = obj["task_id"]
|
|
language_type = task_id.split("/")[0].lower()
|
|
if language_type not in error_types:
|
|
if language_type == "humaneval":
|
|
language_type = "python"
|
|
elif language_type == "C++":
|
|
language_type = "cpp"
|
|
else:
|
|
incompleted = True
|
|
break
|
|
if task_id not in result_stats.keys():
|
|
default_stats = {}
|
|
default_stats["task_id"] = task_id
|
|
default_stats["n_sample"] = 0
|
|
for k in error_types[language_type]:
|
|
default_stats[k] = 0
|
|
|
|
result_stats[task_id] = default_stats.copy()
|
|
if df is None:
|
|
df = pd.DataFrame(columns=error_types[language_type])
|
|
result_stats[task_id]["n_sample"] += 1
|
|
|
|
if "failed" in obj["result"]:
|
|
error = obj["result"]
|
|
if language_type == "python":
|
|
if "assertionerror" in error.lower():
|
|
result_stats[task_id]["assertion error"] += 1
|
|
elif "syntax" in error.lower() or "indent" in error.lower() or "literal" in error.lower():
|
|
result_stats[task_id]["syntax error"] += 1
|
|
elif "not defined" in error.lower():
|
|
result_stats[task_id]["undefined error"] += 1
|
|
elif "timeout" in error.lower():
|
|
result_stats[task_id]["timeout error"] += 1
|
|
elif "type" in error.lower():
|
|
result_stats[task_id]["type error"] += 1
|
|
else:
|
|
result_stats[task_id]["runtime error"] += 1
|
|
|
|
elif language_type == "java":
|
|
if "wrong answer" in error:
|
|
result_stats[task_id]["assertion error"] += 1
|
|
elif "compilation error" in error:
|
|
result_stats[task_id]["compilation error"] += 1
|
|
elif "time out" in error:
|
|
result_stats[task_id]["timeout error"] += 1
|
|
elif "IndexOutOfBounds" in error:
|
|
result_stats[task_id]["index error"] += 1
|
|
elif "UnsupportedOperation" in error:
|
|
result_stats[task_id]["unsupported operation"] += 1
|
|
elif "ClassCast" in error:
|
|
result_stats[task_id]["class cast error"] += 1
|
|
elif "NullPointer" in error:
|
|
result_stats[task_id]["null pointer error"] += 1
|
|
elif "NumberFormat" in error:
|
|
result_stats[task_id]["number format error"] += 1
|
|
elif "NoSuchElement" in error:
|
|
result_stats[task_id]["no such element"] += 1
|
|
elif "StackOverflow" in error:
|
|
result_stats[task_id]["stack overflow"] += 1
|
|
elif "Arithmetic" in error:
|
|
result_stats[task_id]["arithmetic error"] += 1
|
|
elif "OutOfMemory" in error:
|
|
result_stats[task_id]["out of memory"] += 1
|
|
elif "IllegalArgument" in error:
|
|
result_stats[task_id]["illegal argument"] += 1
|
|
else:
|
|
result_stats[task_id]["others"] += 1
|
|
|
|
elif language_type == "cpp":
|
|
if "compilation error" in error.lower():
|
|
result_stats[task_id]["compilation error"] += 1
|
|
elif "int main(): assertion" in error.lower():
|
|
result_stats[task_id]["assertion error"] += 1
|
|
elif "out_of_range" in error.lower():
|
|
result_stats[task_id]['range error'] += 1
|
|
elif "corrupted top size" in error.lower():
|
|
result_stats[task_id]['range error'] += 1
|
|
elif "length_error" in error.lower():
|
|
result_stats[task_id]['range error'] += 1
|
|
elif "invalid_argument" in error.lower():
|
|
result_stats[task_id]['invalid argument'] += 1
|
|
elif "invalid pointer" in error.lower():
|
|
result_stats[task_id]['pointer error'] += 1
|
|
elif "double free" in error.lower():
|
|
result_stats[task_id]['pointer error'] += 1
|
|
elif "free()" in error.lower():
|
|
result_stats[task_id]['pointer error'] += 1
|
|
elif "logic_error" in error.lower():
|
|
result_stats[task_id]['pointer error'] += 1
|
|
elif "sysmalloc: assertion" in error.lower():
|
|
result_stats[task_id]['pointer error'] += 1
|
|
elif "stack smashing" in error.lower():
|
|
result_stats[task_id]['out of memory'] += 1
|
|
elif "bad_alloc" in error.lower():
|
|
result_stats[task_id]['out of memory'] += 1
|
|
elif "terminate called after throwing an instance of" in error.lower():
|
|
result_stats[task_id]['package error'] += 1
|
|
else:
|
|
result_stats[task_id]["others"] += 1
|
|
|
|
elif language_type == "javascript":
|
|
if "Assertion failed" in error:
|
|
result_stats[task_id]["assertion error"] += 1
|
|
elif "SyntaxError" in error:
|
|
result_stats[task_id]["syntax error"] += 1
|
|
elif "ReferenceError" in error:
|
|
result_stats[task_id]["undefined error"] += 1
|
|
elif "timed out" in error:
|
|
result_stats[task_id]["timeout error"] += 1
|
|
elif "TypeError" in error:
|
|
result_stats[task_id]["type error"] += 1
|
|
elif "RangeError" in error:
|
|
result_stats[task_id]["range error"] += 1
|
|
else:
|
|
result_stats[task_id]["runtime error"] += 1
|
|
|
|
elif language_type == "go":
|
|
if "Error: \tNot equal:" in error:
|
|
result_stats[task_id]["assertion error"] += 1
|
|
elif "undefined" in error:
|
|
result_stats[task_id]["undefined error"] += 1
|
|
elif "expected" in error and "found" in error:
|
|
result_stats[task_id]["syntax error"] += 1
|
|
elif "illegal" in error:
|
|
result_stats[task_id]["syntax error"] += 1
|
|
elif "unexpected" in error:
|
|
result_stats[task_id]["syntax error"] += 1
|
|
elif "FAIL" in error:
|
|
result_stats[task_id]["runtime error"] += 1
|
|
elif "timed out" in error:
|
|
result_stats[task_id]["timeout error"] += 1
|
|
elif "not used" in error:
|
|
result_stats[task_id]['notused error'] += 1
|
|
elif "type" in error:
|
|
result_stats[task_id]['type error'] += 1
|
|
else:
|
|
incompleted = True
|
|
break
|
|
else:
|
|
if obj["passed"]:
|
|
result_stats[task_id]["accepted"] += 1
|
|
|
|
if incompleted:
|
|
print(f"Language not supported, aborted. {input_file}")
|
|
else:
|
|
try:
|
|
total, correct = [], []
|
|
for k, res in result_stats.items():
|
|
total.append(res["n_sample"])
|
|
correct.append(res["accepted"])
|
|
df_res = pd.DataFrame(res, index=[int(k.split("/")[-1])])
|
|
df = pd.concat([df, df_res], axis=0)
|
|
|
|
total = np.array(total)
|
|
correct = np.array(correct)
|
|
|
|
ks = [1, 10, 100, 1000]
|
|
pass_at_k = {f"pass@{k}": estimate_pass_at_k(total, correct, k).mean()
|
|
for k in ks if (total >= k).all()}
|
|
|
|
print(pass_at_k)
|
|
pass_at_k["file"] = input_file
|
|
pass_at_k["n"] = res["n_sample"]
|
|
pass_at_k_outs.append(pass_at_k)
|
|
|
|
output_prefix = input_file.split("/")[-1].split(".jsonl")[0]
|
|
output_file = os.path.join(output_dir, output_prefix + "_stats.xlsx")
|
|
df = df.sort_index(ascending=True)
|
|
df.to_excel(output_file)
|
|
|
|
print(f"Stats saved in {output_file}")
|
|
except Exception as e:
|
|
print(e)
|
|
print(f"Data incompleted, aborted. {input_file}")
|
|
|
|
if pass_at_k_outpath is not None:
|
|
jsonl_path = os.path.join(output_dir, pass_at_k_outpath)
|
|
with open(jsonl_path, "w") as f_out:
|
|
for p in pass_at_k_outs:
|
|
f_out.write(json.dumps(p) + "\n")
|
|
print(f"Pass at k saved in {jsonl_path}")
|
|
|
|
|
|
def main():
|
|
fire.Fire(inspect_result)
|
|
|
|
|
|
if __name__ == "__main__":
|
|
sys.exit(main())
|