You cannot select more than 25 topics Topics must start with a letter or number, can include dashes ('-') and can be up to 35 characters long.
DCT-Net/extract_align_faces.py

163 lines
4.5 KiB
Python

import cv2
import os
import numpy as np
import argparse
from source.facelib.facer import FaceAna
import source.utils as utils
from source.mtcnn_pytorch.src.align_trans import warp_and_crop_face, get_reference_facial_points
from modelscope.hub.snapshot_download import snapshot_download
class FaceProcesser:
def __init__(self, dataroot, crop_size = 256, max_face = 1):
self.max_face = max_face
self.crop_size = crop_size
self.facer = FaceAna(dataroot)
def filter_face(self, lm, crop_size):
a = max(lm[:, 0])-min(lm[:, 0])
b = max(lm[:, 1])-min(lm[:, 1])
# print("a:%d, b:%d"%(a,b))
if max(a, b)<int(crop_size*0.3): # 眼间距 70
return 0
else:
return 1
def process(self, img):
warped_face = None
h, w, c = img.shape
if c==4:
img_bgr = img[:,:,:3]
else:
img_bgr = img
src_img = cv2.cvtColor(img_bgr, cv2.COLOR_BGR2RGB)
boxes, landmarks, _ = self.facer.run(src_img)
if boxes.shape[0] == 0:
print('No face detected!')
return warped_face
# process all faces
warped_faces = []
i = 0
for landmark in landmarks:
if self.max_face and i>0:
continue
if self.filter_face(landmark, self.crop_size)==0:
print("filtered!")
continue
f5p = utils.get_f5p(landmark, img_bgr)
# face alignment
warped_face, _ = warp_and_crop_face(
img_bgr,
f5p,
ratio=0.75,
reference_pts=get_reference_facial_points(default_square=True),
crop_size=(self.crop_size, self.crop_size),
return_trans_inv=True)
warped_faces.append(warped_face)
i = i+1
return warped_faces
if __name__ == "__main__":
parser = argparse.ArgumentParser(description="process remove bg result")
parser.add_argument("--src_dir", type=str, default='', help="Path to src images.")
parser.add_argument("--save_dir", type=str, default='', help="Path to save images.")
parser.add_argument("--crop_size", type=int, default=256)
parser.add_argument("--max_face", type=int, default=1)
parser.add_argument("--overwrite", type=int, default=1)
args = parser.parse_args()
args.save_dir = os.path.dirname(args.src_dir) + '/face_cartoon/raw_style_faces'
crop_size = args.crop_size
max_face = args.max_face
overwrite = args.overwrite
# model_dir = snapshot_download('damo/cv_unet_person-image-cartoon_compound-models', cache_dir='.')
# print('model assets saved to %s'%model_dir)
model_dir = 'damo/cv_unet_person-image-cartoon_compound-models'
processer = FaceProcesser(dataroot=model_dir,crop_size=crop_size, max_face =max_face)
src_dir = args.src_dir
save_dir = args.save_dir
# print('Step: start to extract aligned faces ... ...')
print('src_dir:%s'% src_dir)
print('save_dir:%s'% save_dir)
if not os.path.exists(save_dir):
os.makedirs(save_dir)
paths = utils.all_file(src_dir)
print('to process %d images'% len(paths))
for path in sorted(paths):
dirname = path[len(src_dir)+1:].split('/')[0]
outpath = save_dir + path[len(src_dir):]
if not overwrite:
if os.path.exists(outpath):
continue
sub_dir = os.path.dirname(outpath)
# print(sub_dir)
if not os.path.exists(sub_dir):
os.makedirs(sub_dir, exist_ok=True)
imgb = None
imgc = None
img = cv2.imread(path, -1)
if img is None:
continue
if len(img.shape)==2:
img = cv2.cvtColor(img, cv2.COLOR_GRAY2BGR)
# print(img.shape)
h,w,c = img.shape
if h<256 or w<256:
continue
imgs = []
# if need resize, resize here
img_h, img_w, _ = img.shape
warped_faces = processer.process(img)
if warped_faces is None:
continue
# ### only for anime faces, single, not detect face
# warped_face = imga
i=0
for res in warped_faces:
# filter small faces
h, w, c = res.shape
if h < 256 or w < 256:
continue
outpath = os.path.join(os.path.dirname(outpath), os.path.basename(outpath)[:-4] + '_' + str(i) + '.png')
cv2.imwrite(outpath, res)
print('save %s' % outpath)
i = i+1