aboutsummaryrefslogtreecommitdiffstats
diff options
context:
space:
mode:
-rw-r--r--bsfs/graph/nodes.py51
-rw-r--r--test/graph/test_nodes.py109
2 files changed, 158 insertions, 2 deletions
diff --git a/bsfs/graph/nodes.py b/bsfs/graph/nodes.py
index 18ab30d..85e5fdb 100644
--- a/bsfs/graph/nodes.py
+++ b/bsfs/graph/nodes.py
@@ -93,6 +93,57 @@ class Nodes():
"""Return the store's local schema."""
return self._backend.schema
+ def __add__(self, other: typing.Any) -> 'Nodes':
+ """Concatenate guids. Backend, user, and node type must match."""
+ if not isinstance(other, type(self)):
+ return NotImplemented
+ if self._backend != other._backend:
+ raise ValueError(other)
+ if self._user != other._user:
+ raise ValueError(other)
+ if self.node_type != other.node_type:
+ raise ValueError(other)
+ return Nodes(self._backend, self._user, self.node_type, self._guids | other._guids)
+
+ def __or__(self, other: typing.Any) -> 'Nodes':
+ """Concatenate guids. Backend, user, and node type must match."""
+ return self.__add__(other)
+
+ def __sub__(self, other: typing.Any) -> 'Nodes':
+ """Subtract guids. Backend, user, and node type must match."""
+ if not isinstance(other, type(self)):
+ return NotImplemented
+ if self._backend != other._backend:
+ raise ValueError(other)
+ if self._user != other._user:
+ raise ValueError(other)
+ if self.node_type != other.node_type:
+ raise ValueError(other)
+ return Nodes(self._backend, self._user, self.node_type, self._guids - other._guids)
+
+ def __and__(self, other: typing.Any) -> 'Nodes':
+ """Intersect guids. Backend, user, and node type must match."""
+ if not isinstance(other, type(self)):
+ return NotImplemented
+ if self._backend != other._backend:
+ raise ValueError(other)
+ if self._user != other._user:
+ raise ValueError(other)
+ if self.node_type != other.node_type:
+ raise ValueError(other)
+ return Nodes(self._backend, self._user, self.node_type, self._guids & other._guids)
+
+ def __len__(self) -> int:
+ """Return the number of guids."""
+ return len(self._guids)
+
+ def __iter__(self) -> typing.Iterator['Nodes']:
+ """Iterate over individual guids. Returns `Nodes` instances."""
+ return iter(
+ Nodes(self._backend, self._user, self.node_type, {guid})
+ for guid in self._guids
+ )
+
def __getattr__(self, name: str):
try:
return super().__getattr__(name) # type: ignore [misc] # parent has no getattr
diff --git a/test/graph/test_nodes.py b/test/graph/test_nodes.py
index 4eae250..c07fa53 100644
--- a/test/graph/test_nodes.py
+++ b/test/graph/test_nodes.py
@@ -4,10 +4,14 @@ Part of the bsfs test suite.
A copy of the license is provided with the project.
Author: Matthias Baumgartner, 2022
"""
-# imports
-import rdflib
+# standard imports
+from functools import partial
+import operator
import unittest
+# external imports
+import rdflib
+
# bsie imports
from bsfs import schema as bsc
from bsfs.graph.walk import Walk
@@ -476,6 +480,107 @@ class TestNodes(unittest.TestCase):
# invalid step raises an error
self.assertRaises(ValueError, getattr, nodes, 'foobar')
+ def test_schema(self):
+ self.assertEqual(Nodes(self.backend, self.user, self.ent_type,
+ {URI('http://example.com/me/entity#1234')}).schema, self.backend.schema)
+
+ def test_operators(self): # __add__, __or__, __sub__, __and__
+ gen = partial(Nodes, self.backend, self.user, self.ent_type)
+ nodes = gen({URI('http://example.com/me/entity#1234')})
+ # add/or concatenates guids
+ self.assertEqual(
+ gen({URI('http://example.com/me/entity#1234')}) +
+ gen({URI('http://example.com/me/entity#4321')}),
+ # target
+ gen({
+ URI('http://example.com/me/entity#1234'),
+ URI('http://example.com/me/entity#4321')}))
+ self.assertEqual(
+ gen({URI('http://example.com/me/entity#1234')}) |
+ gen({URI('http://example.com/me/entity#4321')}),
+ # target
+ gen({
+ URI('http://example.com/me/entity#1234'),
+ URI('http://example.com/me/entity#4321')}))
+ # repeated guids are ignored
+ self.assertEqual(
+ gen({URI('http://example.com/me/entity#1234')}) +
+ gen({URI('http://example.com/me/entity#1234')}),
+ # target
+ gen({URI('http://example.com/me/entity#1234')}))
+ self.assertEqual(
+ gen({URI('http://example.com/me/entity#1234')}) |
+ gen({URI('http://example.com/me/entity#1234')}),
+ # target
+ gen({URI('http://example.com/me/entity#1234')}))
+
+ # sub substracts guids
+ self.assertEqual(
+ gen({URI('http://example.com/me/entity#1234'),
+ URI('http://example.com/me/entity#4321')}) -
+ gen({URI('http://example.com/me/entity#4321')}),
+ # target
+ gen({URI('http://example.com/me/entity#1234')}))
+ # missing guids are ignored
+ self.assertEqual(
+ gen({URI('http://example.com/me/entity#1234')}) -
+ gen({URI('http://example.com/me/entity#4321')}),
+ # target
+ gen({URI('http://example.com/me/entity#1234')}))
+
+ # and intersects guids
+ self.assertEqual(
+ gen({URI('http://example.com/me/entity#1234'),
+ URI('http://example.com/me/entity#4321')}) &
+ gen({URI('http://example.com/me/entity#4321'),
+ URI('http://example.com/me/entity#5678')}),
+ # target
+ gen({URI('http://example.com/me/entity#4321')}))
+
+ for op in (operator.add, operator.or_, operator.sub, operator.and_):
+ # type must match
+ self.assertRaises(TypeError, op, nodes, 1234)
+ self.assertRaises(TypeError, op, nodes, 'hello world')
+ # backend must match
+ self.assertRaises(ValueError, op, nodes,
+ Nodes(None, self.user, self.ent_type, {URI('http://example.com/me/entity#1234')}))
+ # user must match
+ self.assertRaises(ValueError, op, nodes,
+ Nodes(self.backend, '', self.ent_type, {URI('http://example.com/me/entity#1234')}))
+ # node type must match
+ self.assertRaises(ValueError, op, nodes,
+ Nodes(self.backend, self.user, self.tag_type, {URI('http://example.com/me/entity#1234')}))
+
+ def test_len(self):
+ self.assertEqual(1, len(Nodes(self.backend, self.user, self.ent_type, {
+ URI('http://example.com/me/entity#1234'),
+ })))
+ self.assertEqual(2, len(Nodes(self.backend, self.user, self.ent_type, {
+ URI('http://example.com/me/entity#1234'),
+ URI('http://example.com/me/entity#4321'),
+ })))
+ self.assertEqual(4, len(Nodes(self.backend, self.user, self.ent_type, {
+ URI('http://example.com/me/entity#1234'),
+ URI('http://example.com/me/entity#4321'),
+ URI('http://example.com/me/entity#5678'),
+ URI('http://example.com/me/entity#8765'),
+ })))
+
+ def test_iter(self): # __iter__
+ gen = partial(Nodes, self.backend, self.user, self.ent_type)
+ self.assertSetEqual(set(Nodes(self.backend, self.user, self.ent_type, {
+ URI('http://example.com/me/entity#1234'),
+ URI('http://example.com/me/entity#4321'),
+ URI('http://example.com/me/entity#5678'),
+ URI('http://example.com/me/entity#8765'),
+ })), {
+ gen({URI('http://example.com/me/entity#1234')}),
+ gen({URI('http://example.com/me/entity#4321')}),
+ gen({URI('http://example.com/me/entity#5678')}),
+ gen({URI('http://example.com/me/entity#8765')}),
+ })
+
+
## main ##
if __name__ == '__main__':