diff options
author | Matthias Baumgartner <dev@igsor.net> | 2023-01-30 14:35:32 +0100 |
---|---|---|
committer | Matthias Baumgartner <dev@igsor.net> | 2023-01-30 14:35:32 +0100 |
commit | 7e0987bcda136a17baea45b8eb22eb5ea668abc0 (patch) | |
tree | 39cdf798b2b1a66bef4bb72286ce8131ec68e136 | |
parent | 72e0bd78dc9cc1d74c3061b028040b64c0efcf9f (diff) | |
download | bsfs-7e0987bcda136a17baea45b8eb22eb5ea668abc0.tar.gz bsfs-7e0987bcda136a17baea45b8eb22eb5ea668abc0.tar.bz2 bsfs-7e0987bcda136a17baea45b8eb22eb5ea668abc0.zip |
filter ast comparison
-rw-r--r-- | bsfs/query/matcher.py | 366 | ||||
-rw-r--r-- | setup.py | 5 | ||||
-rw-r--r-- | test/query/test_matcher.py | 1182 |
3 files changed, 1552 insertions, 1 deletions
diff --git a/bsfs/query/matcher.py b/bsfs/query/matcher.py new file mode 100644 index 0000000..a910756 --- /dev/null +++ b/bsfs/query/matcher.py @@ -0,0 +1,366 @@ +""" + +Part of the BlackStar filesystem (bsfs) module. +A copy of the license is provided with the project. +Author: Matthias Baumgartner, 2022 +""" +# imports +from collections import defaultdict +from itertools import product +from time import time +import random +import threading +import typing + +# external imports +from hopcroftkarp import HopcroftKarp + +# bsfs imports +from bsfs.utils import errors, typename + +# inner-module imports +from . import ast + +# exports +__all__ : typing.Sequence[str] = ( + 'Filter', + ) + + +## code ## + +class Any(ast.filter.FilterExpression, ast.filter.PredicateExpression): + """Match any ast class. + + Note that Any instances are unique, i.e. they do not compare, and + can hence be repeated in a set: + >>> Any() == Any() + False + >>> len({Any(), Any(), Any(), Any()}) + 4 + + """ + + # unique instance id + _uid: typing.Tuple[int, int, float, float] + + def __init__(self): + self._uid = ( + id(self), + id(threading.current_thread()), + time(), + random.random(), + ) + + def __eq__(self, other: typing.Any): + return super().__eq__(other) and self._uid == other._uid + + def __hash__(self): + return hash((super().__hash__(), self._uid)) + + +class Rest(ast.filter.FilterExpression, ast.filter.PredicateExpression): + """Match the leftovers in a set of items to be compared. + + Rest can be used in junction with aggregating expressions such as ast.filter.And, + ast.filter.Or, ast.filter.OneOf. It controls childs expressions that were not yet + consumed by other matching rules. Rest may match to only a specific expression. + The expresssion defaults to Any(). + + For example, the following to ast structures would match since Rest + allows an arbitrary repetition of ast.filter.Equals statements. + + >>> And(Equals('hello'), Equals('world'), Equals('foobar')) + >>> And(Equals('world'), Rest(Partial(Equals))) + + """ + + # child expression for the Rest. + expr: typing.Union[ast.filter.FilterExpression, ast.filter.PredicateExpression] + + def __init__( + self, + expr: typing.Optional[typing.Union[ast.filter.FilterExpression, ast.filter.PredicateExpression]] = None, + ): + if expr is None: + expr = Any() + self.expr = expr + + def __repr__(self) -> str: + return f'{typename(self)}({self.expr})' + + def __hash__(self) -> int: + return hash((super().__hash__(), self.expr)) + + def __eq__(self, other: typing.Any) -> bool: + return super().__eq__(other) and self.expr == other.expr + + +class Partial(ast.filter.FilterExpression, ast.filter.PredicateExpression): + """Match a partially defined ast expression. + + Literal values might be irrelevant or unknown when comparing two ast + structures. Partial allows to constrain the matcher to a certain + ast class, while leaving some of its members unspecified. + + Pass the class (not instance) and its members as keyword arguments + to Partial. Note that the arguments are not validated. + + For example, the following instance matches any ast.filter.Equals, + irrespective of its value: + + >>> Partial(ast.filter.Equals) + + Likewise, the following instance matches any ast.filter.LessThan + that has a strict bounds, but makes no claim about the threshold: + + >>> Partial(ast.filter.LessThan, strict=False) + + """ + + # target node type. + node: typing.Type + + # node construction args. + kwargs: typing.Dict[str, typing.Any] + + def __init__( + self, + node: typing.Type, + **kwargs, + ): + self.node = node + self.kwargs = kwargs + + def __repr__(self) -> str: + return f'{typename(self)}({self.node.__name__}, {self.kwargs})' + + def __hash__(self) -> int: + kwargs = tuple((key, self.kwargs[key]) for key in sorted(self.kwargs)) + return hash((super().__hash__(), self.node, kwargs)) + + def __eq__(self, other: typing.Any) -> bool: + return super().__eq__(other) \ + and self.node == other.node \ + and self.kwargs == other.kwargs + + def match( + self, + name: str, + value: typing.Any, + ) -> bool: + """Return True if *name* is unspecified or matches *value*.""" + return name not in self.kwargs or self.kwargs[name] == value + + +T_ITEM_TYPE = typing.TypeVar('T_ITEM_TYPE') # pylint: disable=invalid-name + +def _set_matcher( + query: typing.Collection[T_ITEM_TYPE], + reference: typing.Collection[T_ITEM_TYPE], + cmp: typing.Callable[[T_ITEM_TYPE, T_ITEM_TYPE], bool], + ) -> bool: + """Compare two sets of child expressions. + + This check has a best-case complexity of O(|N|**2) and worst-case + complexity of O(|N|**3), with N the number of child expressions. + """ + # get reference items + r_items = list(reference) + # deal with Rest + r_rest = {itm for itm in r_items if isinstance(itm, Rest)} + if len(r_rest) > 1: + raise errors.BackendError(f'there must be at most one Rest instance per set, found {len(r_rest)}') + if len(r_rest) == 1: + # replace Rest by filling the reference up with rest's expression + # NOTE: convert r_items to list so that items can be repeated + expr = next(iter(r_rest)).expr # type: ignore [attr-defined] + r_items = [itm for itm in r_items if not isinstance(itm, Rest)] + r_items += [expr for _ in range(len(query) - len(r_items))] # type: ignore [misc] + # sanity check: cannot match if the item sizes differ: + # either a reference item is unmatched (len(r_items) > len(query)) + # or a query item is unmatched (len(r_items) < len(query)) + if len(query) != len(r_items): + return False + + # To have a positive match between the query and the reference, + # each query expr has to match any reference expr. + # However, each reference expr can only be "consumed" once even + # if it matches multiple query exprs (e.g., the Any expression matches + # every query expr). + # This is a bipartide matching problem (Hall's marriage problem) + # and the Hopcroft-Karp-Karzanov algorithm finds a maximum + # matching. While there might be multiple maximum matchings, + # we only need to know whether (at least) one complete matching + # exists. The hopcroftkarp module provides this functionality. + # The HKK algorithm has worst-case complexity of O(|N|**2 * sqrt(|N|)) + # and we also need to compare expressions pairwise, hence O(|N|**2). + num_items = len(r_items) + graph = defaultdict(set) + # build the bipartide graph as {lhs: {rhs}, ...} + # lhs and rhs must be disjoint identifiers. + for (ridx, ref), (nidx, node) in product(enumerate(r_items), enumerate(query)): + # add edges for equal expressions + if cmp(node, ref): + graph[ridx].add(num_items + nidx) + + # maximum_matching returns the matches for all nodes in the graph + # ({ref_itm: node_itm}), hence a complete matching's size is + # the number of reference's child expressions. + return len(HopcroftKarp(graph).maximum_matching(keys_only=True)) == num_items + + +class Filter(): + """Compare a bsfs.query.ast.filter` query's structure to a reference ast. + + The reference ast may include `Rest`, `Partial`, or `Any` to account for irrelevant + or unknown ast pieces. + + This is only a structural comparison, not a semantic one. For example, the + two following queries are semantically identical, but structurally different, + and would therefore not match: + + >>> ast.filter.OneOf(ast.filter.Predicate(ns.bse.filename)) + >>> ast.filter.Predicate(ns.bse.filename) + + """ + + def __call__(self, query: ast.filter.FilterExpression, reference: ast.filter.FilterExpression) -> bool: + """Compare a *query* to a *reference* ast structure. + Return True if both are structurally equivalent. + """ + if not isinstance(query, ast.filter.FilterExpression): + raise errors.BackendError(f'expected filter expression, found {query}') + if not isinstance(reference, ast.filter.FilterExpression): + raise errors.BackendError(f'expected filter expression, found {reference}') + return self._parse_filter_expression(query, reference) + + def _parse_filter_expression( + self, + node: ast.filter.FilterExpression, + reference: ast.filter.FilterExpression, + ) -> bool: + """Route *node* to the handler of the respective FilterExpression subclass.""" + # generic checks: reference type must be Any or match node type + if isinstance(reference, Any): + return True + # node-specific checks + if isinstance(node, ast.filter.Not): + return self._not(node, reference) + if isinstance(node, ast.filter.Has): + return self._has(node, reference) + if isinstance(node, ast.filter.Distance): + return self._distance(node, reference) + if isinstance(node, (ast.filter.Any, ast.filter.All)): + return self._branch(node, reference) + if isinstance(node, (ast.filter.And, ast.filter.Or)): + return self._agg(node, reference) + if isinstance(node, (ast.filter.Is, ast.filter.Equals, ast.filter.Substring, + ast.filter.StartsWith, ast.filter.EndsWith)): + return self._value(node, reference) + if isinstance(node, (ast.filter.LessThan, ast.filter.GreaterThan)): + return self._bounded(node, reference) + # invalid node + raise errors.BackendError(f'expected filter expression, found {node}') + + def _parse_predicate_expression( + self, + node: ast.filter.PredicateExpression, + reference: ast.filter.PredicateExpression, + ) -> bool: + """Route *node* to the handler of the respective PredicateExpression subclass.""" + if isinstance(reference, Any): + return True + if isinstance(node, ast.filter.Predicate): + return self._predicate(node, reference) + if isinstance(node, ast.filter.OneOf): + return self._one_of(node, reference) + # invalid node + raise errors.BackendError(f'expected predicate expression, found {node}') + + def _one_of(self, node: ast.filter.OneOf, reference: ast.filter.PredicateExpression) -> bool: + if not isinstance(reference, type(node)): + return False + return _set_matcher(node, reference, self._parse_predicate_expression) + + def _predicate(self, node: ast.filter.Predicate, reference: ast.filter.PredicateExpression) -> bool: + if not isinstance(reference, (Partial, type(node))): + return False + # partial check + if isinstance(reference, Partial): + if not isinstance(node, reference.node): + return False + return reference.match('predicate', node.predicate) \ + and reference.match('reverse', node.reverse) + # full check + return node.predicate == reference.predicate \ + and node.reverse == reference.reverse + + def _branch(self, + node: typing.Union[ast.filter.Any, ast.filter.All], + reference: ast.filter.FilterExpression, + ) -> bool: + if not isinstance(reference, type(node)): + return False + if not self._parse_predicate_expression(node.predicate, reference.predicate): # type: ignore [attr-defined] + return False + if not self._parse_filter_expression(node.expr, reference.expr): # type: ignore [attr-defined] + return False + return True + + def _agg(self, node: typing.Union[ast.filter.And, ast.filter.Or], reference: ast.filter.FilterExpression) -> bool: + if not isinstance(reference, type(node)): + return False + return _set_matcher(node, reference, self._parse_filter_expression) # type: ignore [arg-type] + + def _not(self, node: ast.filter.Not, reference: ast.filter.FilterExpression) -> bool: + if not isinstance(reference, type(node)): + return False + return self._parse_filter_expression(node.expr, reference.expr) + + def _has(self, node: ast.filter.Has, reference: ast.filter.FilterExpression) -> bool: + if not isinstance(reference, type(node)): + return False + return self._parse_predicate_expression(node.predicate, reference.predicate) \ + and self._parse_filter_expression(node.count, reference.count) + + def _distance(self, node: ast.filter.Distance, reference: ast.filter.FilterExpression) -> bool: + if not isinstance(reference, (Partial, type(node))): + return False + # partial check + if isinstance(reference, Partial): + if not isinstance(node, reference.node): + return False + return reference.match('reference', node.reference) \ + and reference.match('threshold', node.threshold) \ + and reference.match('strict', node.strict) + # full check + return node.reference == reference.reference \ + and node.threshold == reference.threshold \ + and node.strict == reference.strict + + def _value(self, node: ast.filter._Value, reference: ast.filter.FilterExpression) -> bool: + if not isinstance(reference, (Partial, type(node))): + return False + # partial check + if isinstance(reference, Partial): + if not isinstance(node, reference.node): + return False + return reference.match('value', node.value) + # full ckeck + return node.value == reference.value + + def _bounded(self, node: ast.filter._Bounded, reference: ast.filter.FilterExpression) -> bool: + if not isinstance(reference, (Partial, type(node))): + return False + # partial check + if isinstance(reference, Partial): + if not isinstance(node, reference.node): + return False + return reference.match('threshold', node.threshold) \ + and reference.match('strict', node.strict) + # full check + return node.threshold == reference.threshold \ + and node.strict == reference.strict + +## EOF ## @@ -14,7 +14,10 @@ setup( url='https://www.igsor.net/projects/blackstar/bsfs/', download_url='https://pip.igsor.net', packages=('bsfs', ), - install_requires=('rdflib', ), + install_requires=( + 'rdflib', # schema and sparql storage + 'hopcroftkarp', # ast matching + ), python_requires=">=3.7", ) diff --git a/test/query/test_matcher.py b/test/query/test_matcher.py new file mode 100644 index 0000000..e830cf8 --- /dev/null +++ b/test/query/test_matcher.py @@ -0,0 +1,1182 @@ +""" + +Part of the tagit test suite. +A copy of the license is provided with the project. +Author: Matthias Baumgartner, 2022 +""" +# imports +import operator +import unittest + +# bsfs imports +from bsfs.namespace import ns +from bsfs.query import ast +from bsfs.utils import errors + +# objects to test +from bsfs.query.matcher import Any, Filter, Partial, Rest, _set_matcher + + +## code ## + +class TestAny(unittest.TestCase): + def test_essentials(self): + # comparison + a = Any() + b = Any() + self.assertNotEqual(Any(), Any()) + self.assertNotEqual(hash(Any()), hash(Any())) + self.assertNotEqual(a, Any()) + self.assertNotEqual(hash(a), hash(Any())) + self.assertNotEqual(a, b) + self.assertNotEqual(hash(a), hash(b)) + # comparison within sets + self.assertEqual(len({Any(), Any(), Any(), Any()}), 4) + self.assertEqual(len({Any() for _ in range(1000)}), 1000) + # string representation + self.assertEqual(str(Any()), 'Any()') + self.assertEqual(repr(Any()), 'Any()') + + +class TestRest(unittest.TestCase): + def test_essentials(self): + expr = ast.filter.Equals('hello') + # comparison + self.assertEqual(Rest(expr), Rest(expr)) + self.assertEqual(hash(Rest(expr)), hash(Rest(expr))) + # comparison respects type + class Foo(): pass + self.assertNotEqual(Rest(expr), 1234) + self.assertNotEqual(hash(Rest(expr)), hash(1234)) + self.assertNotEqual(Rest(expr), Foo()) + self.assertNotEqual(hash(Rest(expr)), hash(Foo())) + # comparison respects expr + self.assertNotEqual(Rest(expr), Rest(ast.filter.Equals('world'))) + self.assertNotEqual(hash(Rest(expr)), hash(Rest(ast.filter.Equals('world')))) + # default constructor -> Any -> Not equal + self.assertNotEqual(Rest(), Rest()) + self.assertNotEqual(hash(Rest()), hash(Rest())) + # string representation + self.assertEqual(str(Rest()), 'Rest(Any())') + self.assertEqual(str(Rest(expr)), 'Rest(Equals(hello))') + self.assertEqual(repr(Rest()), 'Rest(Any())') + self.assertEqual(repr(Rest(expr)), 'Rest(Equals(hello))') + + + +class TestPartial(unittest.TestCase): + def test_match(self): + p0 = Partial(ast.filter.LessThan) + p1 = Partial(ast.filter.LessThan, threshold=3) + p2 = Partial(ast.filter.LessThan, strict=False) + p3 = Partial(ast.filter.LessThan, threshold=3, strict=False) + # match respects name + self.assertTrue(p0.match('foo', None)) + self.assertTrue(p1.match('foo', None)) + self.assertTrue(p2.match('foo', None)) + self.assertTrue(p3.match('foo', None)) + # match respects correct value + self.assertTrue(p0.match('threshold', 3)) + self.assertTrue(p1.match('threshold', 3)) + self.assertTrue(p2.match('threshold', 3)) + self.assertTrue(p3.match('threshold', 3)) + self.assertTrue(p0.match('strict', False)) + self.assertTrue(p1.match('strict', False)) + self.assertTrue(p2.match('strict', False)) + self.assertTrue(p3.match('strict', False)) + # match respects incorrect value + self.assertTrue(p0.match('threshold', 5)) + self.assertFalse(p1.match('threshold', 5)) + self.assertTrue(p2.match('threshold', 5)) + self.assertFalse(p3.match('threshold', 5)) + self.assertTrue(p0.match('strict', True)) + self.assertTrue(p1.match('strict', True)) + self.assertFalse(p2.match('strict', True)) + self.assertFalse(p3.match('strict', True)) + + def test_members(self): + # node returns expression + self.assertEqual(Partial(ast.filter.Equals).node, ast.filter.Equals) + self.assertEqual(Partial(ast.filter.LessThan).node, ast.filter.LessThan) + # kwargs returns arguments + self.assertDictEqual(Partial(ast.filter.Equals, value='hello').kwargs, + {'value': 'hello'}) + self.assertDictEqual(Partial(ast.filter.LessThan, threshold=3, strict=False).kwargs, + {'threshold': 3, 'strict': False}) + # Partial does not check about kwargs + self.assertDictEqual(Partial(ast.filter.LessThan, value='hello').kwargs, + {'value': 'hello'}) + self.assertDictEqual(Partial(ast.filter.Equals, threshold=3, strict=False).kwargs, + {'threshold': 3, 'strict': False}) + + def test_essentials(self): + # comparison respects type + class Foo(): pass + self.assertNotEqual(Partial(ast.filter.Equals), 1234) + self.assertNotEqual(hash(Partial(ast.filter.Equals)), hash(1234)) + self.assertNotEqual(Partial(ast.filter.Equals), Foo()) + self.assertNotEqual(hash(Partial(ast.filter.Equals)), hash(Foo())) + self.assertNotEqual(Partial(ast.filter.Equals), ast.filter.Equals) + self.assertNotEqual(hash(Partial(ast.filter.Equals)), hash(ast.filter.Equals)) + self.assertNotEqual(Partial(ast.filter.Equals), ast.filter.Equals('hello')) + self.assertNotEqual(hash(Partial(ast.filter.Equals)), hash(ast.filter.Equals('hello'))) + # comparison respects node + self.assertEqual(Partial(ast.filter.Equals), Partial(ast.filter.Equals)) + self.assertEqual(hash(Partial(ast.filter.Equals)), hash(Partial(ast.filter.Equals))) + self.assertEqual(Partial(ast.filter.LessThan), Partial(ast.filter.LessThan)) + self.assertEqual(hash(Partial(ast.filter.LessThan)), hash(Partial(ast.filter.LessThan))) + self.assertNotEqual(Partial(ast.filter.Equals), Partial(ast.filter.LessThan)) + self.assertNotEqual(hash(Partial(ast.filter.Equals)), hash(Partial(ast.filter.LessThan))) + # comparison respects kwargs + self.assertEqual( + Partial(ast.filter.Equals, value='hello'), + Partial(ast.filter.Equals, value='hello')) + self.assertEqual( + hash(Partial(ast.filter.Equals, value='hello')), + hash(Partial(ast.filter.Equals, value='hello'))) + self.assertEqual( + Partial(ast.filter.LessThan, threshold=3, strict=False), + Partial(ast.filter.LessThan, threshold=3, strict=False)) + self.assertEqual( + hash(Partial(ast.filter.LessThan, threshold=3, strict=False)), + hash(Partial(ast.filter.LessThan, threshold=3, strict=False))) + self.assertNotEqual( + Partial(ast.filter.Equals, value='hello'), + Partial(ast.filter.Equals)) + self.assertNotEqual( + hash(Partial(ast.filter.Equals, value='hello')), + hash(Partial(ast.filter.Equals))) + self.assertNotEqual( + Partial(ast.filter.Equals, value='hello'), + Partial(ast.filter.Equals, value='world')) + self.assertNotEqual( + hash(Partial(ast.filter.Equals, value='hello')), + hash(Partial(ast.filter.Equals, value='world'))) + self.assertNotEqual( + Partial(ast.filter.LessThan, threshold=3, strict=False), + Partial(ast.filter.LessThan)) + self.assertNotEqual( + hash(Partial(ast.filter.LessThan, threshold=3, strict=False)), + hash(Partial(ast.filter.LessThan))) + self.assertNotEqual( + Partial(ast.filter.LessThan, threshold=3, strict=False), + Partial(ast.filter.LessThan, threshold=5)) + self.assertNotEqual( + hash(Partial(ast.filter.LessThan, threshold=3, strict=False)), + hash(Partial(ast.filter.LessThan, threshold=5))) + self.assertNotEqual( + Partial(ast.filter.LessThan, threshold=3, strict=False), + Partial(ast.filter.LessThan, strict=False)) + self.assertNotEqual( + hash(Partial(ast.filter.LessThan, threshold=3, strict=False)), + hash(Partial(ast.filter.LessThan, strict=False))) + self.assertNotEqual( + Partial(ast.filter.LessThan, threshold=3, strict=False), + Partial(ast.filter.LessThan, threshold=3, strict=True)) + self.assertNotEqual( + hash(Partial(ast.filter.LessThan, threshold=3, strict=False)), + hash(Partial(ast.filter.LessThan, threshold=3, strict=True))) + self.assertNotEqual( + Partial(ast.filter.LessThan, threshold=3, strict=False), + Partial(ast.filter.LessThan, threshold=5, strict=False)) + self.assertNotEqual( + hash(Partial(ast.filter.LessThan, threshold=3, strict=False)), + hash(Partial(ast.filter.LessThan, threshold=5, strict=False))) + # string representation + self.assertEqual(str(Partial(ast.filter.Equals)), 'Partial(Equals, {})') + self.assertEqual(repr(Partial(ast.filter.Equals)), 'Partial(Equals, {})') + self.assertEqual(str(Partial(ast.filter.LessThan)), 'Partial(LessThan, {})') + self.assertEqual(repr(Partial(ast.filter.LessThan)), 'Partial(LessThan, {})') + self.assertEqual(str(Partial(ast.filter.Equals, value='hello')), "Partial(Equals, {'value': 'hello'})") + self.assertEqual(repr(Partial(ast.filter.Equals, value='hello')), "Partial(Equals, {'value': 'hello'})") + self.assertEqual(str(Partial(ast.filter.LessThan, threshold=3)), "Partial(LessThan, {'threshold': 3})") + self.assertEqual(repr(Partial(ast.filter.LessThan, threshold=3)), "Partial(LessThan, {'threshold': 3})") + self.assertEqual(str(Partial(ast.filter.LessThan, strict=False)), "Partial(LessThan, {'strict': False})") + self.assertEqual(repr(Partial(ast.filter.LessThan, strict=False)), "Partial(LessThan, {'strict': False})") + self.assertEqual(str(Partial(ast.filter.LessThan, threshold=3, strict=False)), "Partial(LessThan, {'threshold': 3, 'strict': False})") + self.assertEqual(repr(Partial(ast.filter.LessThan, threshold=3, strict=False)), "Partial(LessThan, {'threshold': 3, 'strict': False})") + + +class TestSetMatcher(unittest.TestCase): + def test_set_matcher(self): + # setup + A = ast.filter.Equals('A') + B = ast.filter.Equals('B') + C = ast.filter.Equals('C') + D = ast.filter.Equals('D') + matcher = Filter() + + # identical sets match + self.assertTrue(_set_matcher( + ast.filter.And(A, B, C), + ast.filter.And(A, B, C), + matcher._parse_filter_expression, + )) + + # order is irrelevant + self.assertTrue(_set_matcher( + ast.filter.And(A, B, C), + ast.filter.And(B, C, A), + matcher._parse_filter_expression, + )) + + # all reference items must be present + self.assertFalse(_set_matcher( + ast.filter.And(A, B), + ast.filter.And(A, B, C), + matcher._parse_filter_expression, + )) + + # all reference items must have a match + self.assertFalse(_set_matcher( + ast.filter.And(A, B, C), + ast.filter.And(D, B, C), + matcher._parse_filter_expression, + )) + self.assertFalse(_set_matcher( + ast.filter.And(A, B, C), + ast.filter.And(A, D, C), + matcher._parse_filter_expression, + )) + self.assertFalse(_set_matcher( + ast.filter.And(A, B, C), + ast.filter.And(A, B, D), + matcher._parse_filter_expression, + )) + + # Any matches every item + self.assertTrue(_set_matcher( + ast.filter.And(A, B, C), + ast.filter.And(Any(), B, C), + matcher._parse_filter_expression, + )) + self.assertTrue(_set_matcher( + ast.filter.And(A, B, C), + ast.filter.And(A, Any(), C), + matcher._parse_filter_expression, + )) + self.assertTrue(_set_matcher( + ast.filter.And(A, B, C), + ast.filter.And(A, B, Any()), + matcher._parse_filter_expression, + )) + self.assertTrue(_set_matcher( + ast.filter.And(A, B, D), + ast.filter.And(A, B, Any()), + matcher._parse_filter_expression, + )) + + # there can be multiple Any's + self.assertTrue(_set_matcher( + ast.filter.And(A, B, C), + ast.filter.And(A, Any(), Any()), + matcher._parse_filter_expression, + )) + self.assertTrue(_set_matcher( + ast.filter.And(A, B, C), + ast.filter.And(Any(), B, Any()), + matcher._parse_filter_expression, + )) + self.assertTrue(_set_matcher( + ast.filter.And(A, B, C), + ast.filter.And(Any(), Any(), C), + matcher._parse_filter_expression, + )) + + # Any covers exactly one element + self.assertTrue(_set_matcher( + ast.filter.And(A, B, C), + ast.filter.And(A, B, Any()), + matcher._parse_filter_expression, + )) + self.assertTrue(_set_matcher( + ast.filter.And(A, B, D), + ast.filter.And(A, B, Any()), + matcher._parse_filter_expression, + )) + self.assertFalse(_set_matcher( + ast.filter.And(A, B), + ast.filter.And(A, B, Any()), + matcher._parse_filter_expression, + )) + self.assertFalse(_set_matcher( + ast.filter.And(A, B, C, D), + ast.filter.And(A, B, Any()), + matcher._parse_filter_expression, + )) + + # each Any covers exactly one element + self.assertTrue(_set_matcher( + ast.filter.And(A, B, C), + ast.filter.And(Any(), Any(), Any()), + matcher._parse_filter_expression, + )) + self.assertFalse(_set_matcher( + ast.filter.And(A, B, C), + ast.filter.And(Any(), Any()), + matcher._parse_filter_expression, + )) + self.assertFalse(_set_matcher( + ast.filter.And(A, B), + ast.filter.And(Any(), Any(), Any()), + matcher._parse_filter_expression, + )) + + # Rest captures remainder + self.assertTrue(_set_matcher( + ast.filter.And(A, B, C), + ast.filter.And(A, B, Rest()), + matcher._parse_filter_expression, + )) + self.assertTrue(_set_matcher( + ast.filter.And(A, B, C, D), + ast.filter.And(A, B, Rest()), + matcher._parse_filter_expression, + )) + # remainder matches the empty set + self.assertTrue(_set_matcher( + ast.filter.And(A, B), + ast.filter.And(A, B, Rest()), + matcher._parse_filter_expression, + )) + # Rest does not absolve other refernce items from having a match + self.assertFalse(_set_matcher( + ast.filter.And(A, C, D), + ast.filter.And(A, B, Rest()), + matcher._parse_filter_expression, + )) + # Rest can be combined with Any ... + self.assertTrue(_set_matcher( + ast.filter.And(A, C, D), + ast.filter.And(A, Any(), Rest()), + matcher._parse_filter_expression, + )) + self.assertTrue(_set_matcher( + ast.filter.And(A, C, D), + ast.filter.And(A, Any(), Rest()), + matcher._parse_filter_expression, + )) + # ... explicit items still need to match + self.assertFalse(_set_matcher( + ast.filter.And(A, C, D), + ast.filter.And(B, Any(), Rest()), + matcher._parse_filter_expression, + )) + # ... Any still determines minimum element count + self.assertTrue(_set_matcher( + ast.filter.And(A, B), + ast.filter.And(A, Any(), Rest()), + matcher._parse_filter_expression, + )) + self.assertFalse(_set_matcher( + ast.filter.And(A, B), + ast.filter.And(A, Any(), Any(), Rest()), + matcher._parse_filter_expression, + )) + # Rest cannot be repeated ... + self.assertRaises(errors.BackendError, _set_matcher, + ast.filter.And(A, B, C), + ast.filter.And(A, Rest(), Rest(ast.filter.Equals('hello'))), + matcher._parse_filter_expression, + ) + # ... unless they are identical + self.assertRaises(errors.BackendError, _set_matcher, + ast.filter.And(A, B, C), + ast.filter.And(A, Rest(), Rest()), # Any instances are different! + matcher._parse_filter_expression, + ) + # ... unless they are identical + self.assertTrue(_set_matcher( + ast.filter.And(A, B, C), + ast.filter.And(A, B, Rest(C), Rest(C)), + matcher._parse_filter_expression, + )) + # Rest can mandate a specific expression + self.assertTrue(_set_matcher( + ast.filter.And(A, B, C), + ast.filter.And(A, B, Rest(C)), + matcher._parse_filter_expression, + )) + self.assertFalse(_set_matcher( + ast.filter.And(A, B, C), + ast.filter.And(A, B, Rest(D)), + matcher._parse_filter_expression, + )) + # Rest can mandate a partial expression + self.assertTrue(_set_matcher( + ast.filter.And(A, B, C), + ast.filter.And(A, B, Rest(Partial(ast.filter.Equals))), + matcher._parse_filter_expression, + )) + self.assertTrue(_set_matcher( + ast.filter.And(A, B, C), + ast.filter.And(A, Rest(Partial(ast.filter.Equals))), + matcher._parse_filter_expression, + )) + self.assertFalse(_set_matcher( + ast.filter.And(A, B, C), + ast.filter.And(A, B, Rest(Partial(ast.filter.Substring))), + matcher._parse_filter_expression, + )) + self.assertFalse(_set_matcher( + ast.filter.And(A, B, C), + ast.filter.And(A, B, Rest(Partial(ast.filter.Equals, value='D'))), + matcher._parse_filter_expression, + )) + # Rest can be the only expression + self.assertTrue(_set_matcher( + ast.filter.And(A, B, C), + ast.filter.And(Rest(Partial(ast.filter.Equals))), + matcher._parse_filter_expression, + )) + # Rest's expression defaults to Any + self.assertTrue(_set_matcher( + ast.filter.And(A, B, C), + ast.filter.And(Rest()), + matcher._parse_filter_expression, + )) + + +class TestFilter(unittest.TestCase): + def setUp(self): + self.match = Filter() + + def test_call(self): + # query must be a filter expression + self.assertRaises(errors.BackendError, self.match, 1234, Any()) + self.assertRaises(errors.BackendError, self.match, ast.filter.Predicate(ns.bse.filename), Any()) + # reference must be a filter expression + self.assertRaises(errors.BackendError, self.match, ast.filter.Equals('hello'), 1234) + self.assertRaises(errors.BackendError, self.match, ast.filter.Equals('hello'), ast.filter.Predicate(ns.bse.filename)) + # reference can be Any or Partial + self.assertTrue(self.match( + ast.filter.Equals('hello'), + Any(), + )) + self.assertTrue(self.match( + ast.filter.Equals('hello'), + Partial(ast.filter.Equals), + )) + # call parses expression + self.assertTrue(self.match( + # query + ast.filter.And( + ast.filter.Any(ns.bse.tag, + ast.filter.All(ns.bse.label, + ast.filter.Or( + ast.filter.Equals('hello'), + ast.filter.Equals('world'), + ast.filter.StartsWith('foo'), + ast.filter.EndsWith('bar'), + ) + ) + ), + ast.filter.Any(ns.bse.iso, + ast.filter.And( + ast.filter.GreaterThan(100, strict=True), + ast.filter.LessThan(200, strict=False), + ) + ), + ast.filter.Any(ast.filter.OneOf(ns.bse.featureA, ns.bse.featureB), + ast.filter.Distance([1,2,3], 1) + ), + ), + # reference + ast.filter.And( + ast.filter.Any(Any(), + ast.filter.All(Partial(ast.filter.Predicate, reverse=False), + ast.filter.Or( + Partial(ast.filter.StartsWith), + ast.filter.EndsWith('bar'), + Rest(Partial(ast.filter.Equals)), + ) + ) + ), + ast.filter.Any(ns.bse.iso, + ast.filter.And( + Partial(ast.filter.GreaterThan, strict=True), + Any(), + Rest(), + ) + ), + ast.filter.Any(ast.filter.OneOf(Rest()), + Partial(ast.filter.Distance) + ), + ), + )) + self.assertFalse(self.match( + # query + ast.filter.Any(ns.bse.tag, + ast.filter.And( + ast.filter.Any(ns.bse.label, ast.filter.Equals('hello')), + ast.filter.Any(ns.bse.collection, ast.filter.Is('http://example.com/col#123')), + ast.filter.Not(ast.filter.Has(ns.bse.label)), + ) + ), + # reference + ast.filter.Any(ns.bse.tag, + ast.filter.And( + Any(), + ast.filter.Any(Partial(ast.filter.Predicate, reverse=True), # reverse mismatch + Partial(ast.filter.Is)), + ast.filter.Not(ast.filter.Has(Any(), Any())), + ) + ) + )) + + def test_parse_filter_expression(self): + # Any matches every filter expression + self.assertTrue(self.match._parse_filter_expression( + ast.filter.Not(ast.filter.FilterExpression()), Any())) + self.assertTrue(self.match._parse_filter_expression( + ast.filter.Has(ns.bse.filename), Any())) + self.assertTrue(self.match._parse_filter_expression( + ast.filter.Distance([1,2,3], 1.0), Any())) + self.assertTrue(self.match._parse_filter_expression( + ast.filter.Any(ns.bse.filename, ast.filter.Equals('hello')), Any())) + self.assertTrue(self.match._parse_filter_expression( + ast.filter.All(ns.bse.filename, ast.filter.Equals('hello')), Any())) + self.assertTrue(self.match._parse_filter_expression( + ast.filter.And(ast.filter.Equals('hello')), Any())) + self.assertTrue(self.match._parse_filter_expression( + ast.filter.Or(ast.filter.Equals('hello')), Any())) + self.assertTrue(self.match._parse_filter_expression( + ast.filter.Equals('hello'), Any())) + self.assertTrue(self.match._parse_filter_expression( + ast.filter.Substring('hello'), Any())) + self.assertTrue(self.match._parse_filter_expression( + ast.filter.StartsWith('hello'), Any())) + self.assertTrue(self.match._parse_filter_expression( + ast.filter.EndsWith('hello'), Any())) + self.assertTrue(self.match._parse_filter_expression( + ast.filter.Is('hello'), Any())) + self.assertTrue(self.match._parse_filter_expression( + ast.filter.LessThan(3), Any())) + self.assertTrue(self.match._parse_filter_expression( + ast.filter.GreaterThan(3), Any())) + # Any matches invalid filter expressions + self.assertTrue(self.match._parse_filter_expression( + ast.filter.FilterExpression(), Any())) + # node must be an appropriate filter expression + self.assertRaises(errors.BackendError, self.match._parse_filter_expression, + ast.filter.FilterExpression(), ast.filter.FilterExpression()) + self.assertRaises(errors.BackendError, self.match._parse_filter_expression, + 1234, ast.filter.FilterExpression()) + + def test_parse_predicate_expression(self): + # Any matches every predicate expression + self.assertTrue(self.match._parse_predicate_expression( + ast.filter.Predicate(ns.bse.filename), Any())) + self.assertTrue(self.match._parse_predicate_expression( + ast.filter.OneOf(ns.bse.filename), Any())) + # Any matches invalid predicate expression + self.assertTrue(self.match._parse_predicate_expression( + ast.filter.FilterExpression(), Any())) + # node must be an appropriate predicate expression + self.assertRaises(errors.BackendError, self.match._parse_predicate_expression, + ast.filter.PredicateExpression(), ast.filter.PredicateExpression()) + self.assertRaises(errors.BackendError, self.match._parse_predicate_expression, + 1234, ast.filter.PredicateExpression()) + + def test_predicate(self): + # identical expressions match + self.assertTrue(self.match._predicate( + ast.filter.Predicate(ns.bse.filename, reverse=False), + ast.filter.Predicate(ns.bse.filename, reverse=False), + )) + # _predicate respects type + self.assertFalse(self.match._predicate( + ast.filter.Predicate(ns.bse.filename, reverse=False), + ast.filter.FilterExpression(), + )) + # _predicate respects predicate + self.assertFalse(self.match._predicate( + ast.filter.Predicate(ns.bse.filename, reverse=False), + ast.filter.Predicate(ns.bse.filesize, reverse=False), + )) + # _predicate respects reverse + self.assertFalse(self.match._predicate( + ast.filter.Predicate(ns.bse.filename, reverse=False), + ast.filter.Predicate(ns.bse.filename, reverse=True), + )) + # Partial requires ast.filter.Predicate + self.assertFalse(self.match._predicate( + ast.filter.Predicate(ns.bse.filename, reverse=False), + Partial(ast.filter.Equals), + )) + # predicate and reverse can be specified + self.assertTrue(self.match._predicate( + ast.filter.Predicate(ns.bse.filename, reverse=False), + Partial(ast.filter.Predicate, predicate=ns.bse.filename, reverse=False), + )) + self.assertFalse(self.match._predicate( + ast.filter.Predicate(ns.bse.filename, reverse=False), + Partial(ast.filter.Predicate, predicate=ns.bse.filesize, reverse=False), + )) + self.assertFalse(self.match._predicate( + ast.filter.Predicate(ns.bse.filename, reverse=False), + Partial(ast.filter.Predicate, predicate=ns.bse.filename, reverse=True), + )) + # predicate can remain unspecified + self.assertTrue(self.match._predicate( + ast.filter.Predicate(ns.bse.filename, reverse=False), + Partial(ast.filter.Predicate, reverse=False), + )) + self.assertTrue(self.match._predicate( + ast.filter.Predicate(ns.bse.filesize, reverse=False), + Partial(ast.filter.Predicate, reverse=False), + )) + self.assertFalse(self.match._predicate( + ast.filter.Predicate(ns.bse.filesize, reverse=False), + Partial(ast.filter.Predicate, reverse=True), + )) + # reverse can remain unspecified + self.assertTrue(self.match._predicate( + ast.filter.Predicate(ns.bse.filename, reverse=False), + Partial(ast.filter.Predicate, predicate=ns.bse.filename), + )) + self.assertTrue(self.match._predicate( + ast.filter.Predicate(ns.bse.filename, reverse=True), + Partial(ast.filter.Predicate, predicate=ns.bse.filename), + )) + self.assertFalse(self.match._predicate( + ast.filter.Predicate(ns.bse.filename, reverse=False), + Partial(ast.filter.Predicate, predicate=ns.bse.filesize), + )) + + def test_one_of(self): + A = ast.filter.Predicate(ns.bse.filename) + B = ast.filter.Predicate(ns.bse.filesize) + C = ast.filter.Predicate(ns.bse.filename, reverse=True) + # identical expressions match + self.assertTrue(self.match._one_of( + ast.filter.OneOf(A, B), + ast.filter.OneOf(A, B), + )) + # _one_of respects type + self.assertFalse(self.match._one_of( + ast.filter.OneOf(A, B), + ast.filter.Predicate(ns.bse.filesize, reverse=True), + )) + # _one_of respects child expressions + self.assertFalse(self.match._one_of( + ast.filter.OneOf(A, B), + ast.filter.OneOf(A, C), + )) + self.assertFalse(self.match._one_of( + ast.filter.OneOf(A, B), + ast.filter.OneOf(A), + )) + self.assertFalse(self.match._one_of( + ast.filter.OneOf(A, B), + ast.filter.OneOf(A, B, C), + )) + self.assertTrue(self.match._one_of( + ast.filter.OneOf(A, B), + ast.filter.OneOf(B, A), + )) + self.assertTrue(self.match._one_of( + ast.filter.OneOf(A, B), + ast.filter.OneOf(A, Any()), + )) + self.assertTrue(self.match._one_of( + ast.filter.OneOf(A, B), + ast.filter.OneOf(B, Rest()), + )) + + def test_branch(self): + # identical expressions match + self.assertTrue(self.match._branch( + ast.filter.Any(ns.bse.filename, ast.filter.Equals('hello')), + ast.filter.Any(ns.bse.filename, ast.filter.Equals('hello')), + )) + self.assertTrue(self.match._branch( + ast.filter.All(ns.bse.filename, ast.filter.Equals('hello')), + ast.filter.All(ns.bse.filename, ast.filter.Equals('hello')), + )) + # _agg respects type + self.assertFalse(self.match._branch( + ast.filter.Any(ns.bse.filename, ast.filter.Equals('hello')), + ast.filter.All(ns.bse.filename, ast.filter.Equals('hello')), + )) + self.assertFalse(self.match._branch( + ast.filter.All(ns.bse.filename, ast.filter.Equals('hello')), + ast.filter.Any(ns.bse.filename, ast.filter.Equals('hello')), + )) + self.assertFalse(self.match._branch( + ast.filter.Any(ns.bse.filename, ast.filter.Equals('hello')), + ast.filter.Equals('hello'), + )) + self.assertFalse(self.match._branch( + ast.filter.All(ns.bse.filename, ast.filter.Equals('hello')), + ast.filter.Equals('hello'), + )) + # _agg respects predicate expression + self.assertTrue(self.match._branch( + ast.filter.Any(ns.bse.filename, ast.filter.Equals('hello')), + ast.filter.Any(ast.filter.Predicate(ns.bse.filename), ast.filter.Equals('hello')), + )) + self.assertTrue(self.match._branch( + ast.filter.All(ns.bse.filename, ast.filter.Equals('hello')), + ast.filter.All(ast.filter.Predicate(ns.bse.filename), ast.filter.Equals('hello')), + )) + self.assertFalse(self.match._branch( + ast.filter.Any(ns.bse.filename, ast.filter.Equals('hello')), + ast.filter.Any(ns.bse.filesize, ast.filter.Equals('hello')), + )) + self.assertFalse(self.match._branch( + ast.filter.All(ns.bse.filename, ast.filter.Equals('hello')), + ast.filter.All(ns.bse.filesize, ast.filter.Equals('hello')), + )) + self.assertFalse(self.match._branch( + ast.filter.Any(ns.bse.filename, ast.filter.Equals('hello')), + ast.filter.Any(ast.filter.OneOf(ns.bse.filename), ast.filter.Equals('hello')), + )) + self.assertFalse(self.match._branch( + ast.filter.All(ns.bse.filename, ast.filter.Equals('hello')), + ast.filter.All(ast.filter.OneOf(ns.bse.filename), ast.filter.Equals('hello')), + )) + self.assertFalse(self.match._branch( + ast.filter.Any(ns.bse.filename, ast.filter.Equals('hello')), + ast.filter.Any(ast.filter.Predicate(ns.bse.filename, reverse=True), ast.filter.Equals('hello')), + )) + self.assertFalse(self.match._branch( + ast.filter.All(ns.bse.filename, ast.filter.Equals('hello')), + ast.filter.All(ast.filter.Predicate(ns.bse.filename, reverse=True), ast.filter.Equals('hello')), + )) + self.assertTrue(self.match._branch( + ast.filter.Any(ns.bse.filename, ast.filter.Equals('hello')), + ast.filter.Any(Any(), ast.filter.Equals('hello')), + )) + self.assertTrue(self.match._branch( + ast.filter.All(ns.bse.filename, ast.filter.Equals('hello')), + ast.filter.All(Any(), ast.filter.Equals('hello')), + )) + self.assertTrue(self.match._branch( + ast.filter.Any(ns.bse.filename, ast.filter.Equals('hello')), + ast.filter.Any(Partial(ast.filter.Predicate), ast.filter.Equals('hello')), + )) + self.assertTrue(self.match._branch( + ast.filter.All(ns.bse.filename, ast.filter.Equals('hello')), + ast.filter.All(Partial(ast.filter.Predicate), ast.filter.Equals('hello')), + )) + # _agg respects filter expression + self.assertFalse(self.match._branch( + ast.filter.Any(ns.bse.filename, ast.filter.Equals('hello')), + ast.filter.Any(ns.bse.filename, ast.filter.Substring('hello')), + )) + self.assertFalse(self.match._branch( + ast.filter.All(ns.bse.filename, ast.filter.Equals('hello')), + ast.filter.All(ns.bse.filename, ast.filter.Substring('hello')), + )) + self.assertFalse(self.match._branch( + ast.filter.Any(ns.bse.filename, ast.filter.Equals('hello')), + ast.filter.Any(ns.bse.filename, ast.filter.Any(Any(), Any())), + )) + self.assertFalse(self.match._branch( + ast.filter.All(ns.bse.filename, ast.filter.Equals('hello')), + ast.filter.All(ns.bse.filename, ast.filter.All(Any(), Any())), + )) + self.assertTrue(self.match._branch( + ast.filter.Any(ns.bse.filename, ast.filter.Equals('hello')), + ast.filter.Any(ns.bse.filename, Any()), + )) + self.assertTrue(self.match._branch( + ast.filter.All(ns.bse.filename, ast.filter.Equals('hello')), + ast.filter.All(ns.bse.filename, Any()), + )) + self.assertTrue(self.match._branch( + ast.filter.Any(ns.bse.filename, ast.filter.Equals('hello')), + ast.filter.Any(ns.bse.filename, Partial(ast.filter.Equals)), + )) + self.assertTrue(self.match._branch( + ast.filter.All(ns.bse.filename, ast.filter.Equals('hello')), + ast.filter.All(ns.bse.filename, Partial(ast.filter.Equals)), + )) + self.assertFalse(self.match._branch( + ast.filter.Any(ns.bse.filename, ast.filter.Equals('hello')), + ast.filter.Any(ns.bse.filename, Partial(ast.filter.Equals, value='world')), + )) + self.assertFalse(self.match._branch( + ast.filter.All(ns.bse.filename, ast.filter.Equals('hello')), + ast.filter.All(ns.bse.filename, Partial(ast.filter.Equals, value='world')), + )) + + def test_agg(self): + A = ast.filter.Equals('hello') + B = ast.filter.Equals('world') + C = ast.filter.Equals('foobar') + # identical expressions match + self.assertTrue(self.match._agg( + ast.filter.And(A, B), + ast.filter.And(A, B), + )) + self.assertTrue(self.match._agg( + ast.filter.Or(A, B), + ast.filter.Or(A, B), + )) + # _agg respects type + self.assertFalse(self.match._agg( + ast.filter.And(A, B), + ast.filter.Or(A, B), + )) + self.assertFalse(self.match._agg( + ast.filter.Or(A, B), + ast.filter.And(A, B), + )) + self.assertFalse(self.match._agg( + ast.filter.And(A, B), + ast.filter.Equals('hello'), + )) + self.assertFalse(self.match._agg( + ast.filter.Or(A, B), + ast.filter.Equals('hello'), + )) + # _agg respects child expressions + self.assertFalse(self.match._agg( + ast.filter.And(A, B), + ast.filter.And(A, ast.filter.Equals('bar')), + )) + self.assertFalse(self.match._agg( + ast.filter.Or(A, B), + ast.filter.Or(A, ast.filter.Equals('bar')), + )) + self.assertFalse(self.match._agg( + ast.filter.And(A, B), + ast.filter.And(A), + )) + self.assertFalse(self.match._agg( + ast.filter.Or(A, B), + ast.filter.Or(A), + )) + self.assertFalse(self.match._agg( + ast.filter.And(A, B), + ast.filter.And(A, B, C), + )) + self.assertFalse(self.match._agg( + ast.filter.Or(A, B), + ast.filter.Or(A, B, C), + )) + self.assertTrue(self.match._agg( + ast.filter.And(A, B), + ast.filter.And(B, A), + )) + self.assertTrue(self.match._agg( + ast.filter.Or(A, B), + ast.filter.Or(B, A), + )) + self.assertTrue(self.match._agg( + ast.filter.And(A, B), + ast.filter.And(A, Any()), + )) + self.assertTrue(self.match._agg( + ast.filter.Or(A, B), + ast.filter.Or(A, Any()), + )) + self.assertTrue(self.match._agg( + ast.filter.And(A, B), + ast.filter.And(B, Rest()), + )) + self.assertTrue(self.match._agg( + ast.filter.Or(A, B), + ast.filter.Or(B, Rest()), + )) + + def test_not(self): + # identical expressions match + self.assertTrue(self.match._not( + ast.filter.Not(ast.filter.Equals('hello')), + ast.filter.Not(ast.filter.Equals('hello')), + )) + # _not respects type + self.assertFalse(self.match._not( + ast.filter.Not(ast.filter.Equals('hello')), + ast.filter.Equals('hello'), + )) + # _not respects child expression + self.assertFalse(self.match._not( + ast.filter.Not(ast.filter.Equals('hello')), + ast.filter.Not(ast.filter.Equals('world')), + )) + self.assertFalse(self.match._not( + ast.filter.Not(ast.filter.Equals('hello')), + ast.filter.Not(ast.filter.Substring('hello')), + )) + self.assertTrue(self.match._not( + ast.filter.Not(ast.filter.Equals('hello')), + ast.filter.Not(Any()), + )) + + def test_has(self): + # identical expressions match + self.assertTrue(self.match._has( + ast.filter.Has(ns.bse.filesize), + ast.filter.Has(ns.bse.filesize), + )) + self.assertTrue(self.match._has( + ast.filter.Has(ns.bse.filesize, ast.filter.LessThan(3)), + ast.filter.Has(ns.bse.filesize, ast.filter.LessThan(3)), + )) + # _has respects type + self.assertFalse(self.match._has( + ast.filter.Has(ns.bse.filesize), + ast.filter.Equals('hello'), + )) + self.assertFalse(self.match._has( + ast.filter.Has(ns.bse.filesize), + ast.filter.Equals('hello'), + )) + # _has respects predicate + self.assertFalse(self.match._has( + ast.filter.Has(ns.bse.filesize, ast.filter.LessThan(3)), + ast.filter.Has(ns.bse.iso, ast.filter.LessThan(3)), + )) + self.assertTrue(self.match._has( + ast.filter.Has(ns.bse.filesize, ast.filter.LessThan(3)), + ast.filter.Has(Any(), ast.filter.LessThan(3)), + )) + self.assertTrue(self.match._has( + ast.filter.Has(ns.bse.filesize, ast.filter.LessThan(3)), + ast.filter.Has(Partial(ast.filter.Predicate), ast.filter.LessThan(3)), + )) + # _has respects count + self.assertFalse(self.match._has( + ast.filter.Has(ns.bse.filesize, ast.filter.LessThan(3)), + ast.filter.Has(ns.bse.filesize, ast.filter.GreaterThan(3)), + )) + self.assertFalse(self.match._has( + ast.filter.Has(ns.bse.filesize, ast.filter.LessThan(3)), + ast.filter.Has(ns.bse.filesize, ast.filter.LessThan(5)), + )) + self.assertTrue(self.match._has( + ast.filter.Has(ns.bse.filesize, ast.filter.LessThan(3)), + ast.filter.Has(ns.bse.filesize, Any()), + )) + self.assertTrue(self.match._has( + ast.filter.Has(ns.bse.filesize, ast.filter.LessThan(3)), + ast.filter.Has(ns.bse.filesize, Partial(ast.filter.LessThan)), + )) + + def test_distance(self): + # identical expressions match + self.assertTrue(self.match._distance( + ast.filter.Distance([1,2,3], 5, True), + ast.filter.Distance([1,2,3], 5, True), + )) + # _distance respects type + self.assertFalse(self.match._distance( + ast.filter.Distance([1,2,3], 5, True), + ast.filter.Equals('hello'), + )) + self.assertFalse(self.match._distance( + ast.filter.Distance([1,2,3], 5, True), + Partial(ast.filter.Equals), + )) + # _distance respects reference value + self.assertFalse(self.match._distance( + ast.filter.Distance([1,2,3], 5, True), + ast.filter.Distance([3,2,1], 5, True), + )) + self.assertTrue(self.match._distance( + ast.filter.Distance([1,2,3], 5, True), + Partial(ast.filter.Distance, threshold=5, strict=True), + )) + self.assertTrue(self.match._distance( + ast.filter.Distance([1,2,3], 5, True), + Partial(ast.filter.Distance, reference=[1,2,3], threshold=5, strict=True), + )) + self.assertFalse(self.match._distance( + ast.filter.Distance([1,2,3], 5, True), + Partial(ast.filter.Distance, reference=[3,2,1], threshold=5, strict=True), + )) + # _distance respects threshold + self.assertFalse(self.match._distance( + ast.filter.Distance([1,2,3], 5, True), + ast.filter.Distance([1,2,3], 8, True), + )) + self.assertTrue(self.match._distance( + ast.filter.Distance([1,2,3], 5, True), + Partial(ast.filter.Distance, reference=[1,2,3], strict=True), + )) + self.assertTrue(self.match._distance( + ast.filter.Distance([1,2,3], 5, True), + Partial(ast.filter.Distance, reference=[1,2,3], threshold=5, strict=True), + )) + self.assertFalse(self.match._distance( + ast.filter.Distance([1,2,3], 5, True), + Partial(ast.filter.Distance, reference=[1,2,3], threshold=8, strict=True), + )) + # _distance respects strict + self.assertFalse(self.match._distance( + ast.filter.Distance([1,2,3], 5, True), + ast.filter.Distance([1,2,3], 5, False), + )) + self.assertTrue(self.match._distance( + ast.filter.Distance([1,2,3], 5, True), + Partial(ast.filter.Distance, reference=[1,2,3], threshold=5), + )) + self.assertTrue(self.match._distance( + ast.filter.Distance([1,2,3], 5, True), + Partial(ast.filter.Distance, reference=[1,2,3], threshold=5, strict=True), + )) + self.assertFalse(self.match._distance( + ast.filter.Distance([1,2,3], 5, True), + Partial(ast.filter.Distance, reference=[1,2,3], threshold=5, strict=False), + )) + + def test_value(self): + # identical expressions match + self.assertTrue(self.match._value(ast.filter.Equals('hello'), ast.filter.Equals('hello'))) + self.assertTrue(self.match._value(ast.filter.Substring('hello'), ast.filter.Substring('hello'))) + self.assertTrue(self.match._value(ast.filter.StartsWith('hello'), ast.filter.StartsWith('hello'))) + self.assertTrue(self.match._value(ast.filter.EndsWith('hello'), ast.filter.EndsWith('hello'))) + self.assertTrue(self.match._value(ast.filter.Is('hello'), ast.filter.Is('hello'))) + # _value respects type + self.assertFalse(self.match._value(ast.filter.Equals('hello'), ast.filter.Is('hello'))) + self.assertFalse(self.match._value(ast.filter.Substring('hello'), ast.filter.Is('hello'))) + self.assertFalse(self.match._value(ast.filter.StartsWith('hello'), ast.filter.Is('hello'))) + self.assertFalse(self.match._value(ast.filter.EndsWith('hello'), ast.filter.Is('hello'))) + self.assertFalse(self.match._value(ast.filter.Is('hello'), ast.filter.Equals('hello'))) + # _value respects value + self.assertFalse(self.match._value(ast.filter.Equals('hello'), ast.filter.Equals('world'))) + self.assertFalse(self.match._value(ast.filter.Substring('hello'), ast.filter.Substring('world'))) + self.assertFalse(self.match._value(ast.filter.StartsWith('hello'), ast.filter.StartsWith('world'))) + self.assertFalse(self.match._value(ast.filter.EndsWith('hello'), ast.filter.EndsWith('world'))) + self.assertFalse(self.match._value(ast.filter.Is('hello'), ast.filter.Is('world'))) + # Partial requires correct type + self.assertFalse(self.match._value(ast.filter.Equals('hello'), Partial(ast.filter.Is))) + self.assertFalse(self.match._value(ast.filter.Substring('hello'), Partial(ast.filter.Is))) + self.assertFalse(self.match._value(ast.filter.StartsWith('hello'), Partial(ast.filter.Is))) + self.assertFalse(self.match._value(ast.filter.EndsWith('hello'), Partial(ast.filter.Is))) + self.assertFalse(self.match._value(ast.filter.Is('hello'), Partial(ast.filter.Equals))) + # value can be specified + self.assertTrue(self.match._value(ast.filter.Equals('hello'), Partial(ast.filter.Equals, value='hello'))) + self.assertFalse(self.match._value(ast.filter.Equals('hello'), Partial(ast.filter.Equals, value='world'))) + self.assertTrue(self.match._value(ast.filter.Substring('hello'), Partial(ast.filter.Substring, value='hello'))) + self.assertFalse(self.match._value(ast.filter.Substring('hello'), Partial(ast.filter.Substring, value='world'))) + self.assertTrue(self.match._value(ast.filter.StartsWith('hello'), Partial(ast.filter.StartsWith, value='hello'))) + self.assertFalse(self.match._value(ast.filter.StartsWith('hello'), Partial(ast.filter.StartsWith, value='world'))) + self.assertTrue(self.match._value(ast.filter.EndsWith('hello'), Partial(ast.filter.EndsWith, value='hello'))) + self.assertFalse(self.match._value(ast.filter.EndsWith('hello'), Partial(ast.filter.EndsWith, value='world'))) + self.assertTrue(self.match._value(ast.filter.Is('hello'), Partial(ast.filter.Is, value='hello'))) + self.assertFalse(self.match._value(ast.filter.Is('hello'), Partial(ast.filter.Is, value='world'))) + # value can remain unspecified + self.assertTrue(self.match._value(ast.filter.Equals('hello'), Partial(ast.filter.Equals))) + self.assertTrue(self.match._value(ast.filter.Substring('hello'), Partial(ast.filter.Substring))) + self.assertTrue(self.match._value(ast.filter.StartsWith('hello'), Partial(ast.filter.StartsWith))) + self.assertTrue(self.match._value(ast.filter.EndsWith('hello'), Partial(ast.filter.EndsWith))) + self.assertTrue(self.match._value(ast.filter.Is('hello'), Partial(ast.filter.Is))) + + def test_bounded(self): + # identical expressions match + self.assertTrue(self.match._bounded( + ast.filter.LessThan(threshold=3, strict=False), + ast.filter.LessThan(threshold=3, strict=False), + )) + self.assertTrue(self.match._bounded( + ast.filter.GreaterThan(threshold=3, strict=False), + ast.filter.GreaterThan(threshold=3, strict=False), + )) + # _bounded respects type + self.assertFalse(self.match._bounded( + ast.filter.LessThan(threshold=3, strict=False), + ast.filter.GreaterThan(threshold=3, strict=False), + )) + self.assertFalse(self.match._bounded( + ast.filter.GreaterThan(threshold=3, strict=False), + ast.filter.LessThan(threshold=3, strict=False), + )) + # _bounded respects threshold + self.assertFalse(self.match._bounded( + ast.filter.LessThan(threshold=3, strict=False), + ast.filter.LessThan(threshold=4, strict=False), + )) + self.assertFalse(self.match._bounded( + ast.filter.GreaterThan(threshold=3, strict=False), + ast.filter.GreaterThan(threshold=4, strict=False), + )) + # _bounded respects strict + self.assertFalse(self.match._bounded( + ast.filter.LessThan(threshold=3, strict=False), + ast.filter.LessThan(threshold=3, strict=True), + )) + self.assertFalse(self.match._bounded( + ast.filter.GreaterThan(threshold=3, strict=False), + ast.filter.GreaterThan(threshold=3, strict=True), + )) + # Partial requires correct type + self.assertFalse(self.match._bounded( + ast.filter.LessThan(threshold=3, strict=False), + Partial(ast.filter.GreaterThan), + )) + self.assertFalse(self.match._bounded( + ast.filter.GreaterThan(threshold=3, strict=False), + Partial(ast.filter.LessThan), + )) + # threshold and strict can be specified + self.assertTrue(self.match._bounded( + ast.filter.LessThan(threshold=3, strict=False), + Partial(ast.filter.LessThan, threshold=3, strict=False), + )) + self.assertTrue(self.match._bounded( + ast.filter.GreaterThan(threshold=3, strict=False), + Partial(ast.filter.GreaterThan, threshold=3, strict=False), + )) + self.assertFalse(self.match._bounded( + ast.filter.LessThan(threshold=3, strict=False), + Partial(ast.filter.LessThan, threshold=4, strict=False), + )) + self.assertFalse(self.match._bounded( + ast.filter.GreaterThan(threshold=3, strict=False), + Partial(ast.filter.GreaterThan, threshold=4, strict=False), + )) + self.assertFalse(self.match._bounded( + ast.filter.LessThan(threshold=3, strict=False), + Partial(ast.filter.LessThan, threshold=3, strict=True), + )) + self.assertFalse(self.match._bounded( + ast.filter.GreaterThan(threshold=3, strict=False), + Partial(ast.filter.GreaterThan, threshold=3, strict=True), + )) + # threshold can remain unspecified + self.assertTrue(self.match._bounded( + ast.filter.LessThan(threshold=3, strict=False), + Partial(ast.filter.LessThan, strict=False), + )) + self.assertTrue(self.match._bounded( + ast.filter.GreaterThan(threshold=3, strict=False), + Partial(ast.filter.GreaterThan, strict=False), + )) + self.assertFalse(self.match._bounded( + ast.filter.LessThan(threshold=3, strict=False), + Partial(ast.filter.LessThan, strict=True), + )) + self.assertFalse(self.match._bounded( + ast.filter.GreaterThan(threshold=3, strict=False), + Partial(ast.filter.GreaterThan, strict=True), + )) + # strict can remain unspecified + self.assertTrue(self.match._bounded( + ast.filter.LessThan(threshold=3, strict=False), + Partial(ast.filter.LessThan, threshold=3), + )) + self.assertTrue(self.match._bounded( + ast.filter.GreaterThan(threshold=3, strict=False), + Partial(ast.filter.GreaterThan, threshold=3), + )) + self.assertFalse(self.match._bounded( + ast.filter.LessThan(threshold=3, strict=False), + Partial(ast.filter.LessThan, threshold=4), + )) + self.assertFalse(self.match._bounded( + ast.filter.GreaterThan(threshold=3, strict=False), + Partial(ast.filter.GreaterThan, threshold=4), + )) + + +## main ## + +if __name__ == '__main__': + unittest.main() + +## EOF ## |