mirror of https://github.com/menyifang/DCT-Net
update
parent
0f3af40868
commit
147e4f2bb7
@ -0,0 +1,62 @@
|
||||
import os
|
||||
import unittest
|
||||
|
||||
import cv2
|
||||
|
||||
from modelscope.exporters.cv import CartoonTranslationExporter
|
||||
from modelscope.msdatasets import MsDataset
|
||||
from modelscope.outputs import OutputKeys
|
||||
from modelscope.pipelines import pipeline
|
||||
from modelscope.pipelines.base import Pipeline
|
||||
from modelscope.trainers.cv import CartoonTranslationTrainer
|
||||
from modelscope.utils.constant import Tasks
|
||||
from modelscope.utils.test_utils import test_level
|
||||
|
||||
|
||||
class TestImagePortraitStylizationTrainer(unittest.TestCase):
|
||||
|
||||
def setUp(self) -> None:
|
||||
self.task = Tasks.image_portrait_stylization
|
||||
self.test_image = 'https://modelscope.oss-cn-beijing.aliyuncs.com/test/images/image_cartoon.png'
|
||||
|
||||
@unittest.skipUnless(test_level() >= 0, 'skip test in current test level')
|
||||
def test_run_with_model_name(self):
|
||||
model_id = 'damo/cv_unet_person-image-cartoon_compound-models'
|
||||
|
||||
data_dir = MsDataset.load(
|
||||
'dctnet_train_clipart_mini_ms',
|
||||
namespace='menyifang',
|
||||
split='train').config_kwargs['split_config']['train']
|
||||
|
||||
data_photo = os.path.join(data_dir, 'face_photo')
|
||||
data_cartoon = os.path.join(data_dir, 'face_cartoon')
|
||||
work_dir = 'exp_localtoon'
|
||||
max_steps = 10
|
||||
trainer = CartoonTranslationTrainer(
|
||||
model=model_id,
|
||||
work_dir=work_dir,
|
||||
photo=data_photo,
|
||||
cartoon=data_cartoon,
|
||||
max_steps=max_steps)
|
||||
trainer.train()
|
||||
|
||||
# export pb file
|
||||
ckpt_path = os.path.join(work_dir, 'saved_models', 'model-' + str(0))
|
||||
pb_path = os.path.join(trainer.model_dir, 'cartoon_h.pb')
|
||||
exporter = CartoonTranslationExporter()
|
||||
exporter.export_frozen_graph_def(
|
||||
ckpt_path=ckpt_path, frozen_graph_path=pb_path)
|
||||
|
||||
# infer with pb file
|
||||
self.pipeline_person_image_cartoon(trainer.model_dir)
|
||||
|
||||
def pipeline_person_image_cartoon(self, model_dir):
|
||||
pipeline_cartoon = pipeline(task=self.task, model=model_dir)
|
||||
result = pipeline_cartoon(input=self.test_image)
|
||||
if result is not None:
|
||||
cv2.imwrite('result.png', result[OutputKeys.OUTPUT_IMG])
|
||||
print(f'Output written to {os.path.abspath("result.png")}')
|
||||
|
||||
|
||||
if __name__ == '__main__':
|
||||
unittest.main()
|
Loading…
Reference in New Issue