aboutsummaryrefslogtreecommitdiffstats
path: root/bsie/extractor/image/face/identify.py
diff options
context:
space:
mode:
Diffstat (limited to 'bsie/extractor/image/face/identify.py')
-rw-r--r--bsie/extractor/image/face/identify.py12
1 files changed, 6 insertions, 6 deletions
diff --git a/bsie/extractor/image/face/identify.py b/bsie/extractor/image/face/identify.py
index 44a75c4..dee935f 100644
--- a/bsie/extractor/image/face/identify.py
+++ b/bsie/extractor/image/face/identify.py
@@ -4,7 +4,6 @@ import csv
import typing
# external imports
-from facenet_pytorch import MTCNN, InceptionResnetV1
import numpy as np
import torch
@@ -26,6 +25,7 @@ __all__: typing.Sequence[str] = (
bsf = ns.bsn.Face()
class FaceIdentify(base.Extractor):
+ """Extract identified people in an image."""
CONTENT_READER = 'bsie.reader.face.FaceExtract'
@@ -49,9 +49,9 @@ class FaceIdentify(base.Extractor):
thres: float = 0.9,
cuda_device: str = 'cuda:0',
restklasse: str = 'https://example.com/user/anon',
- ):
+ ): # pylint: disable=too-many-arguments
# initialize parent with the schema
- super().__init__(bsfs.schema.from_string(base.SCHEMA_PREAMBLE + f'''
+ super().__init__(bsfs.schema.from_string(base.SCHEMA_PREAMBLE + '''
bsn:Face rdfs:subClassOf bsfs:Node .
bsn:Person rdfs:subClassOf bsfs:Node .
<https://schema.bsfs.io/ie/Node/Face#depicts> rdfs:subClassOf bsfs:Predicate ;
@@ -84,7 +84,7 @@ class FaceIdentify(base.Extractor):
targets, embeds = emb_with_trg[:, 0], emb_with_trg[:, 1:]
self._targets = torch.tensor(targets, dtype=torch.int32).to(self._device)
self._embeds = torch.tensor(embeds).to(self._device)
- with open(ref_mapping, 'rt') as ifile:
+ with open(ref_mapping, 'rt', encoding='UTF-8') as ifile:
mapping = [(int(idx), name) for name, idx in csv.reader(ifile)]
# ensure that the mapping is unique
ids, names = zip(*mapping)
@@ -133,13 +133,13 @@ class FaceIdentify(base.Extractor):
self._restidx,
))
- def _classify(self, emb: torch.Tensor) -> torch.Tensor: # [Nx512] -> [N]
+ def _classify(self, emb: torch.Tensor) -> typing.List[int]: # [Nx512] -> [N]
# nearest neighbour approach
dist = torch.cdist(emb, self._embeds) # pairwise distances
best = dist.argmin(dim=1) # idx of lowest distance, per row
labels = self._targets[best] # label (int) of nearest neighbour
acc = dist[range(len(best)), best] < self._thres # check if distance is below threshold
- return [lbl.item() if cnd == True else self._restidx for cnd, lbl in zip(acc, labels)]
+ return [lbl.item() if cnd is True else self._restidx for cnd, lbl in zip(acc, labels)]
def extract(
self,