aboutsummaryrefslogtreecommitdiffstats
path: root/test/extractor/image/face/test_identify.py
blob: 2d52353f36c0e37953cc3a31f4c396cee2849626 (plain)
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
83
84
85
86
87
88
89
90
91
92
93
94
95
96
97
98
99
100
101
102
103
104
105
106
107
108
109
110
111
112
113
114
115
116
117
118
119
120
121
122
123
124
125
126
127
128
129
130
131
132
133
134
135
136
137
138
139
140
141
142
143
144
145
146
147
148
149
150
151

# standard imports
import contextlib
import io
import os
import unittest

# external imports
import requests

# bsie imports
from bsie.extractor import base
from bsie.matcher import nodes
from bsie.reader.face import FaceExtract
from bsie.utils import bsfs, ns

# objects to test
from bsie.extractor.image.face.identify import FaceIdentify, bsf


## code ##

def fetch(source, target):
    target = os.path.join(os.path.dirname(__file__), target)
    if not os.path.exists(target):
        with open(target, 'wb') as ofile:
            ans = requests.get(source)
            ofile.write(ans.content)

class TestFaceIdentify(unittest.TestCase):
    def setUp(self):
        # download test images
        fetch('https://www.bsfs.io/testdata/iepahGee1uch5ahr3ic1.jpg', 'testface1.jpg')
        fetch('https://www.bsfs.io/testdata/Woayiesae8eiL9aivoba.jpg', 'testface2.jpg')
        fetch('https://www.bsfs.io/testdata/ATiagheiduth4So5ohxi.jpg', 'testface3.jpg')
        # download reference vectors
        fetch('https://www.bsfs.io/testdata/aetie3foo0faiDaiBahk.npy', 'ref_embeds.npy')
        fetch('https://www.bsfs.io/testdata/uopoS8gei8Phiek3shei.npy', 'ref_embeds_alt1.npy')
        fetch('https://www.bsfs.io/testdata/Otoo7ain6Ied2Iep2ein.npy', 'ref_embeds_alt2.npy')
        fetch('https://www.bsfs.io/testdata/ie0keriChafahroeRo7i.npy', 'ref_embeds_extra.npy')
        fetch('https://www.bsfs.io/testdata/phoophui3teeni4hieKu.csv', 'ref_mapping.csv')
        fetch('https://www.bsfs.io/testdata/Quit4Wum8ael7Zeis4ei.csv', 'ref_mapping_alt.csv')
        fetch('https://www.bsfs.io/testdata/Angu5cioVei5pohgh0aa.csv', 'ref_mapping_id_reuse.csv')
        fetch('https://www.bsfs.io/testdata/ooshooK1bai5Queengae.csv', 'ref_mapping_name_reuse.csv')
        fetch('https://www.bsfs.io/testdata/eixuepah3Ronge7oe4qu.csv', 'ref_mapping_restklasse.csv')

    def test_essentials(self):
        # setup
        pth_embeds = os.path.join(os.path.dirname(__file__), 'ref_embeds.npy')
        pth_embeds_alt1 = os.path.join(os.path.dirname(__file__), 'ref_embeds_alt1.npy')
        pth_embeds_alt2 = os.path.join(os.path.dirname(__file__), 'ref_embeds_alt2.npy')
        pth_mapping = os.path.join(os.path.dirname(__file__), 'ref_mapping.csv')
        pth_mapping_alt = os.path.join(os.path.dirname(__file__), 'ref_mapping_alt.csv')
        restklasse = 'https://example.com/user/fake_anon'
        ext = FaceIdentify(pth_embeds, pth_mapping)
        # string conversion returns class name
        self.assertEqual(str(ext), 'FaceIdentify')
        # representation respects number of embeddings
        self.assertEqual(repr(ext), 'FaceIdentify(N=2, restklasse=https://example.com/user/anon)')
        # representation respects restklasse
        self.assertEqual(repr(FaceIdentify(pth_embeds, pth_mapping, restklasse=restklasse)),
            'FaceIdentify(N=2, restklasse=https://example.com/user/fake_anon)')
        # identity
        self.assertEqual(ext, FaceIdentify(pth_embeds, pth_mapping))
        self.assertEqual(hash(ext), hash(FaceIdentify(pth_embeds, pth_mapping))) # FIXME!
        # comparison respects embeddings
        self.assertNotEqual(ext, FaceIdentify(pth_embeds_alt1, pth_mapping))
        self.assertNotEqual(hash(ext), hash(FaceIdentify(pth_embeds_alt1, pth_mapping)))
        self.assertNotEqual(ext, FaceIdentify(pth_embeds_alt2, pth_mapping))
        self.assertNotEqual(hash(ext), hash(FaceIdentify(pth_embeds_alt2, pth_mapping)))
        # comparison respects mappings
        self.assertNotEqual(ext, FaceIdentify(pth_embeds, pth_mapping_alt))
        self.assertNotEqual(hash(ext), hash(FaceIdentify(pth_embeds, pth_mapping_alt)))
        # comparison respects threshold
        self.assertNotEqual(ext, FaceIdentify(pth_embeds, pth_mapping, thres=0.1))
        self.assertNotEqual(hash(ext), hash(FaceIdentify(pth_embeds, pth_mapping, thres=0.1)))
        # comparison respects restklasse
        self.assertNotEqual(ext, FaceIdentify(pth_embeds, pth_mapping, restklasse=restklasse))
        self.assertNotEqual(hash(ext),
            hash(FaceIdentify(pth_embeds, pth_mapping, restklasse=restklasse)))

    def test_construct(self):
        pth_embeds = os.path.join(os.path.dirname(__file__), 'ref_embeds.npy')
        pth_mapping = os.path.join(os.path.dirname(__file__), 'ref_mapping.csv')
        # valid construction
        self.assertIsInstance(FaceIdentify(pth_embeds, pth_mapping), FaceIdentify)
        # restklasse may be part of the mapping
        ext = FaceIdentify(pth_embeds, os.path.join(os.path.dirname(__file__), 'ref_mapping_restklasse.csv'))
        self.assertIsInstance(ext, FaceIdentify)
        self.assertEqual(ext._restidx, 1)
        # pass invalid mapping (name re-use)
        self.assertRaises(Exception, FaceIdentify, pth_embeds,
            os.path.join(os.path.dirname(__file__), 'ref_mapping_name_reuse.csv'))
        # pass invalid mapping (id re-use)
        self.assertRaises(Exception, FaceIdentify, pth_embeds,
            os.path.join(os.path.dirname(__file__), 'ref_mapping_id_reuse.csv'))
        # pass invalid embeds (extra embeddings)
        self.assertRaises(Exception, FaceIdentify,
            os.path.join(os.path.dirname(__file__), 'ref_embeds_extra.npy'),
            pth_mapping)

    def test_extract(self):
        with contextlib.redirect_stderr(io.StringIO()): # NOTE: hide warnings from facenet_pytorch
            # setup
            rdr = FaceExtract()
            ext = FaceIdentify(
                os.path.join(os.path.dirname(__file__), 'ref_embeds.npy'),
                os.path.join(os.path.dirname(__file__), 'ref_mapping.csv'),
                )
            subject = nodes.Entity(ucid='abc123')
            content = rdr(os.path.join(os.path.dirname(__file__), 'testface1.jpg'))
            principals = set(ext.principals)
            face = nodes.Face(
                ucid='2a7203c1515e0caa66a7461452c0b4552f1433a613cb3033e59ed2361790ad45')
            person = nodes.Person(uri='https://example.com/user/Angelina_Jolie')
            triples = list(ext.extract(subject, content, principals))
            # principls is bse:face, bsf:depicts
            self.assertSetEqual(set(ext.principals), {
                ext.schema.predicate(ns.bse.face),
                ext.schema.predicate(bsf.depicts)
                })
            # produces two triples ...
            self.assertEqual(len(triples), 2)
            # ... one if at least one person was identified
            self.assertIn((subject, ext.schema.predicate(ns.bse.face), face), triples)
            # ... one for each identified person
            self.assertIn((face, ext.schema.predicate(bsf.depicts), person), triples)
            # produces no triples if no person was identified
            content = rdr(os.path.join(os.path.dirname(__file__), 'testface2.jpg'))
            self.assertListEqual(list(ext.extract(subject, content, principals)), [])
            # identifies the correct person despite somewhat similar options
            content = rdr(os.path.join(os.path.dirname(__file__), 'testface3.jpg'))
            face = nodes.Face(
                ucid='f61fac01ef686ee05805afef1e7a10ba54c30dc1aa095d9e77d79ccdfeb40dc5')
            triples = list(ext.extract(subject, content, principals))
            self.assertEqual(len(triples), 2)
            person = nodes.Person(uri='https://example.com/user/Paul_Rudd')
            self.assertIn((subject, ext.schema.predicate(ns.bse.face), face), triples)
            self.assertIn((face, ext.schema.predicate(bsf.depicts), person), triples)
            # no triples on principal mismatch
            self.assertListEqual(list(ext.extract(subject, content, set())), [])
            # no triples on no content
            self.assertListEqual(list(ext.extract(subject, [], principals)), [])


## main ##

if __name__ == '__main__':
    unittest.main()

## EOF ##