kdtree trop bien omg

This commit is contained in:
CookieKastanie 2022-02-28 23:39:03 +01:00
parent 1a240beee3
commit 73b432d0c8
6 changed files with 222 additions and 2 deletions

View File

@ -47,7 +47,12 @@ target_sources(pfe PRIVATE
src/dihedral_angles_filter.cc src/dihedral_angles_filter.cc
src/dihedral_angles_filter.h src/dihedral_angles_filter.h
src/external_points_filter.cc src/external_points_filter.cc
src/external_points_filter.h) src/external_points_filter.h
src/kd_tree.cc
src/kd_tree.h
src/mesh_fit_filter.cc
src/mesh_fit_filter.h)
target_link_libraries(pfe PRIVATE ${VTK_COMPONENTS}) target_link_libraries(pfe PRIVATE ${VTK_COMPONENTS})

83
src/kd_tree.cc Normal file
View File

@ -0,0 +1,83 @@
#include "kd_tree.h"
#include <algorithm>
KdTree::Point::Point(): x{0}, y{0}, z{0} {}
KdTree::Point::Point(double x, double y, double z): x{x}, y{y}, z{z} {}
KdTree::Point::Point(double *p): x{p[0]}, y{p[1]}, z{p[2]} {}
double KdTree::Point::operator[](std::size_t i) const {
if(i == 0) return x;
if(i == 1) return y;
return z;
}
double KdTree::Point::dist2(Point const &other) const {
double x_ = x - other.x;
double y_ = y - other.y;
double z_ = z - other.z;
return x_ * x_ + y_ * y_ + z_ * z_;
}
ostream& operator<<(ostream& os, KdTree::Point const & point) {
os << "[" << point.x << ", " << point.y << ", " << point.z << "]";
return os;
}
/////////////////////////////////////////////////////////////////////
void KdTree::fill(std::vector<Tuple> &points) {
nodes.resize(points.size());
for(std::size_t i = 0; i < points.size(); ++i) {
nodes[i].position = points[i].first;
nodes[i].index = points[i].second;
}
root = fillRec(0, points.size(), 0);
}
KdTree::Node *KdTree::fillRec(std::size_t begin, std::size_t end, int axis) {
if(end <= begin) return nullptr;
size_t n = begin + (end - begin) / 2;
auto i = nodes.begin();
std::nth_element(i + begin, i + n, i + end, [&](Node &p1, Node &p2) {
return p1.position[axis] < p2.position[axis];
});
axis = (axis + 1) % 3;
nodes[n].leftChild = fillRec(begin, n, axis);
nodes[n].rightChild = fillRec(n + 1, end, axis);
return &nodes[n];
}
KdTree::Point KdTree::query(Point const &point) {
bestDist = 99999.;
bestNode = nullptr;
queryRec(root, point, 0);
return bestNode->position;
}
KdTree::Point KdTree::query(double *position) {
return query({position[0], position[1], position[2]});
}
void KdTree::queryRec(Node *node, Point const &point, int axis) {
if(node == nullptr) return;
double d = point.dist2(node->position);
if(d < bestDist) {
bestDist = d;
bestNode = node;
}
double dx = node->position[axis] - point[axis];
axis = (axis + 1) % 3;
queryRec(dx > 0 ? node->leftChild : node->rightChild, point, axis);
if(dx * dx >= bestDist) return;
queryRec(dx > 0 ? node->rightChild : node->leftChild, point, axis);
}

50
src/kd_tree.h Normal file
View File

@ -0,0 +1,50 @@
#ifndef KD_TREE_H
#define KD_TREE_H
#include <vector>
#include <utility>
#include <vtkIdList.h>
// https://rosettacode.org/wiki/K-d_tree#C.2B.2B
class KdTree {
public:
KdTree() = default;
struct Point {
Point();
Point(double x, double y, double z);
Point(double *position);
double x, y, z;
double operator[] (std::size_t i) const;
double dist2(Point const &other) const;
friend ostream& operator<<(ostream& os, Point const & point);
};
using Tuple = std::pair<Point, vtkIdType>;
void fill(std::vector<Tuple> &points);
Point query(Point const &point);
Point query(double *point);
private:
struct Node {
Point position;
vtkIdType index;
Node *leftChild;
Node *rightChild;
};
std::vector<Node> nodes;
Node *root;
double bestDist;
Node *bestNode;
Node *fillRec(std::size_t begin, std::size_t end, int axis);
void queryRec(Node *node, Point const &point, int axis);
};
#endif

