You cannot select more than 25 topics Topics must start with a letter or number, can include dashes ('-') and can be up to 35 characters long.

155 lines
5.0 KiB
Python

import cv2
import numpy as np
import tensorflow as tf
from .config import config as cfg
if tf.__version__ >= '2.0':
tf = tf.compat.v1
class FaceLandmark:
def __init__(self, dir):
self.model_path = dir + '/keypoints.pb'
self.min_face = 60
self.keypoint_num = cfg.KEYPOINTS.p_num * 2
self._graph = tf.Graph()
with self._graph.as_default():
self._graph, self._sess = self.init_model(self.model_path)
self.img_input = tf.get_default_graph().get_tensor_by_name(
'tower_0/images:0')
self.embeddings = tf.get_default_graph().get_tensor_by_name(
'tower_0/prediction:0')
self.training = tf.get_default_graph().get_tensor_by_name(
'training_flag:0')
self.landmark = self.embeddings[:, :self.keypoint_num]
self.headpose = self.embeddings[:, -7:-4] * 90.
self.state = tf.nn.sigmoid(self.embeddings[:, -4:])
def __call__(self, img, bboxes):
landmark_result = []
state_result = []
for i, bbox in enumerate(bboxes):
landmark, state = self._one_shot_run(img, bbox, i)
if landmark is not None:
landmark_result.append(landmark)
state_result.append(state)
return np.array(landmark_result), np.array(state_result)
def simple_run(self, cropped_img):
with self._graph.as_default():
cropped_img = np.expand_dims(cropped_img, axis=0)
landmark, p, states = self._sess.run(
[self.landmark, self.headpose, self.state],
feed_dict={
self.img_input: cropped_img,
self.training: False
})
return landmark, states
def _one_shot_run(self, image, bbox, i):
bbox_width = bbox[2] - bbox[0]
bbox_height = bbox[3] - bbox[1]
if (bbox_width <= self.min_face and bbox_height <= self.min_face):
return None, None
add = int(max(bbox_width, bbox_height))
bimg = cv2.copyMakeBorder(
image,
add,
add,
add,
add,
borderType=cv2.BORDER_CONSTANT,
value=cfg.DATA.pixel_means)
bbox += add
one_edge = (1 + 2 * cfg.KEYPOINTS.base_extend_range[0]) * bbox_width
center = [(bbox[0] + bbox[2]) // 2, (bbox[1] + bbox[3]) // 2]
bbox[0] = center[0] - one_edge // 2
bbox[1] = center[1] - one_edge // 2
bbox[2] = center[0] + one_edge // 2
bbox[3] = center[1] + one_edge // 2
bbox = bbox.astype(np.int)
crop_image = bimg[bbox[1]:bbox[3], bbox[0]:bbox[2], :]
h, w, _ = crop_image.shape
crop_image = cv2.resize(
crop_image,
(cfg.KEYPOINTS.input_shape[1], cfg.KEYPOINTS.input_shape[0]))
crop_image = crop_image.astype(np.float32)
keypoints, state = self.simple_run(crop_image)
res = keypoints[0][:self.keypoint_num].reshape((-1, 2))
res[:, 0] = res[:, 0] * w / cfg.KEYPOINTS.input_shape[1]
res[:, 1] = res[:, 1] * h / cfg.KEYPOINTS.input_shape[0]
landmark = []
for _index in range(res.shape[0]):
x_y = res[_index]
landmark.append([
int(x_y[0] * cfg.KEYPOINTS.input_shape[0] + bbox[0] - add),
int(x_y[1] * cfg.KEYPOINTS.input_shape[1] + bbox[1] - add)
])
landmark = np.array(landmark, np.float32)
return landmark, state
def init_model(self, *args):
if len(args) == 1:
use_pb = True
pb_path = args[0]
else:
use_pb = False
meta_path = args[0]
restore_model_path = args[1]
def ini_ckpt():
graph = tf.Graph()
graph.as_default()
configProto = tf.ConfigProto()
configProto.gpu_options.allow_growth = True
sess = tf.Session(config=configProto)
# load_model(model_path, sess)
saver = tf.train.import_meta_graph(meta_path)
saver.restore(sess, restore_model_path)
print('Model restred!')
return (graph, sess)
def init_pb(model_path):
config = tf.ConfigProto()
config.gpu_options.per_process_gpu_memory_fraction = 0.2
compute_graph = tf.Graph()
compute_graph.as_default()
sess = tf.Session(config=config)
with tf.gfile.GFile(model_path, 'rb') as fid:
graph_def = tf.GraphDef()
graph_def.ParseFromString(fid.read())
tf.import_graph_def(graph_def, name='')
# saver = tf.train.Saver(tf.global_variables())
# saver.save(sess, save_path='./tmp.ckpt')
return (compute_graph, sess)
if use_pb:
model = init_pb(pb_path)
else:
model = ini_ckpt()
graph = model[0]
sess = model[1]
return graph, sess