diff options
Diffstat (limited to 'bsie/reader/face.py')
-rw-r--r-- | bsie/reader/face.py | 179 |
1 files changed, 179 insertions, 0 deletions
diff --git a/bsie/reader/face.py b/bsie/reader/face.py new file mode 100644 index 0000000..c5374e0 --- /dev/null +++ b/bsie/reader/face.py @@ -0,0 +1,179 @@ + +# standard imports +import operator +import typing + +# external imports +from facenet_pytorch import MTCNN, InceptionResnetV1 +import PIL.Image +import torch + +# bsie imports +from bsie.utils import bsfs, errors, node, ns + +# inner-module imports +from . import base + +# exports +__all__: typing.Sequence[str] = ( + 'FaceExtract', + ) + + +## code ## + +class FaceExtract(base.Reader): + """Extract faces and their feature vector from an image file.""" + + # Face patch size. + _target_size: int + + # Lower bound on the detected face's probability. + _min_face_prob: float + + # Face detector network. + _detector: MTCNN + + # Face feature extractor network. + _embedder: InceptionResnetV1 + + def __init__( + self, + target_size: int = 1000, + min_face_size: int = 40, + min_face_prob: float = 0.992845, + cuda_device: str = 'cuda:0', + ext_face_size: int = 160, + thresholds: typing.Tuple[float, float, float] = [0.5, 0.6, 0.6], + factor: float = 0.709, + ): + # initialize + self._device = torch.device(cuda_device if torch.cuda.is_available() else 'cpu') + # initialize the face detection network + self._target_size = target_size + self._min_face_prob = min_face_prob + self._carghash = hash((min_face_size, ext_face_size, tuple(thresholds), factor)) + self._detector = MTCNN( + min_face_size=min_face_size, + image_size=ext_face_size, + thresholds=thresholds, + factor=factor, + device=self._device, + keep_all=True, + ).to(self._device) + # initialize the face embedding netwrok + self._embedder = InceptionResnetV1('vggface2').to(self._device).eval() + + def __repr__(self) -> str: + return f'{bsfs.typename(self)}({self._min_face_prob})' + + def __eq__(self, other: typing.Any) -> bool: + return super().__eq__(other) \ + and self._target_size == other._target_size \ + and self._min_face_prob == other._min_face_prob \ + and self._carghash == other._carghash + + def __hash__(self) -> int: + return hash((super().__hash__(), self._target_size, self._min_face_prob, self._carghash)) + + @staticmethod + def preprocess( + img: PIL.Image.Image, + target_size: int, + rotate: typing.Union[bool, int] = True, + ) -> typing.Tuple[PIL.Image.Image, typing.Callable[[typing.Tuple[float, float]], typing.Tuple[float, float]]]: + """Preprocess an image. Return the image and a coordinate back-transformation function. + 1. Scale larger side to *target_size* + 2. Rotate by angle *rotate*, or auto-rotate if *rotate=None* (the default). + """ + # FIXME: re-using reader.Image would cover more file formats! + + # >>> from PIL import ExifTags + # >>> exif_ori = [k for k, tag in ExifTags.TAGS.items() if tag == 'Orientation'] + # >>> exif_ori = exif_ori[0] + exif_ori = 274 + + # scale image + orig_size = img.size + if img.size[0] > img.size[1]: # landscape + img = img.resize((target_size, int(img.height / img.width * target_size)), reducing_gap=3) + elif img.size[0] < img.size[1]: # portrait + img = img.resize((int(img.width / img.height * target_size), target_size), reducing_gap=3) + else: # square + img = img.resize(( + int(img.width / img.height * target_size), + int(img.width / img.height * target_size), + ), reducing_gap=3) + + # get scale factors + sX = orig_size[0] / img.width + sY = orig_size[1] / img.height + + # rotate image (if need be) + denorm = lambda xy: (sX*xy[0], sY*xy[1]) + if rotate is not None: + # auto-rotate according to EXIF information + img_ori = img.getexif().get(exif_ori, None) + if img_ori == 3 or rotate == 180: + img = img.rotate(180, expand=True) + denorm = lambda xy: (orig_size[0] - sX*xy[0], orig_size[1] - sY*xy[1]) + elif img_ori == 6 or rotate == 270: + img = img.rotate(270, expand=True) + denorm = lambda xy: (orig_size[0] - sX*xy[1], sY*xy[0]) + elif img_ori == 8 or rotate == 90: + img = img.rotate(90, expand=True) + denorm = lambda xy: (sX*xy[1], orig_size[1] - sY*xy[0]) + + # return image and denormalization function + return img, denorm + + def __call__(self, path: str) -> typing.Sequence[dict]: + try: + # open the image + img = PIL.Image.open(path) + # rotate and scale the image + img, denorm = self.preprocess(img, self._target_size) + + # detect faces + boxes, probs = self._detector.detect(img) + if boxes is None: # no faces detected + return [] + # ignore boxes with probability below threshold + boxes = [box for box, p in zip(boxes, probs) if p >= self._min_face_prob] + if len(boxes) == 0: # no faces detected + return [] + # compute face embeddings + faces_img = self._detector.extract(img, boxes, None).to(self._device) + embeds = self._embedder(faces_img) + + faces = [] + for bbox, face, emb in zip(boxes, faces_img, embeds): + # face hash + ucid = bsfs.uuid.UCID.from_bytes(bytes(face.detach().cpu().numpy())) + # position / size + x0, y0 = denorm(bbox[:2]) + x1, y1 = denorm(bbox[2:]) + x, y = min(x0, x1), min(y0, y1) + width, height = max(x0, x1) - x, max(y0, y1) - y + # assembled + faces.append(dict( + ucid=ucid, # str + x=x, # float + y=y, # float + width=width, # float + height=height, # float + embedding=emb, # np.array + )) + + return faces + + except PIL.UnidentifiedImageError as err: # format not supported by PIL + raise errors.UnsupportedFileFormatError(path) from err + except IOError as err: # file not found and file open errors + raise errors.ReaderError(path) from err + except RuntimeError as err: # pytorch errors + raise errors.ReaderError(path) from err + except ValueError as err: # negative seek value + raise errors.ReaderError(path) from err + +## EOF ## |