aboutsummaryrefslogtreecommitdiffstats
path: root/bsfs/graph/nodes.py
blob: 84996c7910ac5a1db380342b8ec93e41d9eb4c3a (plain)
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
83
84
85
86
87
88
89
90
91
92
93
94
95
96
97
98
99
100
101
102
103
104
105
106
107
108
109
110
111
112
113
114
115
116
117
118
119
120
121
122
123
124
125
126
127
128
129
130
131
132
133
134
135
136
137
138
139
140
141
142
143
144
145
146
147
148
149
150
151
152
153
154
155
156
157
158
159
160
161
162
163
164
165
166
167
168
169
170
171
172
173
174
175
176
177
178
179
180
181
182
183
184
185
186
187
188
189
190
191
192
193
194
195
196
197
198
199
200
201
202
203
204
205
206
207
208
209
210
211
212
213
214
215
216
217
218
219
220
221
222
223
224
225
226
227
228
229
230
231
232
233
234
235
236
237
238
239
240
241
242
243
244
245
246
247
248
249
250
251
252
253
254
255
256
257
258
259
260
261
262
263
264
265
266
267
268
269
270
271
272
273
274
275
276
277
278
279
280
281
282
283
284
285
286
287
288
289
290
291
292
293
294
295
296
297
298
299
300
301
302
303
304
305
306
307
308
309
310
311
312
313
314
315
316
317
318
319
320
321
322
323
324
325
326
327
328
329
330
331
332
333
334
335
336
337
338
339
340
341
342
343
344
345
346
347
348
349
350
351
352
353
354
355
356
357
358
359
360
361
362
363
364
365
366
367
368
369
370
371
372
373
374
375
376
377
378
379
380
381
382
383
384
385
386
387
388
389
390
391
392
393
394
395
396
397
398
399
400
401

# imports
from collections import abc
import time
import typing

# bsfs imports
from bsfs import schema as bsc
from bsfs.namespace import ns
from bsfs.query import ast, validate
from bsfs.triple_store import TripleStoreBase
from bsfs.utils import errors, URI, typename

# inner-module imports
from . import ac
from . import result
from . import walk

# exports
__all__: typing.Sequence[str] = (
    'Nodes',
    )


## code ##

