Skip to content

Add support of face alignment #110

New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Open
wants to merge 7 commits into
base: master
Choose a base branch
from
Open
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension


Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
57 changes: 56 additions & 1 deletion models/mtcnn.py
Original file line number Diff line number Diff line change
Expand Up @@ -4,7 +4,7 @@
import os

from .utils.detect_face import detect_face, extract_face

from .utils.align_trans import get_reference_facial_points, warp_and_crop_face

class PNet(nn.Module):
"""MTCNN PNet.
Expand Down Expand Up @@ -210,6 +210,8 @@ def __init__(
if device is not None:
self.device = device
self.to(device)
scale = float(image_size)/112
self.facial_reference_points = get_reference_facial_points(default_square=True) * scale

def forward(self, img, save_path=None, return_prob=False):
"""Run MTCNN face detection on a PIL image or numpy array. This method performs both
Expand Down Expand Up @@ -280,6 +282,7 @@ def forward(self, img, save_path=None, return_prob=False):
face_path = save_name + '_' + str(i + 1) + ext

face = extract_face(im, box, self.image_size, self.margin, face_path)

if self.post_process:
face = fixed_image_standardization(face)
faces_im.append(face)
Expand All @@ -302,6 +305,58 @@ def forward(self, img, save_path=None, return_prob=False):
else:
return faces

def extract_aligned_face(self, img, return_prob=False):
""" function argument and outputs are similar to those in forward function.
But the returned faces are aligned based on detected face landmark points.
"""

with torch.no_grad():
batch_boxes, batch_probs, batch_landmarks = self.detect(img, landmarks=True)

# Determine if a batch or single image was passed
batch_mode = True
if not isinstance(img, (list, tuple)) and not (isinstance(img, np.ndarray) and len(img.shape) == 4):
img = [img]
batch_boxes = [batch_boxes]
batch_probs = [batch_probs]
batch_landmarks = [batch_landmarks]
batch_mode = False

# Process all bounding boxes and probabilities
faces, probs = [], []
for im, box_im, prob_im, landmarks in zip(img, batch_boxes, batch_probs, batch_landmarks):
if box_im is None:
faces.append(None)
probs.append([None] if self.keep_all else None)
continue

if not self.keep_all:
box_im = box_im[[0]]

faces_im = []
for landmark in landmarks:
facial5points = landmark
face = warp_and_crop_face(np.array(im), facial5points, self.facial_reference_points, crop_size=(self.image_size, self.image_size))
faces_im.append(torch.from_numpy(face) )

if self.keep_all:
faces_im = torch.stack(faces_im)
else:
faces_im = faces_im[0]
prob_im = prob_im[0]
faces.append(faces_im)
probs.append(prob_im)

if not batch_mode:
faces = faces[0]
probs = probs[0]

if return_prob:
return faces, probs
else:
return faces


