1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
83
84
85
86
87
88
89
90
91
92
93
94
95
96
97
98
99
100
101
102
103
104
105
106
107
108
109
110
111
112
113
114
115
116
117
118
119
120
121
122
123
124
125
126
127
128
129
130
131
132
133
134
135
136
137
138
139
140
141
142
143
144
145
146
147
148
149
150
151
152
153
154
155
156
157
158
159
160
161
162
163
164
165
166
167
168
169
170
171
172
173
174
175
176
177
178
179
180
181
182
183
184
185
186
187
188
189
190
191
192
193
194
195
196
197
198
199
200
201
202
203
204
205
206
207
208
209
210
211
212
213
214
215
216
217
218
219
220
221
222
223
224
225
226
227
228
229
230
231
232
233
234
235
236
237
238
239
240
241
242
243
244
245
246
247
248
249
250
251
252
253
254
255
256
257
258
259
260
261
262
263
264
265
266
267
268
269
270
271
272
273
274
275
276
277
278
279
280
281
282
283
284
285
286
287
288
289
290
291
292
293
294
295
296
297
298
299
300
301
302
303
304
305
306
307
308
309
310
311
312
313
314
315
316
317
318
319
320
321
322
323
324
325
326
327
328
329
330
331
332
333
334
335
336
337
338
339
340
341
342
343
344
345
346
347
348
349
350
351
352
353
354
355
356
357
358
359
360
361
|
# imports
from collections import defaultdict
from itertools import product
from time import time
import random
import threading
import typing
# external imports
from hopcroftkarp import HopcroftKarp
# bsfs imports
from bsfs.utils import errors, typename
# inner-module imports
from . import ast
# exports
__all__ : typing.Sequence[str] = (
'Filter',
)
## code ##
class Any(ast.filter.FilterExpression, ast.filter.PredicateExpression):
"""Match any ast class.
Note that Any instances are unique, i.e. they do not compare, and
can hence be repeated in a set:
>>> Any() == Any()
False
>>> len({Any(), Any(), Any(), Any()})
4
"""
# unique instance id
_uid: typing.Tuple[int, int, float, float]
def __init__(self):
self._uid = (
id(self),
id(threading.current_thread()),
time(),
random.random(),
)
def __eq__(self, other: typing.Any):
return super().__eq__(other) and self._uid == other._uid
def __hash__(self):
return hash((super().__hash__(), self._uid))
class Rest(ast.filter.FilterExpression, ast.filter.PredicateExpression):
"""Match the leftovers in a set of items to be compared.
Rest can be used in junction with aggregating expressions such as ast.filter.And,
ast.filter.Or, ast.filter.OneOf. It controls childs expressions that were not yet
consumed by other matching rules. Rest may match to only a specific expression.
The expresssion defaults to Any().
For example, the following to ast structures would match since Rest
allows an arbitrary repetition of ast.filter.Equals statements.
>>> And(Equals('hello'), Equals('world'), Equals('foobar'))
>>> And(Equals('world'), Rest(Partial(Equals)))
"""
# child expression for the Rest.
expr: typing.Union[ast.filter.FilterExpression, ast.filter.PredicateExpression]
def __init__(
self,
expr: typing.Optional[typing.Union[ast.filter.FilterExpression, ast.filter.PredicateExpression]] = None,
):
if expr is None:
expr = Any()
self.expr = expr
def __repr__(self) -> str:
return f'{typename(self)}({self.expr})'
def __hash__(self) -> int:
return hash((super().__hash__(), self.expr))
def __eq__(self, other: typing.Any) -> bool:
return super().__eq__(other) and self.expr == other.expr
class Partial(ast.filter.FilterExpression, ast.filter.PredicateExpression):
"""Match a partially defined ast expression.
Literal values might be irrelevant or unknown when comparing two ast
structures. Partial allows to constrain the matcher to a certain
ast class, while leaving some of its members unspecified.
Pass the class (not instance) and its members as keyword arguments
to Partial. Note that the arguments are not validated.
For example, the following instance matches any ast.filter.Equals,
irrespective of its value:
>>> Partial(ast.filter.Equals)
Likewise, the following instance matches any ast.filter.LessThan
that has a strict bounds, but makes no claim about the threshold:
>>> Partial(ast.filter.LessThan, strict=False)
"""
# target node type.
node: typing.Type
# node construction args.
kwargs: typing.Dict[str, typing.Any]
def __init__(
self,
node: typing.Type,
**kwargs,
):
self.node = node
self.kwargs = kwargs
def __repr__(self) -> str:
return f'{typename(self)}({self.node.__name__}, {self.kwargs})'
def __hash__(self) -> int:
kwargs = tuple((key, self.kwargs[key]) for key in sorted(self.kwargs))
return hash((super().__hash__(), self.node, kwargs))
def __eq__(self, other: typing.Any) -> bool:
return super().__eq__(other) \
and self.node == other.node \
and self.kwargs == other.kwargs
def match(
self,
name: str,
value: typing.Any,
) -> bool:
"""Return True if *name* is unspecified or matches *value*."""
return name not in self.kwargs or self.kwargs[name] == value
T_ITEM_TYPE = typing.TypeVar('T_ITEM_TYPE') # pylint: disable=invalid-name
def _set_matcher(
query: typing.Collection[T_ITEM_TYPE],
reference: typing.Collection[T_ITEM_TYPE],
cmp: typing.Callable[[T_ITEM_TYPE, T_ITEM_TYPE], bool],
) -> bool:
"""Compare two sets of child expressions.
This check has a best-case complexity of O(|N|**2) and worst-case
complexity of O(|N|**3), with N the number of child expressions.
"""
# get reference items
r_items = list(reference)
# deal with Rest
r_rest = {itm for itm in r_items if isinstance(itm, Rest)}
if len(r_rest) > 1:
raise errors.BackendError(f'there must be at most one Rest instance per set, found {len(r_rest)}')
if len(r_rest) == 1:
# replace Rest by filling the reference up with rest's expression
# NOTE: convert r_items to list so that items can be repeated
expr = next(iter(r_rest)).expr # type: ignore [attr-defined]
r_items = [itm for itm in r_items if not isinstance(itm, Rest)]
r_items += [expr for _ in range(len(query) - len(r_items))] # type: ignore [misc]
# sanity check: cannot match if the item sizes differ:
# either a reference item is unmatched (len(r_items) > len(query))
# or a query item is unmatched (len(r_items) < len(query))
if len(query) != len(r_items):
return False
# To have a positive match between the query and the reference,
# each query expr has to match any reference expr.
# However, each reference expr can only be "consumed" once even
# if it matches multiple query exprs (e.g., the Any expression matches
# every query expr).
# This is a bipartide matching problem (Hall's marriage problem)
# and the Hopcroft-Karp-Karzanov algorithm finds a maximum
# matching. While there might be multiple maximum matchings,
# we only need to know whether (at least) one complete matching
# exists. The hopcroftkarp module provides this functionality.
# The HKK algorithm has worst-case complexity of O(|N|**2 * sqrt(|N|))
# and we also need to compare expressions pairwise, hence O(|N|**2).
num_items = len(r_items)
graph = defaultdict(set)
# build the bipartide graph as {lhs: {rhs}, ...}
# lhs and rhs must be disjoint identifiers.
for (ridx, ref), (nidx, node) in product(enumerate(r_items), enumerate(query)):
# add edges for equal expressions
if cmp(node, ref):
graph[ridx].add(num_items + nidx)
# maximum_matching returns the matches for all nodes in the graph
# ({ref_itm: node_itm}), hence a complete matching's size is
# the number of reference's child expressions.
return len(HopcroftKarp(graph).maximum_matching(keys_only=True)) == num_items
class Filter():
"""Compare a bsfs.query.ast.filter` query's structure to a reference ast.
The reference ast may include `Rest`, `Partial`, or `Any` to account for irrelevant
or unknown ast pieces.
This is only a structural comparison, not a semantic one. For example, the
two following queries are semantically identical, but structurally different,
and would therefore not match:
>>> ast.filter.OneOf(ast.filter.Predicate(ns.bse.filename))
>>> ast.filter.Predicate(ns.bse.filename)
"""
def __call__(self, query: ast.filter.FilterExpression, reference: ast.filter.FilterExpression) -> bool:
"""Compare a *query* to a *reference* ast structure.
Return True if both are structurally equivalent.
"""
if not isinstance(query, ast.filter.FilterExpression):
raise errors.BackendError(f'expected filter expression, found {query}')
if not isinstance(reference, ast.filter.FilterExpression):
raise errors.BackendError(f'expected filter expression, found {reference}')
return self._parse_filter_expression(query, reference)
def _parse_filter_expression(
self,
node: ast.filter.FilterExpression,
reference: ast.filter.FilterExpression,
) -> bool:
"""Route *node* to the handler of the respective FilterExpression subclass."""
# generic checks: reference type must be Any or match node type
if isinstance(reference, Any):
return True
# node-specific checks
if isinstance(node, ast.filter.Not):
return self._not(node, reference)
if isinstance(node, ast.filter.Has):
return self._has(node, reference)
if isinstance(node, ast.filter.Distance):
return self._distance(node, reference)
if isinstance(node, (ast.filter.Any, ast.filter.All)):
return self._branch(node, reference)
if isinstance(node, (ast.filter.And, ast.filter.Or)):
return self._agg(node, reference)
if isinstance(node, (ast.filter.Is, ast.filter.Equals, ast.filter.Substring,
ast.filter.StartsWith, ast.filter.EndsWith)):
return self._value(node, reference)
if isinstance(node, (ast.filter.LessThan, ast.filter.GreaterThan)):
return self._bounded(node, reference)
# invalid node
raise errors.BackendError(f'expected filter expression, found {node}')
def _parse_predicate_expression(
self,
node: ast.filter.PredicateExpression,
reference: ast.filter.PredicateExpression,
) -> bool:
"""Route *node* to the handler of the respective PredicateExpression subclass."""
if isinstance(reference, Any):
return True
if isinstance(node, ast.filter.Predicate):
return self._predicate(node, reference)
if isinstance(node, ast.filter.OneOf):
return self._one_of(node, reference)
# invalid node
raise errors.BackendError(f'expected predicate expression, found {node}')
def _one_of(self, node: ast.filter.OneOf, reference: ast.filter.PredicateExpression) -> bool:
if not isinstance(reference, type(node)):
return False
return _set_matcher(node, reference, self._parse_predicate_expression)
def _predicate(self, node: ast.filter.Predicate, reference: ast.filter.PredicateExpression) -> bool:
if not isinstance(reference, (Partial, type(node))):
return False
# partial check
if isinstance(reference, Partial):
if not isinstance(node, reference.node):
return False
return reference.match('predicate', node.predicate) \
and reference.match('reverse', node.reverse)
# full check
return node.predicate == reference.predicate \
and node.reverse == reference.reverse
def _branch(self,
node: typing.Union[ast.filter.Any, ast.filter.All],
reference: ast.filter.FilterExpression,
) -> bool:
if not isinstance(reference, type(node)):
return False
if not self._parse_predicate_expression(node.predicate, reference.predicate): # type: ignore [attr-defined]
return False
if not self._parse_filter_expression(node.expr, reference.expr): # type: ignore [attr-defined]
return False
return True
def _agg(self, node: typing.Union[ast.filter.And, ast.filter.Or], reference: ast.filter.FilterExpression) -> bool:
if not isinstance(reference, type(node)):
return False
return _set_matcher(node, reference, self._parse_filter_expression) # type: ignore [arg-type]
def _not(self, node: ast.filter.Not, reference: ast.filter.FilterExpression) -> bool:
if not isinstance(reference, type(node)):
return False
return self._parse_filter_expression(node.expr, reference.expr)
def _has(self, node: ast.filter.Has, reference: ast.filter.FilterExpression) -> bool:
if not isinstance(reference, type(node)):
return False
return self._parse_predicate_expression(node.predicate, reference.predicate) \
and self._parse_filter_expression(node.count, reference.count)
def _distance(self, node: ast.filter.Distance, reference: ast.filter.FilterExpression) -> bool:
if not isinstance(reference, (Partial, type(node))):
return False
# partial check
if isinstance(reference, Partial):
if not isinstance(node, reference.node):
return False
return reference.match('reference', node.reference) \
and reference.match('threshold', node.threshold) \
and reference.match('strict', node.strict)
# full check
return node.reference == reference.reference \
and node.threshold == reference.threshold \
and node.strict == reference.strict
def _value(self, node: ast.filter._Value, reference: ast.filter.FilterExpression) -> bool:
if not isinstance(reference, (Partial, type(node))):
return False
# partial check
if isinstance(reference, Partial):
if not isinstance(node, reference.node):
return False
return reference.match('value', node.value)
# full ckeck
return node.value == reference.value
def _bounded(self, node: ast.filter._Bounded, reference: ast.filter.FilterExpression) -> bool:
if not isinstance(reference, (Partial, type(node))):
return False
# partial check
if isinstance(reference, Partial):
if not isinstance(node, reference.node):
return False
return reference.match('threshold', node.threshold) \
and reference.match('strict', node.strict)
# full check
return node.threshold == reference.threshold \
and node.strict == reference.strict
## EOF ##
|