aboutsummaryrefslogtreecommitdiffstats
path: root/bsfs/triple_store/sparql/distance.py
blob: 9b5808805c35a2dc93431a2ea77c3f8a4d59f686 (plain)
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51

# standard imports
import typing

# external imports
import numpy as np

# bsfs imports
from bsfs.namespace import ns

# constants
EPS = 1e-9

# exports
__all__: typing.Sequence[str] = (
    'DISTANCE_FU',
    )


## code ##

def euclid(fst, snd) -> float:
    """Euclidean distance (l2 norm)."""
    fst = np.array(fst)
    snd = np.array(snd)
    return float(np.linalg.norm(fst - snd))

def cosine(fst, snd) -> float:
    """Cosine distance."""
    fst = np.array(fst)
    snd = np.array(snd)
    if (fst == snd).all():
        return 0.0
    nrm0 = np.linalg.norm(fst)
    nrm1 = np.linalg.norm(snd)
    return float(1.0 - np.dot(fst, snd) / (nrm0 * nrm1 + EPS))

def manhatten(fst, snd) -> float:
    """Manhatten (cityblock) distance (l1 norm)."""
    fst = np.array(fst)
    snd = np.array(snd)
    return float(np.abs(fst - snd).sum())

# Known distance functions.
DISTANCE_FU = {
    ns.bsfs.euclidean: euclid,
    ns.bsfs.cosine: cosine,
    ns.bsfs.manhatten: manhatten,
}

## EOF ##