aboutsummaryrefslogtreecommitdiffstats
path: root/test
diff options
context:
space:
mode:
authorMatthias Baumgartner <dev@igsor.net>2023-07-25 18:48:21 +0200
committerMatthias Baumgartner <dev@igsor.net>2023-07-25 18:48:21 +0200
commitbb8130b093e51474a7ce6f6431c7f9a02c4f930b (patch)
tree6fb1d3811686ba4520946a48020ff5cf984db946 /test
parentbe6027859c815e18b08a49ca1a45df3fc0aac301 (diff)
parent5d0ff7b2e0d1c63d9551e44ed3ffd96c695b69d9 (diff)
downloadbsie-bb8130b093e51474a7ce6f6431c7f9a02c4f930b.tar.gz
bsie-bb8130b093e51474a7ce6f6431c7f9a02c4f930b.tar.bz2
bsie-bb8130b093e51474a7ce6f6431c7f9a02c4f930b.zip
Merge branch 'mb/faces' into develop
Diffstat (limited to 'test')
-rw-r--r--test/extractor/image/face/__init__.py0
-rw-r--r--test/extractor/image/face/test_detect.py62
-rw-r--r--test/extractor/image/face/test_identify.py148
-rw-r--r--test/lib/test_naming_policy.py16
-rw-r--r--test/reader/image/load_nef.py2
-rw-r--r--test/reader/preview/load_nef.py2
-rw-r--r--test/reader/test_face.py227
7 files changed, 455 insertions, 2 deletions
diff --git a/test/extractor/image/face/__init__.py b/test/extractor/image/face/__init__.py
new file mode 100644
index 0000000..e69de29
--- /dev/null
+++ b/test/extractor/image/face/__init__.py
diff --git a/test/extractor/image/face/test_detect.py b/test/extractor/image/face/test_detect.py
new file mode 100644
index 0000000..92375a2
--- /dev/null
+++ b/test/extractor/image/face/test_detect.py
@@ -0,0 +1,62 @@
+
+# standard imports
+import contextlib
+import io
+import os
+import requests
+import unittest
+
+# bsie imports
+from bsie.extractor import base
+from bsie.reader.face import FaceExtract
+from bsie.utils import bsfs, node as _node, ns
+
+# objects to test
+from bsie.extractor.image.face.detect import FaceDetect, bsf
+
+
+## code ##
+
+class TestFaceDetect(unittest.TestCase):
+ def setUp(self):
+ # download test image
+ target = os.path.join(os.path.dirname(__file__), 'testface1.jpg')
+ if not os.path.exists(target):
+ with open(target, 'wb') as ofile:
+ ans = requests.get('https://www.bsfs.io/testdata/iepahGee1uch5ahr3ic1.jpg')
+ ofile.write(ans.content)
+
+ def test_extract(self):
+ with contextlib.redirect_stderr(io.StringIO()): # NOTE: hide warnings from facenet_pytorch
+ # setup
+ rdr = FaceExtract()
+ ext = FaceDetect()
+ subject = _node.Node(ns.bsfs.Entity)
+ content = rdr(os.path.join(os.path.dirname(__file__), 'testface1.jpg'))
+ principals = set(ext.principals)
+ face = _node.Node(ns.bsn.Face, ucid='2a7203c1515e0caa66a7461452c0b4552f1433a613cb3033e59ed2361790ad45')
+ triples = list(ext.extract(subject, content, principals))
+ # principals is bse:face
+ self.assertSetEqual(principals, {ext.schema.predicate(ns.bse.face)})
+ # check triples
+ self.assertIn((subject, ns.bse.face, face), triples)
+ self.assertIn((face, bsf.x, 575.4721153898192), triples)
+ self.assertIn((face, bsf.y, 265.3955625), triples)
+ self.assertIn((face, bsf.width, 626.3928904791771), triples)
+ self.assertIn((face, bsf.height,858.6870625), triples)
+ # check embedding
+ emb = [o for s, p, o in triples if s == face and p == bsf.embedding]
+ self.assertEqual(len(emb), 1)
+ self.assertAlmostEqual(emb[0].sum(), -1.9049968)
+ # no triples on principal mismatch
+ self.assertListEqual(list(ext.extract(subject, content, set())), [])
+ # no triples on no content
+ self.assertListEqual(list(ext.extract(subject, [], principals)), [])
+
+
+## main ##
+
+if __name__ == '__main__':
+ unittest.main()
+
+## EOF ##
diff --git a/test/extractor/image/face/test_identify.py b/test/extractor/image/face/test_identify.py
new file mode 100644
index 0000000..dde41db
--- /dev/null
+++ b/test/extractor/image/face/test_identify.py
@@ -0,0 +1,148 @@
+
+# standard imports
+import contextlib
+import io
+import os
+import unittest
+
+# external imports
+import requests
+
+# bsie imports
+from bsie.extractor import base
+from bsie.reader.face import FaceExtract
+from bsie.utils import bsfs, node as _node, ns
+
+# objects to test
+from bsie.extractor.image.face.identify import FaceIdentify, bsf
+
+
+## code ##
+
+def fetch(source, target):
+ target = os.path.join(os.path.dirname(__file__), target)
+ if not os.path.exists(target):
+ with open(target, 'wb') as ofile:
+ ans = requests.get(source)
+ ofile.write(ans.content)
+
+class TestFaceIdentify(unittest.TestCase):
+ def setUp(self):
+ # download test images
+ fetch('https://www.bsfs.io/testdata/iepahGee1uch5ahr3ic1.jpg', 'testface1.jpg')
+ fetch('https://www.bsfs.io/testdata/Woayiesae8eiL9aivoba.jpg', 'testface2.jpg')
+ fetch('https://www.bsfs.io/testdata/ATiagheiduth4So5ohxi.jpg', 'testface3.jpg')
+ # download reference vectors
+ fetch('https://www.bsfs.io/testdata/aetie3foo0faiDaiBahk.npy', 'ref_embeds.npy')
+ fetch('https://www.bsfs.io/testdata/uopoS8gei8Phiek3shei.npy', 'ref_embeds_alt1.npy')
+ fetch('https://www.bsfs.io/testdata/Otoo7ain6Ied2Iep2ein.npy', 'ref_embeds_alt2.npy')
+ fetch('https://www.bsfs.io/testdata/ie0keriChafahroeRo7i.npy', 'ref_embeds_extra.npy')
+ fetch('https://www.bsfs.io/testdata/phoophui3teeni4hieKu.csv', 'ref_mapping.csv')
+ fetch('https://www.bsfs.io/testdata/Quit4Wum8ael7Zeis4ei.csv', 'ref_mapping_alt.csv')
+ fetch('https://www.bsfs.io/testdata/Angu5cioVei5pohgh0aa.csv', 'ref_mapping_id_reuse.csv')
+ fetch('https://www.bsfs.io/testdata/ooshooK1bai5Queengae.csv', 'ref_mapping_name_reuse.csv')
+ fetch('https://www.bsfs.io/testdata/eixuepah3Ronge7oe4qu.csv', 'ref_mapping_restklasse.csv')
+
+ def test_essentials(self):
+ # setup
+ pth_embeds = os.path.join(os.path.dirname(__file__), 'ref_embeds.npy')
+ pth_embeds_alt1 = os.path.join(os.path.dirname(__file__), 'ref_embeds_alt1.npy')
+ pth_embeds_alt2 = os.path.join(os.path.dirname(__file__), 'ref_embeds_alt2.npy')
+ pth_mapping = os.path.join(os.path.dirname(__file__), 'ref_mapping.csv')
+ pth_mapping_alt = os.path.join(os.path.dirname(__file__), 'ref_mapping_alt.csv')
+ restklasse = 'https://example.com/user/fake_anon'
+ ext = FaceIdentify(pth_embeds, pth_mapping)
+ # string conversion returns class name
+ self.assertEqual(str(ext), 'FaceIdentify')
+ # representation respects number of embeddings
+ self.assertEqual(repr(ext), 'FaceIdentify(N=2, restklasse=https://example.com/user/anon)')
+ # representation respects restklasse
+ self.assertEqual(repr(FaceIdentify(pth_embeds, pth_mapping, restklasse=restklasse)),
+ 'FaceIdentify(N=2, restklasse=https://example.com/user/fake_anon)')
+ # identity
+ self.assertEqual(ext, FaceIdentify(pth_embeds, pth_mapping))
+ self.assertEqual(hash(ext), hash(FaceIdentify(pth_embeds, pth_mapping))) # FIXME!
+ # comparison respects embeddings
+ self.assertNotEqual(ext, FaceIdentify(pth_embeds_alt1, pth_mapping))
+ self.assertNotEqual(hash(ext), hash(FaceIdentify(pth_embeds_alt1, pth_mapping)))
+ self.assertNotEqual(ext, FaceIdentify(pth_embeds_alt2, pth_mapping))
+ self.assertNotEqual(hash(ext), hash(FaceIdentify(pth_embeds_alt2, pth_mapping)))
+ # comparison respects mappings
+ self.assertNotEqual(ext, FaceIdentify(pth_embeds, pth_mapping_alt))
+ self.assertNotEqual(hash(ext), hash(FaceIdentify(pth_embeds, pth_mapping_alt)))
+ # comparison respects threshold
+ self.assertNotEqual(ext, FaceIdentify(pth_embeds, pth_mapping, thres=0.1))
+ self.assertNotEqual(hash(ext), hash(FaceIdentify(pth_embeds, pth_mapping, thres=0.1)))
+ # comparison respects restklasse
+ self.assertNotEqual(ext, FaceIdentify(pth_embeds, pth_mapping, restklasse=restklasse))
+ self.assertNotEqual(hash(ext),
+ hash(FaceIdentify(pth_embeds, pth_mapping, restklasse=restklasse)))
+
+ def test_construct(self):
+ pth_embeds = os.path.join(os.path.dirname(__file__), 'ref_embeds.npy')
+ pth_mapping = os.path.join(os.path.dirname(__file__), 'ref_mapping.csv')
+ # valid construction
+ self.assertIsInstance(FaceIdentify(pth_embeds, pth_mapping), FaceIdentify)
+ # restklasse may be part of the mapping
+ ext = FaceIdentify(pth_embeds, os.path.join(os.path.dirname(__file__), 'ref_mapping_restklasse.csv'))
+ self.assertIsInstance(ext, FaceIdentify)
+ self.assertEqual(ext._restidx, 1)
+ # pass invalid mapping (name re-use)
+ self.assertRaises(Exception, FaceIdentify, pth_embeds,
+ os.path.join(os.path.dirname(__file__), 'ref_mapping_name_reuse.csv'))
+ # pass invalid mapping (id re-use)
+ self.assertRaises(Exception, FaceIdentify, pth_embeds,
+ os.path.join(os.path.dirname(__file__), 'ref_mapping_id_reuse.csv'))
+ # pass invalid embeds (extra embeddings)
+ self.assertRaises(Exception, FaceIdentify,
+ os.path.join(os.path.dirname(__file__), 'ref_embeds_extra.npy'),
+ pth_mapping)
+
+ def test_extract(self):
+ with contextlib.redirect_stderr(io.StringIO()): # NOTE: hide warnings from facenet_pytorch
+ # setup
+ rdr = FaceExtract()
+ ext = FaceIdentify(
+ os.path.join(os.path.dirname(__file__), 'ref_embeds.npy'),
+ os.path.join(os.path.dirname(__file__), 'ref_mapping.csv'),
+ )
+ subject = _node.Node(ns.bsfs.Entity)
+ content = rdr(os.path.join(os.path.dirname(__file__), 'testface1.jpg'))
+ principals = set(ext.principals)
+ face = _node.Node(ns.bsn.Face, ucid='2a7203c1515e0caa66a7461452c0b4552f1433a613cb3033e59ed2361790ad45')
+ person = _node.Node(ns.bsn.Person, uri='https://example.com/user/Angelina_Jolie')
+ triples = list(ext.extract(subject, content, principals))
+ # principls is bse:face, bsf:depicts
+ self.assertSetEqual(set(ext.principals), {
+ ext.schema.predicate(ns.bse.face),
+ ext.schema.predicate(bsf.depicts)
+ })
+ # produces two triples ...
+ self.assertEqual(len(triples), 2)
+ # ... one if at least one person was identified
+ self.assertIn((subject, ext.schema.predicate(ns.bse.face), face), triples)
+ # ... one for each identified person
+ self.assertIn((face, ext.schema.predicate(bsf.depicts), person), triples)
+ # produces no triples if no person was identified
+ content = rdr(os.path.join(os.path.dirname(__file__), 'testface2.jpg'))
+ self.assertListEqual(list(ext.extract(subject, content, principals)), [])
+ # identifies the correct person despite somewhat similar options
+ content = rdr(os.path.join(os.path.dirname(__file__), 'testface3.jpg'))
+ face = _node.Node(ns.bsn.Face, ucid='f61fac01ef686ee05805afef1e7a10ba54c30dc1aa095d9e77d79ccdfeb40dc5')
+ triples = list(ext.extract(subject, content, principals))
+ self.assertEqual(len(triples), 2)
+ person = _node.Node(ns.bsn.Person, uri='https://example.com/user/Paul_Rudd')
+ self.assertIn((subject, ext.schema.predicate(ns.bse.face), face), triples)
+ self.assertIn((face, ext.schema.predicate(bsf.depicts), person), triples)
+ # no triples on principal mismatch
+ self.assertListEqual(list(ext.extract(subject, content, set())), [])
+ # no triples on no content
+ self.assertListEqual(list(ext.extract(subject, [], principals)), [])
+
+
+## main ##
+
+if __name__ == '__main__':
+ unittest.main()
+
+## EOF ##
diff --git a/test/lib/test_naming_policy.py b/test/lib/test_naming_policy.py
index 09fd6f6..a078fbd 100644
--- a/test/lib/test_naming_policy.py
+++ b/test/lib/test_naming_policy.py
@@ -35,6 +35,10 @@ class TestDefaultNamingPolicy(unittest.TestCase):
self.assertEqual(policy.handle_node(
Node(ns.bsn.Tag, label='hello')).uri,
URI('http://example.com/me/tag#hello'))
+ # processes bsn:Face
+ self.assertEqual(policy.handle_node(
+ Node(ns.bsn.Face, ucid='hello')).uri,
+ URI('http://example.com/me/face#hello'))
# raises an exception on unknown types
self.assertRaises(errors.ProgrammingError, policy.handle_node,
Node(ns.bsn.Invalid, ucid='abc123cba', size=123))
@@ -99,6 +103,18 @@ class TestDefaultNamingPolicy(unittest.TestCase):
self.assertTrue(policy.name_tag(
Node(ns.bsn.Tag,)).uri.startswith('http://example.com/me/tag#'))
+ def test_name_face(self):
+ # setup
+ policy = DefaultNamingPolicy('http://example.com', 'me')
+ # name_face uses ucid
+ self.assertEqual(policy.name_face(
+ Node(ns.bsn.Face, ucid='hello_world')).uri,
+ URI('http://example.com/me/face#hello_world'))
+ # name_face falls back to a random guid
+ self.assertTrue(policy.name_face(
+ Node(ns.bsn.Face)).uri.startswith('http://example.com/me/face#'))
+
+
class TestNamingPolicyIterator(unittest.TestCase):
def test_call(self): # NOTE: We test NamingPolicy.__call__ here
diff --git a/test/reader/image/load_nef.py b/test/reader/image/load_nef.py
index 02be470..ded9b6e 100644
--- a/test/reader/image/load_nef.py
+++ b/test/reader/image/load_nef.py
@@ -6,7 +6,7 @@ import os
import requests
# constants
-IMAGE_URL = 'http://igsor.net/eik7AhvohghaeN5.nef'
+IMAGE_URL = 'https://www.bsfs.io/testdata/eik7AhvohghaeN5.nef'
## code ##
diff --git a/test/reader/preview/load_nef.py b/test/reader/preview/load_nef.py
index 02be470..ded9b6e 100644
--- a/test/reader/preview/load_nef.py
+++ b/test/reader/preview/load_nef.py
@@ -6,7 +6,7 @@ import os
import requests
# constants
-IMAGE_URL = 'http://igsor.net/eik7AhvohghaeN5.nef'
+IMAGE_URL = 'https://www.bsfs.io/testdata/eik7AhvohghaeN5.nef'
## code ##
diff --git a/test/reader/test_face.py b/test/reader/test_face.py
new file mode 100644
index 0000000..6164143
--- /dev/null
+++ b/test/reader/test_face.py
@@ -0,0 +1,227 @@
+
+# standard imports
+import contextlib
+import io
+import os
+import unittest
+
+# external imports
+import requests
+import PIL.Image
+
+# bsie imports
+from bsie.utils import errors
+
+# objects to test
+from bsie.reader.face import FaceExtract
+
+
+## code ##
+
+def fetch(source, target):
+ target = os.path.join(os.path.dirname(__file__), target)
+ if not os.path.exists(target):
+ with open(target, 'wb') as ofile:
+ ans = requests.get(source)
+ ofile.write(ans.content)
+
+class TestFaceExtract(unittest.TestCase):
+ def setUp(self):
+ # download test image w/o face
+ fetch('https://www.bsfs.io/testdata/Quiejoore1ahxa9jahma.jpg', 'faces-noface.jpg')
+ # download test image w/ face
+ fetch('https://www.bsfs.io/testdata/ONekai7Ohphooch3aege.jpg', 'faces-ivan.jpg')
+ # download errounous images
+ fetch('https://www.bsfs.io/testdata/kie7wo7sheix7thieG2f.gif', 'faces-dimerr.gif')
+ fetch('https://www.bsfs.io/testdata/Mee1aunooneoSaexohTh.gif', 'faces-valueerr.gif')
+
+ def test_essentials(self):
+ # repr respects min_face_prob
+ self.assertEqual(repr(FaceExtract(min_face_prob=1.0)), 'FaceExtract(1.0)')
+ self.assertEqual(repr(FaceExtract(min_face_prob=0.5)), 'FaceExtract(0.5)')
+ # repr respects type
+ class Foo(FaceExtract): pass
+ self.assertEqual(repr(Foo(min_face_prob=0.5)), 'Foo(0.5)')
+
+ # comparison respects type
+ class Foo(): pass
+ self.assertNotEqual(FaceExtract(), 1234)
+ self.assertNotEqual(hash(FaceExtract()), hash(1234))
+ self.assertNotEqual(FaceExtract(), 'hello')
+ self.assertNotEqual(hash(FaceExtract()), hash('hello'))
+ self.assertNotEqual(FaceExtract(), Foo())
+ self.assertNotEqual(hash(FaceExtract()), hash(Foo()))
+ # comparison respects constructor arguments (except cuda_device)
+ self.assertEqual(FaceExtract(), FaceExtract())
+ self.assertEqual(hash(FaceExtract()), hash(FaceExtract()))
+ self.assertNotEqual(FaceExtract(), FaceExtract(target_size=10))
+ self.assertNotEqual(hash(FaceExtract()), hash(FaceExtract(target_size=10)))
+ self.assertNotEqual(FaceExtract(), FaceExtract(min_face_size=10))
+ self.assertNotEqual(hash(FaceExtract()), hash(FaceExtract(min_face_size=10)))
+ self.assertNotEqual(FaceExtract(), FaceExtract(min_face_prob=1.))
+ self.assertNotEqual(hash(FaceExtract()), hash(FaceExtract(min_face_prob=1.)))
+ self.assertNotEqual(FaceExtract(), FaceExtract(ext_face_size=100))
+ self.assertNotEqual(hash(FaceExtract()), hash(FaceExtract(ext_face_size=100)))
+ self.assertNotEqual(FaceExtract(), FaceExtract(thresholds=[0.1,0.1,0.1]))
+ self.assertNotEqual(hash(FaceExtract()), hash(FaceExtract(thresholds=[0.1,0.1,0.1])))
+ self.assertNotEqual(FaceExtract(), FaceExtract(factor=1.))
+ self.assertNotEqual(hash(FaceExtract()), hash(FaceExtract(factor=1.)))
+ # comparison ignores cuda_device
+ self.assertEqual(FaceExtract(), FaceExtract(cuda_device='cuda:123'))
+ self.assertEqual(hash(FaceExtract()), hash(FaceExtract(cuda_device='cuda:123')))
+
+ def test_preprocess(self):
+ testpath = os.path.join(os.path.dirname(__file__), 'faces-noface.jpg')
+ with PIL.Image.open(testpath) as img:
+ self.assertEqual(img.size, (199, 148))
+ # landscape, downscale, no rotation
+ img, denorm = FaceExtract.preprocess(PIL.Image.open(testpath), 100, None)
+ self.assertEqual(img.size, (100, 74))
+ self.assertEqual(denorm((10,10)), (10*1.99, 10*2.0))
+ # landscape, upscale, no rotation
+ img, denorm = FaceExtract.preprocess(PIL.Image.open(testpath), 398, None)
+ self.assertEqual(img.size, (398, 296))
+ self.assertEqual(denorm((10,10)), (10*0.5, 10*0.5))
+ # landscape, downscale, 90cw
+ img, denorm = FaceExtract.preprocess(PIL.Image.open(testpath), 100, 90)
+ self.assertEqual(img.size, (74, 100))
+ self.assertEqual(denorm((10,10)), (10.0*1.99, 64*2.0))
+ # landscape, upscale, 90cw
+ img, denorm = FaceExtract.preprocess(PIL.Image.open(testpath), 398, 90)
+ self.assertEqual(img.size, (296, 398))
+ self.assertEqual(denorm((10,10)), (10*0.5, 286*0.5))
+ # landscape, downscale, 90ccw
+ img, denorm = FaceExtract.preprocess(PIL.Image.open(testpath), 100, 270)
+ self.assertEqual(img.size, (74, 100))
+ self.assertEqual(denorm((10,10)), (90*1.99, 10*2.0))
+ # landscape, upscale, 90ccw
+ img, denorm = FaceExtract.preprocess(PIL.Image.open(testpath), 398, 270)
+ self.assertEqual(img.size, (296, 398))
+ self.assertEqual(denorm((10,10)), (388*0.5, 10*0.5))
+ # landscape, downscale, 180
+ img, denorm = FaceExtract.preprocess(PIL.Image.open(testpath), 100, 180)
+ self.assertEqual(img.size, (100, 74))
+ self.assertEqual(denorm((10,10)), (90*1.99, 64*2.0))
+ # landscape, upscale, 180
+ img, denorm = FaceExtract.preprocess(PIL.Image.open(testpath), 398, 180)
+ self.assertEqual(img.size, (398, 296))
+ self.assertEqual(denorm((10,10)), (388*0.5, 286*0.5))
+ # portrait, downscale, no rotation
+ img, denorm = FaceExtract.preprocess(PIL.Image.open(testpath).rotate(90, expand=True), 100, None)
+ self.assertEqual(img.size, (74, 100))
+ self.assertEqual(denorm((10,10)), (10*2.0, 10*1.99))
+ # portrait, upscale, no rotation
+ img, denorm = FaceExtract.preprocess(PIL.Image.open(testpath).rotate(90, expand=True), 398, None)
+ self.assertEqual(img.size, (296, 398))
+ self.assertEqual(denorm((10,10)), (10*0.5, 10*0.5))
+ # portrait, downscale, 90cw
+ img, denorm = FaceExtract.preprocess(PIL.Image.open(testpath).rotate(90, expand=True), 100, 90)
+ self.assertEqual(img.size, (100, 74))
+ self.assertEqual(denorm((10,10)), (10.0*2.0, 90*1.99))
+ # portrait, upscale, 90cw
+ img, denorm = FaceExtract.preprocess(PIL.Image.open(testpath).rotate(90, expand=True), 398, 90)
+ self.assertEqual(img.size, (398, 296))
+ self.assertEqual(denorm((10,10)), (10*0.5, 388*0.5))
+ # portrait, downscale, 90ccw
+ img, denorm = FaceExtract.preprocess(PIL.Image.open(testpath).rotate(90, expand=True), 100, 270)
+ self.assertEqual(img.size, (100, 74))
+ self.assertEqual(denorm((10,10)), (64*2.0, 10*1.99))
+ # portrait, upscale, 90ccw
+ img, denorm = FaceExtract.preprocess(PIL.Image.open(testpath).rotate(90, expand=True), 398, 270)
+ self.assertEqual(img.size, (398, 296))
+ self.assertEqual(denorm((10,10)), (286*0.5, 10*0.5))
+ # portrait, downscale, 180
+ img, denorm = FaceExtract.preprocess(PIL.Image.open(testpath).rotate(90, expand=True), 100, 180)
+ self.assertEqual(img.size, (74, 100))
+ self.assertEqual(denorm((10,10)), (64*2.0, 90*1.99))
+ # portrait, upscale, 180
+ img, denorm = FaceExtract.preprocess(PIL.Image.open(testpath).rotate(90, expand=True), 398, 180)
+ self.assertEqual(img.size, (296, 398))
+ self.assertEqual(denorm((10,10)), (286*0.5, 388*0.5))
+
+ # square image
+ testpath = os.path.join(os.path.dirname(__file__), 'faces-ivan.jpg')
+ with PIL.Image.open(testpath) as img:
+ self.assertEqual(img.size, (561, 561))
+ # square, downscale, no rotation
+ img, denorm = FaceExtract.preprocess(PIL.Image.open(testpath), 51, None)
+ self.assertEqual(img.size, (51, 51))
+ self.assertEqual(denorm((10,10)), (10*11, 10*11))
+ # square, upscale, no rotation
+ img, denorm = FaceExtract.preprocess(PIL.Image.open(testpath), 1122, None)
+ self.assertEqual(img.size, (1122, 1122))
+ self.assertEqual(denorm((10,10)), (10*0.5, 10*0.5))
+ # square, downscale, 90cw
+ img, denorm = FaceExtract.preprocess(PIL.Image.open(testpath), 51, 90)
+ self.assertEqual(img.size, (51, 51))
+ self.assertEqual(denorm((10,10)), (10.0*11, 41*11))
+ # square, upscale, 90cw
+ img, denorm = FaceExtract.preprocess(PIL.Image.open(testpath), 1122, 90)
+ self.assertEqual(img.size, (1122, 1122))
+ self.assertEqual(denorm((10,10)), (10*0.5, 1112*0.5))
+ # square, downscale, 90ccw
+ img, denorm = FaceExtract.preprocess(PIL.Image.open(testpath), 51, 270)
+ self.assertEqual(img.size, (51, 51))
+ self.assertEqual(denorm((10,10)), (41*11, 10*11))
+ # square, upscale, 90ccw
+ img, denorm = FaceExtract.preprocess(PIL.Image.open(testpath), 1122, 270)
+ self.assertEqual(img.size, (1122, 1122))
+ self.assertEqual(denorm((10,10)), (1112*0.5, 10*0.5))
+ # square, downscale, 180
+ img, denorm = FaceExtract.preprocess(PIL.Image.open(testpath), 51, 180)
+ self.assertEqual(img.size, (51, 51))
+ self.assertEqual(denorm((10,10)), (41*11, 41*11))
+ # square, upscale, 180
+ img, denorm = FaceExtract.preprocess(PIL.Image.open(testpath), 1122, 180)
+ self.assertEqual(img.size, (1122, 1122))
+ self.assertEqual(denorm((10,10)), (1112*0.5, 1112*0.5))
+
+ def test_call(self):
+ with contextlib.redirect_stderr(io.StringIO()): # NOTE: hide warnings from facenet_pytorch
+ rdr = FaceExtract()
+ # discards non-image files
+ self.assertRaises(errors.UnsupportedFileFormatError, rdr,
+ __file__)
+ # raises on invalid image
+ self.assertRaises(errors.UnsupportedFileFormatError, rdr,
+ os.path.join(os.path.dirname(__file__), 'testimage_exif_corrupted.jpg'))
+ # raises on missing file
+ self.assertRaises(errors.ReaderError, rdr,
+ os.path.join(os.path.dirname(__file__), 'invalid.jpg'))
+ # raises on dimensions error (pytorch RuntimeError)
+ self.assertRaises(errors.ReaderError, rdr,
+ os.path.join(os.path.dirname(__file__), 'faces-dimerr.gif'))
+ # raises on content error (ValueError)
+ self.assertRaises(errors.ReaderError, rdr,
+ os.path.join(os.path.dirname(__file__), 'faces-valueerr.gif'))
+
+ # may return empty list
+ self.assertListEqual(FaceExtract(min_face_prob=1)(
+ os.path.join(os.path.dirname(__file__), 'faces-noface.jpg')), [])
+ self.assertListEqual(FaceExtract(min_face_prob=1)(
+ os.path.join(os.path.dirname(__file__), 'faces-ivan.jpg')), [])
+ # returns faces
+ faces = rdr(os.path.join(os.path.dirname(__file__), 'faces-ivan.jpg'))
+ # check if face was detected
+ self.assertEqual(len(faces), 1)
+ # check ucid
+ self.assertSetEqual({f['ucid'] for f in faces}, {
+ '926dc1684dd453aa2c3c8daf1c82ecf918514ef0de416b6b842235c23bec32ee',
+ })
+ # check embedding
+ for face in faces:
+ self.assertEqual(face['embedding'].shape, (512, ))
+ # check bbox
+ self.assertAlmostEqual(faces[0]['x'], 275.8, 2)
+ self.assertAlmostEqual(faces[0]['y'], 91.67, 2)
+ self.assertAlmostEqual(faces[0]['width'], 50.5, 2)
+ self.assertAlmostEqual(faces[0]['height'], 65.42, 2)
+
+
+
+## main ##
+
+if __name__ == '__main__':
+ unittest.main()
+
+## EOF ##