diff options
Diffstat (limited to 'bsie/extractor/image/face/identify.py')
-rw-r--r-- | bsie/extractor/image/face/identify.py | 12 |
1 files changed, 6 insertions, 6 deletions
diff --git a/bsie/extractor/image/face/identify.py b/bsie/extractor/image/face/identify.py index 44a75c4..dee935f 100644 --- a/bsie/extractor/image/face/identify.py +++ b/bsie/extractor/image/face/identify.py @@ -4,7 +4,6 @@ import csv import typing # external imports -from facenet_pytorch import MTCNN, InceptionResnetV1 import numpy as np import torch @@ -26,6 +25,7 @@ __all__: typing.Sequence[str] = ( bsf = ns.bsn.Face() class FaceIdentify(base.Extractor): + """Extract identified people in an image.""" CONTENT_READER = 'bsie.reader.face.FaceExtract' @@ -49,9 +49,9 @@ class FaceIdentify(base.Extractor): thres: float = 0.9, cuda_device: str = 'cuda:0', restklasse: str = 'https://example.com/user/anon', - ): + ): # pylint: disable=too-many-arguments # initialize parent with the schema - super().__init__(bsfs.schema.from_string(base.SCHEMA_PREAMBLE + f''' + super().__init__(bsfs.schema.from_string(base.SCHEMA_PREAMBLE + ''' bsn:Face rdfs:subClassOf bsfs:Node . bsn:Person rdfs:subClassOf bsfs:Node . <https://schema.bsfs.io/ie/Node/Face#depicts> rdfs:subClassOf bsfs:Predicate ; @@ -84,7 +84,7 @@ class FaceIdentify(base.Extractor): targets, embeds = emb_with_trg[:, 0], emb_with_trg[:, 1:] self._targets = torch.tensor(targets, dtype=torch.int32).to(self._device) self._embeds = torch.tensor(embeds).to(self._device) - with open(ref_mapping, 'rt') as ifile: + with open(ref_mapping, 'rt', encoding='UTF-8') as ifile: mapping = [(int(idx), name) for name, idx in csv.reader(ifile)] # ensure that the mapping is unique ids, names = zip(*mapping) @@ -133,13 +133,13 @@ class FaceIdentify(base.Extractor): self._restidx, )) - def _classify(self, emb: torch.Tensor) -> torch.Tensor: # [Nx512] -> [N] + def _classify(self, emb: torch.Tensor) -> typing.List[int]: # [Nx512] -> [N] # nearest neighbour approach dist = torch.cdist(emb, self._embeds) # pairwise distances best = dist.argmin(dim=1) # idx of lowest distance, per row labels = self._targets[best] # label (int) of nearest neighbour acc = dist[range(len(best)), best] < self._thres # check if distance is below threshold - return [lbl.item() if cnd == True else self._restidx for cnd, lbl in zip(acc, labels)] + return [lbl.item() if cnd is True else self._restidx for cnd, lbl in zip(acc, labels)] def extract( self, |