diff options
Diffstat (limited to 'bsfs/query/matcher.py')
-rw-r--r-- | bsfs/query/matcher.py | 361 |
1 files changed, 361 insertions, 0 deletions
diff --git a/bsfs/query/matcher.py b/bsfs/query/matcher.py new file mode 100644 index 0000000..17c9c8e --- /dev/null +++ b/bsfs/query/matcher.py @@ -0,0 +1,361 @@ + +# imports +from collections import defaultdict +from itertools import product +from time import time +import random +import threading +import typing + +# external imports +from hopcroftkarp import HopcroftKarp + +# bsfs imports +from bsfs.utils import errors, typename + +# inner-module imports +from . import ast + +# exports +__all__ : typing.Sequence[str] = ( + 'Filter', + ) + + +## code ## + +class Any(ast.filter.FilterExpression, ast.filter.PredicateExpression): + """Match any ast class. + + Note that Any instances are unique, i.e. they do not compare, and + can hence be repeated in a set: + >>> Any() == Any() + False + >>> len({Any(), Any(), Any(), Any()}) + 4 + + """ + + # unique instance id + _uid: typing.Tuple[int, int, float, float] + + def __init__(self): + self._uid = ( + id(self), + id(threading.current_thread()), + time(), + random.random(), + ) + + def __eq__(self, other: typing.Any): + return super().__eq__(other) and self._uid == other._uid + + def __hash__(self): + return hash((super().__hash__(), self._uid)) + + +class Rest(ast.filter.FilterExpression, ast.filter.PredicateExpression): + """Match the leftovers in a set of items to be compared. + + Rest can be used in junction with aggregating expressions such as ast.filter.And, + ast.filter.Or, ast.filter.OneOf. It controls childs expressions that were not yet + consumed by other matching rules. Rest may match to only a specific expression. + The expresssion defaults to Any(). + + For example, the following to ast structures would match since Rest + allows an arbitrary repetition of ast.filter.Equals statements. + + >>> And(Equals('hello'), Equals('world'), Equals('foobar')) + >>> And(Equals('world'), Rest(Partial(Equals))) + + """ + + # child expression for the Rest. + expr: typing.Union[ast.filter.FilterExpression, ast.filter.PredicateExpression] + + def __init__( + self, + expr: typing.Optional[typing.Union[ast.filter.FilterExpression, ast.filter.PredicateExpression]] = None, + ): + if expr is None: + expr = Any() + self.expr = expr + + def __repr__(self) -> str: + return f'{typename(self)}({self.expr})' + + def __hash__(self) -> int: + return hash((super().__hash__(), self.expr)) + + def __eq__(self, other: typing.Any) -> bool: + return super().__eq__(other) and self.expr == other.expr + + +class Partial(ast.filter.FilterExpression, ast.filter.PredicateExpression): + """Match a partially defined ast expression. + + Literal values might be irrelevant or unknown when comparing two ast + structures. Partial allows to constrain the matcher to a certain + ast class, while leaving some of its members unspecified. + + Pass the class (not instance) and its members as keyword arguments + to Partial. Note that the arguments are not validated. + + For example, the following instance matches any ast.filter.Equals, + irrespective of its value: + + >>> Partial(ast.filter.Equals) + + Likewise, the following instance matches any ast.filter.LessThan + that has a strict bounds, but makes no claim about the threshold: + + >>> Partial(ast.filter.LessThan, strict=False) + + """ + + # target node type. + node: typing.Type + + # node construction args. + kwargs: typing.Dict[str, typing.Any] + + def __init__( + self, + node: typing.Type, + **kwargs, + ): + self.node = node + self.kwargs = kwargs + + def __repr__(self) -> str: + return f'{typename(self)}({self.node.__name__}, {self.kwargs})' + + def __hash__(self) -> int: + kwargs = tuple((key, self.kwargs[key]) for key in sorted(self.kwargs)) + return hash((super().__hash__(), self.node, kwargs)) + + def __eq__(self, other: typing.Any) -> bool: + return super().__eq__(other) \ + and self.node == other.node \ + and self.kwargs == other.kwargs + + def match( + self, + name: str, + value: typing.Any, + ) -> bool: + """Return True if *name* is unspecified or matches *value*.""" + return name not in self.kwargs or self.kwargs[name] == value + + +T_ITEM_TYPE = typing.TypeVar('T_ITEM_TYPE') # pylint: disable=invalid-name + +def _set_matcher( + query: typing.Collection[T_ITEM_TYPE], + reference: typing.Collection[T_ITEM_TYPE], + cmp: typing.Callable[[T_ITEM_TYPE, T_ITEM_TYPE], bool], + ) -> bool: + """Compare two sets of child expressions. + + This check has a best-case complexity of O(|N|**2) and worst-case + complexity of O(|N|**3), with N the number of child expressions. + """ + # get reference items + r_items = list(reference) + # deal with Rest + r_rest = {itm for itm in r_items if isinstance(itm, Rest)} + if len(r_rest) > 1: + raise errors.BackendError(f'there must be at most one Rest instance per set, found {len(r_rest)}') + if len(r_rest) == 1: + # replace Rest by filling the reference up with rest's expression + # NOTE: convert r_items to list so that items can be repeated + expr = next(iter(r_rest)).expr # type: ignore [attr-defined] + r_items = [itm for itm in r_items if not isinstance(itm, Rest)] + r_items += [expr for _ in range(len(query) - len(r_items))] # type: ignore [misc] + # sanity check: cannot match if the item sizes differ: + # either a reference item is unmatched (len(r_items) > len(query)) + # or a query item is unmatched (len(r_items) < len(query)) + if len(query) != len(r_items): + return False + + # To have a positive match between the query and the reference, + # each query expr has to match any reference expr. + # However, each reference expr can only be "consumed" once even + # if it matches multiple query exprs (e.g., the Any expression matches + # every query expr). + # This is a bipartide matching problem (Hall's marriage problem) + # and the Hopcroft-Karp-Karzanov algorithm finds a maximum + # matching. While there might be multiple maximum matchings, + # we only need to know whether (at least) one complete matching + # exists. The hopcroftkarp module provides this functionality. + # The HKK algorithm has worst-case complexity of O(|N|**2 * sqrt(|N|)) + # and we also need to compare expressions pairwise, hence O(|N|**2). + num_items = len(r_items) + graph = defaultdict(set) + # build the bipartide graph as {lhs: {rhs}, ...} + # lhs and rhs must be disjoint identifiers. + for (ridx, ref), (nidx, node) in product(enumerate(r_items), enumerate(query)): + # add edges for equal expressions + if cmp(node, ref): + graph[ridx].add(num_items + nidx) + + # maximum_matching returns the matches for all nodes in the graph + # ({ref_itm: node_itm}), hence a complete matching's size is + # the number of reference's child expressions. + return len(HopcroftKarp(graph).maximum_matching(keys_only=True)) == num_items + + +class Filter(): + """Compare a bsfs.query.ast.filter` query's structure to a reference ast. + + The reference ast may include `Rest`, `Partial`, or `Any` to account for irrelevant + or unknown ast pieces. + + This is only a structural comparison, not a semantic one. For example, the + two following queries are semantically identical, but structurally different, + and would therefore not match: + + >>> ast.filter.OneOf(ast.filter.Predicate(ns.bse.name)) + >>> ast.filter.Predicate(ns.bse.name) + + """ + + def __call__(self, query: ast.filter.FilterExpression, reference: ast.filter.FilterExpression) -> bool: + """Compare a *query* to a *reference* ast structure. + Return True if both are structurally equivalent. + """ + if not isinstance(query, ast.filter.FilterExpression): + raise errors.BackendError(f'expected filter expression, found {query}') + if not isinstance(reference, ast.filter.FilterExpression): + raise errors.BackendError(f'expected filter expression, found {reference}') + return self._parse_filter_expression(query, reference) + + def _parse_filter_expression( + self, + node: ast.filter.FilterExpression, + reference: ast.filter.FilterExpression, + ) -> bool: + """Route *node* to the handler of the respective FilterExpression subclass.""" + # generic checks: reference type must be Any or match node type + if isinstance(reference, Any): + return True + # node-specific checks + if isinstance(node, ast.filter.Not): + return self._not(node, reference) + if isinstance(node, ast.filter.Has): + return self._has(node, reference) + if isinstance(node, ast.filter.Distance): + return self._distance(node, reference) + if isinstance(node, (ast.filter.Any, ast.filter.All)): + return self._branch(node, reference) + if isinstance(node, (ast.filter.And, ast.filter.Or)): + return self._agg(node, reference) + if isinstance(node, (ast.filter.Is, ast.filter.Equals, ast.filter.Substring, + ast.filter.StartsWith, ast.filter.EndsWith)): + return self._value(node, reference) + if isinstance(node, (ast.filter.LessThan, ast.filter.GreaterThan)): + return self._bounded(node, reference) + # invalid node + raise errors.BackendError(f'expected filter expression, found {node}') + + def _parse_predicate_expression( + self, + node: ast.filter.PredicateExpression, + reference: ast.filter.PredicateExpression, + ) -> bool: + """Route *node* to the handler of the respective PredicateExpression subclass.""" + if isinstance(reference, Any): + return True + if isinstance(node, ast.filter.Predicate): + return self._predicate(node, reference) + if isinstance(node, ast.filter.OneOf): + return self._one_of(node, reference) + # invalid node + raise errors.BackendError(f'expected predicate expression, found {node}') + + def _one_of(self, node: ast.filter.OneOf, reference: ast.filter.PredicateExpression) -> bool: + if not isinstance(reference, type(node)): + return False + return _set_matcher(node, reference, self._parse_predicate_expression) + + def _predicate(self, node: ast.filter.Predicate, reference: ast.filter.PredicateExpression) -> bool: + if not isinstance(reference, (Partial, type(node))): + return False + # partial check + if isinstance(reference, Partial): + if not isinstance(node, reference.node): + return False + return reference.match('predicate', node.predicate) \ + and reference.match('reverse', node.reverse) + # full check + return node.predicate == reference.predicate \ + and node.reverse == reference.reverse + + def _branch(self, + node: typing.Union[ast.filter.Any, ast.filter.All], + reference: ast.filter.FilterExpression, + ) -> bool: + if not isinstance(reference, type(node)): + return False + if not self._parse_predicate_expression(node.predicate, reference.predicate): # type: ignore [attr-defined] + return False + if not self._parse_filter_expression(node.expr, reference.expr): # type: ignore [attr-defined] + return False + return True + + def _agg(self, node: typing.Union[ast.filter.And, ast.filter.Or], reference: ast.filter.FilterExpression) -> bool: + if not isinstance(reference, type(node)): + return False + return _set_matcher(node, reference, self._parse_filter_expression) # type: ignore [arg-type] + + def _not(self, node: ast.filter.Not, reference: ast.filter.FilterExpression) -> bool: + if not isinstance(reference, type(node)): + return False + return self._parse_filter_expression(node.expr, reference.expr) + + def _has(self, node: ast.filter.Has, reference: ast.filter.FilterExpression) -> bool: + if not isinstance(reference, type(node)): + return False + return self._parse_predicate_expression(node.predicate, reference.predicate) \ + and self._parse_filter_expression(node.count, reference.count) + + def _distance(self, node: ast.filter.Distance, reference: ast.filter.FilterExpression) -> bool: + if not isinstance(reference, (Partial, type(node))): + return False + # partial check + if isinstance(reference, Partial): + if not isinstance(node, reference.node): + return False + return reference.match('reference', node.reference) \ + and reference.match('threshold', node.threshold) \ + and reference.match('strict', node.strict) + # full check + return node.reference == reference.reference \ + and node.threshold == reference.threshold \ + and node.strict == reference.strict + + def _value(self, node: ast.filter._Value, reference: ast.filter.FilterExpression) -> bool: + if not isinstance(reference, (Partial, type(node))): + return False + # partial check + if isinstance(reference, Partial): + if not isinstance(node, reference.node): + return False + return reference.match('value', node.value) + # full ckeck + return node.value == reference.value + + def _bounded(self, node: ast.filter._Bounded, reference: ast.filter.FilterExpression) -> bool: + if not isinstance(reference, (Partial, type(node))): + return False + # partial check + if isinstance(reference, Partial): + if not isinstance(node, reference.node): + return False + return reference.match('threshold', node.threshold) \ + and reference.match('strict', node.strict) + # full check + return node.threshold == reference.threshold \ + and node.strict == reference.strict + +## EOF ## |