You cannot select more than 25 topics
Topics must start with a letter or number, can include dashes ('-') and can be up to 35 characters long.
50 lines
1.7 KiB
Python
50 lines
1.7 KiB
Python
"""
|
|
python inference_speed_test.py \
|
|
--model-variant mobilenetv3 \
|
|
--resolution 1920 1080 \
|
|
--downsample-ratio 0.25 \
|
|
--precision float32
|
|
"""
|
|
|
|
import argparse
|
|
import torch
|
|
from tqdm import tqdm
|
|
|
|
from model.model import MattingNetwork
|
|
|
|
torch.backends.cudnn.benchmark = True
|
|
|
|
class InferenceSpeedTest:
|
|
def __init__(self):
|
|
self.parse_args()
|
|
self.init_model()
|
|
self.loop()
|
|
|
|
def parse_args(self):
|
|
parser = argparse.ArgumentParser()
|
|
parser.add_argument('--model-variant', type=str, required=True)
|
|
parser.add_argument('--resolution', type=int, required=True, nargs=2)
|
|
parser.add_argument('--downsample-ratio', type=float, required=True)
|
|
parser.add_argument('--precision', type=str, default='float32')
|
|
parser.add_argument('--disable-refiner', action='store_true')
|
|
self.args = parser.parse_args()
|
|
|
|
def init_model(self):
|
|
self.device = 'cuda'
|
|
self.precision = {'float32': torch.float32, 'float16': torch.float16}[self.args.precision]
|
|
self.model = MattingNetwork(self.args.model_variant)
|
|
self.model = self.model.to(device=self.device, dtype=self.precision).eval()
|
|
self.model = torch.jit.script(self.model)
|
|
self.model = torch.jit.freeze(self.model)
|
|
|
|
def loop(self):
|
|
w, h = self.args.resolution
|
|
src = torch.randn((1, 3, h, w), device=self.device, dtype=self.precision)
|
|
with torch.no_grad():
|
|
rec = None, None, None, None
|
|
for _ in tqdm(range(1000)):
|
|
fgr, pha, *rec = self.model(src, *rec, self.args.downsample_ratio)
|
|
torch.cuda.synchronize()
|
|
|
|
if __name__ == '__main__':
|
|
InferenceSpeedTest() |