mirror of https://github.com/menyifang/DCT-Net
You cannot select more than 25 topics
Topics must start with a letter or number, can include dashes ('-') and can be up to 35 characters long.
155 lines
5.0 KiB
Python
155 lines
5.0 KiB
Python
import cv2
|
|
import numpy as np
|
|
import tensorflow as tf
|
|
|
|
from .config import config as cfg
|
|
|
|
if tf.__version__ >= '2.0':
|
|
tf = tf.compat.v1
|
|
|
|
|
|
class FaceLandmark:
|
|
|
|
def __init__(self, dir):
|
|
self.model_path = dir + '/keypoints.pb'
|
|
self.min_face = 60
|
|
self.keypoint_num = cfg.KEYPOINTS.p_num * 2
|
|
|
|
self._graph = tf.Graph()
|
|
|
|
with self._graph.as_default():
|
|
|
|
self._graph, self._sess = self.init_model(self.model_path)
|
|
self.img_input = tf.get_default_graph().get_tensor_by_name(
|
|
'tower_0/images:0')
|
|
self.embeddings = tf.get_default_graph().get_tensor_by_name(
|
|
'tower_0/prediction:0')
|
|
self.training = tf.get_default_graph().get_tensor_by_name(
|
|
'training_flag:0')
|
|
|
|
self.landmark = self.embeddings[:, :self.keypoint_num]
|
|
self.headpose = self.embeddings[:, -7:-4] * 90.
|
|
self.state = tf.nn.sigmoid(self.embeddings[:, -4:])
|
|
|
|
def __call__(self, img, bboxes):
|
|
landmark_result = []
|
|
state_result = []
|
|
for i, bbox in enumerate(bboxes):
|
|
landmark, state = self._one_shot_run(img, bbox, i)
|
|
if landmark is not None:
|
|
landmark_result.append(landmark)
|
|
state_result.append(state)
|
|
return np.array(landmark_result), np.array(state_result)
|
|
|
|
def simple_run(self, cropped_img):
|
|
with self._graph.as_default():
|
|
|
|
cropped_img = np.expand_dims(cropped_img, axis=0)
|
|
landmark, p, states = self._sess.run(
|
|
[self.landmark, self.headpose, self.state],
|
|
feed_dict={
|
|
self.img_input: cropped_img,
|
|
self.training: False
|
|
})
|
|
|
|
return landmark, states
|
|
|
|
def _one_shot_run(self, image, bbox, i):
|
|
|
|
bbox_width = bbox[2] - bbox[0]
|
|
bbox_height = bbox[3] - bbox[1]
|
|
if (bbox_width <= self.min_face and bbox_height <= self.min_face):
|
|
return None, None
|
|
add = int(max(bbox_width, bbox_height))
|
|
bimg = cv2.copyMakeBorder(
|
|
image,
|
|
add,
|
|
add,
|
|
add,
|
|
add,
|
|
borderType=cv2.BORDER_CONSTANT,
|
|
value=cfg.DATA.pixel_means)
|
|
bbox += add
|
|
|
|
one_edge = (1 + 2 * cfg.KEYPOINTS.base_extend_range[0]) * bbox_width
|
|
center = [(bbox[0] + bbox[2]) // 2, (bbox[1] + bbox[3]) // 2]
|
|
|
|
bbox[0] = center[0] - one_edge // 2
|
|
bbox[1] = center[1] - one_edge // 2
|
|
bbox[2] = center[0] + one_edge // 2
|
|
bbox[3] = center[1] + one_edge // 2
|
|
|
|
bbox = bbox.astype(np.int)
|
|
crop_image = bimg[bbox[1]:bbox[3], bbox[0]:bbox[2], :]
|
|
h, w, _ = crop_image.shape
|
|
crop_image = cv2.resize(
|
|
crop_image,
|
|
(cfg.KEYPOINTS.input_shape[1], cfg.KEYPOINTS.input_shape[0]))
|
|
crop_image = crop_image.astype(np.float32)
|
|
|
|
keypoints, state = self.simple_run(crop_image)
|
|
|
|
res = keypoints[0][:self.keypoint_num].reshape((-1, 2))
|
|
res[:, 0] = res[:, 0] * w / cfg.KEYPOINTS.input_shape[1]
|
|
res[:, 1] = res[:, 1] * h / cfg.KEYPOINTS.input_shape[0]
|
|
|
|
landmark = []
|
|
for _index in range(res.shape[0]):
|
|
x_y = res[_index]
|
|
landmark.append([
|
|
int(x_y[0] * cfg.KEYPOINTS.input_shape[0] + bbox[0] - add),
|
|
int(x_y[1] * cfg.KEYPOINTS.input_shape[1] + bbox[1] - add)
|
|
])
|
|
|
|
landmark = np.array(landmark, np.float32)
|
|
|
|
return landmark, state
|
|
|
|
def init_model(self, *args):
|
|
|
|
if len(args) == 1:
|
|
use_pb = True
|
|
pb_path = args[0]
|
|
else:
|
|
use_pb = False
|
|
meta_path = args[0]
|
|
restore_model_path = args[1]
|
|
|
|
def ini_ckpt():
|
|
graph = tf.Graph()
|
|
graph.as_default()
|
|
configProto = tf.ConfigProto()
|
|
configProto.gpu_options.allow_growth = True
|
|
sess = tf.Session(config=configProto)
|
|
# load_model(model_path, sess)
|
|
saver = tf.train.import_meta_graph(meta_path)
|
|
saver.restore(sess, restore_model_path)
|
|
|
|
print('Model restred!')
|
|
return (graph, sess)
|
|
|
|
def init_pb(model_path):
|
|
config = tf.ConfigProto()
|
|
config.gpu_options.per_process_gpu_memory_fraction = 0.2
|
|
compute_graph = tf.Graph()
|
|
compute_graph.as_default()
|
|
sess = tf.Session(config=config)
|
|
with tf.gfile.GFile(model_path, 'rb') as fid:
|
|
graph_def = tf.GraphDef()
|
|
graph_def.ParseFromString(fid.read())
|
|
tf.import_graph_def(graph_def, name='')
|
|
|
|
# saver = tf.train.Saver(tf.global_variables())
|
|
# saver.save(sess, save_path='./tmp.ckpt')
|
|
return (compute_graph, sess)
|
|
|
|
if use_pb:
|
|
model = init_pb(pb_path)
|
|
else:
|
|
model = ini_ckpt()
|
|
|
|
graph = model[0]
|
|
sess = model[1]
|
|
|
|
return graph, sess
|