aboutsummaryrefslogtreecommitdiffstats
path: root/test/reader/test_face.py
diff options
context:
space:
mode:
Diffstat (limited to 'test/reader/test_face.py')
-rw-r--r--test/reader/test_face.py227
1 files changed, 227 insertions, 0 deletions
diff --git a/test/reader/test_face.py b/test/reader/test_face.py
new file mode 100644
index 0000000..6164143
--- /dev/null
+++ b/test/reader/test_face.py
@@ -0,0 +1,227 @@
+
+# standard imports
+import contextlib
+import io
+import os
+import unittest
+
+# external imports
+import requests
+import PIL.Image
+
+# bsie imports
+from bsie.utils import errors
+
+# objects to test
+from bsie.reader.face import FaceExtract
+
+
+## code ##
+
+def fetch(source, target):
+ target = os.path.join(os.path.dirname(__file__), target)
+ if not os.path.exists(target):
+ with open(target, 'wb') as ofile:
+ ans = requests.get(source)
+ ofile.write(ans.content)
+
+class TestFaceExtract(unittest.TestCase):
+ def setUp(self):
+ # download test image w/o face
+ fetch('https://www.bsfs.io/testdata/Quiejoore1ahxa9jahma.jpg', 'faces-noface.jpg')
+ # download test image w/ face
+ fetch('https://www.bsfs.io/testdata/ONekai7Ohphooch3aege.jpg', 'faces-ivan.jpg')
+ # download errounous images
+ fetch('https://www.bsfs.io/testdata/kie7wo7sheix7thieG2f.gif', 'faces-dimerr.gif')
+ fetch('https://www.bsfs.io/testdata/Mee1aunooneoSaexohTh.gif', 'faces-valueerr.gif')
+
+ def test_essentials(self):
+ # repr respects min_face_prob
+ self.assertEqual(repr(FaceExtract(min_face_prob=1.0)), 'FaceExtract(1.0)')
+ self.assertEqual(repr(FaceExtract(min_face_prob=0.5)), 'FaceExtract(0.5)')
+ # repr respects type
+ class Foo(FaceExtract): pass
+ self.assertEqual(repr(Foo(min_face_prob=0.5)), 'Foo(0.5)')
+
+ # comparison respects type
+ class Foo(): pass
+ self.assertNotEqual(FaceExtract(), 1234)
+ self.assertNotEqual(hash(FaceExtract()), hash(1234))
+ self.assertNotEqual(FaceExtract(), 'hello')
+ self.assertNotEqual(hash(FaceExtract()), hash('hello'))
+ self.assertNotEqual(FaceExtract(), Foo())
+ self.assertNotEqual(hash(FaceExtract()), hash(Foo()))
+ # comparison respects constructor arguments (except cuda_device)
+ self.assertEqual(FaceExtract(), FaceExtract())
+ self.assertEqual(hash(FaceExtract()), hash(FaceExtract()))
+ self.assertNotEqual(FaceExtract(), FaceExtract(target_size=10))
+ self.assertNotEqual(hash(FaceExtract()), hash(FaceExtract(target_size=10)))
+ self.assertNotEqual(FaceExtract(), FaceExtract(min_face_size=10))
+ self.assertNotEqual(hash(FaceExtract()), hash(FaceExtract(min_face_size=10)))
+ self.assertNotEqual(FaceExtract(), FaceExtract(min_face_prob=1.))
+ self.assertNotEqual(hash(FaceExtract()), hash(FaceExtract(min_face_prob=1.)))
+ self.assertNotEqual(FaceExtract(), FaceExtract(ext_face_size=100))
+ self.assertNotEqual(hash(FaceExtract()), hash(FaceExtract(ext_face_size=100)))
+ self.assertNotEqual(FaceExtract(), FaceExtract(thresholds=[0.1,0.1,0.1]))
+ self.assertNotEqual(hash(FaceExtract()), hash(FaceExtract(thresholds=[0.1,0.1,0.1])))
+ self.assertNotEqual(FaceExtract(), FaceExtract(factor=1.))
+ self.assertNotEqual(hash(FaceExtract()), hash(FaceExtract(factor=1.)))
+ # comparison ignores cuda_device
+ self.assertEqual(FaceExtract(), FaceExtract(cuda_device='cuda:123'))
+ self.assertEqual(hash(FaceExtract()), hash(FaceExtract(cuda_device='cuda:123')))
+
+ def test_preprocess(self):
+ testpath = os.path.join(os.path.dirname(__file__), 'faces-noface.jpg')
+ with PIL.Image.open(testpath) as img:
+ self.assertEqual(img.size, (199, 148))
+ # landscape, downscale, no rotation
+ img, denorm = FaceExtract.preprocess(PIL.Image.open(testpath), 100, None)
+ self.assertEqual(img.size, (100, 74))
+ self.assertEqual(denorm((10,10)), (10*1.99, 10*2.0))
+ # landscape, upscale, no rotation
+ img, denorm = FaceExtract.preprocess(PIL.Image.open(testpath), 398, None)
+ self.assertEqual(img.size, (398, 296))
+ self.assertEqual(denorm((10,10)), (10*0.5, 10*0.5))
+ # landscape, downscale, 90cw
+ img, denorm = FaceExtract.preprocess(PIL.Image.open(testpath), 100, 90)
+ self.assertEqual(img.size, (74, 100))
+ self.assertEqual(denorm((10,10)), (10.0*1.99, 64*2.0))
+ # landscape, upscale, 90cw
+ img, denorm = FaceExtract.preprocess(PIL.Image.open(testpath), 398, 90)
+ self.assertEqual(img.size, (296, 398))
+ self.assertEqual(denorm((10,10)), (10*0.5, 286*0.5))
+ # landscape, downscale, 90ccw
+ img, denorm = FaceExtract.preprocess(PIL.Image.open(testpath), 100, 270)
+ self.assertEqual(img.size, (74, 100))
+ self.assertEqual(denorm((10,10)), (90*1.99, 10*2.0))
+ # landscape, upscale, 90ccw
+ img, denorm = FaceExtract.preprocess(PIL.Image.open(testpath), 398, 270)
+ self.assertEqual(img.size, (296, 398))
+ self.assertEqual(denorm((10,10)), (388*0.5, 10*0.5))
+ # landscape, downscale, 180
+ img, denorm = FaceExtract.preprocess(PIL.Image.open(testpath), 100, 180)
+ self.assertEqual(img.size, (100, 74))
+ self.assertEqual(denorm((10,10)), (90*1.99, 64*2.0))
+ # landscape, upscale, 180
+ img, denorm = FaceExtract.preprocess(PIL.Image.open(testpath), 398, 180)
+ self.assertEqual(img.size, (398, 296))
+ self.assertEqual(denorm((10,10)), (388*0.5, 286*0.5))
+ # portrait, downscale, no rotation
+ img, denorm = FaceExtract.preprocess(PIL.Image.open(testpath).rotate(90, expand=True), 100, None)
+ self.assertEqual(img.size, (74, 100))
+ self.assertEqual(denorm((10,10)), (10*2.0, 10*1.99))
+ # portrait, upscale, no rotation
+ img, denorm = FaceExtract.preprocess(PIL.Image.open(testpath).rotate(90, expand=True), 398, None)
+ self.assertEqual(img.size, (296, 398))
+ self.assertEqual(denorm((10,10)), (10*0.5, 10*0.5))
+ # portrait, downscale, 90cw
+ img, denorm = FaceExtract.preprocess(PIL.Image.open(testpath).rotate(90, expand=True), 100, 90)
+ self.assertEqual(img.size, (100, 74))
+ self.assertEqual(denorm((10,10)), (10.0*2.0, 90*1.99))
+ # portrait, upscale, 90cw
+ img, denorm = FaceExtract.preprocess(PIL.Image.open(testpath).rotate(90, expand=True), 398, 90)
+ self.assertEqual(img.size, (398, 296))
+ self.assertEqual(denorm((10,10)), (10*0.5, 388*0.5))
+ # portrait, downscale, 90ccw
+ img, denorm = FaceExtract.preprocess(PIL.Image.open(testpath).rotate(90, expand=True), 100, 270)
+ self.assertEqual(img.size, (100, 74))
+ self.assertEqual(denorm((10,10)), (64*2.0, 10*1.99))
+ # portrait, upscale, 90ccw
+ img, denorm = FaceExtract.preprocess(PIL.Image.open(testpath).rotate(90, expand=True), 398, 270)
+ self.assertEqual(img.size, (398, 296))
+ self.assertEqual(denorm((10,10)), (286*0.5, 10*0.5))
+ # portrait, downscale, 180
+ img, denorm = FaceExtract.preprocess(PIL.Image.open(testpath).rotate(90, expand=True), 100, 180)
+ self.assertEqual(img.size, (74, 100))
+ self.assertEqual(denorm((10,10)), (64*2.0, 90*1.99))
+ # portrait, upscale, 180
+ img, denorm = FaceExtract.preprocess(PIL.Image.open(testpath).rotate(90, expand=True), 398, 180)
+ self.assertEqual(img.size, (296, 398))
+ self.assertEqual(denorm((10,10)), (286*0.5, 388*0.5))
+
+ # square image
+ testpath = os.path.join(os.path.dirname(__file__), 'faces-ivan.jpg')
+ with PIL.Image.open(testpath) as img:
+ self.assertEqual(img.size, (561, 561))
+ # square, downscale, no rotation
+ img, denorm = FaceExtract.preprocess(PIL.Image.open(testpath), 51, None)
+ self.assertEqual(img.size, (51, 51))
+ self.assertEqual(denorm((10,10)), (10*11, 10*11))
+ # square, upscale, no rotation
+ img, denorm = FaceExtract.preprocess(PIL.Image.open(testpath), 1122, None)
+ self.assertEqual(img.size, (1122, 1122))
+ self.assertEqual(denorm((10,10)), (10*0.5, 10*0.5))
+ # square, downscale, 90cw
+ img, denorm = FaceExtract.preprocess(PIL.Image.open(testpath), 51, 90)
+ self.assertEqual(img.size, (51, 51))
+ self.assertEqual(denorm((10,10)), (10.0*11, 41*11))
+ # square, upscale, 90cw
+ img, denorm = FaceExtract.preprocess(PIL.Image.open(testpath), 1122, 90)
+ self.assertEqual(img.size, (1122, 1122))
+ self.assertEqual(denorm((10,10)), (10*0.5, 1112*0.5))
+ # square, downscale, 90ccw
+ img, denorm = FaceExtract.preprocess(PIL.Image.open(testpath), 51, 270)
+ self.assertEqual(img.size, (51, 51))
+ self.assertEqual(denorm((10,10)), (41*11, 10*11))
+ # square, upscale, 90ccw
+ img, denorm = FaceExtract.preprocess(PIL.Image.open(testpath), 1122, 270)
+ self.assertEqual(img.size, (1122, 1122))
+ self.assertEqual(denorm((10,10)), (1112*0.5, 10*0.5))
+ # square, downscale, 180
+ img, denorm = FaceExtract.preprocess(PIL.Image.open(testpath), 51, 180)
+ self.assertEqual(img.size, (51, 51))
+ self.assertEqual(denorm((10,10)), (41*11, 41*11))
+ # square, upscale, 180
+ img, denorm = FaceExtract.preprocess(PIL.Image.open(testpath), 1122, 180)
+ self.assertEqual(img.size, (1122, 1122))
+ self.assertEqual(denorm((10,10)), (1112*0.5, 1112*0.5))
+
+ def test_call(self):
+ with contextlib.redirect_stderr(io.StringIO()): # NOTE: hide warnings from facenet_pytorch
+ rdr = FaceExtract()
+ # discards non-image files
+ self.assertRaises(errors.UnsupportedFileFormatError, rdr,
+ __file__)
+ # raises on invalid image
+ self.assertRaises(errors.UnsupportedFileFormatError, rdr,
+ os.path.join(os.path.dirname(__file__), 'testimage_exif_corrupted.jpg'))
+ # raises on missing file
+ self.assertRaises(errors.ReaderError, rdr,
+ os.path.join(os.path.dirname(__file__), 'invalid.jpg'))
+ # raises on dimensions error (pytorch RuntimeError)
+ self.assertRaises(errors.ReaderError, rdr,
+ os.path.join(os.path.dirname(__file__), 'faces-dimerr.gif'))
+ # raises on content error (ValueError)
+ self.assertRaises(errors.ReaderError, rdr,
+ os.path.join(os.path.dirname(__file__), 'faces-valueerr.gif'))
+
+ # may return empty list
+ self.assertListEqual(FaceExtract(min_face_prob=1)(
+ os.path.join(os.path.dirname(__file__), 'faces-noface.jpg')), [])
+ self.assertListEqual(FaceExtract(min_face_prob=1)(
+ os.path.join(os.path.dirname(__file__), 'faces-ivan.jpg')), [])
+ # returns faces
+ faces = rdr(os.path.join(os.path.dirname(__file__), 'faces-ivan.jpg'))
+ # check if face was detected
+ self.assertEqual(len(faces), 1)
+ # check ucid
+ self.assertSetEqual({f['ucid'] for f in faces}, {
+ '926dc1684dd453aa2c3c8daf1c82ecf918514ef0de416b6b842235c23bec32ee',
+ })
+ # check embedding
+ for face in faces:
+ self.assertEqual(face['embedding'].shape, (512, ))
+ # check bbox
+ self.assertAlmostEqual(faces[0]['x'], 275.8, 2)
+ self.assertAlmostEqual(faces[0]['y'], 91.67, 2)
+ self.assertAlmostEqual(faces[0]['width'], 50.5, 2)
+ self.assertAlmostEqual(faces[0]['height'], 65.42, 2)
+
+
+
+## main ##
+
+if __name__ == '__main__':
+ unittest.main()
+
+## EOF ##