aboutsummaryrefslogtreecommitdiffstats
path: root/test/extractor/image/face/test_detect.py
blob: 92375a284b0633d98d908beb402730480e1ad7e9 (plain)
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62

# standard imports
import contextlib
import io
import os
import requests
import unittest

# bsie imports
from bsie.extractor import base
from bsie.reader.face import FaceExtract
from bsie.utils import bsfs, node as _node, ns

# objects to test
from bsie.extractor.image.face.detect import FaceDetect, bsf


## code ##

class TestFaceDetect(unittest.TestCase):
    def setUp(self):
        # download test image
        target = os.path.join(os.path.dirname(__file__), 'testface1.jpg')
        if not os.path.exists(target):
            with open(target, 'wb') as ofile:
                ans = requests.get('https://www.bsfs.io/testdata/iepahGee1uch5ahr3ic1.jpg')
                ofile.write(ans.content)

    def test_extract(self):
        with contextlib.redirect_stderr(io.StringIO()): # NOTE: hide warnings from facenet_pytorch
            # setup
            rdr = FaceExtract()
            ext = FaceDetect()
            subject = _node.Node(ns.bsfs.Entity)
            content = rdr(os.path.join(os.path.dirname(__file__), 'testface1.jpg'))
            principals = set(ext.principals)
            face = _node.Node(ns.bsn.Face, ucid='2a7203c1515e0caa66a7461452c0b4552f1433a613cb3033e59ed2361790ad45')
            triples = list(ext.extract(subject, content, principals))
            # principals is bse:face
            self.assertSetEqual(principals, {ext.schema.predicate(ns.bse.face)})
            # check triples
            self.assertIn((subject, ns.bse.face, face), triples)
            self.assertIn((face, bsf.x, 575.4721153898192), triples)
            self.assertIn((face, bsf.y, 265.3955625), triples)
            self.assertIn((face, bsf.width, 626.3928904791771), triples)
            self.assertIn((face, bsf.height,858.6870625), triples)
            # check embedding
            emb = [o for s, p, o in triples if s == face and p == bsf.embedding]
            self.assertEqual(len(emb), 1)
            self.assertAlmostEqual(emb[0].sum(), -1.9049968)
            # no triples on principal mismatch
            self.assertListEqual(list(ext.extract(subject, content, set())), [])
            # no triples on no content
            self.assertListEqual(list(ext.extract(subject, [], principals)), [])


## main ##

if __name__ == '__main__':
    unittest.main()

## EOF ##