diff options
Diffstat (limited to 'bsfs/query')
-rw-r--r-- | bsfs/query/__init__.py | 15 | ||||
-rw-r--r-- | bsfs/query/ast/__init__.py | 23 | ||||
-rw-r--r-- | bsfs/query/ast/fetch.py | 169 | ||||
-rw-r--r-- | bsfs/query/ast/filter_.py | 516 | ||||
-rw-r--r-- | bsfs/query/matcher.py | 361 | ||||
-rw-r--r-- | bsfs/query/validator.py | 351 |
6 files changed, 1435 insertions, 0 deletions
diff --git a/bsfs/query/__init__.py b/bsfs/query/__init__.py new file mode 100644 index 0000000..58ff03a --- /dev/null +++ b/bsfs/query/__init__.py @@ -0,0 +1,15 @@ + +# imports +import typing + +# inner-module imports +from . import ast +from . import validator as validate + +# exports +__all__: typing.Sequence[str] = ( + 'ast', + 'validate', + ) + +## EOF ## diff --git a/bsfs/query/ast/__init__.py b/bsfs/query/ast/__init__.py new file mode 100644 index 0000000..bceaac0 --- /dev/null +++ b/bsfs/query/ast/__init__.py @@ -0,0 +1,23 @@ +"""Query AST components. + +The query AST consists of a Filter and a Fetch syntax trees. + +Classes beginning with an underscore (_) represent internal type hierarchies +and should not be used for parsing. Note that the AST structures do not +(and cannot) check semantic validity or consistency with a given schema. + +""" +# imports +import typing + +# inner-module imports +from . import fetch +from . import filter_ as filter # pylint: disable=redefined-builtin + +# exports +__all__: typing.Sequence[str] = ( + 'fetch', + 'filter', + ) + +## EOF ## diff --git a/bsfs/query/ast/fetch.py b/bsfs/query/ast/fetch.py new file mode 100644 index 0000000..66d94e1 --- /dev/null +++ b/bsfs/query/ast/fetch.py @@ -0,0 +1,169 @@ + +# imports +from collections import abc +import typing + +# bsfs imports +from bsfs.utils import URI, typename, normalize_args + +# exports +__all__ : typing.Sequence[str] = ( + 'All', + 'Fetch', + 'FetchExpression', + 'Node', + 'This', + 'Value', + ) + + +## code ## + +class FetchExpression(abc.Hashable): + """Generic Fetch expression.""" + + def __repr__(self) -> str: + """Return the expressions's string representation.""" + return f'{typename(self)}()' + + def __hash__(self) -> int: + """Return the expression's integer representation.""" + return hash(type(self)) + + def __eq__(self, other: typing.Any) -> bool: + """Return True if *self* and *other* are equivalent.""" + return isinstance(other, type(self)) + + +class All(FetchExpression): + """Fetch all child expressions.""" + + # child expressions. + expr: typing.Set[FetchExpression] + + def __init__(self, *expr): + # unpack child expressions + unfolded = set(normalize_args(*expr)) + # check child expressions + if len(unfolded) == 0: + raise AttributeError('expected at least one expression, found none') + if not all(isinstance(itm, FetchExpression) for itm in unfolded): + raise TypeError(expr) + # initialize + super().__init__() + # assign members + self.expr = unfolded + + def __iter__(self) -> typing.Iterator[FetchExpression]: + return iter(self.expr) + + def __len__(self) -> int: + return len(self.expr) + + def __repr__(self) -> str: + return f'{typename(self)}({self.expr})' + + def __hash__(self) -> int: + return hash((super().__hash__(), tuple(sorted(self.expr, key=repr)))) + + def __eq__(self, other: typing.Any) -> bool: + return super().__eq__(other) and self.expr == other.expr + + +class _Branch(FetchExpression): + """Branch along a predicate.""" + + # FIXME: Use a Predicate (like in ast.filter) so that we can also reverse them! + + # predicate to follow. + predicate: URI + + def __init__(self, predicate: URI): + if not isinstance(predicate, URI): + raise TypeError(predicate) + self.predicate = predicate + + def __repr__(self) -> str: + return f'{typename(self)}({self.predicate})' + + def __hash__(self) -> int: + return hash((super().__hash__(), self.predicate)) + + def __eq__(self, other: typing.Any) -> bool: + return super().__eq__(other) and self.predicate == other.predicate + + +class Fetch(_Branch): + """Follow a predicate before evaluating a child epxression.""" + + # child expression. + expr: FetchExpression + + def __init__(self, predicate: URI, expr: FetchExpression): + # check child expressions + if not isinstance(expr, FetchExpression): + raise TypeError(expr) + # initialize + super().__init__(predicate) + # assign members + self.expr = expr + + def __repr__(self) -> str: + return f'{typename(self)}({self.predicate}, {self.expr})' + + def __hash__(self) -> int: + return hash((super().__hash__(), self.expr)) + + def __eq__(self, other: typing.Any) -> bool: + return super().__eq__(other) and self.expr == other.expr + + +class _Named(_Branch): + """Fetch a (named) symbol at a predicate.""" + + # symbol name. + name: str + + def __init__(self, predicate: URI, name: str): + super().__init__(predicate) + self.name = str(name) + + def __repr__(self) -> str: + return f'{typename(self)}({self.predicate}, {self.name})' + + def __hash__(self) -> int: + return hash((super().__hash__(), self.name)) + + def __eq__(self, other: typing.Any) -> bool: + return super().__eq__(other) and self.name == other.name + + +class Node(_Named): # pylint: disable=too-few-public-methods + """Fetch a Node at a predicate.""" + # FIXME: Is this actually needed? + + +class Value(_Named): # pylint: disable=too-few-public-methods + """Fetch a Literal at a predicate.""" + + +class This(FetchExpression): + """Fetch the current Node.""" + + # symbol name. + name: str + + def __init__(self, name: str): + super().__init__() + self.name = str(name) + + def __repr__(self) -> str: + return f'{typename(self)}({self.name})' + + def __hash__(self) -> int: + return hash((super().__hash__(), self.name)) + + def __eq__(self, other: typing.Any) -> bool: + return super().__eq__(other) and self.name == other.name + +## EOF ## diff --git a/bsfs/query/ast/filter_.py b/bsfs/query/ast/filter_.py new file mode 100644 index 0000000..610fdb4 --- /dev/null +++ b/bsfs/query/ast/filter_.py @@ -0,0 +1,516 @@ +"""Filter AST. + +Note that it is easily possible to construct an AST that is inconsistent with +a given schema. Furthermore, it is possible to construct a semantically invalid +AST which that cannot be parsed correctly or includes contradicting statements. +The AST nodes do not (and cannot) check such issues. + +For example, consider the following AST: + +>>> Any(ns.bse.collection, +... And( +... Equals('hello'), +... Is('hello world'), +... Any(ns.bse.tag, Equals('world')), +... Any(ns.bst.label, Equals('world')), +... All(ns.bst.label, Not(Equals('world'))), +... ) +... ) + +This AST has multiple issues that are not verified upon its creation: +* A condition on a non-literal. +* A Filter on a literal. +* Conditions exclude each other +* The predicate along the branch have incompatible domains and ranges. + +""" +# imports +from collections import abc +import typing + +# bsfs imports +from bsfs.utils import URI, typename, normalize_args + +# exports +__all__ : typing.Sequence[str] = ( + # base classes + 'FilterExpression', + 'PredicateExpression', + # predicate expressions + 'OneOf', + 'Predicate', + # branching + 'All', + 'Any', + # aggregators + 'And', + 'Or', + # value matchers + 'Equals', + 'Substring', + 'EndsWith', + 'StartsWith', + # range matchers + 'GreaterThan', + 'LessThan', + # misc + 'Has', + 'Is', + 'Not', + ) + + +## code ## + +# pylint: disable=too-few-public-methods # Many expressions use mostly magic methods + +class _Expression(abc.Hashable): + def __repr__(self) -> str: + """Return the expressions's string representation.""" + return f'{typename(self)}()' + + def __hash__(self) -> int: + """Return the expression's integer representation.""" + return hash(type(self)) + + def __eq__(self, other: typing.Any) -> bool: + """Return True if *self* and *other* are equivalent.""" + return isinstance(other, type(self)) + + +class FilterExpression(_Expression): + """Generic Filter expression.""" + + +class PredicateExpression(_Expression): + """Generic Predicate expression.""" + + +class _Branch(FilterExpression): + """Branch the filter along a predicate.""" + + # predicate to follow. + predicate: PredicateExpression + + # child expression to evaluate. + expr: FilterExpression + + def __init__( + self, + predicate: typing.Union[PredicateExpression, URI], + expr: FilterExpression, + ): + # process predicate argument + if isinstance(predicate, URI): + predicate = Predicate(predicate) + elif not isinstance(predicate, PredicateExpression): + raise TypeError(predicate) + # process expression argument + if not isinstance(expr, FilterExpression): + raise TypeError(expr) + # assign members + self.predicate = predicate + self.expr = expr + + def __repr__(self) -> str: + return f'{typename(self)}({self.predicate}, {self.expr})' + + def __hash__(self) -> int: + return hash((super().__hash__(), self.predicate, self.expr)) + + def __eq__(self, other) -> bool: + return super().__eq__(other) \ + and self.predicate == other.predicate \ + and self.expr == other.expr + +class Any(_Branch): + """Any (and at least one) triple matches.""" + + +class All(_Branch): + """All (and at least one) triples match.""" + + +class _Agg(FilterExpression, abc.Collection): + """Combine multiple expressions.""" + + # child expressions + expr: typing.Set[FilterExpression] + + def __init__( + self, + *expr: typing.Union[FilterExpression, + typing.Iterable[FilterExpression], + typing.Iterator[FilterExpression]] + ): + # unfold arguments + unfolded = set(normalize_args(*expr)) + # check type + if not all(isinstance(e, FilterExpression) for e in unfolded): + raise TypeError(expr) + # FIXME: Require at least one child expression? + # assign member + self.expr = unfolded + + def __contains__(self, expr: typing.Any) -> bool: + """Return True if *expr* is among the child expressions.""" + return expr in self.expr + + def __iter__(self) -> typing.Iterator[FilterExpression]: + """Iterator over child expressions.""" + return iter(self.expr) + + def __len__(self) -> int: + """Number of child expressions.""" + return len(self.expr) + + def __repr__(self) -> str: + return f'{typename(self)}({self.expr})' + + def __hash__(self) -> int: + return hash((super().__hash__(), tuple(sorted(self.expr, key=repr)))) + + def __eq__(self, other) -> bool: + return super().__eq__(other) and self.expr == other.expr + + +class And(_Agg): + """All conditions match.""" + + +class Or(_Agg): + """At least one condition matches.""" + + +class Not(FilterExpression): + """Invert a statement.""" + + # child expression + expr: FilterExpression + + def __init__(self, expr: FilterExpression): + # check argument + if not isinstance(expr, FilterExpression): + raise TypeError(expr) + # assign member + self.expr = expr + + def __repr__(self) -> str: + return f'{typename(self)}({self.expr})' + + def __hash__(self) -> int: + return hash((super().__hash__(), self.expr)) + + def __eq__(self, other: typing.Any) -> bool: + return super().__eq__(other) and self.expr == other.expr + + +class Has(FilterExpression): + """Has predicate N times""" + + # predicate to follow. + predicate: PredicateExpression + + # target count + count: FilterExpression + + def __init__( + self, + predicate: typing.Union[PredicateExpression, URI], + count: typing.Optional[typing.Union[FilterExpression, int]] = None, + ): + # check predicate + if isinstance(predicate, URI): + predicate = Predicate(predicate) + elif not isinstance(predicate, PredicateExpression): + raise TypeError(predicate) + # check count + if count is None: + count = GreaterThan(1, strict=False) + elif isinstance(count, int): + count = Equals(count) + elif not isinstance(count, FilterExpression): + raise TypeError(count) + # assign members + self.predicate = predicate + self.count = count + + def __repr__(self) -> str: + return f'{typename(self)}({self.predicate}, {self.count})' + + def __hash__(self) -> int: + return hash((super().__hash__(), self.predicate, self.count)) + + def __eq__(self, other) -> bool: + return super().__eq__(other) \ + and self.predicate == other.predicate \ + and self.count == other.count + + +class _Value(FilterExpression): + """Matches some value.""" + + # target value. + value: typing.Any + + def __init__(self, value: typing.Any): + self.value = value + + def __repr__(self) -> str: + return f'{typename(self)}({self.value})' + + def __hash__(self) -> int: + return hash((super().__hash__(), self.value)) + + def __eq__(self, other) -> bool: + return super().__eq__(other) and self.value == other.value + + +class Is(_Value): + """Match the URI of a node.""" + + +class Equals(_Value): + """Value matches exactly. + NOTE: Value must correspond to literal type. + """ + + +class Substring(_Value): + """Value matches a substring + NOTE: value must be a string. + """ + + +class StartsWith(_Value): + """Value begins with a given string.""" + + +class EndsWith(_Value): + """Value ends with a given string.""" + + +class Distance(FilterExpression): + """Distance to a reference is (strictly) below a threshold. Assumes a Feature literal.""" + + # FIXME: + # (a) pass a node/predicate as anchor instead of a value. + # Then we don't need to materialize the reference. + # (b) pass a FilterExpression (_Bounded) instead of a threshold. + # Then, we could also query values greater than a threshold. + + # reference value. + reference: typing.Any + + # distance threshold. + threshold: float + + # closed (True) or open (False) bound. + strict: bool + + def __init__( + self, + reference: typing.Any, + threshold: float, + strict: bool = False, + ): + self.reference = reference + self.threshold = float(threshold) + self.strict = bool(strict) + + def __repr__(self) -> str: + return f'{typename(self)}({self.reference}, {self.threshold}, {self.strict})' + + def __hash__(self) -> int: + return hash((super().__hash__(), tuple(self.reference), self.threshold, self.strict)) + + def __eq__(self, other) -> bool: + return super().__eq__(other) \ + and self.reference == other.reference \ + and self.threshold == other.threshold \ + and self.strict == other.strict + + +class _Bounded(FilterExpression): + """Value is bounded by a threshold. Assumes a Number literal.""" + + # bound. + threshold: float + + # closed (True) or open (False) bound. + strict: bool + + def __init__( + self, + threshold: float, + strict: bool = True, + ): + self.threshold = float(threshold) + self.strict = bool(strict) + + def __repr__(self) -> str: + return f'{typename(self)}({self.threshold}, {self.strict})' + + def __hash__(self) -> int: + return hash((super().__hash__(), self.threshold, self.strict)) + + def __eq__(self, other) -> bool: + return super().__eq__(other) \ + and self.threshold == other.threshold \ + and self.strict == other.strict + + + +class LessThan(_Bounded): + """Value is (strictly) smaller than threshold. Assumes a Number literal.""" + + +class GreaterThan(_Bounded): + """Value is (strictly) larger than threshold. Assumes a Number literal.""" + + +class Predicate(PredicateExpression): + """A single predicate.""" + + # predicate URI + predicate: URI + + # reverse the predicate's direction + reverse: bool + + def __init__( + self, + predicate: URI, + reverse: typing.Optional[bool] = False, + ): + # check arguments + if not isinstance(predicate, URI): + raise TypeError(predicate) + # assign members + self.predicate = predicate + self.reverse = bool(reverse) + + def __repr__(self) -> str: + return f'{typename(self)}({self.predicate}, {self.reverse})' + + def __hash__(self) -> int: + return hash((super().__hash__(), self.predicate, self.reverse)) + + def __eq__(self, other) -> bool: + return super().__eq__(other) \ + and self.predicate == other.predicate \ + and self.reverse == other.reverse + + +class OneOf(PredicateExpression, abc.Collection): + """A set of predicate alternatives. + + The predicates' domains must be ascendants or descendants of each other. + The overall domain is the most specific one. + + The predicate's domains must be ascendants or descendants of each other. + The overall range is the most generic one. + """ + + # predicate alternatives + expr: typing.Set[PredicateExpression] + + def __init__(self, *expr: typing.Union[PredicateExpression, URI]): + # unfold arguments + unfolded = set(normalize_args(*expr)) # type: ignore [arg-type] # this is getting too complex... + # check arguments + if len(unfolded) == 0: + raise AttributeError('expected at least one expression, found none') + # ensure PredicateExpression + unfolded = {Predicate(e) if isinstance(e, URI) else e for e in unfolded} + # check type + if not all(isinstance(e, PredicateExpression) for e in unfolded): + raise TypeError(expr) + # assign member + self.expr = unfolded + + def __contains__(self, expr: typing.Any) -> bool: + """Return True if *expr* is among the child expressions.""" + return expr in self.expr + + def __iter__(self) -> typing.Iterator[PredicateExpression]: + """Iterator over child expressions.""" + return iter(self.expr) + + def __len__(self) -> int: + """Number of child expressions.""" + return len(self.expr) + + def __repr__(self) -> str: + return f'{typename(self)}({self.expr})' + + def __hash__(self) -> int: + return hash((super().__hash__(), tuple(sorted(self.expr, key=repr)))) + + def __eq__(self, other) -> bool: + return super().__eq__(other) and self.expr == other.expr + + +# Helpers +# invalid-name is disabled since they explicitly mimic an expression + +def IsIn(*values) -> FilterExpression: # pylint: disable=invalid-name + """Match any of the given URIs.""" + args = normalize_args(*values) + if len(args) == 0: + raise AttributeError('expected at least one value, found none') + if len(args) == 1: + return Is(args[0]) + return Or(Is(value) for value in args) + +def IsNotIn(*values) -> FilterExpression: # pylint: disable=invalid-name + """Match none of the given URIs.""" + return Not(IsIn(*values)) + + +def Between( # pylint: disable=invalid-name + lo: float = float('-inf'), + hi: float = float('inf'), + lo_strict: bool = True, + hi_strict: bool = True, + ) -> FilterExpression : + """Match numerical values between *lo* and *hi*. Include bounds if strict is False.""" + if abs(lo) == hi == float('inf'): + raise ValueError('range cannot be INF on both sides') + if lo > hi: + raise ValueError(f'lower bound ({lo}) cannot be less than upper bound ({hi})') + if lo == hi and not lo_strict and not hi_strict: + return Equals(lo) + if lo == hi: # either bound is strict + raise ValueError('bounds cannot be equal when either is strict') + if lo != float('-inf') and hi != float('inf'): + return And(GreaterThan(lo, lo_strict), LessThan(hi, hi_strict)) + if lo != float('-inf'): + return GreaterThan(lo, lo_strict) + # hi != float('inf'): + return LessThan(hi, hi_strict) + + +def Includes(*values, approx: bool = False) -> FilterExpression: # pylint: disable=invalid-name + """Match any of the given *values*. Uses `Substring` if *approx* is set.""" + args = normalize_args(*values) + cls = Substring if approx else Equals + if len(args) == 0: + raise AttributeError('expected at least one value, found none') + if len(args) == 1: + return cls(args[0]) + return Or(cls(v) for v in args) + + +def Excludes(*values, approx: bool = False) -> FilterExpression: # pylint: disable=invalid-name + """Match none of the given *values*. Uses `Substring` if *approx* is set.""" + args = normalize_args(*values) + cls = Substring if approx else Equals + if len(args) == 0: + raise AttributeError('expected at least one value, found none') + if len(args) == 1: + return Not(cls(args[0])) + return Not(Or(cls(v) for v in args)) + + +## EOF ## diff --git a/bsfs/query/matcher.py b/bsfs/query/matcher.py new file mode 100644 index 0000000..17c9c8e --- /dev/null +++ b/bsfs/query/matcher.py @@ -0,0 +1,361 @@ + +# imports +from collections import defaultdict +from itertools import product +from time import time +import random +import threading +import typing + +# external imports +from hopcroftkarp import HopcroftKarp + +# bsfs imports +from bsfs.utils import errors, typename + +# inner-module imports +from . import ast + +# exports +__all__ : typing.Sequence[str] = ( + 'Filter', + ) + + +## code ## + +class Any(ast.filter.FilterExpression, ast.filter.PredicateExpression): + """Match any ast class. + + Note that Any instances are unique, i.e. they do not compare, and + can hence be repeated in a set: + >>> Any() == Any() + False + >>> len({Any(), Any(), Any(), Any()}) + 4 + + """ + + # unique instance id + _uid: typing.Tuple[int, int, float, float] + + def __init__(self): + self._uid = ( + id(self), + id(threading.current_thread()), + time(), + random.random(), + ) + + def __eq__(self, other: typing.Any): + return super().__eq__(other) and self._uid == other._uid + + def __hash__(self): + return hash((super().__hash__(), self._uid)) + + +class Rest(ast.filter.FilterExpression, ast.filter.PredicateExpression): + """Match the leftovers in a set of items to be compared. + + Rest can be used in junction with aggregating expressions such as ast.filter.And, + ast.filter.Or, ast.filter.OneOf. It controls childs expressions that were not yet + consumed by other matching rules. Rest may match to only a specific expression. + The expresssion defaults to Any(). + + For example, the following to ast structures would match since Rest + allows an arbitrary repetition of ast.filter.Equals statements. + + >>> And(Equals('hello'), Equals('world'), Equals('foobar')) + >>> And(Equals('world'), Rest(Partial(Equals))) + + """ + + # child expression for the Rest. + expr: typing.Union[ast.filter.FilterExpression, ast.filter.PredicateExpression] + + def __init__( + self, + expr: typing.Optional[typing.Union[ast.filter.FilterExpression, ast.filter.PredicateExpression]] = None, + ): + if expr is None: + expr = Any() + self.expr = expr + + def __repr__(self) -> str: + return f'{typename(self)}({self.expr})' + + def __hash__(self) -> int: + return hash((super().__hash__(), self.expr)) + + def __eq__(self, other: typing.Any) -> bool: + return super().__eq__(other) and self.expr == other.expr + + +class Partial(ast.filter.FilterExpression, ast.filter.PredicateExpression): + """Match a partially defined ast expression. + + Literal values might be irrelevant or unknown when comparing two ast + structures. Partial allows to constrain the matcher to a certain + ast class, while leaving some of its members unspecified. + + Pass the class (not instance) and its members as keyword arguments + to Partial. Note that the arguments are not validated. + + For example, the following instance matches any ast.filter.Equals, + irrespective of its value: + + >>> Partial(ast.filter.Equals) + + Likewise, the following instance matches any ast.filter.LessThan + that has a strict bounds, but makes no claim about the threshold: + + >>> Partial(ast.filter.LessThan, strict=False) + + """ + + # target node type. + node: typing.Type + + # node construction args. + kwargs: typing.Dict[str, typing.Any] + + def __init__( + self, + node: typing.Type, + **kwargs, + ): + self.node = node + self.kwargs = kwargs + + def __repr__(self) -> str: + return f'{typename(self)}({self.node.__name__}, {self.kwargs})' + + def __hash__(self) -> int: + kwargs = tuple((key, self.kwargs[key]) for key in sorted(self.kwargs)) + return hash((super().__hash__(), self.node, kwargs)) + + def __eq__(self, other: typing.Any) -> bool: + return super().__eq__(other) \ + and self.node == other.node \ + and self.kwargs == other.kwargs + + def match( + self, + name: str, + value: typing.Any, + ) -> bool: + """Return True if *name* is unspecified or matches *value*.""" + return name not in self.kwargs or self.kwargs[name] == value + + +T_ITEM_TYPE = typing.TypeVar('T_ITEM_TYPE') # pylint: disable=invalid-name + +def _set_matcher( + query: typing.Collection[T_ITEM_TYPE], + reference: typing.Collection[T_ITEM_TYPE], + cmp: typing.Callable[[T_ITEM_TYPE, T_ITEM_TYPE], bool], + ) -> bool: + """Compare two sets of child expressions. + + This check has a best-case complexity of O(|N|**2) and worst-case + complexity of O(|N|**3), with N the number of child expressions. + """ + # get reference items + r_items = list(reference) + # deal with Rest + r_rest = {itm for itm in r_items if isinstance(itm, Rest)} + if len(r_rest) > 1: + raise errors.BackendError(f'there must be at most one Rest instance per set, found {len(r_rest)}') + if len(r_rest) == 1: + # replace Rest by filling the reference up with rest's expression + # NOTE: convert r_items to list so that items can be repeated + expr = next(iter(r_rest)).expr # type: ignore [attr-defined] + r_items = [itm for itm in r_items if not isinstance(itm, Rest)] + r_items += [expr for _ in range(len(query) - len(r_items))] # type: ignore [misc] + # sanity check: cannot match if the item sizes differ: + # either a reference item is unmatched (len(r_items) > len(query)) + # or a query item is unmatched (len(r_items) < len(query)) + if len(query) != len(r_items): + return False + + # To have a positive match between the query and the reference, + # each query expr has to match any reference expr. + # However, each reference expr can only be "consumed" once even + # if it matches multiple query exprs (e.g., the Any expression matches + # every query expr). + # This is a bipartide matching problem (Hall's marriage problem) + # and the Hopcroft-Karp-Karzanov algorithm finds a maximum + # matching. While there might be multiple maximum matchings, + # we only need to know whether (at least) one complete matching + # exists. The hopcroftkarp module provides this functionality. + # The HKK algorithm has worst-case complexity of O(|N|**2 * sqrt(|N|)) + # and we also need to compare expressions pairwise, hence O(|N|**2). + num_items = len(r_items) + graph = defaultdict(set) + # build the bipartide graph as {lhs: {rhs}, ...} + # lhs and rhs must be disjoint identifiers. + for (ridx, ref), (nidx, node) in product(enumerate(r_items), enumerate(query)): + # add edges for equal expressions + if cmp(node, ref): + graph[ridx].add(num_items + nidx) + + # maximum_matching returns the matches for all nodes in the graph + # ({ref_itm: node_itm}), hence a complete matching's size is + # the number of reference's child expressions. + return len(HopcroftKarp(graph).maximum_matching(keys_only=True)) == num_items + + +class Filter(): + """Compare a bsfs.query.ast.filter` query's structure to a reference ast. + + The reference ast may include `Rest`, `Partial`, or `Any` to account for irrelevant + or unknown ast pieces. + + This is only a structural comparison, not a semantic one. For example, the + two following queries are semantically identical, but structurally different, + and would therefore not match: + + >>> ast.filter.OneOf(ast.filter.Predicate(ns.bse.name)) + >>> ast.filter.Predicate(ns.bse.name) + + """ + + def __call__(self, query: ast.filter.FilterExpression, reference: ast.filter.FilterExpression) -> bool: + """Compare a *query* to a *reference* ast structure. + Return True if both are structurally equivalent. + """ + if not isinstance(query, ast.filter.FilterExpression): + raise errors.BackendError(f'expected filter expression, found {query}') + if not isinstance(reference, ast.filter.FilterExpression): + raise errors.BackendError(f'expected filter expression, found {reference}') + return self._parse_filter_expression(query, reference) + + def _parse_filter_expression( + self, + node: ast.filter.FilterExpression, + reference: ast.filter.FilterExpression, + ) -> bool: + """Route *node* to the handler of the respective FilterExpression subclass.""" + # generic checks: reference type must be Any or match node type + if isinstance(reference, Any): + return True + # node-specific checks + if isinstance(node, ast.filter.Not): + return self._not(node, reference) + if isinstance(node, ast.filter.Has): + return self._has(node, reference) + if isinstance(node, ast.filter.Distance): + return self._distance(node, reference) + if isinstance(node, (ast.filter.Any, ast.filter.All)): + return self._branch(node, reference) + if isinstance(node, (ast.filter.And, ast.filter.Or)): + return self._agg(node, reference) + if isinstance(node, (ast.filter.Is, ast.filter.Equals, ast.filter.Substring, + ast.filter.StartsWith, ast.filter.EndsWith)): + return self._value(node, reference) + if isinstance(node, (ast.filter.LessThan, ast.filter.GreaterThan)): + return self._bounded(node, reference) + # invalid node + raise errors.BackendError(f'expected filter expression, found {node}') + + def _parse_predicate_expression( + self, + node: ast.filter.PredicateExpression, + reference: ast.filter.PredicateExpression, + ) -> bool: + """Route *node* to the handler of the respective PredicateExpression subclass.""" + if isinstance(reference, Any): + return True + if isinstance(node, ast.filter.Predicate): + return self._predicate(node, reference) + if isinstance(node, ast.filter.OneOf): + return self._one_of(node, reference) + # invalid node + raise errors.BackendError(f'expected predicate expression, found {node}') + + def _one_of(self, node: ast.filter.OneOf, reference: ast.filter.PredicateExpression) -> bool: + if not isinstance(reference, type(node)): + return False + return _set_matcher(node, reference, self._parse_predicate_expression) + + def _predicate(self, node: ast.filter.Predicate, reference: ast.filter.PredicateExpression) -> bool: + if not isinstance(reference, (Partial, type(node))): + return False + # partial check + if isinstance(reference, Partial): + if not isinstance(node, reference.node): + return False + return reference.match('predicate', node.predicate) \ + and reference.match('reverse', node.reverse) + # full check + return node.predicate == reference.predicate \ + and node.reverse == reference.reverse + + def _branch(self, + node: typing.Union[ast.filter.Any, ast.filter.All], + reference: ast.filter.FilterExpression, + ) -> bool: + if not isinstance(reference, type(node)): + return False + if not self._parse_predicate_expression(node.predicate, reference.predicate): # type: ignore [attr-defined] + return False + if not self._parse_filter_expression(node.expr, reference.expr): # type: ignore [attr-defined] + return False + return True + + def _agg(self, node: typing.Union[ast.filter.And, ast.filter.Or], reference: ast.filter.FilterExpression) -> bool: + if not isinstance(reference, type(node)): + return False + return _set_matcher(node, reference, self._parse_filter_expression) # type: ignore [arg-type] + + def _not(self, node: ast.filter.Not, reference: ast.filter.FilterExpression) -> bool: + if not isinstance(reference, type(node)): + return False + return self._parse_filter_expression(node.expr, reference.expr) + + def _has(self, node: ast.filter.Has, reference: ast.filter.FilterExpression) -> bool: + if not isinstance(reference, type(node)): + return False + return self._parse_predicate_expression(node.predicate, reference.predicate) \ + and self._parse_filter_expression(node.count, reference.count) + + def _distance(self, node: ast.filter.Distance, reference: ast.filter.FilterExpression) -> bool: + if not isinstance(reference, (Partial, type(node))): + return False + # partial check + if isinstance(reference, Partial): + if not isinstance(node, reference.node): + return False + return reference.match('reference', node.reference) \ + and reference.match('threshold', node.threshold) \ + and reference.match('strict', node.strict) + # full check + return node.reference == reference.reference \ + and node.threshold == reference.threshold \ + and node.strict == reference.strict + + def _value(self, node: ast.filter._Value, reference: ast.filter.FilterExpression) -> bool: + if not isinstance(reference, (Partial, type(node))): + return False + # partial check + if isinstance(reference, Partial): + if not isinstance(node, reference.node): + return False + return reference.match('value', node.value) + # full ckeck + return node.value == reference.value + + def _bounded(self, node: ast.filter._Bounded, reference: ast.filter.FilterExpression) -> bool: + if not isinstance(reference, (Partial, type(node))): + return False + # partial check + if isinstance(reference, Partial): + if not isinstance(node, reference.node): + return False + return reference.match('threshold', node.threshold) \ + and reference.match('strict', node.strict) + # full check + return node.threshold == reference.threshold \ + and node.strict == reference.strict + +## EOF ## diff --git a/bsfs/query/validator.py b/bsfs/query/validator.py new file mode 100644 index 0000000..10ca492 --- /dev/null +++ b/bsfs/query/validator.py @@ -0,0 +1,351 @@ + +# imports +import typing + +# bsfs imports +from bsfs import schema as bsc +from bsfs.namespace import ns +from bsfs.utils import errors, typename + +# inner-module imports +from . import ast + +# exports +__all__ : typing.Sequence[str] = ( + 'Filter', + ) + +# FIXME: Split into a submodule and the two classes into their own respective files. + +## code ## + +class Filter(): + """Validate a `bsfs.query.ast.filter` query's structure and schema compliance. + + * Conditions (Bounded, Value) can only be applied on literals + * Branches, Id, and Has can only be applied on nodes + * Predicates' domain and range must match + * Predicate paths must follow the schema + * Referenced types are present in the schema + + """ + + # schema to validate against. + schema: bsc.Schema + + def __init__(self, schema: bsc.Schema): + self.schema = schema + + def __call__(self, root_type: bsc.Node, query: ast.filter.FilterExpression) -> bool: + """Alias for `Filter.validate`.""" + return self.validate(root_type, query) + + def validate(self, root_type: bsc.Node, query: ast.filter.FilterExpression) -> bool: + """Validate a filter *query*, assuming the subject having *root_type*. + + Raises a `bsfs.utils.errors.ConsistencyError` if the query violates the schema. + Raises a `bsfs.utils.errors.BackendError` if the query structure is invalid. + + """ + # root_type must be a schema.Node + if not isinstance(root_type, bsc.Node): + raise TypeError(f'expected a node, found {typename(root_type)}') + # root_type must exist in the schema + if root_type not in self.schema.nodes(): + raise errors.ConsistencyError(f'{root_type} is not defined in the schema') + # check root expression + self._parse_filter_expression(root_type, query) + # all tests passed + return True + + + ## routing methods + + def _parse_filter_expression(self, type_: bsc.Vertex, node: ast.filter.FilterExpression): + """Route *node* to the handler of the respective FilterExpression subclass.""" + if isinstance(node, ast.filter.Is): + return self._is(type_, node) + if isinstance(node, ast.filter.Not): + return self._not(type_, node) + if isinstance(node, ast.filter.Has): + return self._has(type_, node) + if isinstance(node, ast.filter.Distance): + return self._distance(type_, node) + if isinstance(node, (ast.filter.Any, ast.filter.All)): + return self._branch(type_, node) + if isinstance(node, (ast.filter.And, ast.filter.Or)): + return self._agg(type_, node) + if isinstance(node, (ast.filter.Equals, ast.filter.Substring, ast.filter.StartsWith, ast.filter.EndsWith)): + return self._value(type_, node) + if isinstance(node, (ast.filter.LessThan, ast.filter.GreaterThan)): + return self._bounded(type_, node) + # invalid node + raise errors.BackendError(f'expected filter expression, found {node}') + + def _parse_predicate_expression(self, node: ast.filter.PredicateExpression) -> typing.Tuple[bsc.Vertex, bsc.Vertex]: + """Route *node* to the handler of the respective PredicateExpression subclass.""" + if isinstance(node, ast.filter.Predicate): + return self._predicate(node) + if isinstance(node, ast.filter.OneOf): + return self._one_of(node) + # invalid node + raise errors.BackendError(f'expected predicate expression, found {node}') + + + ## predicate expressions + + def _predicate(self, node: ast.filter.Predicate) -> typing.Tuple[bsc.Vertex, bsc.Vertex]: + # predicate exists in the schema + if not self.schema.has_predicate(node.predicate): + raise errors.ConsistencyError(f'predicate {node.predicate} is not in the schema') + # determine domain and range + pred = self.schema.predicate(node.predicate) + if not isinstance(pred.range, (bsc.Node, bsc.Literal)): + raise errors.BackendError(f'the range of predicate {pred} is undefined') + dom, rng = pred.domain, pred.range + if node.reverse: + dom, rng = rng, dom # type: ignore [assignment] # variable re-use confuses mypy + # return domain and range + return dom, rng + + def _one_of(self, node: ast.filter.OneOf) -> typing.Tuple[bsc.Vertex, bsc.Vertex]: + # determine domain and range types + # NOTE: select the most specific domain and the most generic range + dom, rng = None, None + for pred in node: + # parse child expression + subdom, subrng = self._parse_predicate_expression(pred) + # determine overall domain + if dom is None or subdom < dom: # pick most specific domain + dom = subdom + # domains must be related across all child expressions + if not subdom <= dom and not subdom >= dom: + raise errors.ConsistencyError(f'domains {subdom} and {dom} are not related') + # determine overall range + if rng is None or subrng > rng: # pick most generic range + rng = subrng + # ranges must be related across all child expressions + if not subrng <= rng and not subrng >= rng: + raise errors.ConsistencyError(f'ranges {subrng} and {rng} are not related') + # OneOf guarantees at least one expression, dom and rng are always bsc.Vertex. + # mypy does not realize this, hence we ignore the warning. + return dom, rng # type: ignore [return-value] + + + ## intermediates + + def _branch(self, type_: bsc.Vertex, node: ast.filter._Branch): + # type is a Node + if not isinstance(type_, bsc.Node): + raise errors.ConsistencyError(f'expected a Node, found {type_}') + # type exists in the schema + # FIXME: Isn't it actually guaranteed that the type (except the root type) is part of the schema? + # all types can be traced back to (a) root_type, (b) predicate, or (c) manually set (e.g. in _is). + # For (a), we do (and have to) perform a check. For (c), the code base should be consistent throughout + # the module, so this is an assumption that has to be ensured in schema.Schema. For (b), we know (and + # check) that the predicate is in the schema, hence all node/literals derived from it are also in the + # schema by construction of the schema.Schema class. So, why do we check this every time? + if type_ not in self.schema.nodes(): + raise errors.ConsistencyError(f'node {type_} is not in the schema') + # predicate is valid + dom, rng = self._parse_predicate_expression(node.predicate) + # type_ is a subtype of the predicate's domain + if not type_ <= dom: + raise errors.ConsistencyError(f'expected type {dom} or subtype thereof, found {type_}') + # child expression is valid + self._parse_filter_expression(rng, node.expr) + + def _agg(self, type_: bsc.Vertex, node: ast.filter._Agg): + for expr in node: + # child expression is valid + self._parse_filter_expression(type_, expr) + + def _not(self, type_: bsc.Vertex, node: ast.filter.Not): + # child expression is valid + self._parse_filter_expression(type_, node.expr) + + def _has(self, type_: bsc.Vertex, node: ast.filter.Has): + # type is a Node + if not isinstance(type_, bsc.Node): + raise errors.ConsistencyError(f'expected a Node, found {type_}') + # type exists in the schema + if type_ not in self.schema.nodes(): + raise errors.ConsistencyError(f'node {type_} is not in the schema') + # predicate is valid + dom, _= self._parse_predicate_expression(node.predicate) + # type_ is a subtype of the predicate's domain + if not type_ <= dom: + raise errors.ConsistencyError(f'expected type {dom}, found {type_}') + # node.count is a numerical expression + self._parse_filter_expression(self.schema.literal(ns.bsl.Number), node.count) + + def _distance(self, type_: bsc.Vertex, node: ast.filter.Distance): + # type is a Literal + if not isinstance(type_, bsc.Feature): + raise errors.ConsistencyError(f'expected a Feature, found {type_}') + # type exists in the schema + if type_ not in self.schema.literals(): + raise errors.ConsistencyError(f'literal {type_} is not in the schema') + # reference matches type_ + if len(node.reference) != type_.dimension: + raise errors.ConsistencyError(f'reference has dimension {len(node.reference)}, expected {type_.dimension}') + # FIXME: test dtype + + + ## conditions + + def _is(self, type_: bsc.Vertex, node: ast.filter.Is): # pylint: disable=unused-argument # (node) + if not isinstance(type_, bsc.Node): + raise errors.ConsistencyError(f'expected a Node, found {type_}') + if type_ not in self.schema.nodes(): + raise errors.ConsistencyError(f'node {type_} is not in the schema') + + def _value(self, type_: bsc.Vertex, node: ast.filter._Value): # pylint: disable=unused-argument # (node) + # type is a literal + if not isinstance(type_, bsc.Literal): + raise errors.ConsistencyError(f'expected a Literal, found {type_}') + # type exists in the schema + if type_ not in self.schema.literals(): + raise errors.ConsistencyError(f'literal {type_} is not in the schema') + # FIXME: Check if node.value corresponds to type_ + # FIXME: A specific literal might be requested (i.e., a numeric type when used in Has) + + def _bounded(self, type_: bsc.Vertex, node: ast.filter._Bounded): # pylint: disable=unused-argument # (node) + # type is a literal + if not isinstance(type_, bsc.Literal): + raise errors.ConsistencyError(f'expected a Literal, found {type_}') + # type exists in the schema + if type_ not in self.schema.literals(): + raise errors.ConsistencyError(f'literal {type_} is not in the schema') + # type must be a numerical + if not type_ <= self.schema.literal(ns.bsl.Number): + raise errors.ConsistencyError(f'expected a number type, found {type_}') + # FIXME: Check if node.value corresponds to type_ + + +class Fetch(): + """Validate a `bsfs.query.ast.fetch` query's structure and schema compliance. + + * Value can only be applied on literals + * Node can only be applied on nodes + * Names must be non-empty + * Branching nodes' predicates must match the type + * Symbols must be in the schema + * Predicates must follow the schema + + """ + + # schema to validate against. + schema: bsc.Schema + + def __init__(self, schema: bsc.Schema): + self.schema = schema + + def __call__(self, root_type: bsc.Node, query: ast.fetch.FetchExpression) -> bool: + """Alias for `Fetch.validate`.""" + return self.validate(root_type, query) + + def validate(self, root_type: bsc.Node, query: ast.fetch.FetchExpression) -> bool: + """Validate a fetch *query*, assuming the subject having *root_type*. + + Raises a `bsfs.utils.errors.ConsistencyError` if the query violates the schema. + Raises a `bsfs.utils.errors.BackendError` if the query structure is invalid. + + """ + # root_type must be a schema.Node + if not isinstance(root_type, bsc.Node): + raise TypeError(f'expected a node, found {typename(root_type)}') + # root_type must exist in the schema + if root_type not in self.schema.nodes(): + raise errors.ConsistencyError(f'{root_type} is not defined in the schema') + # query must be a FetchExpression + if not isinstance(query, ast.fetch.FetchExpression): + raise TypeError(f'expected a fetch expression, found {typename(query)}') + # check root expression + self._parse_fetch_expression(root_type, query) + # all tests passed + return True + + def _parse_fetch_expression(self, type_: bsc.Vertex, node: ast.fetch.FetchExpression): + """Route *node* to the handler of the respective FetchExpression subclass.""" + if isinstance(node, (ast.fetch.Fetch, ast.fetch.Value, ast.fetch.Node)): + # NOTE: don't return so that checks below are executed + self._branch(type_, node) + if isinstance(node, (ast.fetch.Value, ast.fetch.Node)): + # NOTE: don't return so that checks below are executed + self._named(type_, node) + if isinstance(node, ast.fetch.All): + return self._all(type_, node) + if isinstance(node, ast.fetch.Fetch): + return self._fetch(type_, node) + if isinstance(node, ast.fetch.Value): + return self._value(type_, node) + if isinstance(node, ast.fetch.Node): + return self._node(type_, node) + if isinstance(node, ast.fetch.This): + return self._this(type_, node) + # invalid node + raise errors.BackendError(f'expected fetch expression, found {node}') + + def _all(self, type_: bsc.Vertex, node: ast.fetch.All): + # check child expressions + for expr in node: + self._parse_fetch_expression(type_, expr) + + def _branch(self, type_: bsc.Vertex, node: ast.fetch._Branch): + # type is a node + if not isinstance(type_, bsc.Node): + raise errors.ConsistencyError(f'expected a Node, found {type_}') + # node exists in the schema + if type_ not in self.schema.nodes(): + raise errors.ConsistencyError(f'node {type_} is not in the schema') + # predicate exists in the schema + if not self.schema.has_predicate(node.predicate): + raise errors.ConsistencyError(f'predicate {node.predicate} is not in the schema') + pred = self.schema.predicate(node.predicate) + # type_ must be a subclass of domain + if not type_ <= pred.domain: + raise errors.ConsistencyError( + f'expected type {pred.domain} or subtype thereof, found {type_}') + + def _fetch(self, type_: bsc.Vertex, node: ast.fetch.Fetch): # pylint: disable=unused-argument # type_ was considered in _branch + # range must be a node + rng = self.schema.predicate(node.predicate).range + if not isinstance(rng, bsc.Node): + raise errors.ConsistencyError( + f'expected the predicate\'s range to be a Node, found {rng}') + # child expression must be valid + self._parse_fetch_expression(rng, node.expr) + + def _named(self, type_: bsc.Vertex, node: ast.fetch._Named): # pylint: disable=unused-argument # type_ was considered in _branch + # name must be set + if node.name.strip() == '': + raise errors.BackendError('node name cannot be empty') + # FIXME: check for double name use? + + def _node(self, type_: bsc.Vertex, node: ast.fetch.Node): # pylint: disable=unused-argument # type_ was considered in _branch + # range must be a node + rng = self.schema.predicate(node.predicate).range + if not isinstance(rng, bsc.Node): + raise errors.ConsistencyError( + f'expected the predicate\'s range to be a Node, found {rng}') + + def _value(self, type_: bsc.Vertex, node: ast.fetch.Value): # pylint: disable=unused-argument # type_ was considered in _branch + # range must be a literal + rng = self.schema.predicate(node.predicate).range + if not isinstance(rng, bsc.Literal): + raise errors.ConsistencyError( + f'expected the predicate\'s range to be a Literal, found {rng}') + + def _this(self, type_: bsc.Vertex, node: ast.fetch.This): + # type is a node + if not isinstance(type_, bsc.Node): + raise errors.ConsistencyError(f'expected a Node, found {type_}') + # node exists in the schema + if type_ not in self.schema.nodes(): + raise errors.ConsistencyError(f'node {type_} is not in the schema') + # name must be set + if node.name.strip() == '': + raise errors.BackendError('node name cannot be empty') + +## EOF ## |