From 1392951dfc82af05e7a5999baa1c0a4fc72083b8 Mon Sep 17 00:00:00 2001 From: Matthias Baumgartner Date: Sat, 21 Jan 2023 23:03:50 +0100 Subject: Nodes magic methods for convenience --- bsfs/graph/nodes.py | 51 ++++++++++++++++++++++ test/graph/test_nodes.py | 109 ++++++++++++++++++++++++++++++++++++++++++++++- 2 files changed, 158 insertions(+), 2 deletions(-) diff --git a/bsfs/graph/nodes.py b/bsfs/graph/nodes.py index 18ab30d..85e5fdb 100644 --- a/bsfs/graph/nodes.py +++ b/bsfs/graph/nodes.py @@ -93,6 +93,57 @@ class Nodes(): """Return the store's local schema.""" return self._backend.schema + def __add__(self, other: typing.Any) -> 'Nodes': + """Concatenate guids. Backend, user, and node type must match.""" + if not isinstance(other, type(self)): + return NotImplemented + if self._backend != other._backend: + raise ValueError(other) + if self._user != other._user: + raise ValueError(other) + if self.node_type != other.node_type: + raise ValueError(other) + return Nodes(self._backend, self._user, self.node_type, self._guids | other._guids) + + def __or__(self, other: typing.Any) -> 'Nodes': + """Concatenate guids. Backend, user, and node type must match.""" + return self.__add__(other) + + def __sub__(self, other: typing.Any) -> 'Nodes': + """Subtract guids. Backend, user, and node type must match.""" + if not isinstance(other, type(self)): + return NotImplemented + if self._backend != other._backend: + raise ValueError(other) + if self._user != other._user: + raise ValueError(other) + if self.node_type != other.node_type: + raise ValueError(other) + return Nodes(self._backend, self._user, self.node_type, self._guids - other._guids) + + def __and__(self, other: typing.Any) -> 'Nodes': + """Intersect guids. Backend, user, and node type must match.""" + if not isinstance(other, type(self)): + return NotImplemented + if self._backend != other._backend: + raise ValueError(other) + if self._user != other._user: + raise ValueError(other) + if self.node_type != other.node_type: + raise ValueError(other) + return Nodes(self._backend, self._user, self.node_type, self._guids & other._guids) + + def __len__(self) -> int: + """Return the number of guids.""" + return len(self._guids) + + def __iter__(self) -> typing.Iterator['Nodes']: + """Iterate over individual guids. Returns `Nodes` instances.""" + return iter( + Nodes(self._backend, self._user, self.node_type, {guid}) + for guid in self._guids + ) + def __getattr__(self, name: str): try: return super().__getattr__(name) # type: ignore [misc] # parent has no getattr diff --git a/test/graph/test_nodes.py b/test/graph/test_nodes.py index 4eae250..c07fa53 100644 --- a/test/graph/test_nodes.py +++ b/test/graph/test_nodes.py @@ -4,10 +4,14 @@ Part of the bsfs test suite. A copy of the license is provided with the project. Author: Matthias Baumgartner, 2022 """ -# imports -import rdflib +# standard imports +from functools import partial +import operator import unittest +# external imports +import rdflib + # bsie imports from bsfs import schema as bsc from bsfs.graph.walk import Walk @@ -476,6 +480,107 @@ class TestNodes(unittest.TestCase): # invalid step raises an error self.assertRaises(ValueError, getattr, nodes, 'foobar') + def test_schema(self): + self.assertEqual(Nodes(self.backend, self.user, self.ent_type, + {URI('http://example.com/me/entity#1234')}).schema, self.backend.schema) + + def test_operators(self): # __add__, __or__, __sub__, __and__ + gen = partial(Nodes, self.backend, self.user, self.ent_type) + nodes = gen({URI('http://example.com/me/entity#1234')}) + # add/or concatenates guids + self.assertEqual( + gen({URI('http://example.com/me/entity#1234')}) + + gen({URI('http://example.com/me/entity#4321')}), + # target + gen({ + URI('http://example.com/me/entity#1234'), + URI('http://example.com/me/entity#4321')})) + self.assertEqual( + gen({URI('http://example.com/me/entity#1234')}) | + gen({URI('http://example.com/me/entity#4321')}), + # target + gen({ + URI('http://example.com/me/entity#1234'), + URI('http://example.com/me/entity#4321')})) + # repeated guids are ignored + self.assertEqual( + gen({URI('http://example.com/me/entity#1234')}) + + gen({URI('http://example.com/me/entity#1234')}), + # target + gen({URI('http://example.com/me/entity#1234')})) + self.assertEqual( + gen({URI('http://example.com/me/entity#1234')}) | + gen({URI('http://example.com/me/entity#1234')}), + # target + gen({URI('http://example.com/me/entity#1234')})) + + # sub substracts guids + self.assertEqual( + gen({URI('http://example.com/me/entity#1234'), + URI('http://example.com/me/entity#4321')}) - + gen({URI('http://example.com/me/entity#4321')}), + # target + gen({URI('http://example.com/me/entity#1234')})) + # missing guids are ignored + self.assertEqual( + gen({URI('http://example.com/me/entity#1234')}) - + gen({URI('http://example.com/me/entity#4321')}), + # target + gen({URI('http://example.com/me/entity#1234')})) + + # and intersects guids + self.assertEqual( + gen({URI('http://example.com/me/entity#1234'), + URI('http://example.com/me/entity#4321')}) & + gen({URI('http://example.com/me/entity#4321'), + URI('http://example.com/me/entity#5678')}), + # target + gen({URI('http://example.com/me/entity#4321')})) + + for op in (operator.add, operator.or_, operator.sub, operator.and_): + # type must match + self.assertRaises(TypeError, op, nodes, 1234) + self.assertRaises(TypeError, op, nodes, 'hello world') + # backend must match + self.assertRaises(ValueError, op, nodes, + Nodes(None, self.user, self.ent_type, {URI('http://example.com/me/entity#1234')})) + # user must match + self.assertRaises(ValueError, op, nodes, + Nodes(self.backend, '', self.ent_type, {URI('http://example.com/me/entity#1234')})) + # node type must match + self.assertRaises(ValueError, op, nodes, + Nodes(self.backend, self.user, self.tag_type, {URI('http://example.com/me/entity#1234')})) + + def test_len(self): + self.assertEqual(1, len(Nodes(self.backend, self.user, self.ent_type, { + URI('http://example.com/me/entity#1234'), + }))) + self.assertEqual(2, len(Nodes(self.backend, self.user, self.ent_type, { + URI('http://example.com/me/entity#1234'), + URI('http://example.com/me/entity#4321'), + }))) + self.assertEqual(4, len(Nodes(self.backend, self.user, self.ent_type, { + URI('http://example.com/me/entity#1234'), + URI('http://example.com/me/entity#4321'), + URI('http://example.com/me/entity#5678'), + URI('http://example.com/me/entity#8765'), + }))) + + def test_iter(self): # __iter__ + gen = partial(Nodes, self.backend, self.user, self.ent_type) + self.assertSetEqual(set(Nodes(self.backend, self.user, self.ent_type, { + URI('http://example.com/me/entity#1234'), + URI('http://example.com/me/entity#4321'), + URI('http://example.com/me/entity#5678'), + URI('http://example.com/me/entity#8765'), + })), { + gen({URI('http://example.com/me/entity#1234')}), + gen({URI('http://example.com/me/entity#4321')}), + gen({URI('http://example.com/me/entity#5678')}), + gen({URI('http://example.com/me/entity#8765')}), + }) + + ## main ## if __name__ == '__main__': -- cgit v1.2.3