diff --git a/tests/test_inference.py b/tests/test_inference.py index 95b6ef5..be6ed40 100644 --- a/tests/test_inference.py +++ b/tests/test_inference.py @@ -137,6 +137,7 @@ def main(): if args.quantize: model = quantize(model, weight_bit_width=8, backend="torch") model.cuda() + torch.cuda.synchronize() with open(args.prompt_file, "r") as f: prompt = f.readlines()