View File

@ -3,6 +3,8 @@
#include "dihedral_angles_filter.h" #include "dihedral_angles_filter.h"
#include "external_points_filter.h" #include "external_points_filter.h"
#include "mesh_fit_filter.h"
#include <vtkCellData.h> #include <vtkCellData.h>
#include <vtkUnstructuredGrid.h> #include <vtkUnstructuredGrid.h>
#include <vtkUnstructuredGridReader.h> #include <vtkUnstructuredGridReader.h>
@ -89,12 +91,18 @@ int main(int argc, char **argv) {
vtkNew<ExternalPointsFilter> externalPointsFilter; vtkNew<ExternalPointsFilter> externalPointsFilter;
externalPointsFilter->SetInputConnection(dihedralAnglesFilter->GetOutputPort()); externalPointsFilter->SetInputConnection(dihedralAnglesFilter->GetOutputPort());
vtkNew<MeshFitFilter> meshFitFilter;
meshFitFilter->SetInputConnection(externalPointsFilter->GetOutputPort());
vtkNew<vtkUnstructuredGridWriter> writer; vtkNew<vtkUnstructuredGridWriter> writer;
writer->SetInputConnection(externalPointsFilter->GetOutputPort()); writer->SetInputConnection(meshFitFilter->GetOutputPort());
writer->SetFileTypeToASCII(); writer->SetFileTypeToASCII();
writer->SetFileName("out.vtk"); writer->SetFileName("out.vtk");
writer->Write(); writer->Write();
#ifdef USE_VIEWER #ifdef USE_VIEWER
/* Volume rendering properties */ /* Volume rendering properties */
vtkNew<vtkOpenGLProjectedTetrahedraMapper> volumeMapper; vtkNew<vtkOpenGLProjectedTetrahedraMapper> volumeMapper;

55
src/mesh_fit_filter.cc Normal file
View File

@ -0,0 +1,55 @@
#include "mesh_fit_filter.h"
#include <vtkUnstructuredGrid.h>
#include <vtkPointData.h>
#include <vtkCellData.h>
#include <vtkDoubleArray.h>
#include <vtkCellIterator.h>
#include "kd_tree.h"
vtkStandardNewMacro(MeshFitFilter);
vtkTypeBool MeshFitFilter::RequestData(
vtkInformation *request,
vtkInformationVector **inputVector,
vtkInformationVector *outputVector) {
(void) request;
vtkUnstructuredGrid* input =
vtkUnstructuredGrid::GetData(inputVector[0]);
vtkUnstructuredGrid* output =
vtkUnstructuredGrid::GetData(outputVector);
output->CopyStructure(input);
output->GetPointData()->PassData(input->GetPointData());
std::vector<KdTree::Tuple> points;
auto it = input->NewCellIterator();
for(it->InitTraversal(); !it->IsDoneWithTraversal(); it->GoToNextCell()) {
if(it->GetCellType() != VTK_TETRA) continue;
vtkIdList *idList = it->GetPointIds();
for(int i = 0; i < 4; ++i) {
vtkIdType id = idList->GetId(i);
double point[3];
input->GetPoint(id, point);
points.push_back({point, id});
//std::cout << "[" << point[0] << ", " << point[1] << ", " << point[2] << "] (" << id << ")\n";
}
}
KdTree kdTree;
kdTree.fill(points);
std::cout << "[0.3, 1.5, -0.1] => " << kdTree.query({0.3, 1.5, -0.1}) << std::endl;
std::cout << "[5.1, 1.1, 0.5] => " << kdTree.query({5.1, 1.1, 0.5}) << std::endl;
std::cout << "[0, 0, 1] => " << kdTree.query({0, 0, 1}) << std::endl;
return true;
}

19
src/mesh_fit_filter.h Normal file
View File

@ -0,0 +1,19 @@
#ifndef MESH_FIT_FILTER_H
#define MESH_FIT_FILTER_H
#include <vtkUnstructuredGridAlgorithm.h>
#include <vtkIdList.h>
class MeshFitFilter : public vtkUnstructuredGridAlgorithm {
public:
static MeshFitFilter *New();
vtkTypeMacro(MeshFitFilter, vtkUnstructuredGridAlgorithm);
vtkTypeBool RequestData(vtkInformation *request,
vtkInformationVector **inputVector,
vtkInformationVector *outputVector) override;
protected:
~MeshFitFilter() override = default;
};
#endif