mirror of https://github.com/menyifang/DCT-Net
You cannot select more than 25 topics
Topics must start with a letter or number, can include dashes ('-') and can be up to 35 characters long.
121 lines
3.9 KiB
Python
121 lines
3.9 KiB
Python
3 years ago
|
import os
|
||
|
import cv2
|
||
|
import tensorflow as tf
|
||
|
import numpy as np
|
||
|
from source.facelib.facer import FaceAna
|
||
|
import source.utils as utils
|
||
|
from source.mtcnn_pytorch.src.align_trans import warp_and_crop_face, get_reference_facial_points
|
||
|
|
||
|
if tf.__version__ >= '2.0':
|
||
|
tf = tf.compat.v1
|
||
|
tf.disable_eager_execution()
|
||
|
|
||
|
|
||
|
class Cartoonizer():
|
||
|
def __init__(self, dataroot):
|
||
|
|
||
|
self.facer = FaceAna(dataroot)
|
||
|
self.sess_head = self.load_sess(
|
||
|
os.path.join(dataroot, 'cartoon_anime_h.pb'), 'model_head')
|
||
|
self.sess_bg = self.load_sess(
|
||
|
os.path.join(dataroot, 'cartoon_anime_bg.pb'), 'model_bg')
|
||
|
|
||
|
self.box_width = 288
|
||
|
global_mask = cv2.imread(os.path.join(dataroot, 'alpha.jpg'))
|
||
|
global_mask = cv2.resize(
|
||
|
global_mask, (self.box_width, self.box_width),
|
||
|
interpolation=cv2.INTER_AREA)
|
||
|
self.global_mask = cv2.cvtColor(
|
||
|
global_mask, cv2.COLOR_BGR2GRAY).astype(np.float32) / 255.0
|
||
|
|
||
|
def load_sess(self, model_path, name):
|
||
|
config = tf.ConfigProto(allow_soft_placement=True)
|
||
|
config.gpu_options.allow_growth = True
|
||
|
sess = tf.Session(config=config)
|
||
|
print(f'loading model from {model_path}')
|
||
|
with tf.gfile.FastGFile(model_path, 'rb') as f:
|
||
|
graph_def = tf.GraphDef()
|
||
|
graph_def.ParseFromString(f.read())
|
||
|
sess.graph.as_default()
|
||
|
tf.import_graph_def(graph_def, name=name)
|
||
|
sess.run(tf.global_variables_initializer())
|
||
|
print(f'load model {model_path} done.')
|
||
|
return sess
|
||
|
|
||
|
|
||
|
def detect_face(self, img):
|
||
|
src_h, src_w, _ = img.shape
|
||
|
src_x = cv2.cvtColor(img, cv2.COLOR_BGR2RGB)
|
||
|
boxes, landmarks, _ = self.facer.run(src_x)
|
||
|
if boxes.shape[0] == 0:
|
||
|
return None
|
||
|
else:
|
||
|
return landmarks
|
||
|
|
||
|
|
||
|
def cartoonize(self, img):
|
||
|
# img: RGB input
|
||
|
ori_h, ori_w, _ = img.shape
|
||
|
img = utils.resize_size(img, size=720)
|
||
|
|
||
|
img_brg = img[:, :, ::-1]
|
||
|
|
||
|
# background process
|
||
|
pad_bg, pad_h, pad_w = utils.padTo16x(img_brg)
|
||
|
|
||
|
bg_res = self.sess_bg.run(
|
||
|
self.sess_bg.graph.get_tensor_by_name(
|
||
|
'model_bg/output_image:0'),
|
||
|
feed_dict={'model_bg/input_image:0': pad_bg})
|
||
|
res = bg_res[:pad_h, :pad_w, :]
|
||
|
|
||
|
landmarks = self.detect_face(img_brg)
|
||
|
if landmarks is None:
|
||
|
print('No face detected!')
|
||
|
return res
|
||
|
|
||
|
print('%d faces detected!'%len(landmarks))
|
||
|
for landmark in landmarks:
|
||
|
# get facial 5 points
|
||
|
f5p = utils.get_f5p(landmark, img_brg)
|
||
|
|
||
|
# face alignment
|
||
|
head_img, trans_inv = warp_and_crop_face(
|
||
|
img,
|
||
|
f5p,
|
||
|
ratio=0.75,
|
||
|
reference_pts=get_reference_facial_points(default_square=True),
|
||
|
crop_size=(self.box_width, self.box_width),
|
||
|
return_trans_inv=True)
|
||
|
|
||
|
# head process
|
||
|
head_res = self.sess_head.run(
|
||
|
self.sess_head.graph.get_tensor_by_name(
|
||
|
'model_head/output_image:0'),
|
||
|
feed_dict={
|
||
|
'model_head/input_image:0': head_img[:, :, ::-1]
|
||
|
})
|
||
|
|
||
|
# merge head and background
|
||
|
head_trans_inv = cv2.warpAffine(
|
||
|
head_res,
|
||
|
trans_inv, (np.size(img, 1), np.size(img, 0)),
|
||
|
borderValue=(0, 0, 0))
|
||
|
|
||
|
mask = self.global_mask
|
||
|
mask_trans_inv = cv2.warpAffine(
|
||
|
mask,
|
||
|
trans_inv, (np.size(img, 1), np.size(img, 0)),
|
||
|
borderValue=(0, 0, 0))
|
||
|
mask_trans_inv = np.expand_dims(mask_trans_inv, 2)
|
||
|
|
||
|
res = mask_trans_inv * head_trans_inv + (1 - mask_trans_inv) * res
|
||
|
|
||
|
res = cv2.resize(res, (ori_w, ori_h), interpolation=cv2.INTER_AREA)
|
||
|
|
||
|
return res
|
||
|
|
||
|
|
||
|
|
||
|
|