# standard imports import operator import typing # external imports from facenet_pytorch import MTCNN, InceptionResnetV1 import PIL.Image import torch # bsie imports from bsie.utils import bsfs, errors, node, ns # inner-module imports from . import base # exports __all__: typing.Sequence[str] = ( 'FaceExtract', ) ## code ## class FaceExtract(base.Reader): """Extract faces and their feature vector from an image file.""" # Face patch size. _target_size: int # Lower bound on the detected face's probability. _min_face_prob: float # Face detector network. _detector: MTCNN # Face feature extractor network. _embedder: InceptionResnetV1 def __init__( self, target_size: int = 1000, min_face_size: int = 40, min_face_prob: float = 0.992845, cuda_device: str = 'cuda:0', ext_face_size: int = 160, thresholds: typing.Tuple[float, float, float] = [0.5, 0.6, 0.6], factor: float = 0.709, ): # initialize self._device = torch.device(cuda_device if torch.cuda.is_available() else 'cpu') # initialize the face detection network self._target_size = target_size self._min_face_prob = min_face_prob self._carghash = hash((min_face_size, ext_face_size, tuple(thresholds), factor)) self._detector = MTCNN( min_face_size=min_face_size, image_size=ext_face_size, thresholds=thresholds, factor=factor, device=self._device, keep_all=True, ).to(self._device) # initialize the face embedding netwrok self._embedder = InceptionResnetV1('vggface2').to(self._device).eval() def __repr__(self) -> str: return f'{bsfs.typename(self)}({self._min_face_prob})' def __eq__(self, other: typing.Any) -> bool: return super().__eq__(other) \ and self._target_size == other._target_size \ and self._min_face_prob == other._min_face_prob \ and self._carghash == other._carghash def __hash__(self) -> int: return hash((super().__hash__(), self._target_size, self._min_face_prob, self._carghash)) @staticmethod def preprocess( img: PIL.Image.Image, target_size: int, rotate: typing.Union[bool, int] = True, ) -> typing.Tuple[PIL.Image.Image, typing.Callable[[typing.Tuple[float, float]], typing.Tuple[float, float]]]: """Preprocess an image. Return the image and a coordinate back-transformation function. 1. Scale larger side to *target_size* 2. Rotate by angle *rotate*, or auto-rotate if *rotate=None* (the default). """ # FIXME: re-using reader.Image would cover more file formats! # >>> from PIL import ExifTags # >>> exif_ori = [k for k, tag in ExifTags.TAGS.items() if tag == 'Orientation'] # >>> exif_ori = exif_ori[0] exif_ori = 274 # scale image orig_size = img.size if img.size[0] > img.size[1]: # landscape img = img.resize((target_size, int(img.height / img.width * target_size)), reducing_gap=3) elif img.size[0] < img.size[1]: # portrait img = img.resize((int(img.width / img.height * target_size), target_size), reducing_gap=3) else: # square img = img.resize(( int(img.width / img.height * target_size), int(img.width / img.height * target_size), ), reducing_gap=3) # get scale factors sX = orig_size[0] / img.width sY = orig_size[1] / img.height # rotate image (if need be) denorm = lambda xy: (sX*xy[0], sY*xy[1]) if rotate is not None: # auto-rotate according to EXIF information img_ori = img.getexif().get(exif_ori, None) if img_ori == 3 or rotate == 180: img = img.rotate(180, expand=True) denorm = lambda xy: (orig_size[0] - sX*xy[0], orig_size[1] - sY*xy[1]) elif img_ori == 6 or rotate == 270: img = img.rotate(270, expand=True) denorm = lambda xy: (orig_size[0] - sX*xy[1], sY*xy[0]) elif img_ori == 8 or rotate == 90: img = img.rotate(90, expand=True) denorm = lambda xy: (sX*xy[1], orig_size[1] - sY*xy[0]) # return image and denormalization function return img, denorm def __call__(self, path: str) -> typing.Sequence[dict]: try: # open the image img = PIL.Image.open(path) # rotate and scale the image img, denorm = self.preprocess(img, self._target_size) # detect faces boxes, probs = self._detector.detect(img) if boxes is None: # no faces detected return [] # ignore boxes with probability below threshold boxes = [box for box, p in zip(boxes, probs) if p >= self._min_face_prob] if len(boxes) == 0: # no faces detected return [] # compute face embeddings faces_img = self._detector.extract(img, boxes, None).to(self._device) embeds = self._embedder(faces_img) faces = [] for bbox, face, emb in zip(boxes, faces_img, embeds): # face hash ucid = bsfs.uuid.UCID.from_bytes(bytes(face.detach().cpu().numpy())) # position / size x0, y0 = denorm(bbox[:2]) x1, y1 = denorm(bbox[2:]) x, y = min(x0, x1), min(y0, y1) width, height = max(x0, x1) - x, max(y0, y1) - y # assembled faces.append(dict( ucid=ucid, # str x=x, # float y=y, # float width=width, # float height=height, # float embedding=emb, # np.array )) return faces except PIL.UnidentifiedImageError as err: # format not supported by PIL raise errors.UnsupportedFileFormatError(path) from err except IOError as err: # file not found and file open errors raise errors.ReaderError(path) from err except RuntimeError as err: # pytorch errors raise errors.ReaderError(path) from err except ValueError as err: # negative seek value raise errors.ReaderError(path) from err ## EOF ##