You cannot select more than 25 topics Topics must start with a letter or number, can include dashes ('-') and can be up to 35 characters long.

57 lines
1.8 KiB
Python

import os
import sys
import fire
import glob
def gather_output(
output_dir: str = "./output",
output_prefix: str = None,
if_remove_rank_files: int = 0,
):
if output_prefix is None:
output_list = glob.glob(output_dir + "/*")
else:
output_list = glob.glob(os.path.join(output_dir, output_prefix + "*"))
for output_file in output_list:
if "rank0" in output_file:
output_prefix_ = output_file.split("_rank0.jsonl")[0]
rank_files = glob.glob(output_prefix_ + "_rank*")
with open(output_prefix_ + ".jsonl", "w") as f_out:
for rank_file in rank_files:
with open(rank_file, "r") as f_in:
for line in f_in:
f_out.write(line)
if if_remove_rank_files:
os.remove(rank_file)
print(f"Removing {rank_file}...")
if output_prefix is None:
output_list = glob.glob(output_dir + "/*")
else:
output_list = glob.glob(os.path.join(output_dir, output_prefix + "*"))
for output_file in output_list:
if "rank" in output_file or "_unfinished" in output_file or "all" in output_file or "_result" in output_file:
continue
if "_finished" not in output_file:
continue
output_prefix_ = output_file.split("_finished.jsonl")[0]
files = [output_file, output_prefix_ + "_unfinished.jsonl"]
with open(output_prefix_ + "_all.jsonl", "w") as f_out:
for f in files:
with open(f, "r") as f_in:
for line in f_in:
f_out.write(line)
print("Gathering finished. Saved in {}".format(output_prefix_ + "_all.jsonl"))
def main():
fire.Fire(gather_output)
if __name__ == "__main__":
sys.exit(main())