# standard imports import csv import typing # external imports from facenet_pytorch import MTCNN, InceptionResnetV1 import numpy as np import torch # bsie imports from bsie.matcher import nodes from bsie.utils import bsfs, ns # inner-module imports from ... import base # exports __all__: typing.Sequence[str] = ( 'FaceIdentify', ) ## code ## bsf = ns.bsn.Face() class FaceIdentify(base.Extractor): CONTENT_READER = 'bsie.reader.face.FaceExtract' _restklasse: bsfs.URI _thres: float _device: torch.device _restidx: int _id2name: typing.Dict[int, str] _embeds: torch.Tensor _targets: torch.Tensor # FIXME: This could be a bsfs maintenance function instead of a bsie function def __init__( self, # FIXME: Initialize from bsfs storage instead of files ref_embeds: str, ref_mapping: str, thres: float = 0.9, cuda_device: str = 'cuda:0', restklasse: str = 'https://example.com/user/anon', ): # initialize parent with the schema super().__init__(bsfs.schema.from_string(base.SCHEMA_PREAMBLE + f''' bsn:Face rdfs:subClassOf bsfs:Node . bsn:Person rdfs:subClassOf bsfs:Node . rdfs:subClassOf bsfs:Predicate ; rdfs:domain bsn:Face ; rdfs:range bsn:Person . # FIXME: Entity -> Face? bse:face rdfs:subClassOf bsfs:Predicate ; rdfs:domain bsn:Entity ; rdfs:range bsn:Face . # FIXME: Face -> Embedding? # # rdfs:subClassOf bsfs:Predicate ; # rdfs:domain bsn:Face ; # rdfs:range ; # bsfs:unique "true"^^xsd:boolean . # # rdfs:subClassOf bsa:Feature ; # bsfs:distance ; # bsfs:dtype ; # bsfs:dimension "512"^^xsd:integer . ''')) # store extra members self._restklasse = bsfs.URI(restklasse) self._thres = thres # get face instances self._device = torch.device(cuda_device if torch.cuda.is_available() else 'cpu') with open(ref_embeds, 'rb') as ifile: emb_with_trg = np.load(ifile) targets, embeds = emb_with_trg[:, 0], emb_with_trg[:, 1:] self._targets = torch.tensor(targets, dtype=torch.int32).to(self._device) self._embeds = torch.tensor(embeds).to(self._device) with open(ref_mapping, 'rt') as ifile: mapping = [(int(idx), name) for name, idx in csv.reader(ifile)] # ensure that the mapping is unique ids, names = zip(*mapping) if len(set(names)) != len(names): raise Exception('people identifiers must be unique') if len(set(ids)) != len(ids): raise Exception('people indices must be unique') # ensure that all targets are accounted for if not {int(i) for i in self._targets.tolist()}.issubset(set(ids)): raise Exception('all targets must be labelled') # ensure and fetch the index of the restklasse if self._restklasse not in names: mapping.append((max(ids) + 1, self._restklasse)) # store mapping self._restidx = [idx for idx, name in mapping if name == self._restklasse][0] self._id2name = dict(mapping) # discard the restklasse from the reference points self._embeds = self._embeds[self._targets != self._restidx] self._targets = self._targets[self._targets != self._restidx] @property def principals(self) -> typing.Iterator[bsfs.schema.Predicate]: """Return the principal predicates, i.e., relations from/to the extraction subject.""" yield from super().principals yield self.schema.predicate(bsf.depicts) def __repr__(self) -> str: return f'{bsfs.typename(self)}(N={len(self._embeds)}, restklasse={self._restklasse})' def __eq__(self, other: typing.Any) -> bool: return super().__eq__(other) \ and self._thres == other._thres \ and self._id2name == other._id2name \ and torch.equal(self._embeds, other._embeds) \ and torch.equal(self._targets, other._targets) \ and self._restklasse == other._restklasse \ and self._restidx == other._restidx def __hash__(self) -> int: return hash((super().__hash__(), tuple(sorted(self._id2name.items())), self._thres, tuple(self._embeds.detach().cpu().numpy().reshape(-1).tolist()), tuple(self._targets.detach().cpu().numpy().reshape(-1).tolist()), self._restklasse, self._restidx, )) def _classify(self, emb: torch.Tensor) -> torch.Tensor: # [Nx512] -> [N] # nearest neighbour approach dist = torch.cdist(emb, self._embeds) # pairwise distances best = dist.argmin(dim=1) # idx of lowest distance, per row labels = self._targets[best] # label (int) of nearest neighbour acc = dist[range(len(best)), best] < self._thres # check if distance is below threshold return [lbl.item() if cnd == True else self._restidx for cnd, lbl in zip(acc, labels)] def extract( self, subject: nodes.Entity, content: typing.Any, principals: typing.Iterable[bsfs.schema.Predicate], ) -> typing.Iterator[typing.Tuple[nodes.Node, bsfs.schema.Predicate, typing.Any]]: # check principals #if self.schema.predicate(bsf.depicts) not in principals: if self.schema.predicate(ns.bse.face) not in principals: # nothing to do; abort return # check content if len(content) == 0: return # collect embeddings emb = torch.vstack([face['embedding'] for face in content]).to(self._device) # apply classifier labels = self._classify(emb) # walk through faces for face, idx in zip(content, labels): lbl = bsfs.URI(self._id2name[idx]) # label (uri) of nearest neighbour if lbl == self._restklasse: # suppress continue pnode = nodes.Person(uri=lbl) fnode = nodes.Face(ucid=face['ucid']) # emit triple yield fnode, self.schema.predicate(bsf.depicts), pnode # FIXME: emit subject -> face -> fnode? yield subject, self.schema.predicate(ns.bse.face), fnode # FIXME: emit embedding? #yield fnode, bsf.embedding, face['embedding'] ## EOF ##