diff options
author | Matthias Baumgartner <dev@igsor.net> | 2023-01-15 21:00:12 +0100 |
---|---|---|
committer | Matthias Baumgartner <dev@igsor.net> | 2023-01-15 21:00:12 +0100 |
commit | 80a97bfa9f22d0d6dd25928fe1754a3a0d1de78a (patch) | |
tree | 30d30fb669d7b43d7324ef8027306c24c1ec1ac2 /bsfs | |
parent | ccaee71e2b6135d3b324fe551c8652940b67aab3 (diff) | |
download | bsfs-80a97bfa9f22d0d6dd25928fe1754a3a0d1de78a.tar.gz bsfs-80a97bfa9f22d0d6dd25928fe1754a3a0d1de78a.tar.bz2 bsfs-80a97bfa9f22d0d6dd25928fe1754a3a0d1de78a.zip |
Distance filter ast node
Diffstat (limited to 'bsfs')
-rw-r--r-- | bsfs/graph/resolve.py | 5 | ||||
-rw-r--r-- | bsfs/query/ast/filter_.py | 59 | ||||
-rw-r--r-- | bsfs/query/validator.py | 16 | ||||
-rw-r--r-- | bsfs/triple_store/sparql/distance.py | 56 | ||||
-rw-r--r-- | bsfs/triple_store/sparql/parse_filter.py | 41 | ||||
-rw-r--r-- | bsfs/triple_store/sparql/sparql.py | 13 |
6 files changed, 176 insertions, 14 deletions
diff --git a/bsfs/graph/resolve.py b/bsfs/graph/resolve.py index b671204..00b778b 100644 --- a/bsfs/graph/resolve.py +++ b/bsfs/graph/resolve.py @@ -63,6 +63,8 @@ class Filter(): return self._and(type_, node) if isinstance(node, ast.filter.Or): return self._or(type_, node) + if isinstance(node, ast.filter.Distance): + return self._distance(type_, node) if isinstance(node, (ast.filter.Equals, ast.filter.Substring, \ ast.filter.StartsWith, ast.filter.EndsWith)): return self._value(type_, node) @@ -125,6 +127,9 @@ class Filter(): def _has(self, type_: bsc.Vertex, node: ast.filter.Has) -> ast.filter.Has: # pylint: disable=unused-argument return node + def _distance(self, type_: bsc.Vertex, node: ast.filter.Distance): # pylint: disable=unused-argument + return node + def _value(self, type_: bsc.Vertex, node: ast.filter._Value) -> ast.filter._Value: # pylint: disable=unused-argument return node diff --git a/bsfs/query/ast/filter_.py b/bsfs/query/ast/filter_.py index b129ded..2f0270c 100644 --- a/bsfs/query/ast/filter_.py +++ b/bsfs/query/ast/filter_.py @@ -252,8 +252,7 @@ class Has(FilterExpression): class _Value(FilterExpression): - """ - """ + """Matches some value.""" # target value. value: typing.Any @@ -277,13 +276,13 @@ class Is(_Value): class Equals(_Value): """Value matches exactly. - NOTE: Value format must correspond to literal type; can be a string, a number, or a Node + NOTE: Value must correspond to literal type. """ class Substring(_Value): """Value matches a substring - NOTE: value format must be a string + NOTE: value must be a string. """ @@ -295,9 +294,49 @@ class EndsWith(_Value): """Value ends with a given string.""" +class Distance(FilterExpression): + """Distance to a reference is (strictly) below a threshold. Assumes a Feature literal.""" + + # FIXME: + # (a) pass a node/predicate as anchor instead of a value. + # Then we don't need to materialize the reference. + # (b) pass a FilterExpression (_Bounded) instead of a threshold. + # Then, we could also query values greater than a threshold. + + # reference value. + reference: typing.Any + + # distance threshold. + threshold: float + + # closed (True) or open (False) bound. + strict: bool + + def __init__( + self, + reference: typing.Any, + threshold: float, + strict: bool = False, + ): + self.reference = reference + self.threshold = float(threshold) + self.strict = bool(strict) + + def __repr__(self) -> str: + return f'{typename(self)}({self.reference}, {self.threshold}, {self.strict})' + + def __hash__(self) -> int: + return hash((super().__hash__(), tuple(self.reference), self.threshold, self.strict)) + + def __eq__(self, other) -> bool: + return super().__eq__(other) \ + and self.reference == other.reference \ + and self.threshold == other.threshold \ + and self.strict == other.strict + + class _Bounded(FilterExpression): - """ - """ + """Value is bounded by a threshold. Assumes a Number literal.""" # bound. threshold: float @@ -327,15 +366,11 @@ class _Bounded(FilterExpression): class LessThan(_Bounded): - """Value is (strictly) smaller than threshold. - NOTE: only on numerical literals - """ + """Value is (strictly) smaller than threshold. Assumes a Number literal.""" class GreaterThan(_Bounded): - """Value is (strictly) larger than threshold - NOTE: only on numerical literals - """ + """Value is (strictly) larger than threshold. Assumes a Number literal.""" class Predicate(PredicateExpression): diff --git a/bsfs/query/validator.py b/bsfs/query/validator.py index ecea951..1b7f688 100644 --- a/bsfs/query/validator.py +++ b/bsfs/query/validator.py @@ -69,6 +69,8 @@ class Filter(): return self._not(type_, node) if isinstance(node, ast.filter.Has): return self._has(type_, node) + if isinstance(node, ast.filter.Distance): + return self._distance(type_, node) if isinstance(node, (ast.filter.Any, ast.filter.All)): return self._branch(type_, node) if isinstance(node, (ast.filter.And, ast.filter.Or)): @@ -177,6 +179,20 @@ class Filter(): # node.count is a numerical expression self._parse_filter_expression(self.schema.literal(ns.bsfs.Number), node.count) + def _distance(self, type_: bsc.Vertex, node: ast.filter.Distance): + # type is a Literal + if not isinstance(type_, bsc.Feature): + raise errors.ConsistencyError(f'expected a Feature, found {type_}') + # type exists in the schema + if type_ not in self.schema.literals(): + raise errors.ConsistencyError(f'literal {type_} is not in the schema') + # reference matches type_ + if len(node.reference) != type_.dimension: + raise errors.ConsistencyError(f'reference has dimension {len(node.reference)}, expected {type_.dimension}') + # FIXME: + #if node.reference.dtype != type_.dtype: + # raise errors.ConsistencyError(f'') + ## conditions diff --git a/bsfs/triple_store/sparql/distance.py b/bsfs/triple_store/sparql/distance.py new file mode 100644 index 0000000..2f5387a --- /dev/null +++ b/bsfs/triple_store/sparql/distance.py @@ -0,0 +1,56 @@ +""" + +Part of the BlackStar filesystem (bsfs) module. +A copy of the license is provided with the project. +Author: Matthias Baumgartner, 2022 +""" +# standard imports +import typing + +# external imports +import numpy as np + +# bsfs imports +from bsfs.namespace import ns + +# constants +EPS = 1e-9 + +# exports +__all__: typing.Sequence[str] = ( + 'DISTANCE_FU', + ) + + +## code ## + +def euclid(fst, snd) -> float: + """Euclidean distance (l2 norm).""" + fst = np.array(fst) + snd = np.array(snd) + return float(np.linalg.norm(fst - snd)) + +def cosine(fst, snd) -> float: + """Cosine distance.""" + fst = np.array(fst) + snd = np.array(snd) + if (fst == snd).all(): + return 0.0 + nrm0 = np.linalg.norm(fst) + nrm1 = np.linalg.norm(snd) + return float(1.0 - np.dot(fst, snd) / (nrm0 * nrm1 + EPS)) + +def manhatten(fst, snd) -> float: + """Manhatten (cityblock) distance (l1 norm).""" + fst = np.array(fst) + snd = np.array(snd) + return float(np.abs(fst - snd).sum()) + +# Known distance functions. +DISTANCE_FU = { + ns.bsfs.euclidean: euclid, + ns.bsfs.cosine: cosine, + ns.bsfs.manhatten: manhatten, +} + +## EOF ## diff --git a/bsfs/triple_store/sparql/parse_filter.py b/bsfs/triple_store/sparql/parse_filter.py index 5d8a2d9..8b6b976 100644 --- a/bsfs/triple_store/sparql/parse_filter.py +++ b/bsfs/triple_store/sparql/parse_filter.py @@ -5,19 +5,29 @@ A copy of the license is provided with the project. Author: Matthias Baumgartner, 2022 """ # imports +import operator import typing +# external imports +import rdflib + # bsfs imports from bsfs import schema as bsc from bsfs.namespace import ns from bsfs.query import ast from bsfs.utils import URI, errors +# inner-module imports +from .distance import DISTANCE_FU + # exports __all__: typing.Sequence[str] = ( 'Filter', ) + +## code ## + class _GenHopName(): """Generator that produces a new unique symbol name with each iteration.""" @@ -46,7 +56,8 @@ class Filter(): # Generator that produces unique symbol names. ngen: _GenHopName - def __init__(self, schema): + def __init__(self, graph, schema): + self.graph = graph self.schema = schema self.ngen = _GenHopName() @@ -84,6 +95,8 @@ class Filter(): return self._not(type_, node, head) if isinstance(node, ast.filter.Has): return self._has(type_, node, head) + if isinstance(node, ast.filter.Distance): + return self._distance(type_, node, head) if isinstance(node, ast.filter.Any): return self._any(type_, node, head) if isinstance(node, ast.filter.All): @@ -243,6 +256,32 @@ class Filter(): # combine return num_preds + ' . ' + count_bounds + def _distance(self, node_type: bsc.Vertex, node: ast.filter.Distance, head: str) -> str: + """ + """ + if not isinstance(node_type, bsc.Feature): + raise errors.BackendError(f'expected Feature, found {node_type}') + if len(node.reference) != node_type.dimension: + raise errors.ConsistencyError( + f'reference has dimension {len(node.reference)}, expected {node_type.dimension}') + # get distance metric + dist = DISTANCE_FU[node_type.distance] + # get operator + cmp = operator.lt if node.strict else operator.le + # get candidate values + candidates = { + f'"{cand}"^^<{node_type.uri}>' + for cand + in self.graph.objects() + if isinstance(cand, rdflib.Literal) + and cand.datatype == rdflib.URIRef(node_type.uri) + and cmp(dist(cand.value, node.reference), node.threshold) + } + # combine candidate values + values = ' '.join(candidates) if len(candidates) else f'"impossible value"^^<{ns.xsd.string}>' + # return sparql fragment + return f'VALUES {head} {{ {values} }}' + def _is(self, node_type: bsc.Vertex, node: ast.filter.Is, head: str) -> str: """ """ diff --git a/bsfs/triple_store/sparql/sparql.py b/bsfs/triple_store/sparql/sparql.py index 3877d1a..dfd9871 100644 --- a/bsfs/triple_store/sparql/sparql.py +++ b/bsfs/triple_store/sparql/sparql.py @@ -18,6 +18,7 @@ from bsfs.utils import errors, URI # inner-module imports from . import parse_filter from .. import base +from .distance import DISTANCE_FU # exports @@ -97,7 +98,7 @@ class SparqlStore(base.TripleStoreBase): self._transaction = _Transaction(self._graph) # NOTE: parsing bsfs.query.ast.filter.Has requires xsd:integer. self._schema = bsc.Schema(literals={bsc.ROOT_NUMBER.child(ns.xsd.integer)}) - self._filter_parser = parse_filter.Filter(self._schema) + self._filter_parser = parse_filter.Filter(self._graph, self._schema) # NOTE: mypy and pylint complain about the **kwargs not being listed (contrasting super) # However, not having it here is clearer since it's explicit that there are no arguments. @@ -123,6 +124,16 @@ class SparqlStore(base.TripleStoreBase): # check compatibility: No contradicting definitions if not self.schema.consistent_with(schema): raise errors.ConsistencyError(f'{schema} is inconsistent with {self.schema}') + # check distance functions of features + invalid = { + (cand.uri, cand.distance) + for cand + in schema.literals() + if isinstance(cand, bsc.Feature) and cand.distance not in DISTANCE_FU} + if len(invalid) > 0: + cand, dist = zip(*invalid) + raise ValueError( + f'unknown distance function {",".join(dist)} in feature {", ".join(cand)}') # commit the current transaction self.commit() |