diff options
Diffstat (limited to 'bsie/extractor/image/face/identify.py')
-rw-r--r-- | bsie/extractor/image/face/identify.py | 176 |
1 files changed, 176 insertions, 0 deletions
diff --git a/bsie/extractor/image/face/identify.py b/bsie/extractor/image/face/identify.py new file mode 100644 index 0000000..152f113 --- /dev/null +++ b/bsie/extractor/image/face/identify.py @@ -0,0 +1,176 @@ + +# standard imports +import csv +import typing + +# external imports +from facenet_pytorch import MTCNN, InceptionResnetV1 +import numpy as np +import torch + +# bsie imports +from bsie.utils import bsfs, node, ns + +# inner-module imports +from ... import base + +# exports +__all__: typing.Sequence[str] = ( + 'FaceIdentify', + ) + + +## code ## + +bsf = ns.bsn.Face() + +class FaceIdentify(base.Extractor): + + CONTENT_READER = 'bsie.reader.face.FaceExtract' + + _restklasse: bsfs.URI + _thres: float + _device: torch.device + _restidx: int + _id2name: typing.Dict[int, str] + _embeds: torch.Tensor + _targets: torch.Tensor + + + + # FIXME: This could be a bsfs maintenance function instead of a bsie function + + def __init__( + self, + # FIXME: Initialize from bsfs storage instead of files + ref_embeds: str, + ref_mapping: str, + thres: float = 0.9, + cuda_device: str = 'cuda:0', + restklasse: str = 'https://example.com/user/anon', + ): + # initialize parent with the schema + super().__init__(bsfs.schema.from_string(base.SCHEMA_PREAMBLE + f''' + bsn:Face rdfs:subClassOf bsfs:Node . + bsn:Person rdfs:subClassOf bsfs:Node . + <https://schema.bsfs.io/ie/Node/Face#depicts> rdfs:subClassOf bsfs:Predicate ; + rdfs:domain bsn:Face ; + rdfs:range bsn:Person . + # FIXME: Entity -> Face? + bse:face rdfs:subClassOf bsfs:Predicate ; + rdfs:domain bsn:Entity ; + rdfs:range bsn:Face . + # FIXME: Face -> Embedding? + #<https://schema.bsfs.io/ie/Node/Face#embedding> + # rdfs:subClassOf bsfs:Predicate ; + # rdfs:domain bsn:Face ; + # rdfs:range <https://schema.bsfs.io/ie/Literal/Array/Feature/Face#resnet512> ; + # bsfs:unique "true"^^xsd:boolean . + #<https://schema.bsfs.io/ie/Literal/Array/Feature/Face#resnet512> + # rdfs:subClassOf bsa:Feature ; + # bsfs:distance <https://schema.bsfs.io/core/distance#euclidean> ; + # bsfs:dtype <https://schema.bsfs.io/core/dtype#f32>; + # bsfs:dimension "512"^^xsd:integer . + + ''')) + # store extra members + self._restklasse = bsfs.URI(restklasse) + self._thres = thres + # get face instances + self._device = torch.device(cuda_device if torch.cuda.is_available() else 'cpu') + with open(ref_embeds, 'rb') as ifile: + emb_with_trg = np.load(ifile) + targets, embeds = emb_with_trg[:, 0], emb_with_trg[:, 1:] + self._targets = torch.tensor(targets, dtype=torch.int32).to(self._device) + self._embeds = torch.tensor(embeds).to(self._device) + with open(ref_mapping, 'rt') as ifile: + mapping = [(int(idx), name) for name, idx in csv.reader(ifile)] + # ensure that the mapping is unique + ids, names = zip(*mapping) + if len(set(names)) != len(names): + raise Exception('people identifiers must be unique') + if len(set(ids)) != len(ids): + raise Exception('people indices must be unique') + # ensure that all targets are accounted for + if not {int(i) for i in self._targets.tolist()}.issubset(set(ids)): + raise Exception('all targets must be labelled') + # ensure and fetch the index of the restklasse + if self._restklasse not in names: + mapping.append((max(ids) + 1, self._restklasse)) + # store mapping + self._restidx = [idx for idx, name in mapping if name == self._restklasse][0] + self._id2name = dict(mapping) + # discard the restklasse from the reference points + self._embeds = self._embeds[self._targets != self._restidx] + self._targets = self._targets[self._targets != self._restidx] + + @property + def principals(self) -> typing.Iterator[bsfs.schema.Predicate]: + """Return the principal predicates, i.e., relations from/to the extraction subject.""" + yield from super().principals + yield self.schema.predicate(bsf.depicts) + + def __repr__(self) -> str: + return f'{bsfs.typename(self)}(N={len(self._embeds)}, restklasse={self._restklasse})' + + def __eq__(self, other: typing.Any) -> bool: + return super().__eq__(other) \ + and self._thres == other._thres \ + and self._id2name == other._id2name \ + and torch.equal(self._embeds, other._embeds) \ + and torch.equal(self._targets, other._targets) \ + and self._restklasse == other._restklasse \ + and self._restidx == other._restidx + + def __hash__(self) -> int: + return hash((super().__hash__(), + tuple(sorted(self._id2name.items())), + self._thres, + tuple(self._embeds.detach().cpu().numpy().reshape(-1).tolist()), + tuple(self._targets.detach().cpu().numpy().reshape(-1).tolist()), + self._restklasse, + self._restidx, + )) + + def _classify(self, emb: torch.Tensor) -> torch.Tensor: # [Nx512] -> [N] + # nearest neighbour approach + dist = torch.cdist(emb, self._embeds) # pairwise distances + best = dist.argmin(dim=1) # idx of lowest distance, per row + labels = self._targets[best] # label (int) of nearest neighbour + acc = dist[range(len(best)), best] < self._thres # check if distance is below threshold + return [lbl.item() if cnd == True else self._restidx for cnd, lbl in zip(acc, labels)] + + def extract( + self, + subject: node.Node, + content: typing.Any, + principals: typing.Iterable[bsfs.schema.Predicate], + ) -> typing.Iterator[typing.Tuple[node.Node, bsfs.schema.Predicate, typing.Any]]: + # check principals + #if self.schema.predicate(bsf.depicts) not in principals: + if self.schema.predicate(ns.bse.face) not in principals: + # nothing to do; abort + return + # check content + if len(content) == 0: + return + + # collect embeddings + emb = torch.vstack([face['embedding'] for face in content]).to(self._device) + # apply classifier + labels = self._classify(emb) + # walk through faces + for face, idx in zip(content, labels): + lbl = bsfs.URI(self._id2name[idx]) # label (uri) of nearest neighbour + if lbl == self._restklasse: # suppress + continue + pnode = node.Node(ns.bsn.Person, uri=lbl) + fnode = node.Node(ns.bsn.Face, ucid=face['ucid']) + # emit triple + yield fnode, self.schema.predicate(bsf.depicts), pnode + # FIXME: emit subject -> face -> fnode? + yield subject, self.schema.predicate(ns.bse.face), fnode + # FIXME: emit embedding? + #yield fnode, bsf.embedding, face['embedding'] + +## EOF ## |