From 73b432d0c83d352746a33e745f5734c8f3481b97 Mon Sep 17 00:00:00 2001 From: CookieKastanie Date: Mon, 28 Feb 2022 23:39:03 +0100 Subject: [PATCH] kdtree trop bien omg --- CMakeLists.txt | 7 +++- src/kd_tree.cc | 83 ++++++++++++++++++++++++++++++++++++++++++ src/kd_tree.h | 50 +++++++++++++++++++++++++ src/main.cc | 10 ++++- src/mesh_fit_filter.cc | 55 ++++++++++++++++++++++++++++ src/mesh_fit_filter.h | 19 ++++++++++ 6 files changed, 222 insertions(+), 2 deletions(-) create mode 100644 src/kd_tree.cc create mode 100644 src/kd_tree.h create mode 100644 src/mesh_fit_filter.cc create mode 100644 src/mesh_fit_filter.h diff --git a/CMakeLists.txt b/CMakeLists.txt index 9ae03ef..6a7411d 100644 --- a/CMakeLists.txt +++ b/CMakeLists.txt @@ -47,7 +47,12 @@ target_sources(pfe PRIVATE src/dihedral_angles_filter.cc src/dihedral_angles_filter.h src/external_points_filter.cc - src/external_points_filter.h) + src/external_points_filter.h + + src/kd_tree.cc + src/kd_tree.h + src/mesh_fit_filter.cc + src/mesh_fit_filter.h) target_link_libraries(pfe PRIVATE ${VTK_COMPONENTS}) diff --git a/src/kd_tree.cc b/src/kd_tree.cc new file mode 100644 index 0000000..c68e94b --- /dev/null +++ b/src/kd_tree.cc @@ -0,0 +1,83 @@ +#include "kd_tree.h" +#include + +KdTree::Point::Point(): x{0}, y{0}, z{0} {} +KdTree::Point::Point(double x, double y, double z): x{x}, y{y}, z{z} {} +KdTree::Point::Point(double *p): x{p[0]}, y{p[1]}, z{p[2]} {} + +double KdTree::Point::operator[](std::size_t i) const { + if(i == 0) return x; + if(i == 1) return y; + return z; +} + +double KdTree::Point::dist2(Point const &other) const { + double x_ = x - other.x; + double y_ = y - other.y; + double z_ = z - other.z; + + return x_ * x_ + y_ * y_ + z_ * z_; +} + +ostream& operator<<(ostream& os, KdTree::Point const & point) { + os << "[" << point.x << ", " << point.y << ", " << point.z << "]"; + return os; +} + +///////////////////////////////////////////////////////////////////// + +void KdTree::fill(std::vector &points) { + nodes.resize(points.size()); + + for(std::size_t i = 0; i < points.size(); ++i) { + nodes[i].position = points[i].first; + nodes[i].index = points[i].second; + } + + root = fillRec(0, points.size(), 0); +} + +KdTree::Node *KdTree::fillRec(std::size_t begin, std::size_t end, int axis) { + if(end <= begin) return nullptr; + + size_t n = begin + (end - begin) / 2; + auto i = nodes.begin(); + std::nth_element(i + begin, i + n, i + end, [&](Node &p1, Node &p2) { + return p1.position[axis] < p2.position[axis]; + }); + + axis = (axis + 1) % 3; + nodes[n].leftChild = fillRec(begin, n, axis); + nodes[n].rightChild = fillRec(n + 1, end, axis); + + return &nodes[n]; +} + +KdTree::Point KdTree::query(Point const &point) { + bestDist = 99999.; + bestNode = nullptr; + queryRec(root, point, 0); + return bestNode->position; +} + +KdTree::Point KdTree::query(double *position) { + return query({position[0], position[1], position[2]}); +} + +void KdTree::queryRec(Node *node, Point const &point, int axis) { + if(node == nullptr) return; + + double d = point.dist2(node->position); + if(d < bestDist) { + bestDist = d; + bestNode = node; + } + + double dx = node->position[axis] - point[axis]; + + axis = (axis + 1) % 3; + + queryRec(dx > 0 ? node->leftChild : node->rightChild, point, axis); + if(dx * dx >= bestDist) return; + queryRec(dx > 0 ? node->rightChild : node->leftChild, point, axis); +} diff --git a/src/kd_tree.h b/src/kd_tree.h new file mode 100644 index 0000000..6aebaa6 --- /dev/null +++ b/src/kd_tree.h @@ -0,0 +1,50 @@ +#ifndef KD_TREE_H +#define KD_TREE_H + +#include +#include + +#include + +// https://rosettacode.org/wiki/K-d_tree#C.2B.2B + +class KdTree { +public: + KdTree() = default; + + struct Point { + Point(); + Point(double x, double y, double z); + Point(double *position); + double x, y, z; + double operator[] (std::size_t i) const; + double dist2(Point const &other) const; + + friend ostream& operator<<(ostream& os, Point const & point); + }; + + using Tuple = std::pair; + void fill(std::vector &points); + Point query(Point const &point); + Point query(double *point); + +private: + struct Node { + Point position; + vtkIdType index; + + Node *leftChild; + Node *rightChild; + }; + + std::vector nodes; + Node *root; + + double bestDist; + Node *bestNode; + + Node *fillRec(std::size_t begin, std::size_t end, int axis); + void queryRec(Node *node, Point const &point, int axis); +}; + +#endif \ No newline at end of file diff --git a/src/main.cc b/src/main.cc index b2f7f35..8efe320 100644 --- a/src/main.cc +++ b/src/main.cc @@ -3,6 +3,8 @@ #include "dihedral_angles_filter.h" #include "external_points_filter.h" +#include "mesh_fit_filter.h" + #include #include #include @@ -89,12 +91,18 @@ int main(int argc, char **argv) { vtkNew externalPointsFilter; externalPointsFilter->SetInputConnection(dihedralAnglesFilter->GetOutputPort()); + + vtkNew meshFitFilter; + meshFitFilter->SetInputConnection(externalPointsFilter->GetOutputPort()); + + vtkNew writer; - writer->SetInputConnection(externalPointsFilter->GetOutputPort()); + writer->SetInputConnection(meshFitFilter->GetOutputPort()); writer->SetFileTypeToASCII(); writer->SetFileName("out.vtk"); writer->Write(); + #ifdef USE_VIEWER /* Volume rendering properties */ vtkNew volumeMapper; diff --git a/src/mesh_fit_filter.cc b/src/mesh_fit_filter.cc new file mode 100644 index 0000000..5cd6121 --- /dev/null +++ b/src/mesh_fit_filter.cc @@ -0,0 +1,55 @@ +#include "mesh_fit_filter.h" + +#include +#include +#include +#include +#include + +#include "kd_tree.h" + +vtkStandardNewMacro(MeshFitFilter); + +vtkTypeBool MeshFitFilter::RequestData( + vtkInformation *request, + vtkInformationVector **inputVector, + vtkInformationVector *outputVector) { + (void) request; + + vtkUnstructuredGrid* input = + vtkUnstructuredGrid::GetData(inputVector[0]); + vtkUnstructuredGrid* output = + vtkUnstructuredGrid::GetData(outputVector); + output->CopyStructure(input); + output->GetPointData()->PassData(input->GetPointData()); + + + std::vector points; + + auto it = input->NewCellIterator(); + for(it->InitTraversal(); !it->IsDoneWithTraversal(); it->GoToNextCell()) { + if(it->GetCellType() != VTK_TETRA) continue; + + vtkIdList *idList = it->GetPointIds(); + + for(int i = 0; i < 4; ++i) { + vtkIdType id = idList->GetId(i); + + double point[3]; + input->GetPoint(id, point); + + points.push_back({point, id}); + + //std::cout << "[" << point[0] << ", " << point[1] << ", " << point[2] << "] (" << id << ")\n"; + } + } + + KdTree kdTree; + kdTree.fill(points); + + std::cout << "[0.3, 1.5, -0.1] => " << kdTree.query({0.3, 1.5, -0.1}) << std::endl; + std::cout << "[5.1, 1.1, 0.5] => " << kdTree.query({5.1, 1.1, 0.5}) << std::endl; + std::cout << "[0, 0, 1] => " << kdTree.query({0, 0, 1}) << std::endl; + + return true; +} diff --git a/src/mesh_fit_filter.h b/src/mesh_fit_filter.h new file mode 100644 index 0000000..15e9d6d --- /dev/null +++ b/src/mesh_fit_filter.h @@ -0,0 +1,19 @@ +#ifndef MESH_FIT_FILTER_H +#define MESH_FIT_FILTER_H + +#include +#include + +class MeshFitFilter : public vtkUnstructuredGridAlgorithm { +public: + static MeshFitFilter *New(); + vtkTypeMacro(MeshFitFilter, vtkUnstructuredGridAlgorithm); + vtkTypeBool RequestData(vtkInformation *request, + vtkInformationVector **inputVector, + vtkInformationVector *outputVector) override; + +protected: + ~MeshFitFilter() override = default; +}; + +#endif