From 147e4f2bb7156d77ecf301a13b63d997ea2ecf10 Mon Sep 17 00:00:00 2001 From: menyifang Date: Wed, 29 Mar 2023 19:10:36 +0800 Subject: [PATCH] update --- export.py | 62 +++++++++++++++++++++++++++++++++++++++++++++++++++++++ 1 file changed, 62 insertions(+) create mode 100644 export.py diff --git a/export.py b/export.py new file mode 100644 index 0000000..9ff2ccf --- /dev/null +++ b/export.py @@ -0,0 +1,62 @@ +import os +import unittest + +import cv2 + +from modelscope.exporters.cv import CartoonTranslationExporter +from modelscope.msdatasets import MsDataset +from modelscope.outputs import OutputKeys +from modelscope.pipelines import pipeline +from modelscope.pipelines.base import Pipeline +from modelscope.trainers.cv import CartoonTranslationTrainer +from modelscope.utils.constant import Tasks +from modelscope.utils.test_utils import test_level + + +class TestImagePortraitStylizationTrainer(unittest.TestCase): + + def setUp(self) -> None: + self.task = Tasks.image_portrait_stylization + self.test_image = 'https://modelscope.oss-cn-beijing.aliyuncs.com/test/images/image_cartoon.png' + + @unittest.skipUnless(test_level() >= 0, 'skip test in current test level') + def test_run_with_model_name(self): + model_id = 'damo/cv_unet_person-image-cartoon_compound-models' + + data_dir = MsDataset.load( + 'dctnet_train_clipart_mini_ms', + namespace='menyifang', + split='train').config_kwargs['split_config']['train'] + + data_photo = os.path.join(data_dir, 'face_photo') + data_cartoon = os.path.join(data_dir, 'face_cartoon') + work_dir = 'exp_localtoon' + max_steps = 10 + trainer = CartoonTranslationTrainer( + model=model_id, + work_dir=work_dir, + photo=data_photo, + cartoon=data_cartoon, + max_steps=max_steps) + trainer.train() + + # export pb file + ckpt_path = os.path.join(work_dir, 'saved_models', 'model-' + str(0)) + pb_path = os.path.join(trainer.model_dir, 'cartoon_h.pb') + exporter = CartoonTranslationExporter() + exporter.export_frozen_graph_def( + ckpt_path=ckpt_path, frozen_graph_path=pb_path) + + # infer with pb file + self.pipeline_person_image_cartoon(trainer.model_dir) + + def pipeline_person_image_cartoon(self, model_dir): + pipeline_cartoon = pipeline(task=self.task, model=model_dir) + result = pipeline_cartoon(input=self.test_image) + if result is not None: + cv2.imwrite('result.png', result[OutputKeys.OUTPUT_IMG]) + print(f'Output written to {os.path.abspath("result.png")}') + + +if __name__ == '__main__': + unittest.main() \ No newline at end of file