pfe/src/kd_tree.cc

84 lines
2.3 KiB
C++

#include "kd_tree.h"
#include <algorithm>
#include <limits>
KdTree::Point::Point(): x{0}, y{0}, z{0} {}
KdTree::Point::Point(double x, double y, double z): x{x}, y{y}, z{z} {}
KdTree::Point::Point(double *p): x{p[0]}, y{p[1]}, z{p[2]} {}
double KdTree::Point::operator[](std::size_t i) const {
if(i == 0) return x;
if(i == 1) return y;
return z;
}
double KdTree::Point::dist2(Point const &other) const {
double x_ = x - other.x;
double y_ = y - other.y;
double z_ = z - other.z;
return x_ * x_ + y_ * y_ + z_ * z_;
}
ostream& operator<<(ostream &os, KdTree::Point const &point) {
os << "[" << point.x << ", " << point.y << ", " << point.z << "]";
return os;
}
KdTree::Node::Node(Point const &position, vtkIdType id): position{position}, id{id} {}
/////////////////////////////////////////////////////////////////////
void KdTree::fill(std::vector<Tuple> &points) {
nodes.reserve(points.size());
for(std::size_t i = 0; i < points.size(); ++i)
nodes.push_back({points[i].first, points[i].second});
root = fillRec(0, points.size(), 0);
}
KdTree::Node *KdTree::fillRec(std::size_t begin, std::size_t end, int axis) {
if(end <= begin) return nullptr;
std::size_t n = begin + (end - begin) / 2;
auto i = nodes.begin();
std::nth_element(i + begin, i + n, i + end, [&](Node &p1, Node &p2) {
return p1.position[axis] < p2.position[axis];
});
axis = (axis + 1) % 3;
nodes[n].leftChild = fillRec(begin, n, axis);
nodes[n].rightChild = fillRec(n + 1, end, axis);
return &nodes[n];
}
KdTree::Point KdTree::query(Point const &point) {
bestDist = std::numeric_limits<double>::max();
bestNode = nullptr;
queryRec(root, point, 0);
return bestNode->position;
}
KdTree::Point KdTree::query(double *position) {
return query({position[0], position[1], position[2]});
}
void KdTree::queryRec(Node *node, Point const &point, int axis) {
if(node == nullptr) return;
double d = point.dist2(node->position);
if(d < bestDist) {
bestDist = d;
bestNode = node;
}
double dx = node->position[axis] - point[axis];
axis = (axis + 1) % 3;
queryRec(dx > 0 ? node->leftChild : node->rightChild, point, axis);
if(dx * dx >= bestDist) return;
queryRec(dx > 0 ? node->rightChild : node->leftChild, point, axis);
}