# standard imports import contextlib import io import os import unittest # external imports import requests # bsie imports from bsie.extractor import base from bsie.matcher import nodes from bsie.reader.face import FaceExtract from bsie.utils import bsfs, ns # objects to test from bsie.extractor.image.face.identify import FaceIdentify, bsf ## code ## def fetch(source, target): target = os.path.join(os.path.dirname(__file__), target) if not os.path.exists(target): with open(target, 'wb') as ofile: ans = requests.get(source) ofile.write(ans.content) class TestFaceIdentify(unittest.TestCase): def setUp(self): # download test images fetch('https://www.bsfs.io/testdata/iepahGee1uch5ahr3ic1.jpg', 'testface1.jpg') fetch('https://www.bsfs.io/testdata/Woayiesae8eiL9aivoba.jpg', 'testface2.jpg') fetch('https://www.bsfs.io/testdata/ATiagheiduth4So5ohxi.jpg', 'testface3.jpg') # download reference vectors fetch('https://www.bsfs.io/testdata/aetie3foo0faiDaiBahk.npy', 'ref_embeds.npy') fetch('https://www.bsfs.io/testdata/uopoS8gei8Phiek3shei.npy', 'ref_embeds_alt1.npy') fetch('https://www.bsfs.io/testdata/Otoo7ain6Ied2Iep2ein.npy', 'ref_embeds_alt2.npy') fetch('https://www.bsfs.io/testdata/ie0keriChafahroeRo7i.npy', 'ref_embeds_extra.npy') fetch('https://www.bsfs.io/testdata/phoophui3teeni4hieKu.csv', 'ref_mapping.csv') fetch('https://www.bsfs.io/testdata/Quit4Wum8ael7Zeis4ei.csv', 'ref_mapping_alt.csv') fetch('https://www.bsfs.io/testdata/Angu5cioVei5pohgh0aa.csv', 'ref_mapping_id_reuse.csv') fetch('https://www.bsfs.io/testdata/ooshooK1bai5Queengae.csv', 'ref_mapping_name_reuse.csv') fetch('https://www.bsfs.io/testdata/eixuepah3Ronge7oe4qu.csv', 'ref_mapping_restklasse.csv') def test_essentials(self): # setup pth_embeds = os.path.join(os.path.dirname(__file__), 'ref_embeds.npy') pth_embeds_alt1 = os.path.join(os.path.dirname(__file__), 'ref_embeds_alt1.npy') pth_embeds_alt2 = os.path.join(os.path.dirname(__file__), 'ref_embeds_alt2.npy') pth_mapping = os.path.join(os.path.dirname(__file__), 'ref_mapping.csv') pth_mapping_alt = os.path.join(os.path.dirname(__file__), 'ref_mapping_alt.csv') restklasse = 'https://example.com/user/fake_anon' ext = FaceIdentify(pth_embeds, pth_mapping) # string conversion returns class name self.assertEqual(str(ext), 'FaceIdentify') # representation respects number of embeddings self.assertEqual(repr(ext), 'FaceIdentify(N=2, restklasse=https://example.com/user/anon)') # representation respects restklasse self.assertEqual(repr(FaceIdentify(pth_embeds, pth_mapping, restklasse=restklasse)), 'FaceIdentify(N=2, restklasse=https://example.com/user/fake_anon)') # identity self.assertEqual(ext, FaceIdentify(pth_embeds, pth_mapping)) self.assertEqual(hash(ext), hash(FaceIdentify(pth_embeds, pth_mapping))) # FIXME! # comparison respects embeddings self.assertNotEqual(ext, FaceIdentify(pth_embeds_alt1, pth_mapping)) self.assertNotEqual(hash(ext), hash(FaceIdentify(pth_embeds_alt1, pth_mapping))) self.assertNotEqual(ext, FaceIdentify(pth_embeds_alt2, pth_mapping)) self.assertNotEqual(hash(ext), hash(FaceIdentify(pth_embeds_alt2, pth_mapping))) # comparison respects mappings self.assertNotEqual(ext, FaceIdentify(pth_embeds, pth_mapping_alt)) self.assertNotEqual(hash(ext), hash(FaceIdentify(pth_embeds, pth_mapping_alt))) # comparison respects threshold self.assertNotEqual(ext, FaceIdentify(pth_embeds, pth_mapping, thres=0.1)) self.assertNotEqual(hash(ext), hash(FaceIdentify(pth_embeds, pth_mapping, thres=0.1))) # comparison respects restklasse self.assertNotEqual(ext, FaceIdentify(pth_embeds, pth_mapping, restklasse=restklasse)) self.assertNotEqual(hash(ext), hash(FaceIdentify(pth_embeds, pth_mapping, restklasse=restklasse))) def test_construct(self): pth_embeds = os.path.join(os.path.dirname(__file__), 'ref_embeds.npy') pth_mapping = os.path.join(os.path.dirname(__file__), 'ref_mapping.csv') # valid construction self.assertIsInstance(FaceIdentify(pth_embeds, pth_mapping), FaceIdentify) # restklasse may be part of the mapping ext = FaceIdentify(pth_embeds, os.path.join(os.path.dirname(__file__), 'ref_mapping_restklasse.csv')) self.assertIsInstance(ext, FaceIdentify) self.assertEqual(ext._restidx, 1) # pass invalid mapping (name re-use) self.assertRaises(Exception, FaceIdentify, pth_embeds, os.path.join(os.path.dirname(__file__), 'ref_mapping_name_reuse.csv')) # pass invalid mapping (id re-use) self.assertRaises(Exception, FaceIdentify, pth_embeds, os.path.join(os.path.dirname(__file__), 'ref_mapping_id_reuse.csv')) # pass invalid embeds (extra embeddings) self.assertRaises(Exception, FaceIdentify, os.path.join(os.path.dirname(__file__), 'ref_embeds_extra.npy'), pth_mapping) def test_extract(self): with contextlib.redirect_stderr(io.StringIO()): # NOTE: hide warnings from facenet_pytorch # setup rdr = FaceExtract() ext = FaceIdentify( os.path.join(os.path.dirname(__file__), 'ref_embeds.npy'), os.path.join(os.path.dirname(__file__), 'ref_mapping.csv'), ) subject = nodes.Entity(ucid='abc123') content = rdr(os.path.join(os.path.dirname(__file__), 'testface1.jpg')) principals = set(ext.principals) face = nodes.Face( ucid='2a7203c1515e0caa66a7461452c0b4552f1433a613cb3033e59ed2361790ad45') person = nodes.Person(uri='https://example.com/user/Angelina_Jolie') triples = list(ext.extract(subject, content, principals)) # principls is bse:face, bsf:depicts self.assertSetEqual(set(ext.principals), { ext.schema.predicate(ns.bse.face), ext.schema.predicate(bsf.depicts) }) # produces two triples ... self.assertEqual(len(triples), 2) # ... one if at least one person was identified self.assertIn((subject, ext.schema.predicate(ns.bse.face), face), triples) # ... one for each identified person self.assertIn((face, ext.schema.predicate(bsf.depicts), person), triples) # produces no triples if no person was identified content = rdr(os.path.join(os.path.dirname(__file__), 'testface2.jpg')) self.assertListEqual(list(ext.extract(subject, content, principals)), []) # identifies the correct person despite somewhat similar options content = rdr(os.path.join(os.path.dirname(__file__), 'testface3.jpg')) face = nodes.Face( ucid='f61fac01ef686ee05805afef1e7a10ba54c30dc1aa095d9e77d79ccdfeb40dc5') triples = list(ext.extract(subject, content, principals)) self.assertEqual(len(triples), 2) person = nodes.Person(uri='https://example.com/user/Paul_Rudd') self.assertIn((subject, ext.schema.predicate(ns.bse.face), face), triples) self.assertIn((face, ext.schema.predicate(bsf.depicts), person), triples) # no triples on principal mismatch self.assertListEqual(list(ext.extract(subject, content, set())), []) # no triples on no content self.assertListEqual(list(ext.extract(subject, [], principals)), []) ## main ## if __name__ == '__main__': unittest.main() ## EOF ##