aboutsummaryrefslogtreecommitdiffstats
path: root/bsie/extractor/image/face/identify.py
blob: 44a75c444fffa7dcb1837582baaef63fd604ab99 (plain)
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
83
84
85
86
87
88
89
90
91
92
93
94
95
96
97
98
99
100
101
102
103
104
105
106
107
108
109
110
111
112
113
114
115
116
117
118
119
120
121
122
123
124
125
126
127
128
129
130
131
132
133
134
135
136
137
138
139
140
141
142
143
144
145
146
147
148
149
150
151
152
153
154
155
156
157
158
159
160
161
162
163
164
165
166
167
168
169
170
171
172
173
174
175
176
177

# standard imports
import csv
import typing

# external imports
from facenet_pytorch import MTCNN, InceptionResnetV1
import numpy as np
import torch

# bsie imports
from bsie.matcher import nodes
from bsie.utils import bsfs, ns

# inner-module imports
from ... import base

# exports
__all__: typing.Sequence[str] = (
    'FaceIdentify',
    )


## code ##

bsf = ns.bsn.Face()

class FaceIdentify(base.Extractor):

    CONTENT_READER = 'bsie.reader.face.FaceExtract'

    _restklasse: bsfs.URI
    _thres: float
    _device: torch.device
    _restidx: int
    _id2name: typing.Dict[int, str]
    _embeds: torch.Tensor
    _targets: torch.Tensor



    # FIXME: This could be a bsfs maintenance function instead of a bsie function

    def __init__(
            self,
            # FIXME: Initialize from bsfs storage instead of files
            ref_embeds: str,
            ref_mapping: str,
            thres: float = 0.9,
            cuda_device: str = 'cuda:0',
            restklasse: str = 'https://example.com/user/anon',
            ):
        # initialize parent with the schema
        super().__init__(bsfs.schema.from_string(base.SCHEMA_PREAMBLE + f'''
            bsn:Face rdfs:subClassOf bsfs:Node .
            bsn:Person rdfs:subClassOf bsfs:Node .
            <https://schema.bsfs.io/ie/Node/Face#depicts> rdfs:subClassOf bsfs:Predicate ;
                rdfs:domain bsn:Face ;
                rdfs:range bsn:Person .
            # FIXME: Entity -> Face?
            bse:face rdfs:subClassOf bsfs:Predicate ;
                rdfs:domain bsn:Entity ;
                rdfs:range bsn:Face .
            # FIXME: Face -> Embedding?
            #<https://schema.bsfs.io/ie/Node/Face#embedding>
            #    rdfs:subClassOf bsfs:Predicate ;
            #    rdfs:domain bsn:Face ;
            #    rdfs:range <https://schema.bsfs.io/ie/Literal/Array/Feature/Face#resnet512> ;
            #    bsfs:unique "true"^^xsd:boolean .
            #<https://schema.bsfs.io/ie/Literal/Array/Feature/Face#resnet512>
            #    rdfs:subClassOf bsa:Feature ;
            #    bsfs:distance <https://schema.bsfs.io/core/distance#euclidean> ;
            #    bsfs:dtype <https://schema.bsfs.io/core/dtype#f32>;
            #    bsfs:dimension "512"^^xsd:integer .

            '''))
        # store extra members
        self._restklasse = bsfs.URI(restklasse)
        self._thres = thres
        # get face instances
        self._device = torch.device(cuda_device if torch.cuda.is_available() else 'cpu')
        with open(ref_embeds, 'rb') as ifile:
            emb_with_trg = np.load(ifile)
            targets, embeds = emb_with_trg[:, 0], emb_with_trg[:, 1:]
            self._targets = torch.tensor(targets, dtype=torch.int32).to(self._device)
            self._embeds = torch.tensor(embeds).to(self._device)
        with open(ref_mapping, 'rt') as ifile:
            mapping = [(int(idx), name) for name, idx in csv.reader(ifile)]
        # ensure that the mapping is unique
        ids, names = zip(*mapping)
        if len(set(names)) != len(names):
            raise Exception('people identifiers must be unique')
        if len(set(ids)) != len(ids):
            raise Exception('people indices must be unique')
        # ensure that all targets are accounted for
        if not {int(i) for i in self._targets.tolist()}.issubset(set(ids)):
            raise Exception('all targets must be labelled')
        # ensure and fetch the index of the restklasse
        if self._restklasse not in names:
            mapping.append((max(ids) + 1, self._restklasse))
        # store mapping
        self._restidx = [idx for idx, name in mapping if name == self._restklasse][0]
        self._id2name = dict(mapping)
        # discard the restklasse from the reference points
        self._embeds = self._embeds[self._targets != self._restidx]
        self._targets = self._targets[self._targets != self._restidx]

    @property
    def principals(self) -> typing.Iterator[bsfs.schema.Predicate]:
        """Return the principal predicates, i.e., relations from/to the extraction subject."""
        yield from super().principals
        yield self.schema.predicate(bsf.depicts)

    def __repr__(self) -> str:
        return f'{bsfs.typename(self)}(N={len(self._embeds)}, restklasse={self._restklasse})'

    def __eq__(self, other: typing.Any) -> bool:
        return super().__eq__(other) \
           and self._thres == other._thres \
           and self._id2name == other._id2name \
           and torch.equal(self._embeds, other._embeds) \
           and torch.equal(self._targets, other._targets) \
           and self._restklasse == other._restklasse \
           and self._restidx == other._restidx

    def __hash__(self) -> int:
        return hash((super().__hash__(),
            tuple(sorted(self._id2name.items())),
            self._thres,
            tuple(self._embeds.detach().cpu().numpy().reshape(-1).tolist()),
            tuple(self._targets.detach().cpu().numpy().reshape(-1).tolist()),
            self._restklasse,
            self._restidx,
            ))

    def _classify(self, emb: torch.Tensor) -> torch.Tensor: # [Nx512] -> [N]
        # nearest neighbour approach
        dist = torch.cdist(emb, self._embeds) # pairwise distances
        best = dist.argmin(dim=1) # idx of lowest distance, per row
        labels = self._targets[best] # label (int) of nearest neighbour
        acc = dist[range(len(best)), best] < self._thres # check if distance is below threshold
        return [lbl.item() if cnd == True else self._restidx for cnd, lbl in zip(acc, labels)]

    def extract(
            self,
            subject: nodes.Entity,
            content: typing.Any,
            principals: typing.Iterable[bsfs.schema.Predicate],
            ) -> typing.Iterator[typing.Tuple[nodes.Node, bsfs.schema.Predicate, typing.Any]]:
        # check principals
        #if self.schema.predicate(bsf.depicts) not in principals:
        if self.schema.predicate(ns.bse.face) not in principals:
            # nothing to do; abort
            return
        # check content
        if len(content) == 0:
            return

        # collect embeddings
        emb = torch.vstack([face['embedding'] for face in content]).to(self._device)
        # apply classifier
        labels = self._classify(emb)
        # walk through faces
        for face, idx in zip(content, labels):
            lbl = bsfs.URI(self._id2name[idx]) # label (uri) of nearest neighbour
            if lbl == self._restklasse: # suppress
                continue
            pnode = nodes.Person(uri=lbl)
            fnode = nodes.Face(ucid=face['ucid'])
            # emit triple
            yield fnode, self.schema.predicate(bsf.depicts), pnode
            # FIXME: emit subject -> face -> fnode?
            yield subject, self.schema.predicate(ns.bse.face), fnode
            # FIXME: emit embedding?
            #yield fnode, bsf.embedding, face['embedding']

## EOF ##