aboutsummaryrefslogtreecommitdiffstats
path: root/bsie/extractor/image/face/identify.py
diff options
context:
space:
mode:
Diffstat (limited to 'bsie/extractor/image/face/identify.py')
-rw-r--r--bsie/extractor/image/face/identify.py176
1 files changed, 176 insertions, 0 deletions
diff --git a/bsie/extractor/image/face/identify.py b/bsie/extractor/image/face/identify.py
new file mode 100644
index 0000000..152f113
--- /dev/null
+++ b/bsie/extractor/image/face/identify.py
@@ -0,0 +1,176 @@
+
+# standard imports
+import csv
+import typing
+
+# external imports
+from facenet_pytorch import MTCNN, InceptionResnetV1
+import numpy as np
+import torch
+
+# bsie imports
+from bsie.utils import bsfs, node, ns
+
+# inner-module imports
+from ... import base
+
+# exports
+__all__: typing.Sequence[str] = (
+ 'FaceIdentify',
+ )
+
+
+## code ##
+
+bsf = ns.bsn.Face()
+
+class FaceIdentify(base.Extractor):
+
+ CONTENT_READER = 'bsie.reader.face.FaceExtract'
+
+ _restklasse: bsfs.URI
+ _thres: float
+ _device: torch.device
+ _restidx: int
+ _id2name: typing.Dict[int, str]
+ _embeds: torch.Tensor
+ _targets: torch.Tensor
+
+
+
+ # FIXME: This could be a bsfs maintenance function instead of a bsie function
+
+ def __init__(
+ self,
+ # FIXME: Initialize from bsfs storage instead of files
+ ref_embeds: str,
+ ref_mapping: str,
+ thres: float = 0.9,
+ cuda_device: str = 'cuda:0',
+ restklasse: str = 'https://example.com/user/anon',
+ ):
+ # initialize parent with the schema
+ super().__init__(bsfs.schema.from_string(base.SCHEMA_PREAMBLE + f'''
+ bsn:Face rdfs:subClassOf bsfs:Node .
+ bsn:Person rdfs:subClassOf bsfs:Node .
+ <https://schema.bsfs.io/ie/Node/Face#depicts> rdfs:subClassOf bsfs:Predicate ;
+ rdfs:domain bsn:Face ;
+ rdfs:range bsn:Person .
+ # FIXME: Entity -> Face?
+ bse:face rdfs:subClassOf bsfs:Predicate ;
+ rdfs:domain bsn:Entity ;
+ rdfs:range bsn:Face .
+ # FIXME: Face -> Embedding?
+ #<https://schema.bsfs.io/ie/Node/Face#embedding>
+ # rdfs:subClassOf bsfs:Predicate ;
+ # rdfs:domain bsn:Face ;
+ # rdfs:range <https://schema.bsfs.io/ie/Literal/Array/Feature/Face#resnet512> ;
+ # bsfs:unique "true"^^xsd:boolean .
+ #<https://schema.bsfs.io/ie/Literal/Array/Feature/Face#resnet512>
+ # rdfs:subClassOf bsa:Feature ;
+ # bsfs:distance <https://schema.bsfs.io/core/distance#euclidean> ;
+ # bsfs:dtype <https://schema.bsfs.io/core/dtype#f32>;
+ # bsfs:dimension "512"^^xsd:integer .
+
+ '''))
+ # store extra members
+ self._restklasse = bsfs.URI(restklasse)
+ self._thres = thres
+ # get face instances
+ self._device = torch.device(cuda_device if torch.cuda.is_available() else 'cpu')
+ with open(ref_embeds, 'rb') as ifile:
+ emb_with_trg = np.load(ifile)
+ targets, embeds = emb_with_trg[:, 0], emb_with_trg[:, 1:]
+ self._targets = torch.tensor(targets, dtype=torch.int32).to(self._device)
+ self._embeds = torch.tensor(embeds).to(self._device)
+ with open(ref_mapping, 'rt') as ifile:
+ mapping = [(int(idx), name) for name, idx in csv.reader(ifile)]
+ # ensure that the mapping is unique
+ ids, names = zip(*mapping)
+ if len(set(names)) != len(names):
+ raise Exception('people identifiers must be unique')
+ if len(set(ids)) != len(ids):
+ raise Exception('people indices must be unique')
+ # ensure that all targets are accounted for
+ if not {int(i) for i in self._targets.tolist()}.issubset(set(ids)):
+ raise Exception('all targets must be labelled')
+ # ensure and fetch the index of the restklasse
+ if self._restklasse not in names:
+ mapping.append((max(ids) + 1, self._restklasse))
+ # store mapping
+ self._restidx = [idx for idx, name in mapping if name == self._restklasse][0]
+ self._id2name = dict(mapping)
+ # discard the restklasse from the reference points
+ self._embeds = self._embeds[self._targets != self._restidx]
+ self._targets = self._targets[self._targets != self._restidx]
+
+ @property
+ def principals(self) -> typing.Iterator[bsfs.schema.Predicate]:
+ """Return the principal predicates, i.e., relations from/to the extraction subject."""
+ yield from super().principals
+ yield self.schema.predicate(bsf.depicts)
+
+ def __repr__(self) -> str:
+ return f'{bsfs.typename(self)}(N={len(self._embeds)}, restklasse={self._restklasse})'
+
+ def __eq__(self, other: typing.Any) -> bool:
+ return super().__eq__(other) \
+ and self._thres == other._thres \
+ and self._id2name == other._id2name \
+ and torch.equal(self._embeds, other._embeds) \
+ and torch.equal(self._targets, other._targets) \
+ and self._restklasse == other._restklasse \
+ and self._restidx == other._restidx
+
+ def __hash__(self) -> int:
+ return hash((super().__hash__(),
+ tuple(sorted(self._id2name.items())),
+ self._thres,
+ tuple(self._embeds.detach().cpu().numpy().reshape(-1).tolist()),
+ tuple(self._targets.detach().cpu().numpy().reshape(-1).tolist()),
+ self._restklasse,
+ self._restidx,
+ ))
+
+ def _classify(self, emb: torch.Tensor) -> torch.Tensor: # [Nx512] -> [N]
+ # nearest neighbour approach
+ dist = torch.cdist(emb, self._embeds) # pairwise distances
+ best = dist.argmin(dim=1) # idx of lowest distance, per row
+ labels = self._targets[best] # label (int) of nearest neighbour
+ acc = dist[range(len(best)), best] < self._thres # check if distance is below threshold
+ return [lbl.item() if cnd == True else self._restidx for cnd, lbl in zip(acc, labels)]
+
+ def extract(
+ self,
+ subject: node.Node,
+ content: typing.Any,
+ principals: typing.Iterable[bsfs.schema.Predicate],
+ ) -> typing.Iterator[typing.Tuple[node.Node, bsfs.schema.Predicate, typing.Any]]:
+ # check principals
+ #if self.schema.predicate(bsf.depicts) not in principals:
+ if self.schema.predicate(ns.bse.face) not in principals:
+ # nothing to do; abort
+ return
+ # check content
+ if len(content) == 0:
+ return
+
+ # collect embeddings
+ emb = torch.vstack([face['embedding'] for face in content]).to(self._device)
+ # apply classifier
+ labels = self._classify(emb)
+ # walk through faces
+ for face, idx in zip(content, labels):
+ lbl = bsfs.URI(self._id2name[idx]) # label (uri) of nearest neighbour
+ if lbl == self._restklasse: # suppress
+ continue
+ pnode = node.Node(ns.bsn.Person, uri=lbl)
+ fnode = node.Node(ns.bsn.Face, ucid=face['ucid'])
+ # emit triple
+ yield fnode, self.schema.predicate(bsf.depicts), pnode
+ # FIXME: emit subject -> face -> fnode?
+ yield subject, self.schema.predicate(ns.bse.face), fnode
+ # FIXME: emit embedding?
+ #yield fnode, bsf.embedding, face['embedding']
+
+## EOF ##