class Nodes():
    """
    NOTE: guids may or may not exist. This is not verified as nodes are created on demand.
    """

    # triple store backend.
    _backend: TripleStoreBase

    # access controls.
    _ac: ac.AccessControlBase

    # node type.
    _node_type: bsc.Node

    # guids of nodes. Can be empty.
    _guids: typing.Set[URI]

    def __init__(
            self,
            backend: TripleStoreBase,
            access_control: ac.AccessControlBase,
            node_type: bsc.Node,
            guids: typing.Iterable[URI],
            ):
        # set main members
        self._backend = backend
        self._ac = access_control
        self._node_type = node_type
        # convert to URI since this is not guaranteed by Graph
        self._guids = {URI(guid) for guid in guids}

    def __eq__(self, other: typing.Any) -> bool:
        return isinstance(other, Nodes) \
           and self._backend == other._backend \
           and self._ac == other._ac \
           and self._node_type == other._node_type \
           and self._guids == other._guids

    def __hash__(self) -> int:
        return hash((type(self), self._backend, self._ac, self._node_type, tuple(sorted(self._guids))))

    def __repr__(self) -> str:
        return f'{typename(self)}({self._backend}, {self._ac}, {self._node_type}, {self._guids})'

    def __str__(self) -> str:
        return f'{typename(self)}({self._node_type}, {self._guids})'

    @property
    def node_type(self) -> bsc.Node:
        """Return the node's type."""
        return self._node_type

    @property
    def guids(self) -> typing.Iterator[URI]:
        """Return all node guids."""
        return iter(self._guids)

    @property
    def schema(self) -> bsc.Schema:
        """Return the store's local schema."""
        return self._backend.schema

    def __add__(self, other: typing.Any) -> 'Nodes':
        """Concatenate guids. Backend, AC, and node type must match."""
        if not isinstance(other, type(self)):
            return NotImplemented
        if self._backend != other._backend:
            raise ValueError(other)
        if self._ac != other._ac:
            raise ValueError(other)
        if self.node_type != other.node_type:
            raise ValueError(other)
        return Nodes(self._backend, self._ac, self.node_type, self._guids | other._guids)

    def __or__(self, other: typing.Any) -> 'Nodes':
        """Concatenate guids. Backend, AC, and node type must match."""
        return self.__add__(other)

    def __sub__(self, other: typing.Any) -> 'Nodes':
        """Subtract guids. Backend, AC, and node type must match."""
        if not isinstance(other, type(self)):
            return NotImplemented
        if self._backend != other._backend:
            raise ValueError(other)
        if self._ac != other._ac:
            raise ValueError(other)
        if self.node_type != other.node_type:
            raise ValueError(other)
        return Nodes(self._backend, self._ac, self.node_type, self._guids - other._guids)

    def __and__(self, other: typing.Any) -> 'Nodes':
        """Intersect guids. Backend, AC, and node type must match."""
        if not isinstance(other, type(self)):
            return NotImplemented
        if self._backend != other._backend:
            raise ValueError(other)
        if self._ac != other._ac:
            raise ValueError(other)
        if self.node_type != other.node_type:
            raise ValueError(other)
        return Nodes(self._backend, self._ac, self.node_type, self._guids & other._guids)

    def __len__(self) -> int:
        """Return the number of guids."""
        return len(self._guids)

    def __iter__(self) -> typing.Iterator['Nodes']:
        """Iterate over individual guids. Returns `Nodes` instances."""
        return iter(
            Nodes(self._backend, self._ac, self.node_type, {guid})
            for guid in self._guids
            )

    def __getattr__(self, name: str):
        try:
            return super().__getattr__(name) # type: ignore [misc] # parent has no getattr
        except AttributeError:
            pass
        return walk.Walk(self, walk.Walk.step(self.schema, self.node_type, name))

    def set(
            self,
            pred: URI, # FIXME: URI or bsc.Predicate?
            value: typing.Any,
            ) -> 'Nodes':
        """Set predicate *pred* to *value*."""
        return self.set_from_iterable([(pred, value)])

    def set_from_iterable(
            self,
            predicate_values: typing.Iterable[typing.Tuple[URI, typing.Any]], # FIXME: URI or bsc.Predicate?
            ) -> 'Nodes':
        """Set mutliple predicate-value pairs at once."""
        # TODO: Could group predicate_values by predicate to gain some efficiency
        # TODO: ignore errors on some predicates; For now this could leave residual
        #       data (e.g. some nodes were created, some not).
        try:
            # insert triples
            for pred, value in predicate_values:
                self.__set(pred, value)
            # save changes
            self._backend.commit()

        except (
                errors.PermissionDeniedError, # tried to set a protected predicate (ns.bsm.t_created)
                errors.ConsistencyError, # node types are not in the schema or don't match the predicate
                errors.InstanceError, # guids/values don't have the correct type
                TypeError, # value is supposed to be a Nodes instance
                ValueError, # multiple values passed to unique predicate
                ):
            # revert changes
            self._backend.rollback()
            # notify the client
            raise

        # FIXME: How about other errors? Shouldn't I then rollback as well?!

        return self

    def get(
            self,
            *paths: typing.Union[URI, typing.Iterable[URI]],
            view: typing.Union[typing.Type[list], typing.Type[dict]] = dict,
            **view_kwargs,
            ) -> typing.Any:
        """Get values or nodes at *paths*.
        Return an iterator (view=list) or a dict (view=dict) over the results.
        """
        # FIXME: user-provided Fetch query AST?
        # check args
        if len(paths) == 0:
            raise AttributeError('expected at least one path, found none')
        if view not in (dict, list):
            raise ValueError(f'expected dict or list, found {view}')
        # process paths: create fetch ast, build name mapping, and find unique paths
        schema = self.schema
        statements = set()
        name2path = {}
        unique_paths = set() # paths that result in a single (unique) value
        normpath: typing.Tuple[URI, ...]
        for idx, path in enumerate(paths):
            # normalize path
            if isinstance(path, str):
                normpath = (URI(path), )
            elif isinstance(path, abc.Iterable):
                if not all(isinstance(step, str) for step in path):
                    raise TypeError(path)
                normpath = tuple(URI(step) for step in path)
            else:
                raise TypeError(path)
            # check path's schema consistency
            if not all(schema.has_predicate(pred) for pred in normpath):
                raise errors.ConsistencyError(f'path is not fully covered by the schema: {path}')
            # check path's uniqueness
            if all(schema.predicate(pred).unique for pred in normpath):
                unique_paths.add(path)
            # fetch tail predicate
            tail = schema.predicate(normpath[-1])
            # determine tail ast node type
            factory = ast.fetch.Node if isinstance(tail.range, bsc.Node) else ast.fetch.Value
            # assign name
            name = f'fetch{idx}'
            name2path[name] = (path, tail)
            # create tail ast node
            curr: ast.fetch.FetchExpression = factory(tail.uri, name)
            # walk towards front
            hop: URI
            for hop in normpath[-2::-1]:
                curr = ast.fetch.Fetch(hop, curr)
            # add to fetch query
            statements.add(curr)
        # aggregate fetch statements
        if len(statements) == 1:
            fetch = next(iter(statements))
        else:
            fetch = ast.fetch.All(*statements)
        # add access controls to fetch
        fetch = self._ac.fetch_read(self.node_type, fetch)

        if len(self._guids) == 0:
            # shortcut: no need to query; no triples
            # FIXME: if the Fetch query can given by the user, we might want to check its validity
            def triple_iter():
                return []
        else:
            # compose filter ast
            filter = ast.filter.IsIn(self.guids) # pylint: disable=redefined-builtin
            # add access controls to filter
            filter = self._ac.filter_read(self.node_type, filter) # type: ignore [assignment]

            # validate queries
            validate.Filter(self._backend.schema)(self.node_type, filter)
            validate.Fetch(self._backend.schema)(self.node_type, fetch)

            # process results, convert if need be
            def triple_iter():
                # query the backend
                triples = self._backend.fetch(self.node_type, filter, fetch)
                # process triples
                for root, name, raw in triples:
                    # get node
                    node = Nodes(self._backend, self._ac, self.node_type, {root})
                    # get path
                    path, tail = name2path[name]
                    # covert raw to value
                    if isinstance(tail.range, bsc.Node):
                        value = Nodes(self._backend, self._ac, tail.range, {raw})
                    else:
                        value = raw
                    # emit triple
                    yield node, path, value

        # simplify by default
        view_kwargs['node'] = view_kwargs.get('node', len(self._guids) != 1)
        view_kwargs['path'] = view_kwargs.get('path', len(paths) != 1)
        view_kwargs['value'] = view_kwargs.get('value', False)

        # return results view
        if view == list:
            return result.to_list_view(
                triple_iter(),
                # aggregation args
                **view_kwargs,
                )

        if view == dict:
            return result.to_dict_view(
                triple_iter(),
                # context
                len(self._guids) == 1,
                len(paths) == 1,
                unique_paths,
                # aggregation args
                **view_kwargs,
                )

        raise errors.UnreachableError() # view was already checked


    def __set(self, predicate: URI, value: typing.Any):
        """
        """
        # get normalized predicate. Raises KeyError if *pred* not in the schema.
        pred = self._backend.schema.predicate(predicate)

        # node_type must be a subclass of the predicate's domain
        node_type = self.node_type
        if not node_type <= pred.domain:
            raise errors.ConsistencyError(f'{node_type} must be a subclass of {pred.domain}')

        # check reserved predicates (access controls, metadata, internal structures)
        # FIXME: Needed? Could be integrated into other AC methods (by passing the predicate!)
        #        This could allow more fine-grained predicate control (e.g. based on ownership)
        #        rather than a global approach like this.
        if self._ac.is_protected_predicate(pred):
            raise errors.PermissionDeniedError(pred)

        # set operation affects all nodes (if possible)
        guids = set(self.guids)

        # ensure subject node existence; create nodes if need be
        guids = set(self._ensure_nodes(node_type, guids))

        # check value
        if isinstance(pred.range, bsc.Literal):
            # check write permissions on existing nodes
            # As long as the user has write permissions, we don't restrict
            # the creation or modification of literal values.
            guids = set(self._ac.write_literal(node_type, guids))

            # insert literals
            # TODO: Support passing iterators as values for non-unique predicates
            self._backend.set(
                node_type,
                guids,
                pred,
                [value],
                )

        elif isinstance(pred.range, bsc.Node):
            # check value type
            # FIXME: value could be a set of Nodes
            if not isinstance(value, Nodes):
                raise TypeError(value)
            # value's node_type must be a subclass of the predicate's range
            if not value.node_type <= pred.range:
                raise errors.ConsistencyError(f'{value.node_type} must be a subclass of {pred.range}')

            # check link permissions on source nodes
            # Link permissions cover adding and removing links on the source node.
            # Specifically, link permissions also allow to remove links to other
            # nodes if needed (e.g. for unique predicates).
            guids = set(self._ac.link_from_node(node_type, guids))

            # get link targets
            targets = set(value.guids)
            # ensure existence of value nodes; create nodes if need be
            targets = set(self._ensure_nodes(value.node_type, targets))
            # check link permissions on target nodes
            targets = set(self._ac.link_to_node(value.node_type, targets))

            # insert node links
            self._backend.set(
                node_type,
                guids,
                pred,
                targets,
                )

        else:
            raise errors.UnreachableError()

    def _ensure_nodes(self, node_type: bsc.Node, guids: typing.Iterable[URI]):
        """
        """
        # check node existence
        guids = set(guids)
        existing = set(self._backend.exists(node_type, guids))
        # get nodes to be created
        missing = guids - existing
        # create nodes if need be
        if len(missing) > 0:
            # check which missing nodes can be created
            missing = set(self._ac.createable(node_type, missing))
            # create nodes
            self._backend.create(node_type, missing)
            # add bookkeeping triples
            self._backend.set(node_type, missing,
                self._backend.schema.predicate(ns.bsm.t_created), [time.time()])
            # add permission triples
            self._ac.create(node_type, missing)
        # return available nodes
        return existing | missing

## EOF ##