def detect(self, img, landmarks=False):
"""Detect all faces in PIL image and return bounding boxes and optional facial landmarks.

Expand Down
299 changes: 299 additions & 0 deletions models/utils/align_trans.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,299 @@
import numpy as np
import cv2
from .matlab_cp2tform import get_similarity_transform_for_cv2

"""
Copyright: this code is from https://github.com/ZhaoJ9014/face.evoLVe.PyTorch/tree/master/align
"""

# reference facial points, a list of coordinates (x,y)
REFERENCE_FACIAL_POINTS = [ # default reference facial points for crop_size = (112, 112); should adjust REFERENCE_FACIAL_POINTS accordingly for other crop_size
[30.29459953, 51.69630051],
[65.53179932, 51.50139999],
[48.02519989, 71.73660278],
[33.54930115, 92.3655014],
[62.72990036, 92.20410156]
]

DEFAULT_CROP_SIZE = (96, 112)


class FaceWarpException(Exception):
def __str__(self):
return 'In File {}:{}'.format(
__file__, super.__str__(self))


def get_reference_facial_points(output_size = None,
inner_padding_factor = 0.0,
outer_padding=(0, 0),
default_square = False):
"""
Function:
----------
get reference 5 key points according to crop settings:
0. Set default crop_size:
if default_square:
crop_size = (112, 112)
else:
crop_size = (96, 112)
1. Pad the crop_size by inner_padding_factor in each side;
2. Resize crop_size into (output_size - outer_padding*2),
pad into output_size with outer_padding;
3. Output reference_5point;
Parameters:
----------
@output_size: (w, h) or None
size of aligned face image
@inner_padding_factor: (w_factor, h_factor)
padding factor for inner (w, h)
@outer_padding: (w_pad, h_pad)
each row is a pair of coordinates (x, y)
@default_square: True or False
if True:
default crop_size = (112, 112)
else:
default crop_size = (96, 112);
!!! make sure, if output_size is not None:
(output_size - outer_padding)
= some_scale * (default crop_size * (1.0 + inner_padding_factor))
Returns:
----------
@reference_5point: 5x2 np.array
each row is a pair of transformed coordinates (x, y)
"""
#print('\n===> get_reference_facial_points():')

#print('---> Params:')
#print(' output_size: ', output_size)
#print(' inner_padding_factor: ', inner_padding_factor)
#print(' outer_padding:', outer_padding)
#print(' default_square: ', default_square)

tmp_5pts = np.array(REFERENCE_FACIAL_POINTS)
tmp_crop_size = np.array(DEFAULT_CROP_SIZE)

# 0) make the inner region a square
if default_square:
size_diff = max(tmp_crop_size) - tmp_crop_size
tmp_5pts += size_diff / 2
tmp_crop_size += size_diff

#print('---> default:')
#print(' crop_size = ', tmp_crop_size)
#print(' reference_5pts = ', tmp_5pts)

if (output_size and
output_size[0] == tmp_crop_size[0] and
output_size[1] == tmp_crop_size[1]):
#print('output_size == DEFAULT_CROP_SIZE {}: return default reference points'.format(tmp_crop_size))
return tmp_5pts

if (inner_padding_factor == 0 and
outer_padding == (0, 0)):
if output_size is None:
#print('No paddings to do: return default reference points')
return tmp_5pts
else:
raise FaceWarpException(
'No paddings to do, output_size must be None or {}'.format(tmp_crop_size))

# check output size
if not (0 <= inner_padding_factor <= 1.0):
raise FaceWarpException('Not (0 <= inner_padding_factor <= 1.0)')

if ((inner_padding_factor > 0 or outer_padding[0] > 0 or outer_padding[1] > 0)
and output_size is None):
output_size = tmp_crop_size * \
(1 + inner_padding_factor * 2).astype(np.int32)
output_size += np.array(outer_padding)
#print(' deduced from paddings, output_size = ', output_size)

if not (outer_padding[0] < output_size[0]
and outer_padding[1] < output_size[1]):
raise FaceWarpException('Not (outer_padding[0] < output_size[0]'
'and outer_padding[1] < output_size[1])')

# 1) pad the inner region according inner_padding_factor
#print('---> STEP1: pad the inner region according inner_padding_factor')
if inner_padding_factor > 0:
size_diff = tmp_crop_size * inner_padding_factor * 2
tmp_5pts += size_diff / 2
tmp_crop_size += np.round(size_diff).astype(np.int32)

#print(' crop_size = ', tmp_crop_size)
#print(' reference_5pts = ', tmp_5pts)

# 2) resize the padded inner region
#print('---> STEP2: resize the padded inner region')
size_bf_outer_pad = np.array(output_size) - np.array(outer_padding) * 2
#print(' crop_size = ', tmp_crop_size)
#print(' size_bf_outer_pad = ', size_bf_outer_pad)

if size_bf_outer_pad[0] * tmp_crop_size[1] != size_bf_outer_pad[1] * tmp_crop_size[0]:
raise FaceWarpException('Must have (output_size - outer_padding)'
'= some_scale * (crop_size * (1.0 + inner_padding_factor)')

scale_factor = size_bf_outer_pad[0].astype(np.float32) / tmp_crop_size[0]
#print(' resize scale_factor = ', scale_factor)
tmp_5pts = tmp_5pts * scale_factor
# size_diff = tmp_crop_size * (scale_factor - min(scale_factor))
# tmp_5pts = tmp_5pts + size_diff / 2
tmp_crop_size = size_bf_outer_pad
#print(' crop_size = ', tmp_crop_size)
#print(' reference_5pts = ', tmp_5pts)

# 3) add outer_padding to make output_size
reference_5point = tmp_5pts + np.array(outer_padding)
tmp_crop_size = output_size
#print('---> STEP3: add outer_padding to make output_size')
#print(' crop_size = ', tmp_crop_size)
#print(' reference_5pts = ', tmp_5pts)

#print('===> end get_reference_facial_points\n')

return reference_5point


def get_affine_transform_matrix(src_pts, dst_pts):
"""
Function:
----------
get affine transform matrix 'tfm' from src_pts to dst_pts
Parameters:
----------
@src_pts: Kx2 np.array
source points matrix, each row is a pair of coordinates (x, y)
@dst_pts: Kx2 np.array
destination points matrix, each row is a pair of coordinates (x, y)
Returns:
----------
@tfm: 2x3 np.array
transform matrix from src_pts to dst_pts
"""

tfm = np.float32([[1, 0, 0], [0, 1, 0]])
n_pts = src_pts.shape[0]
ones = np.ones((n_pts, 1), src_pts.dtype)
src_pts_ = np.hstack([src_pts, ones])
dst_pts_ = np.hstack([dst_pts, ones])

# #print(('src_pts_:\n' + str(src_pts_))
# #print(('dst_pts_:\n' + str(dst_pts_))

A, res, rank, s = np.linalg.lstsq(src_pts_, dst_pts_)

# #print(('np.linalg.lstsq return A: \n' + str(A))
# #print(('np.linalg.lstsq return res: \n' + str(res))
# #print(('np.linalg.lstsq return rank: \n' + str(rank))
# #print(('np.linalg.lstsq return s: \n' + str(s))

if rank == 3:
tfm = np.float32([
[A[0, 0], A[1, 0], A[2, 0]],
[A[0, 1], A[1, 1], A[2, 1]]
])
elif rank == 2:
tfm = np.float32([
[A[0, 0], A[1, 0], 0],
[A[0, 1], A[1, 1], 0]
])

return tfm


def warp_and_crop_face(src_img,
facial_pts,
reference_pts = None,
crop_size=(96, 112),
align_type = 'smilarity'):
"""
Function:
----------
apply affine transform 'trans' to uv
Parameters:
----------
@src_img: 3x3 np.array
input image
@facial_pts: could be
1)a list of K coordinates (x,y)
or
2) Kx2 or 2xK np.array
each row or col is a pair of coordinates (x, y)
@reference_pts: could be
1) a list of K coordinates (x,y)
or
2) Kx2 or 2xK np.array
each row or col is a pair of coordinates (x, y)
or
3) None
if None, use default reference facial points
@crop_size: (w, h)
output face image size
@align_type: transform type, could be one of
1) 'similarity': use similarity transform
2) 'cv2_affine': use the first 3 points to do affine transform,
by calling cv2.getAffineTransform()
3) 'affine': use all points to do affine transform
Returns:
----------
@face_img: output face image with size (w, h) = @crop_size
"""

if reference_pts is None:
if crop_size[0] == 96 and crop_size[1] == 112:
reference_pts = REFERENCE_FACIAL_POINTS
else:
default_square = False
inner_padding_factor = 0
outer_padding = (0, 0)
output_size = crop_size

reference_pts = get_reference_facial_points(output_size,
inner_padding_factor,
outer_padding,
default_square)

ref_pts = np.float32(reference_pts)
ref_pts_shp = ref_pts.shape
if max(ref_pts_shp) < 3 or min(ref_pts_shp) != 2:
raise FaceWarpException(
'reference_pts.shape must be (K,2) or (2,K) and K>2')

if ref_pts_shp[0] == 2:
ref_pts = ref_pts.T

src_pts = np.float32(facial_pts)
src_pts_shp = src_pts.shape
if max(src_pts_shp) < 3 or min(src_pts_shp) != 2:
raise FaceWarpException(
'facial_pts.shape must be (K,2) or (2,K) and K>2')

if src_pts_shp[0] == 2:
src_pts = src_pts.T

# #print('--->src_pts:\n', src_pts
# #print('--->ref_pts\n', ref_pts

if src_pts.shape != ref_pts.shape:
raise FaceWarpException(
'facial_pts and reference_pts must have the same shape')

if align_type is 'cv2_affine':
tfm = cv2.getAffineTransform(src_pts[0:3], ref_pts[0:3])
# #print(('cv2.getAffineTransform() returns tfm=\n' + str(tfm))
elif align_type is 'affine':
tfm = get_affine_transform_matrix(src_pts, ref_pts)
# #print(('get_affine_transform_matrix() returns tfm=\n' + str(tfm))
else:
tfm = get_similarity_transform_for_cv2(src_pts, ref_pts)
# #print(('get_similarity_transform_for_cv2() returns tfm=\n' + str(tfm))

# #print('--->Transform matrix: '
# #print(('type(tfm):' + str(type(tfm)))
# #print(('tfm.dtype:' + str(tfm.dtype))
# #print( tfm

face_img = cv2.warpAffine(src_img, tfm, (crop_size[0], crop_size[1]))

return face_img
Loading