mirror of https://github.com/THUDM/CodeGeeX.git
You cannot select more than 25 topics
Topics must start with a letter or number, can include dashes ('-') and can be up to 35 characters long.
57 lines
1.8 KiB
Python
57 lines
1.8 KiB
Python
import os
|
|
import sys
|
|
import fire
|
|
import glob
|
|
|
|
|
|
def gather_output(
|
|
output_dir: str = "./output",
|
|
output_prefix: str = None,
|
|
if_remove_rank_files: int = 0,
|
|
):
|
|
if output_prefix is None:
|
|
output_list = glob.glob(output_dir + "/*")
|
|
else:
|
|
output_list = glob.glob(os.path.join(output_dir, output_prefix + "*"))
|
|
|
|
for output_file in output_list:
|
|
if "rank0" in output_file:
|
|
output_prefix_ = output_file.split("_rank0.jsonl")[0]
|
|
rank_files = glob.glob(output_prefix_ + "_rank*")
|
|
with open(output_prefix_ + ".jsonl", "w") as f_out:
|
|
for rank_file in rank_files:
|
|
with open(rank_file, "r") as f_in:
|
|
for line in f_in:
|
|
f_out.write(line)
|
|
if if_remove_rank_files:
|
|
os.remove(rank_file)
|
|
print(f"Removing {rank_file}...")
|
|
|
|
if output_prefix is None:
|
|
output_list = glob.glob(output_dir + "/*")
|
|
else:
|
|
output_list = glob.glob(os.path.join(output_dir, output_prefix + "*"))
|
|
|
|
for output_file in output_list:
|
|
if "rank" in output_file or "_unfinished" in output_file or "all" in output_file or "_result" in output_file:
|
|
continue
|
|
if "_finished" not in output_file:
|
|
continue
|
|
output_prefix_ = output_file.split("_finished.jsonl")[0]
|
|
files = [output_file, output_prefix_ + "_unfinished.jsonl"]
|
|
with open(output_prefix_ + "_all.jsonl", "w") as f_out:
|
|
for f in files:
|
|
with open(f, "r") as f_in:
|
|
for line in f_in:
|
|
f_out.write(line)
|
|
|
|
print("Gathering finished. Saved in {}".format(output_prefix_ + "_all.jsonl"))
|
|
|
|
|
|
def main():
|
|
fire.Fire(gather_output)
|
|
|
|
|
|
if __name__ == "__main__":
|
|
sys.exit(main())
|