1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
83
84
85
86
87
88
89
90
91
92
93
94
95
96
97
98
99
100
101
102
103
104
105
106
107
108
109
110
111
112
113
114
115
116
117
118
119
120
121
122
123 | #include <algorithm>
#include <array>
#include <vector>
#include <iostream>
#include <cmath>
#include <functional>
/// @class the KdTree Implementation
template<typename T, size_t D>
class KdTree {
public:
using Point = std::array<T, D>;
private:
/// @struct Kdtree node defination
struct Node {
Point point{};
Node* left{nullptr};
Node* right{nullptr};
};
/// @struct Less compare operator for two points
struct NodeComp {
size_t index{0};
bool operator()(const Point& a, const Point& b) {
return a[index] < b[index];
}
};
/// Build the tree with a list of points recursively
/// @return root of the tree
Node* build_tree(
std::vector<Point>& pts, size_t begin, size_t end, size_t dim) {
if (begin >= end) return nullptr;
size_t mid = begin + (end - begin) / 2;
auto it = pts.begin();
std::nth_element(it + begin, it + mid, it + end, NodeComp{.index = dim});
dim = ++dim % D;
return new Node{pts[mid], build_tree(pts, begin, mid, dim),
build_tree(pts, mid + 1, end, dim)};
}
/// Calculate the square distance of two points
/// @return squre distance
double sqr_dis(const Point& p1, const Point& p2) {
double ans{0};
for (size_t i = 0; i < D; ++i) ans += (p1[i] - p2[i]) * (p1[i] - p2[i]);
return ans;
}
/// Get the nearest point and distance by the given point
void nearest(const Node* root, const Point& tar, size_t dim, Point& best,
double& min_dis) {
if (!root) return;
const double d = sqr_dis(root->point, tar);
if (d < min_dis) {
min_dis = d;
best = root->point;
}
const double dx = root->point[dim] - tar[dim];
dim = ++dim % D;
nearest(dx > 0 ? root->left : root->right, tar, dim, best, min_dis);
if (dx * dx > min_dis) return;
nearest(dx > 0 ? root->right : root->left, tar, dim, best, min_dis);
}
public:
/// Copy control
KdTree(const KdTree&) = delete;
KdTree& operator=(const KdTree&) = delete;
KdTree(KdTree&&) = delete;
KdTree& operator=(KdTree&&) = delete;
~KdTree() {
std::function<void(Node*)> f = [&f](Node* root) {
if (!root) return;
f(root->left);
f(root->right);
delete root;
root = nullptr;
return;
};
f(root);
}
/// Constructor with a list of points
KdTree(std::vector<Point> pts) {
root = build_tree(pts, 0, pts.size(), 0);
}
/// Get the nearest point and distance with given point
/// @return point and distance of the nearest
std::pair<Point, double> Nearest(const Point& tar) {
if (!root) return {};
double min_dis{std::numeric_limits<double>::max()};
Point ans{};
nearest(root, tar, 0, ans, min_dis);
return {ans, std::sqrt(min_dis)};
}
private:
Node* root{nullptr};
};
/// Test code
int main() {
std::vector<std::array<double, 2>> pts{
{2, 3}, {5, 4}, {9, 6}, {4, 7}, {8, 1}, {7, 2}};
KdTree<double, 2> kdt{pts};
auto [pt, dis] = kdt.Nearest({9, 2});
std::cout << "nearest: (" << pt[0] << ", " << pt[1] << ") in " << dis
<< std::endl;
}
|