K-d tree

struct Kd_node {
    d,
    split,
    left,
    right,
}
 
struct Orthotope {
    min,
    max,
}
 
class Kd_tree(n, bounds) {
 
    method init {
        n = self.nk2(0, n);
    }
 
    method nk2(split, e) {
        return(nil) if (e.len <= 0);
        var exset = e.sort_by { _[split] }
        var m = (exset.len // 2);
        var d = exset[m];
        while ((m+1 < exset.len) && (exset[m+1][split] == d[split])) {
            ++m;
        }
 
        var s2 = ((split + 1) % d.len);     # cycle coordinates
        Kd_node(d: d, split: split,
                left:  self.nk2(s2, exset.first(m)),
                right: self.nk2(s2, exset.last(m-1)));
    }
}
 
struct T3 {
    nearest,
    dist_sqd = Inf,
    nodes_visited = 0,
}
 
func find_nearest(k, t, p) {
    func nn(kd, target, hr, max_dist_sqd) {
        kd || return T3(nearest: [0]*k);
 
        var nodes_visited = 1;
        var s = kd.split;
        var pivot = kd.d;
        var left_hr = Orthotope(hr.min, hr.max);
        var right_hr = Orthotope(hr.min, hr.max);
        left_hr.max[s] = pivot[s];
        right_hr.min[s] = pivot[s];
 
        var nearer_kd;
        var further_kd;
        var nearer_hr;
        var further_hr;
        if (target[s] <= pivot[s]) {
            (nearer_kd, nearer_hr) = (kd.left, left_hr);
            (further_kd, further_hr) = (kd.right, right_hr);
        }
        else {
            (nearer_kd, nearer_hr) = (kd.right, right_hr);
            (further_kd, further_hr) = (kd.left, left_hr);
        }
 
        var n1 = nn(nearer_kd, target, nearer_hr, max_dist_sqd);
        var nearest = n1.nearest;
        var dist_sqd = n1.dist_sqd;
        nodes_visited += n1.nodes_visited;
 
        if (dist_sqd < max_dist_sqd) {
            max_dist_sqd = dist_sqd;
        }
        var d = (pivot[s] - target[s] -> sqr);
        if (d > max_dist_sqd) {
            return T3(nearest: nearest, dist_sqd: dist_sqd, nodes_visited: nodes_visited);
        }
        d = (pivot ~Z- target »sqr»() «+»);
        if (d < dist_sqd) {
            nearest = pivot;
            dist_sqd = d;
            max_dist_sqd = dist_sqd;
        }
 
        var n2 = nn(further_kd, target, further_hr, max_dist_sqd);
        nodes_visited += n2.nodes_visited;
        if (n2.dist_sqd < dist_sqd) {
            nearest = n2.nearest;
            dist_sqd = n2.dist_sqd;
        }
 
        T3(nearest: nearest, dist_sqd: dist_sqd, nodes_visited: nodes_visited);
    }
 
    return nn(t.n, p, t.bounds, Inf);
}
 
func show_nearest(k, heading, kd, p) {
    print <<-"END"
        #{heading}:
        Point:            [#{p.join(',')}]
        END
    var n = find_nearest(k, kd, p);
    print <<-"END"
        Nearest neighbor: [#{n.nearest.join(',')}]
        Distance:         #{sqrt(n.dist_sqd)}
        Nodes visited:    #{n.nodes_visited()}
 
        END
}
 
func random_point(k) { k.of { 1.rand } }
func random_points(k, n) { n.of { random_point(k) } }
 
var kd1 = Kd_tree([[2, 3],[5, 4],[9, 6],[4, 7],[8, 1],[7, 2]],
              Orthotope(min: [0, 0], max: [10, 10]));
show_nearest(2, "Wikipedia example data", kd1, [9, 2]);
 
var N = 1000
var t0 = Time.micro
var kd2 = Kd_tree(random_points(3, N), Orthotope(min: [0,0,0], max: [1,1,1]))
 
var t1 = Time.micro
show_nearest(2,
    "k-d tree with #{N} random 3D points (generation time: #{t1 - t0}s)",
     kd2, random_point(3))

Output:

Wikipedia example data:
Point:            [9,2]
Nearest neighbor: [8,1]
Distance:         1.41421356237309504880168872420969807856967187537695
Nodes visited:    3

k-d tree with 1000 random 3D points (generation time: 0.25858s):
Point:            [0.28961099389483532140941908632585463245703951185091,0.53383735570157521548169561846919925542828568155241,0.06543742264584889854604603500785549847809104064543]
Nearest neighbor: [0.28344130897803248636407083474733485442276225490765,0.54224255944924304382130637584184680424540291089424,0.12494797195998536910186918458377312162395340543897]
Distance:         0.06041703353924870133411659885818573345011076349573
Nodes visited:    30

Last updated