aboutsummaryrefslogtreecommitdiffstats
path: root/bsie/reader/face.py
diff options
context:
space:
mode:
Diffstat (limited to 'bsie/reader/face.py')
-rw-r--r--bsie/reader/face.py179
1 files changed, 179 insertions, 0 deletions
diff --git a/bsie/reader/face.py b/bsie/reader/face.py
new file mode 100644
index 0000000..c5374e0
--- /dev/null
+++ b/bsie/reader/face.py
@@ -0,0 +1,179 @@
+
+# standard imports
+import operator
+import typing
+
+# external imports
+from facenet_pytorch import MTCNN, InceptionResnetV1
+import PIL.Image
+import torch
+
+# bsie imports
+from bsie.utils import bsfs, errors, node, ns
+
+# inner-module imports
+from . import base
+
+# exports
+__all__: typing.Sequence[str] = (
+ 'FaceExtract',
+ )
+
+
+## code ##
+
+class FaceExtract(base.Reader):
+ """Extract faces and their feature vector from an image file."""
+
+ # Face patch size.
+ _target_size: int
+
+ # Lower bound on the detected face's probability.
+ _min_face_prob: float
+
+ # Face detector network.
+ _detector: MTCNN
+
+ # Face feature extractor network.
+ _embedder: InceptionResnetV1
+
+ def __init__(
+ self,
+ target_size: int = 1000,
+ min_face_size: int = 40,
+ min_face_prob: float = 0.992845,
+ cuda_device: str = 'cuda:0',
+ ext_face_size: int = 160,
+ thresholds: typing.Tuple[float, float, float] = [0.5, 0.6, 0.6],
+ factor: float = 0.709,
+ ):
+ # initialize
+ self._device = torch.device(cuda_device if torch.cuda.is_available() else 'cpu')
+ # initialize the face detection network
+ self._target_size = target_size
+ self._min_face_prob = min_face_prob
+ self._carghash = hash((min_face_size, ext_face_size, tuple(thresholds), factor))
+ self._detector = MTCNN(
+ min_face_size=min_face_size,
+ image_size=ext_face_size,
+ thresholds=thresholds,
+ factor=factor,
+ device=self._device,
+ keep_all=True,
+ ).to(self._device)
+ # initialize the face embedding netwrok
+ self._embedder = InceptionResnetV1('vggface2').to(self._device).eval()
+
+ def __repr__(self) -> str:
+ return f'{bsfs.typename(self)}({self._min_face_prob})'
+
+ def __eq__(self, other: typing.Any) -> bool:
+ return super().__eq__(other) \
+ and self._target_size == other._target_size \
+ and self._min_face_prob == other._min_face_prob \
+ and self._carghash == other._carghash
+
+ def __hash__(self) -> int:
+ return hash((super().__hash__(), self._target_size, self._min_face_prob, self._carghash))
+
+ @staticmethod
+ def preprocess(
+ img: PIL.Image.Image,
+ target_size: int,
+ rotate: typing.Union[bool, int] = True,
+ ) -> typing.Tuple[PIL.Image.Image, typing.Callable[[typing.Tuple[float, float]], typing.Tuple[float, float]]]:
+ """Preprocess an image. Return the image and a coordinate back-transformation function.
+ 1. Scale larger side to *target_size*
+ 2. Rotate by angle *rotate*, or auto-rotate if *rotate=None* (the default).
+ """
+ # FIXME: re-using reader.Image would cover more file formats!
+
+ # >>> from PIL import ExifTags
+ # >>> exif_ori = [k for k, tag in ExifTags.TAGS.items() if tag == 'Orientation']
+ # >>> exif_ori = exif_ori[0]
+ exif_ori = 274
+
+ # scale image
+ orig_size = img.size
+ if img.size[0] > img.size[1]: # landscape
+ img = img.resize((target_size, int(img.height / img.width * target_size)), reducing_gap=3)
+ elif img.size[0] < img.size[1]: # portrait
+ img = img.resize((int(img.width / img.height * target_size), target_size), reducing_gap=3)
+ else: # square
+ img = img.resize((
+ int(img.width / img.height * target_size),
+ int(img.width / img.height * target_size),
+ ), reducing_gap=3)
+
+ # get scale factors
+ sX = orig_size[0] / img.width
+ sY = orig_size[1] / img.height
+
+ # rotate image (if need be)
+ denorm = lambda xy: (sX*xy[0], sY*xy[1])
+ if rotate is not None:
+ # auto-rotate according to EXIF information
+ img_ori = img.getexif().get(exif_ori, None)
+ if img_ori == 3 or rotate == 180:
+ img = img.rotate(180, expand=True)
+ denorm = lambda xy: (orig_size[0] - sX*xy[0], orig_size[1] - sY*xy[1])
+ elif img_ori == 6 or rotate == 270:
+ img = img.rotate(270, expand=True)
+ denorm = lambda xy: (orig_size[0] - sX*xy[1], sY*xy[0])
+ elif img_ori == 8 or rotate == 90:
+ img = img.rotate(90, expand=True)
+ denorm = lambda xy: (sX*xy[1], orig_size[1] - sY*xy[0])
+
+ # return image and denormalization function
+ return img, denorm
+
+ def __call__(self, path: str) -> typing.Sequence[dict]:
+ try:
+ # open the image
+ img = PIL.Image.open(path)
+ # rotate and scale the image
+ img, denorm = self.preprocess(img, self._target_size)
+
+ # detect faces
+ boxes, probs = self._detector.detect(img)
+ if boxes is None: # no faces detected
+ return []
+ # ignore boxes with probability below threshold
+ boxes = [box for box, p in zip(boxes, probs) if p >= self._min_face_prob]
+ if len(boxes) == 0: # no faces detected
+ return []
+ # compute face embeddings
+ faces_img = self._detector.extract(img, boxes, None).to(self._device)
+ embeds = self._embedder(faces_img)
+
+ faces = []
+ for bbox, face, emb in zip(boxes, faces_img, embeds):
+ # face hash
+ ucid = bsfs.uuid.UCID.from_bytes(bytes(face.detach().cpu().numpy()))
+ # position / size
+ x0, y0 = denorm(bbox[:2])
+ x1, y1 = denorm(bbox[2:])
+ x, y = min(x0, x1), min(y0, y1)
+ width, height = max(x0, x1) - x, max(y0, y1) - y
+ # assembled
+ faces.append(dict(
+ ucid=ucid, # str
+ x=x, # float
+ y=y, # float
+ width=width, # float
+ height=height, # float
+ embedding=emb, # np.array
+ ))
+
+ return faces
+
+ except PIL.UnidentifiedImageError as err: # format not supported by PIL
+ raise errors.UnsupportedFileFormatError(path) from err
+ except IOError as err: # file not found and file open errors
+ raise errors.ReaderError(path) from err
+ except RuntimeError as err: # pytorch errors
+ raise errors.ReaderError(path) from err
+ except ValueError as err: # negative seek value
+ raise errors.ReaderError(path) from err
+
+## EOF ##