""" Part of the BlackStar filesystem (bsfs) module. A copy of the license is provided with the project. Author: Matthias Baumgartner, 2022 """ # imports from collections import defaultdict from itertools import product from time import time import random import threading import typing # external imports from hopcroftkarp import HopcroftKarp # bsfs imports from bsfs.utils import errors, typename # inner-module imports from . import ast # exports __all__ : typing.Sequence[str] = ( 'Filter', ) ## code ## class Any(ast.filter.FilterExpression, ast.filter.PredicateExpression): """Match any ast class. Note that Any instances are unique, i.e. they do not compare, and can hence be repeated in a set: >>> Any() == Any() False >>> len({Any(), Any(), Any(), Any()}) 4 """ # unique instance id _uid: typing.Tuple[int, int, float, float] def __init__(self): self._uid = ( id(self), id(threading.current_thread()), time(), random.random(), ) def __eq__(self, other: typing.Any): return super().__eq__(other) and self._uid == other._uid def __hash__(self): return hash((super().__hash__(), self._uid)) class Rest(ast.filter.FilterExpression, ast.filter.PredicateExpression): """Match the leftovers in a set of items to be compared. Rest can be used in junction with aggregating expressions such as ast.filter.And, ast.filter.Or, ast.filter.OneOf. It controls childs expressions that were not yet consumed by other matching rules. Rest may match to only a specific expression. The expresssion defaults to Any(). For example, the following to ast structures would match since Rest allows an arbitrary repetition of ast.filter.Equals statements. >>> And(Equals('hello'), Equals('world'), Equals('foobar')) >>> And(Equals('world'), Rest(Partial(Equals))) """ # child expression for the Rest. expr: typing.Union[ast.filter.FilterExpression, ast.filter.PredicateExpression] def __init__( self, expr: typing.Optional[typing.Union[ast.filter.FilterExpression, ast.filter.PredicateExpression]] = None, ): if expr is None: expr = Any() self.expr = expr def __repr__(self) -> str: return f'{typename(self)}({self.expr})' def __hash__(self) -> int: return hash((super().__hash__(), self.expr)) def __eq__(self, other: typing.Any) -> bool: return super().__eq__(other) and self.expr == other.expr class Partial(ast.filter.FilterExpression, ast.filter.PredicateExpression): """Match a partially defined ast expression. Literal values might be irrelevant or unknown when comparing two ast structures. Partial allows to constrain the matcher to a certain ast class, while leaving some of its members unspecified. Pass the class (not instance) and its members as keyword arguments to Partial. Note that the arguments are not validated. For example, the following instance matches any ast.filter.Equals, irrespective of its value: >>> Partial(ast.filter.Equals) Likewise, the following instance matches any ast.filter.LessThan that has a strict bounds, but makes no claim about the threshold: >>> Partial(ast.filter.LessThan, strict=False) """ # target node type. node: typing.Type # node construction args. kwargs: typing.Dict[str, typing.Any] def __init__( self, node: typing.Type, **kwargs, ): self.node = node self.kwargs = kwargs def __repr__(self) -> str: return f'{typename(self)}({self.node.__name__}, {self.kwargs})' def __hash__(self) -> int: kwargs = tuple((key, self.kwargs[key]) for key in sorted(self.kwargs)) return hash((super().__hash__(), self.node, kwargs)) def __eq__(self, other: typing.Any) -> bool: return super().__eq__(other) \ and self.node == other.node \ and self.kwargs == other.kwargs def match( self, name: str, value: typing.Any, ) -> bool: """Return True if *name* is unspecified or matches *value*.""" return name not in self.kwargs or self.kwargs[name] == value T_ITEM_TYPE = typing.TypeVar('T_ITEM_TYPE') # pylint: disable=invalid-name def _set_matcher( query: typing.Collection[T_ITEM_TYPE], reference: typing.Collection[T_ITEM_TYPE], cmp: typing.Callable[[T_ITEM_TYPE, T_ITEM_TYPE], bool], ) -> bool: """Compare two sets of child expressions. This check has a best-case complexity of O(|N|**2) and worst-case complexity of O(|N|**3), with N the number of child expressions. """ # get reference items r_items = list(reference) # deal with Rest r_rest = {itm for itm in r_items if isinstance(itm, Rest)} if len(r_rest) > 1: raise errors.BackendError(f'there must be at most one Rest instance per set, found {len(r_rest)}') if len(r_rest) == 1: # replace Rest by filling the reference up with rest's expression # NOTE: convert r_items to list so that items can be repeated expr = next(iter(r_rest)).expr # type: ignore [attr-defined] r_items = [itm for itm in r_items if not isinstance(itm, Rest)] r_items += [expr for _ in range(len(query) - len(r_items))] # type: ignore [misc] # sanity check: cannot match if the item sizes differ: # either a reference item is unmatched (len(r_items) > len(query)) # or a query item is unmatched (len(r_items) < len(query)) if len(query) != len(r_items): return False # To have a positive match between the query and the reference, # each query expr has to match any reference expr. # However, each reference expr can only be "consumed" once even # if it matches multiple query exprs (e.g., the Any expression matches # every query expr). # This is a bipartide matching problem (Hall's marriage problem) # and the Hopcroft-Karp-Karzanov algorithm finds a maximum # matching. While there might be multiple maximum matchings, # we only need to know whether (at least) one complete matching # exists. The hopcroftkarp module provides this functionality. # The HKK algorithm has worst-case complexity of O(|N|**2 * sqrt(|N|)) # and we also need to compare expressions pairwise, hence O(|N|**2). num_items = len(r_items) graph = defaultdict(set) # build the bipartide graph as {lhs: {rhs}, ...} # lhs and rhs must be disjoint identifiers. for (ridx, ref), (nidx, node) in product(enumerate(r_items), enumerate(query)): # add edges for equal expressions if cmp(node, ref): graph[ridx].add(num_items + nidx) # maximum_matching returns the matches for all nodes in the graph # ({ref_itm: node_itm}), hence a complete matching's size is # the number of reference's child expressions. return len(HopcroftKarp(graph).maximum_matching(keys_only=True)) == num_items class Filter(): """Compare a bsfs.query.ast.filter` query's structure to a reference ast. The reference ast may include `Rest`, `Partial`, or `Any` to account for irrelevant or unknown ast pieces. This is only a structural comparison, not a semantic one. For example, the two following queries are semantically identical, but structurally different, and would therefore not match: >>> ast.filter.OneOf(ast.filter.Predicate(ns.bse.filename)) >>> ast.filter.Predicate(ns.bse.filename) """ def __call__(self, query: ast.filter.FilterExpression, reference: ast.filter.FilterExpression) -> bool: """Compare a *query* to a *reference* ast structure. Return True if both are structurally equivalent. """ if not isinstance(query, ast.filter.FilterExpression): raise errors.BackendError(f'expected filter expression, found {query}') if not isinstance(reference, ast.filter.FilterExpression): raise errors.BackendError(f'expected filter expression, found {reference}') return self._parse_filter_expression(query, reference) def _parse_filter_expression( self, node: ast.filter.FilterExpression, reference: ast.filter.FilterExpression, ) -> bool: """Route *node* to the handler of the respective FilterExpression subclass.""" # generic checks: reference type must be Any or match node type if isinstance(reference, Any): return True # node-specific checks if isinstance(node, ast.filter.Not): return self._not(node, reference) if isinstance(node, ast.filter.Has): return self._has(node, reference) if isinstance(node, ast.filter.Distance): return self._distance(node, reference) if isinstance(node, (ast.filter.Any, ast.filter.All)): return self._branch(node, reference) if isinstance(node, (ast.filter.And, ast.filter.Or)): return self._agg(node, reference) if isinstance(node, (ast.filter.Is, ast.filter.Equals, ast.filter.Substring, ast.filter.StartsWith, ast.filter.EndsWith)): return self._value(node, reference) if isinstance(node, (ast.filter.LessThan, ast.filter.GreaterThan)): return self._bounded(node, reference) # invalid node raise errors.BackendError(f'expected filter expression, found {node}') def _parse_predicate_expression( self, node: ast.filter.PredicateExpression, reference: ast.filter.PredicateExpression, ) -> bool: """Route *node* to the handler of the respective PredicateExpression subclass.""" if isinstance(reference, Any): return True if isinstance(node, ast.filter.Predicate): return self._predicate(node, reference) if isinstance(node, ast.filter.OneOf): return self._one_of(node, reference) # invalid node raise errors.BackendError(f'expected predicate expression, found {node}') def _one_of(self, node: ast.filter.OneOf, reference: ast.filter.PredicateExpression) -> bool: if not isinstance(reference, type(node)): return False return _set_matcher(node, reference, self._parse_predicate_expression) def _predicate(self, node: ast.filter.Predicate, reference: ast.filter.PredicateExpression) -> bool: if not isinstance(reference, (Partial, type(node))): return False # partial check if isinstance(reference, Partial): if not isinstance(node, reference.node): return False return reference.match('predicate', node.predicate) \ and reference.match('reverse', node.reverse) # full check return node.predicate == reference.predicate \ and node.reverse == reference.reverse def _branch(self, node: typing.Union[ast.filter.Any, ast.filter.All], reference: ast.filter.FilterExpression, ) -> bool: if not isinstance(reference, type(node)): return False if not self._parse_predicate_expression(node.predicate, reference.predicate): # type: ignore [attr-defined] return False if not self._parse_filter_expression(node.expr, reference.expr): # type: ignore [attr-defined] return False return True def _agg(self, node: typing.Union[ast.filter.And, ast.filter.Or], reference: ast.filter.FilterExpression) -> bool: if not isinstance(reference, type(node)): return False return _set_matcher(node, reference, self._parse_filter_expression) # type: ignore [arg-type] def _not(self, node: ast.filter.Not, reference: ast.filter.FilterExpression) -> bool: if not isinstance(reference, type(node)): return False return self._parse_filter_expression(node.expr, reference.expr) def _has(self, node: ast.filter.Has, reference: ast.filter.FilterExpression) -> bool: if not isinstance(reference, type(node)): return False return self._parse_predicate_expression(node.predicate, reference.predicate) \ and self._parse_filter_expression(node.count, reference.count) def _distance(self, node: ast.filter.Distance, reference: ast.filter.FilterExpression) -> bool: if not isinstance(reference, (Partial, type(node))): return False # partial check if isinstance(reference, Partial): if not isinstance(node, reference.node): return False return reference.match('reference', node.reference) \ and reference.match('threshold', node.threshold) \ and reference.match('strict', node.strict) # full check return node.reference == reference.reference \ and node.threshold == reference.threshold \ and node.strict == reference.strict def _value(self, node: ast.filter._Value, reference: ast.filter.FilterExpression) -> bool: if not isinstance(reference, (Partial, type(node))): return False # partial check if isinstance(reference, Partial): if not isinstance(node, reference.node): return False return reference.match('value', node.value) # full ckeck return node.value == reference.value def _bounded(self, node: ast.filter._Bounded, reference: ast.filter.FilterExpression) -> bool: if not isinstance(reference, (Partial, type(node))): return False # partial check if isinstance(reference, Partial): if not isinstance(node, reference.node): return False return reference.match('threshold', node.threshold) \ and reference.match('strict', node.strict) # full check return node.threshold == reference.threshold \ and node.strict == reference.strict ## EOF ##