4 * Created On: 22 Aug 2021
11 #include <Eigen/Geometry>
18 #include <flann/flann.hpp>
24 /********************************************************
26 *******************************************************/
28 typedef typename
Eigen::MatrixXd::Index Index
;
30 typedef Eigen::Matrix
<Index
, Eigen::Dynamic
, 1> Vectori
;
31 typedef Eigen::Matrix
<Index
, 2, 1> Vector2i
;
32 typedef Eigen::Matrix
<Index
, 3, 1> Vector3i
;
33 typedef Eigen::Matrix
<Index
, 4, 1> Vector4i
;
34 typedef Eigen::Matrix
<Index
, 5, 1> Vector5i
;
36 typedef Eigen::Matrix
<Index
, Eigen::Dynamic
, Eigen::Dynamic
> Matrixi
;
37 typedef Eigen::Matrix
<Index
, 2, 2> Matrix2i
;
38 typedef Eigen::Matrix
<Index
, 3, 3> Matrix3i
;
39 typedef Eigen::Matrix
<Index
, 4, 4> Matrix4i
;
40 typedef Eigen::Matrix
<Index
, 5, 5> Matrix5i
;
42 typedef Eigen::Matrix
<float, Eigen::Dynamic
, Eigen::Dynamic
> Matrixf
;
43 typedef Eigen::Matrix
<double, Eigen::Dynamic
, Eigen::Dynamic
> Matrixd
;
45 /********************************************************
47 *******************************************************/
49 /** Manhatten distance functor.
50 * This the same as the L1 minkowski distance but more efficient.
51 * @see EuclideanDistance, ChebyshevDistance, MinkowskiDistance */
52 template <typename Scalar
>
53 struct ManhattenDistance
55 /** Compute the unrooted distance between two vectors.
56 * @param lhs vector on left hand side
57 * @param rhs vector on right hand side */
58 template<typename DerivedA
, typename DerivedB
>
59 Scalar
operator()(const Eigen::MatrixBase
<DerivedA
> &lhs
,
60 const Eigen::MatrixBase
<DerivedB
> &rhs
) const
63 std::is_same
<typename
Eigen::MatrixBase
<DerivedA
>::Scalar
,Scalar
>::value
,
64 "distance scalar and input matrix A must have same type");
66 std::is_same
<typename
Eigen::MatrixBase
<DerivedB
>::Scalar
, Scalar
>::value
,
67 "distance scalar and input matrix B must have same type");
69 return (lhs
- rhs
).cwiseAbs().sum();
72 /** Compute the unrooted distance between two scalars.
73 * @param lhs scalar on left hand side
74 * @param rhs scalar on right hand side */
75 Scalar
operator()(const Scalar lhs
,
76 const Scalar rhs
) const
78 return std::abs(lhs
- rhs
);
81 /** Compute the root of a unrooted distance value.
82 * @param value unrooted distance value */
83 Scalar
operator()(const Scalar val
) const
89 /** Euclidean distance functor.
90 * This the same as the L2 minkowski distance but more efficient.
91 * @see ManhattenDistance, ChebyshevDistance, MinkowskiDistance */
92 template <typename Scalar
>
93 struct EuclideanDistance
95 /** Compute the unrooted distance between two vectors.
96 * @param lhs vector on left hand side
97 * @param rhs vector on right hand side */
98 template<typename DerivedA
, typename DerivedB
>
99 Scalar
operator()(const Eigen::MatrixBase
<DerivedA
> &lhs
,
100 const Eigen::MatrixBase
<DerivedB
> &rhs
) const
103 std::is_same
<typename
Eigen::MatrixBase
<DerivedA
>::Scalar
,Scalar
>::value
,
104 "distance scalar and input matrix A must have same type");
106 std::is_same
<typename
Eigen::MatrixBase
<DerivedB
>::Scalar
, Scalar
>::value
,
107 "distance scalar and input matrix B must have same type");
109 return (lhs
- rhs
).cwiseAbs2().sum();
112 /** Compute the unrooted distance between two scalars.
113 * @param lhs scalar on left hand side
114 * @param rhs scalar on right hand side */
115 Scalar
operator()(const Scalar lhs
,
116 const Scalar rhs
) const
118 Scalar diff
= lhs
- rhs
;
122 /** Compute the root of a unrooted distance value.
123 * @param value unrooted distance value */
124 Scalar
operator()(const Scalar val
) const
126 return std::sqrt(val
);
130 /** General minkowski distance functor.
131 * The infinite version is only available through the chebyshev distance.
132 * @see ManhattenDistance, EuclideanDistance, ChebyshevDistance */
133 template <typename Scalar
, int P
>
134 struct MinkowskiDistance
138 Scalar
operator()(const Scalar val
) const
141 for(int i
= 0; i
< P
; ++i
)
147 /** Compute the unrooted distance between two vectors.
148 * @param lhs vector on left hand side
149 * @param rhs vector on right hand side */
150 template<typename DerivedA
, typename DerivedB
>
151 Scalar
operator()(const Eigen::MatrixBase
<DerivedA
> &lhs
,
152 const Eigen::MatrixBase
<DerivedB
> &rhs
) const
155 std::is_same
<typename
Eigen::MatrixBase
<DerivedA
>::Scalar
,Scalar
>::value
,
156 "distance scalar and input matrix A must have same type");
158 std::is_same
<typename
Eigen::MatrixBase
<DerivedB
>::Scalar
, Scalar
>::value
,
159 "distance scalar and input matrix B must have same type");
161 return (lhs
- rhs
).cwiseAbs().unaryExpr(MinkowskiDistance::Pow()).sum();
164 /** Compute the unrooted distance between two scalars.
165 * @param lhs scalar on left hand side
166 * @param rhs scalar on right hand side */
167 Scalar
operator()(const Scalar lhs
,
168 const Scalar rhs
) const
170 return std::pow(std::abs(lhs
- rhs
), P
);;
173 /** Compute the root of a unrooted distance value.
174 * @param value unrooted distance value */
175 Scalar
operator()(const Scalar val
) const
177 return std::pow(val
, 1 / static_cast<Scalar
>(P
));
181 /** Chebyshev distance functor.
182 * This distance is the same as infinity minkowski distance.
183 * @see ManhattenDistance, EuclideanDistance, MinkowskiDistance */
184 template<typename Scalar
>
185 struct ChebyshevDistance
187 /** Compute the unrooted distance between two vectors.
188 * @param lhs vector on left hand side
189 * @param rhs vector on right hand side */
190 template<typename DerivedA
, typename DerivedB
>
191 Scalar
operator()(const Eigen::MatrixBase
<DerivedA
> &lhs
,
192 const Eigen::MatrixBase
<DerivedB
> &rhs
) const
195 std::is_same
<typename
Eigen::MatrixBase
<DerivedA
>::Scalar
,Scalar
>::value
,
196 "distance scalar and input matrix A must have same type");
198 std::is_same
<typename
Eigen::MatrixBase
<DerivedB
>::Scalar
, Scalar
>::value
,
199 "distance scalar and input matrix B must have same type");
201 return (lhs
- rhs
).cwiseAbs().maxCoeff();
204 /** Compute the unrooted distance between two scalars.
205 * @param lhs scalar on left hand side
206 * @param rhs scalar on right hand side */
207 Scalar
operator()(const Scalar lhs
,
208 const Scalar rhs
) const
210 return std::abs(lhs
- rhs
);
213 /** Compute the root of a unrooted distance value.
214 * @param value unrooted distance value */
215 Scalar
operator()(const Scalar val
) const
221 /** Hamming distance functor.
222 * The distance vectors have to be of integral type and should hold the
223 * information vectors as bitmasks.
224 * Performs a XOR operation on the vectors and counts the number of set
226 template<typename Scalar
>
227 struct HammingDistance
229 static_assert(std::is_integral
<Scalar
>::value
,
230 "HammingDistance requires integral Scalar type");
234 Scalar
operator()(const Scalar lhs
, const Scalar rhs
) const
242 Scalar
operator()(const Scalar lhs
) const
246 while(copy
!= static_cast<Scalar
>(0))
256 /** Compute the unrooted distance between two vectors.
257 * @param lhs vector on left hand side
258 * @param rhs vector on right hand side */
259 template<typename DerivedA
, typename DerivedB
>
260 Scalar
operator()(const Eigen::MatrixBase
<DerivedA
> &lhs
,
261 const Eigen::MatrixBase
<DerivedB
> &rhs
) const
264 std::is_same
<typename
Eigen::MatrixBase
<DerivedA
>::Scalar
,Scalar
>::value
,
265 "distance scalar and input matrix A must have same type");
267 std::is_same
<typename
Eigen::MatrixBase
<DerivedB
>::Scalar
, Scalar
>::value
,
268 "distance scalar and input matrix B must have same type");
271 binaryExpr(rhs
, XOR()).
272 unaryExpr(BitCount()).
276 /** Compute the unrooted distance between two scalars.
277 * @param lhs scalar on left hand side
278 * @param rhs scalar on right hand side */
279 Scalar
operator()(const Scalar lhs
,
280 const Scalar rhs
) const
284 return cnt(xOr(lhs
, rhs
));
287 /** Compute the root of a unrooted distance value.
288 * @param value unrooted distance value */
289 Scalar
operator()(const Scalar value
) const
295 /** Efficient heap structure to query nearest neighbours. */
296 template<typename Scalar
>
300 Index
*indices_
= nullptr;
301 Scalar
*distances_
= nullptr;
305 /** Creates a query heap with the given index and distance memory regions. */
306 QueryHeap(Index
*indices
, Scalar
*distances
, const size_t maxSize
)
307 : indices_(indices
), distances_(distances
), maxSize_(maxSize
)
310 /** Pushes a new query data set into the heap with the given
311 * index and distance.
312 * The index identifies the point for which the given distance
314 * @param idx index / ID of the query point
315 * @param dist distance that was computed for the query point*/
316 void push(const Index idx
, const Scalar dist
)
320 // add new value at the end
321 indices_
[size_
] = idx
;
322 distances_
[size_
] = dist
;
326 size_t k
= size_
- 1;
327 size_t tmp
= (k
- 1) / 2;
328 while(k
> 0 && distances_
[tmp
] < dist
)
330 distances_
[k
] = distances_
[tmp
];
331 indices_
[k
] = indices_
[tmp
];
335 distances_
[k
] = dist
;
339 /** Removes the element at the front of the heap and restores
345 // replace first element with last
347 distances_
[0] = distances_
[size_
];
348 indices_
[0] = indices_
[size_
];
353 Scalar dist
= distances_
[0];
354 Index idx
= indices_
[0];
355 while(2 * k
+ 1 < size_
)
358 if(j
+ 1 < size_
&& distances_
[j
+1] > distances_
[j
])
360 // j references now greatest child
361 if(dist
>= distances_
[j
])
363 distances_
[k
] = distances_
[j
];
364 indices_
[k
] = indices_
[j
];
367 distances_
[k
] = dist
;
371 /** Returns the distance of the element in front of the heap. */
375 return distances_
[0];
378 /** Determines if this query heap is full.
379 * The heap is considered full if its number of elements
380 * has reached its max size.
381 * @return true if the heap is full, else false */
384 return size_
>= maxSize_
;
387 /** Determines if this query heap is empty.
388 * @return true if the heap contains no elements, else false */
394 /** Returns the number of elements within the query heap.
395 * @return number of elements in the heap */
401 /** Clears the query heap. */
407 /** Sorts the elements within the heap according to
412 for(size_t i
= 0; i
< cnt
; ++i
)
414 Index idx
= indices_
[0];
415 Scalar dist
= distances_
[0];
417 indices_
[cnt
- i
- 1] = idx
;
418 distances_
[cnt
- i
- 1] = dist
;
423 /** Class for performing brute force knn search. */
424 template<typename Scalar
,
425 typename Distance
=EuclideanDistance
<Scalar
>>
429 typedef Eigen::Matrix
<Scalar
, Eigen::Dynamic
, Eigen::Dynamic
> Matrix
;
430 typedef Eigen::Matrix
<Scalar
, Eigen::Dynamic
, 1> Vector
;
431 typedef knncpp::Matrixi Matrixi
;
433 Distance distance_
= Distance();
434 Matrix dataCopy_
= Matrix();
435 const Matrix
*data_
= nullptr;
438 bool takeRoot_
= true;
444 BruteForce() = default;
446 /** Constructs a brute force instance with the given data.
447 * @param data NxM matrix, M points of dimension N
448 * @param copy if true copies the data, otherwise assumes static data */
449 BruteForce(const Matrix
&data
, const bool copy
= false)
455 /** Set if the points returned by the queries should be sorted
456 * according to their distance to the query points.
457 * @param sorted sort query results */
458 void setSorted(const bool sorted
)
463 /** Set if the distances after the query should be rooted or not.
464 * Taking the root of the distances increases query time, but the
465 * function will return true distances instead of their powered
467 * @param takeRoot set true if root should be taken else false */
468 void setTakeRoot(const bool takeRoot
)
470 takeRoot_
= takeRoot
;
473 /** Set the amount of threads that should be used for querying.
474 * OpenMP has to be enabled for this to work.
475 * @param threads amount of threads, 0 for optimal choice */
476 void setThreads(const unsigned int threads
)
481 /** Set the maximum distance for querying the tree.
482 * The search will be pruned if the maximum distance is set to any
484 * @param maxDist maximum distance, <= 0 for no limit */
485 void setMaxDistance(const Scalar maxDist
)
490 /** Set the data points used for this tree.
491 * This does not build the tree.
492 * @param data NxM matrix, M points of dimension N
493 * @param copy if true data is copied, assumes static data otherwise */
494 void setData(const Matrix
&data
, const bool copy
= false)
507 void setDistance(const Distance
&distance
)
509 distance_
= distance
;
515 template<typename Derived
>
516 void query(const Eigen::MatrixBase
<Derived
> &queryPoints
,
519 Matrix
&distances
) const
522 throw std::runtime_error("cannot query BruteForce: data not set");
523 if(data_
->size() == 0)
524 throw std::runtime_error("cannot query BruteForce: data is empty");
525 if(queryPoints
.rows() != dimension())
526 throw std::runtime_error("cannot query BruteForce: data and query descriptors do not have same dimension");
528 const Matrix
&dataPoints
= *data_
;
530 indices
.setConstant(knn
, queryPoints
.cols(), -1);
531 distances
.setConstant(knn
, queryPoints
.cols(), -1);
533 #pragma omp parallel for num_threads(threads_)
534 for(Index i
= 0; i
< queryPoints
.cols(); ++i
)
536 Index
*idxPoint
= &indices
.data()[i
* knn
];
537 Scalar
*distPoint
= &distances
.data()[i
* knn
];
539 QueryHeap
<Scalar
> heap(idxPoint
, distPoint
, knn
);
541 for(Index j
= 0; j
< dataPoints
.cols(); ++j
)
543 Scalar dist
= distance_(queryPoints
.col(i
), dataPoints
.col(j
));
545 // check if point is in range if max distance was set
546 bool isInRange
= maxDist_
<= 0 || dist
<= maxDist_
;
547 // check if this node was an improvement if heap is already full
548 bool isImprovement
= !heap
.full() ||
550 if(isInRange
&& isImprovement
)
563 for(size_t j
= 0; j
< knn
; ++j
)
567 distPoint
[j
] = distance_(distPoint
[j
]);
573 /** Returns the amount of data points stored in the search index.
574 * @return number of data points */
577 return data_
== nullptr ? 0 : data_
->cols();
580 /** Returns the dimension of the data points in the search index.
581 * @return dimension of data points */
582 Index
dimension() const
584 return data_
== nullptr ? 0 : data_
->rows();
588 // template<typename Scalar>
589 // struct MeanMidpointRule
591 // typedef Eigen::Matrix<Scalar, Eigen::Dynamic, Eigen::Dynamic> Matrix;
592 // typedef knncpp::Matrixi Matrixi;
594 // void operator(const Matrix &data, const Matrixi &indices, Index split)
597 /** Class for performing k nearest neighbour searches with minkowski distances.
598 * This kdtree only works reliably with the minkowski distance and its
599 * special cases like manhatten or euclidean distance.
600 * @see ManhattenDistance, EuclideanDistance, ChebyshevDistance, MinkowskiDistance*/
601 template<typename _Scalar
, int _Dimension
, typename _Distance
>
602 class KDTreeMinkowski
605 typedef _Scalar Scalar
;
606 typedef _Distance Distance
;
607 typedef Eigen::Matrix
<Scalar
, Eigen::Dynamic
, Eigen::Dynamic
> Matrix
;
608 typedef Eigen::Matrix
<Scalar
, _Dimension
, Eigen::Dynamic
> DataMatrix
;
609 typedef Eigen::Matrix
<Scalar
, _Dimension
, 1> DataVector
;
610 typedef knncpp::Matrixi Matrixi
;
612 typedef Eigen::Matrix
<Scalar
, 2, 1> Bounds
;
613 typedef Eigen::Matrix
<Scalar
, 2, _Dimension
> BoundingBox
;
615 /** Struct representing a node in the KDTree.
616 * It can be either a inner node or a leaf node. */
619 /** Indices of data points in this leaf node. */
623 /** Left child of this inner node. */
625 /** Right child of this inner node. */
627 /** Axis of the axis aligned splitting hyper plane. */
628 Index splitaxis
= -1;
629 /** Translation of the axis aligned splitting hyper plane. */
630 Scalar splitpoint
= 0;
631 /** Lower end of the splitpoint range */
632 Scalar splitlower
= 0;
633 /** Upper end of the splitpoint range */
634 Scalar splitupper
= 0;
639 /** Constructor for leaf nodes */
640 Node(const Index startIdx
, const Index length
)
641 : startIdx(startIdx
), length(length
)
644 /** Constructor for inner nodes */
645 Node(const Index splitaxis
, const Scalar splitpoint
,
646 const Index left
, const Index right
)
647 : left(left
), right(right
),
648 splitaxis(splitaxis
), splitpoint(splitpoint
)
653 return !hasLeft() && !hasRight();
658 return hasLeft() && hasRight();
666 bool hasRight() const
672 DataMatrix dataCopy_
= DataMatrix();
673 const DataMatrix
*data_
= nullptr;
674 std::vector
<Index
> indices_
= std::vector
<Index
>();
675 std::vector
<Node
> nodes_
= std::vector
<Node
>();
677 Index bucketSize_
= 16;
679 bool compact_
= false;
680 bool balanced_
= false;
681 bool takeRoot_
= true;
685 Distance distance_
= Distance();
687 BoundingBox bbox_
= BoundingBox();
689 Index
buildLeafNode(const Index startIdx
,
693 nodes_
.push_back(Node(startIdx
, length
));
694 calculateBoundingBox(startIdx
, length
, bbox
);
695 return static_cast<Index
>(nodes_
.size() - 1);
698 /** Finds the minimum and maximum values of each dimension (row) in the
699 * data matrix. Only respects the columns specified by the index
701 * @param startIdx starting index within indices data structure to search for bounding box
702 * @param length length of the block of indices*/
703 void calculateBoundingBox(const Index startIdx
,
705 BoundingBox
&bbox
) const
708 assert(startIdx
>= 0);
709 assert(static_cast<size_t>(startIdx
+ length
) <= indices_
.size());
710 assert(data_
->rows() == bbox
.cols());
712 const DataMatrix
&data
= *data_
;
714 // initialize bounds of the bounding box
715 Index first
= indices_
[startIdx
];
716 for(Index i
= 0; i
< bbox
.cols(); ++i
)
718 bbox(0, i
) = data(i
, first
);
719 bbox(1, i
) = data(i
, first
);
722 // search for min / max values in data
723 for(Index i
= 1; i
< length
; ++i
)
725 // retrieve data index
726 Index col
= indices_
[startIdx
+ i
];
727 assert(col
>= 0 && col
< data
.cols());
729 // check min and max for each dimension individually
730 for(Index j
= 0; j
< data
.rows(); ++j
)
732 bbox(0, j
) = std::min(bbox(0, j
), data(j
, col
));
733 bbox(1, j
) = std::max(bbox(1, j
), data(j
, col
));
738 /** Calculates the bounds (min / max values) for the given dimension and block of data. */
739 void calculateBounds(const Index startIdx
,
742 Bounds
&bounds
) const
745 assert(startIdx
>= 0);
746 assert(static_cast<size_t>(startIdx
+ length
) <= indices_
.size());
748 const DataMatrix
&data
= *data_
;
750 bounds(0) = data(dim
, indices_
[startIdx
]);
751 bounds(1) = data(dim
, indices_
[startIdx
]);
753 for(Index i
= 1; i
< length
; ++i
)
755 Index col
= indices_
[startIdx
+ i
];
756 assert(col
>= 0 && col
< data
.cols());
758 bounds(0) = std::min(bounds(0), data(dim
, col
));
759 bounds(1) = std::max(bounds(1), data(dim
, col
));
763 void calculateSplittingMidpoint(const Index startIdx
,
765 const BoundingBox
&bbox
,
770 const DataMatrix
&data
= *data_
;
772 // search for axis with longest distance
774 Scalar splitsize
= static_cast<Scalar
>(0);
775 for(Index i
= 0; i
< data
.rows(); ++i
)
777 Scalar diff
= bbox(1, i
) - bbox(0, i
);
785 // calculate the bounds in this axis and update our data
788 calculateBounds(startIdx
, length
, splitaxis
, bounds
);
789 splitsize
= bounds(1) - bounds(0);
791 const Index origSplitaxis
= splitaxis
;
792 for(Index i
= 0; i
< data
.rows(); ++i
)
794 // skip the dimension of the previously found splitaxis
795 if(i
== origSplitaxis
)
797 Scalar diff
= bbox(1, i
) - bbox(0, i
);
798 // check if the split for this dimension would be potentially larger
802 // update the bounds to their actual current value
803 calculateBounds(startIdx
, length
, splitaxis
, newBounds
);
804 diff
= newBounds(1) - newBounds(0);
814 // use the sliding midpoint rule
815 splitpoint
= (bounds(0) + bounds(1)) / static_cast<Scalar
>(2);
817 Index leftIdx
= startIdx
;
818 Index rightIdx
= startIdx
+ length
- 1;
820 // first loop checks left < splitpoint and right >= splitpoint
821 while(leftIdx
<= rightIdx
)
823 // increment left as long as left has not reached right and
824 // the value of the left element is less than the splitpoint
825 while(leftIdx
<= rightIdx
&& data(splitaxis
, indices_
[leftIdx
]) < splitpoint
)
828 // decrement right as long as left has not reached right and
829 // the value of the right element is greater than the splitpoint
830 while(leftIdx
<= rightIdx
&& data(splitaxis
, indices_
[rightIdx
]) >= splitpoint
)
833 if(leftIdx
<= rightIdx
)
835 std::swap(indices_
[leftIdx
], indices_
[rightIdx
]);
841 // remember this offset from starting index
842 const Index offset1
= leftIdx
- startIdx
;
844 rightIdx
= startIdx
+ length
- 1;
845 // second loop checks left <= splitpoint and right > splitpoint
846 while(leftIdx
<= rightIdx
)
848 // increment left as long as left has not reached right and
849 // the value of the left element is less than the splitpoint
850 while(leftIdx
<= rightIdx
&& data(splitaxis
, indices_
[leftIdx
]) <= splitpoint
)
853 // decrement right as long as left has not reached right and
854 // the value of the right element is greater than the splitpoint
855 while(leftIdx
<= rightIdx
&& data(splitaxis
, indices_
[rightIdx
]) > splitpoint
)
858 if(leftIdx
<= rightIdx
)
860 std::swap(indices_
[leftIdx
], indices_
[rightIdx
]);
866 // remember this offset from starting index
867 const Index offset2
= leftIdx
- startIdx
;
869 const Index halfLength
= length
/ static_cast<Index
>(2);
871 // find a separation of points such that is best balanced
872 // offset1 denotes separation where equal points are all on the right
873 // offset2 denots separation where equal points are all on the left
874 if (offset1
> halfLength
)
875 splitoffset
= offset1
;
876 else if (offset2
< halfLength
)
877 splitoffset
= offset2
;
878 // when we get here offset1 < halflength and offset2 > halflength
879 // so simply split the equal elements in the middle
881 splitoffset
= halfLength
;
884 Index
buildInnerNode(const Index startIdx
,
889 assert(startIdx
>= 0);
890 assert(static_cast<size_t>(startIdx
+ length
) <= indices_
.size());
891 assert(data_
->rows() == bbox
.cols());
894 const Index nodeIdx
= nodes_
.size();
895 nodes_
.push_back(Node());
900 calculateSplittingMidpoint(startIdx
, length
, bbox
, splitaxis
, splitpoint
, splitoffset
);
902 nodes_
[nodeIdx
].splitaxis
= splitaxis
;
903 nodes_
[nodeIdx
].splitpoint
= splitpoint
;
905 const Index leftStart
= startIdx
;
906 const Index leftLength
= splitoffset
;
907 const Index rightStart
= startIdx
+ splitoffset
;
908 const Index rightLength
= length
- splitoffset
;
910 BoundingBox bboxLeft
= bbox
;
911 BoundingBox bboxRight
= bbox
;
914 bboxLeft(1, splitaxis
) = splitpoint
;
915 Index left
= buildR(leftStart
, leftLength
, bboxLeft
);
916 nodes_
[nodeIdx
].left
= left
;
919 bboxRight(0, splitaxis
) = splitpoint
;
920 Index right
= buildR(rightStart
, rightLength
, bboxRight
);
921 nodes_
[nodeIdx
].right
= right
;
923 // extract the range of the splitpoint
924 nodes_
[nodeIdx
].splitlower
= bboxLeft(1, splitaxis
);
925 nodes_
[nodeIdx
].splitupper
= bboxRight(0, splitaxis
);
927 // update the bounding box to the values of the new bounding boxes
928 for(Index i
= 0; i
< bbox
.cols(); ++i
)
930 bbox(0, i
) = std::min(bboxLeft(0, i
), bboxRight(0, i
));
931 bbox(1, i
) = std::max(bboxLeft(1, i
), bboxRight(1, i
));
937 Index
buildR(const Index startIdx
,
941 // check for base case
942 if(length
<= bucketSize_
)
943 return buildLeafNode(startIdx
, length
, bbox
);
945 return buildInnerNode(startIdx
, length
, bbox
);
948 bool isDistanceInRange(const Scalar dist
) const
950 return maxDist_
<= 0 || dist
<= maxDist_
;
953 bool isDistanceImprovement(const Scalar dist
, const QueryHeap
<Scalar
> &dataHeap
) const
955 return !dataHeap
.full() || dist
< dataHeap
.front();
958 template<typename Derived
>
959 void queryLeafNode(const Node
&node
,
960 const Eigen::MatrixBase
<Derived
> &queryPoint
,
961 QueryHeap
<Scalar
> &dataHeap
) const
963 assert(node
.isLeaf());
965 const DataMatrix
&data
= *data_
;
967 // go through all points in this leaf node and do brute force search
968 for(Index i
= 0; i
< node
.length
; ++i
)
970 const Index idx
= node
.startIdx
+ i
;
971 assert(idx
>= 0 && idx
< static_cast<Index
>(indices_
.size()));
973 // retrieve index of the current data point
974 const Index dataIdx
= indices_
[idx
];
975 const Scalar dist
= distance_(queryPoint
, data
.col(dataIdx
));
977 // check if point is within max distance and if the value would be
979 if(isDistanceInRange(dist
) && isDistanceImprovement(dist
, dataHeap
))
983 dataHeap
.push(dataIdx
, dist
);
988 template<typename Derived
>
989 void queryInnerNode(const Node
&node
,
990 const Eigen::MatrixBase
<Derived
> &queryPoint
,
991 QueryHeap
<Scalar
> &dataHeap
,
992 DataVector
&splitdists
,
993 const Scalar mindist
) const
995 assert(node
.isInner());
997 const Index splitaxis
= node
.splitaxis
;
998 const Scalar splitval
= queryPoint(splitaxis
, 0);
1002 // check if right or left child should be visited
1003 const bool visitLeft
= (splitval
- node
.splitlower
+ splitval
- node
.splitupper
) < 0;
1006 firstNode
= node
.left
;
1007 secondNode
= node
.right
;
1008 splitdist
= distance_(splitval
, node
.splitupper
);
1012 firstNode
= node
.right
;
1013 secondNode
= node
.left
;
1014 splitdist
= distance_(splitval
, node
.splitlower
);
1017 queryR(nodes_
[firstNode
], queryPoint
, dataHeap
, splitdists
, mindist
);
1019 const Scalar mindistNew
= mindist
+ splitdist
- splitdists(splitaxis
);
1021 // check if node is in range if max distance was set
1022 // check if this node was an improvement if heap is already full
1023 if(isDistanceInRange(mindistNew
) && isDistanceImprovement(mindistNew
, dataHeap
))
1025 const Scalar splitdistOld
= splitdists(splitaxis
);
1026 splitdists(splitaxis
) = splitdist
;
1027 queryR(nodes_
[secondNode
], queryPoint
, dataHeap
, splitdists
, mindistNew
);
1028 splitdists(splitaxis
) = splitdistOld
;
1032 template<typename Derived
>
1033 void queryR(const Node
&node
,
1034 const Eigen::MatrixBase
<Derived
> &queryPoint
,
1035 QueryHeap
<Scalar
> &dataHeap
,
1036 DataVector
&splitdists
,
1037 const Scalar mindist
) const
1040 queryLeafNode(node
, queryPoint
, dataHeap
);
1042 queryInnerNode(node
, queryPoint
, dataHeap
, splitdists
, mindist
);
1045 /** Recursively computes the depth for the given node. */
1046 Index
depthR(const Node
&node
) const
1052 Index left
= depthR(nodes_
[node
.left
]);
1053 Index right
= depthR(nodes_
[node
.right
]);
1054 return std::max(left
, right
) + 1;
1060 /** Constructs an empty KDTree. */
1064 /** Constructs KDTree with the given data. This does not build the
1065 * the index of the tree.
1066 * @param data NxM matrix, M points of dimension N
1067 * @param copy if true copies the data, otherwise assumes static data */
1068 KDTreeMinkowski(const DataMatrix
&data
, const bool copy
=false)
1070 setData(data
, copy
);
1073 /** Set the maximum amount of data points per leaf in the tree (aka
1075 * @param bucketSize amount of points per leaf. */
1076 void setBucketSize(const Index bucketSize
)
1078 bucketSize_
= bucketSize
;
1081 /** Set if the points returned by the queries should be sorted
1082 * according to their distance to the query points.
1083 * @param sorted sort query results */
1084 void setSorted(const bool sorted
)
1089 /** Set if the tree should be built as balanced as possible.
1090 * This increases build time, but decreases search time.
1091 * @param balanced set true to build a balanced tree */
1092 void setBalanced(const bool balanced
)
1094 balanced_
= balanced
;
1097 /** Set if the distances after the query should be rooted or not.
1098 * Taking the root of the distances increases query time, but the
1099 * function will return true distances instead of their powered
1101 * @param takeRoot set true if root should be taken else false */
1102 void setTakeRoot(const bool takeRoot
)
1104 takeRoot_
= takeRoot
;
1107 /** Set if the tree should be built with compact leaf nodes.
1108 * This increases build time, but makes leaf nodes denser (more)
1109 * points. Thus less visits are necessary.
1110 * @param compact set true ti build a tree with compact leafs */
1111 void setCompact(const bool compact
)
1116 /** Set the amount of threads that should be used for building and
1117 * querying the tree.
1118 * OpenMP has to be enabled for this to work.
1119 * @param threads amount of threads, 0 for optimal choice */
1120 void setThreads(const unsigned int threads
)
1125 /** Set the maximum distance for querying the tree.
1126 * The search will be pruned if the maximum distance is set to any
1128 * @param maxDist maximum distance, <= 0 for no limit */
1129 void setMaxDistance(const Scalar maxDist
)
1134 /** Set the data points used for this tree.
1135 * This does not build the tree.
1136 * @param data NxM matrix, M points of dimension N
1137 * @param copy if true data is copied, assumes static data otherwise */
1138 void setData(const DataMatrix
&data
, const bool copy
= false)
1152 void setDistance(const Distance
&distance
)
1154 distance_
= distance
;
1157 /** Builds the search index of the tree.
1158 * Data has to be set and must be non-empty. */
1161 if(data_
== nullptr)
1162 throw std::runtime_error("cannot build KDTree; data not set");
1164 if(data_
->size() == 0)
1165 throw std::runtime_error("cannot build KDTree; data is empty");
1168 nodes_
.reserve((data_
->cols() / bucketSize_
) + 1);
1170 // initialize indices in simple sequence
1171 indices_
.resize(data_
->cols());
1172 for(size_t i
= 0; i
< indices_
.size(); ++i
)
1175 bbox_
.resize(2, data_
->rows());
1177 Index length
= data_
->cols();
1179 calculateBoundingBox(startIdx
, length
, bbox_
);
1181 buildR(startIdx
, length
, bbox_
);
1184 /** Queries the tree for the nearest neighbours of the given query
1187 * The tree has to be built before it can be queried.
1189 * The query points have to have the same dimension as the data points
1192 * The result matrices will be resized appropriatley.
1193 * Indices and distances will be set to -1 if less than knn neighbours
1196 * @param queryPoints NxM matrix, M points of dimension N
1197 * @param knn amount of neighbours to be found
1198 * @param indices KNNxM matrix, indices of neighbours in the data set
1199 * @param distances KNNxM matrix, distance between query points and
1201 template<typename Derived
>
1202 void query(const Eigen::MatrixBase
<Derived
> &queryPoints
,
1205 Matrix
&distances
) const
1207 if(nodes_
.size() == 0)
1208 throw std::runtime_error("cannot query KDTree; not built yet");
1210 if(queryPoints
.rows() != dimension())
1211 throw std::runtime_error("cannot query KDTree; data and query points do not have same dimension");
1213 distances
.setConstant(knn
, queryPoints
.cols(), -1);
1214 indices
.setConstant(knn
, queryPoints
.cols(), -1);
1216 Index
*indicesRaw
= indices
.data();
1217 Scalar
*distsRaw
= distances
.data();
1219 #pragma omp parallel for num_threads(threads_)
1220 for(Index i
= 0; i
< queryPoints
.cols(); ++i
)
1223 Scalar
*distPoint
= &distsRaw
[i
* knn
];
1224 Index
*idxPoint
= &indicesRaw
[i
* knn
];
1226 // create heap to find nearest neighbours
1227 QueryHeap
<Scalar
> dataHeap(idxPoint
, distPoint
, knn
);
1229 Scalar mindist
= static_cast<Scalar
>(0);
1230 DataVector
splitdists(queryPoints
.rows());
1232 for(Index j
= 0; j
< splitdists
.rows(); ++j
)
1234 const Scalar value
= queryPoints(j
, i
);
1235 const Scalar lower
= bbox_(0, j
);
1236 const Scalar upper
= bbox_(1, j
);
1239 splitdists(j
) = distance_(value
, lower
);
1241 else if(value
> upper
)
1243 splitdists(j
) = distance_(value
, upper
);
1247 splitdists(j
) = static_cast<Scalar
>(0);
1250 mindist
+= splitdists(j
);
1253 queryR(nodes_
[0], queryPoints
.col(i
), dataHeap
, splitdists
, mindist
);
1260 for(size_t j
= 0; j
< knn
; ++j
)
1262 if(distPoint
[j
] < 0)
1264 distPoint
[j
] = distance_(distPoint
[j
]);
1270 /** Clears the tree. */
1276 /** Returns the amount of data points stored in the search index.
1277 * @return number of data points */
1280 return data_
== nullptr ? 0 : data_
->cols();
1283 /** Returns the dimension of the data points in the search index.
1284 * @return dimension of data points */
1285 Index
dimension() const
1287 return data_
== nullptr ? 0 : data_
->rows();
1290 /** Returns the maxximum depth of the tree.
1291 * @return maximum depth of the tree */
1294 return nodes_
.size() == 0 ? 0 : depthR(nodes_
.front());
1298 template<typename _Scalar
, typename _Distance
= EuclideanDistance
<_Scalar
>> using KDTreeMinkowski2
= KDTreeMinkowski
<_Scalar
, 2, _Distance
>;
1299 template<typename _Scalar
, typename _Distance
= EuclideanDistance
<_Scalar
>> using KDTreeMinkowski3
= KDTreeMinkowski
<_Scalar
, 3, _Distance
>;
1300 template<typename _Scalar
, typename _Distance
= EuclideanDistance
<_Scalar
>> using KDTreeMinkowski4
= KDTreeMinkowski
<_Scalar
, 4, _Distance
>;
1301 template<typename _Scalar
, typename _Distance
= EuclideanDistance
<_Scalar
>> using KDTreeMinkowski5
= KDTreeMinkowski
<_Scalar
, 5, _Distance
>;
1302 template<typename _Scalar
, typename _Distance
= EuclideanDistance
<_Scalar
>> using KDTreeMinkowskiX
= KDTreeMinkowski
<_Scalar
, Eigen::Dynamic
, _Distance
>;
1304 /** Class for performing KNN search in hamming space by multi-index hashing. */
1305 template<typename Scalar
>
1306 class MultiIndexHashing
1309 static_assert(std::is_integral
<Scalar
>::value
, "MultiIndexHashing Scalar has to be integral");
1311 typedef Eigen::Matrix
<Scalar
, Eigen::Dynamic
, Eigen::Dynamic
> Matrix
;
1312 typedef Eigen::Matrix
<Scalar
, Eigen::Dynamic
, 1> Vector
;
1313 typedef knncpp::Matrixi Matrixi
;
1316 HammingDistance
<Scalar
> distance_
;
1319 const Matrix
*data_
;
1325 std::vector
<std::map
<Scalar
, std::vector
<Index
>>> buckets_
;
1327 template<typename Derived
>
1328 Scalar
extractCode(const Eigen::MatrixBase
<Derived
> &data
,
1330 const Index offset
) const
1332 Index leftShift
= std::max
<Index
>(0, static_cast<Index
>(sizeof(Scalar
)) - offset
- substrLen_
);
1333 Index rightShift
= leftShift
+ offset
;
1335 Scalar code
= (data(idx
, 0) << (leftShift
* 8)) >> (rightShift
* 8);
1337 if(static_cast<Index
>(sizeof(Scalar
)) - offset
< substrLen_
&& idx
+ 1 < data
.rows())
1339 Index shift
= 2 * static_cast<Index
>(sizeof(Scalar
)) - substrLen_
- offset
;
1340 code
|= data(idx
+1, 0) << (shift
* 8);
1347 : distance_(), dataCopy_(), data_(nullptr), sorted_(true),
1348 maxDist_(0), substrLen_(1), threads_(1)
1351 /** Constructs an index with the given data.
1352 * This does not build the the index.
1353 * @param data NxM matrix, M points of dimension N
1354 * @param copy if true copies the data, otherwise assumes static data */
1355 MultiIndexHashing(const Matrix
&data
, const bool copy
=false)
1356 : MultiIndexHashing()
1358 setData(data
, copy
);
1361 /** Set the maximum distance for querying the index.
1362 * Note that if no maximum distance is used, this algorithm performs
1363 * basically a brute force search.
1364 * @param maxDist maximum distance, <= 0 for no limit */
1365 void setMaxDistance(const Scalar maxDist
)
1370 /** Set if the points returned by the queries should be sorted
1371 * according to their distance to the query points.
1372 * @param sorted sort query results */
1373 void setSorted(const bool sorted
)
1378 /** Set the amount of threads that should be used for building and
1379 * querying the tree.
1380 * OpenMP has to be enabled for this to work.
1381 * @param threads amount of threads, 0 for optimal choice */
1382 void setThreads(const unsigned int threads
)
1387 /** Set the length of substrings (in bytes) used for multi index hashing.
1388 * @param len lentth of bucket substrings in bytes*/
1389 void setSubstringLength(const Index len
)
1394 /** Set the data points used for the KNN search.
1395 * @param data NxM matrix, M points of dimension N
1396 * @param copy if true data is copied, assumes static data otherwise */
1397 void setData(const Matrix
&data
, const bool copy
= false)
1413 if(data_
== nullptr)
1414 throw std::runtime_error("cannot build MultiIndexHashing; data not set");
1415 if(data_
->size() == 0)
1416 throw std::runtime_error("cannot build MultiIndexHashing; data is empty");
1418 const Matrix
&data
= *data_
;
1419 const Index bytesPerVec
= data
.rows() * static_cast<Index
>(sizeof(Scalar
));
1420 if(bytesPerVec
% substrLen_
!= 0)
1421 throw std::runtime_error("cannot build MultiIndexHashing; cannot divide byte count per vector by substring length without remainings");
1424 buckets_
.resize(bytesPerVec
/ substrLen_
);
1426 for(size_t i
= 0; i
< buckets_
.size(); ++i
)
1428 Index start
= static_cast<Index
>(i
) * substrLen_
;
1429 Index idx
= start
/ static_cast<Index
>(sizeof(Scalar
));
1430 Index offset
= start
% static_cast<Index
>(sizeof(Scalar
));
1431 std::map
<Scalar
, std::vector
<Index
>> &map
= buckets_
[i
];
1433 for(Index c
= 0; c
< data
.cols(); ++c
)
1435 Scalar code
= extractCode(data
.col(c
), idx
, offset
);
1436 if(map
.find(code
) == map
.end())
1437 map
[code
] = std::vector
<Index
>();
1438 map
[code
].push_back(c
);
1443 template<typename Derived
>
1444 void query(const Eigen::MatrixBase
<Derived
> &queryPoints
,
1447 Matrix
&distances
) const
1449 if(buckets_
.size() == 0)
1450 throw std::runtime_error("cannot query MultiIndexHashing; not built yet");
1451 if(queryPoints
.rows() != dimension())
1452 throw std::runtime_error("cannot query MultiIndexHashing; data and query points do not have same dimension");
1454 const Matrix
&data
= *data_
;
1456 indices
.setConstant(knn
, queryPoints
.cols(), -1);
1457 distances
.setConstant(knn
, queryPoints
.cols(), -1);
1459 Index
*indicesRaw
= indices
.data();
1460 Scalar
*distsRaw
= distances
.data();
1462 Scalar maxDistPart
= maxDist_
/ buckets_
.size();
1464 #pragma omp parallel for num_threads(threads_)
1465 for(Index c
= 0; c
< queryPoints
.cols(); ++c
)
1467 std::set
<Index
> candidates
;
1468 for(size_t i
= 0; i
< buckets_
.size(); ++i
)
1470 Index start
= static_cast<Index
>(i
) * substrLen_
;
1471 Index idx
= start
/ static_cast<Index
>(sizeof(Scalar
));
1472 Index offset
= start
% static_cast<Index
>(sizeof(Scalar
));
1473 const std::map
<Scalar
, std::vector
<Index
>> &map
= buckets_
[i
];
1475 Scalar code
= extractCode(queryPoints
.col(c
), idx
, offset
);
1476 for(const auto &x
: map
)
1478 Scalar dist
= distance_(x
.first
, code
);
1479 if(maxDistPart
<= 0 || dist
<= maxDistPart
)
1481 for(size_t j
= 0; j
< x
.second
.size(); ++j
)
1482 candidates
.insert(x
.second
[j
]);
1487 Scalar
*distPoint
= &distsRaw
[c
* knn
];
1488 Index
*idxPoint
= &indicesRaw
[c
* knn
];
1489 // create heap to find nearest neighbours
1490 QueryHeap
<Scalar
> dataHeap(idxPoint
, distPoint
, knn
);
1492 for(Index idx
: candidates
)
1494 Scalar dist
= distance_(data
.col(idx
), queryPoints
.col(c
));
1496 bool isInRange
= maxDist_
<= 0 || dist
<= maxDist_
;
1497 bool isImprovement
= !dataHeap
.full() ||
1498 dist
< dataHeap
.front();
1499 if(isInRange
&& isImprovement
)
1503 dataHeap
.push(idx
, dist
);
1512 /** Returns the amount of data points stored in the search index.
1513 * @return number of data points */
1516 return data_
== nullptr ? 0 : data_
->cols();
1519 /** Returns the dimension of the data points in the search index.
1520 * @return dimension of data points */
1521 Index
dimension() const
1523 return data_
== nullptr ? 0 : data_
->rows();
1529 dataCopy_
.resize(0, 0);
1537 /** Wrapper class of FLANN kdtrees for the use with Eigen3. */
1538 template<typename Scalar
,
1539 typename Distance
=flann::L2_Simple
<Scalar
>>
1543 typedef Eigen::Matrix
<Scalar
, Eigen::Dynamic
, Eigen::Dynamic
> Matrix
;
1544 typedef Eigen::Matrix
<Scalar
, Eigen::Dynamic
, 1> Vector
;
1545 typedef Eigen::Matrix
<int, Eigen::Dynamic
, Eigen::Dynamic
> Matrixi
;
1548 typedef flann::Index
<Distance
> FlannIndex
;
1551 Matrix
*dataPoints_
;
1554 flann::SearchParams searchParams_
;
1555 flann::IndexParams indexParams_
;
1560 : dataCopy_(), dataPoints_(nullptr), index_(nullptr),
1561 searchParams_(32, 0, false),
1562 indexParams_(flann::KDTreeSingleIndexParams(15)),
1567 KDTreeFlann(Matrix
&data
, const bool copy
= false)
1570 setData(data
, copy
);
1578 void setIndexParams(const flann::IndexParams
¶ms
)
1580 indexParams_
= params
;
1583 void setChecks(const int checks
)
1585 searchParams_
.checks
= checks
;
1588 void setSorted(const bool sorted
)
1590 searchParams_
.sorted
= sorted
;
1593 void setThreads(const int threads
)
1595 searchParams_
.cores
= threads
;
1598 void setEpsilon(const float eps
)
1600 searchParams_
.eps
= eps
;
1603 void setMaxDistance(const Scalar dist
)
1608 void setData(Matrix
&data
, const bool copy
= false)
1613 dataPoints_
= &dataCopy_
;
1617 dataPoints_
= &data
;
1625 if(dataPoints_
== nullptr)
1626 throw std::runtime_error("cannot build KDTree; data not set");
1627 if(dataPoints_
->size() == 0)
1628 throw std::runtime_error("cannot build KDTree; data is empty");
1630 if(index_
!= nullptr)
1633 flann::Matrix
<Scalar
> dataPts(
1634 dataPoints_
->data(),
1635 dataPoints_
->cols(),
1636 dataPoints_
->rows());
1638 index_
= new FlannIndex(dataPts
, indexParams_
);
1639 index_
->buildIndex();
1642 void query(Matrix
&queryPoints
,
1645 Matrix
&distances
) const
1647 if(index_
== nullptr)
1648 throw std::runtime_error("cannot query KDTree; not built yet");
1649 if(dataPoints_
->rows() != queryPoints
.rows())
1650 throw std::runtime_error("cannot query KDTree; KDTree has different dimension than query data");
1652 // resize result matrices
1653 distances
.resize(knn
, queryPoints
.cols());
1654 indices
.resize(knn
, queryPoints
.cols());
1656 // wrap matrices into flann matrices
1657 flann::Matrix
<Scalar
> queryPts(
1660 queryPoints
.rows());
1661 flann::Matrix
<int> indicesF(
1665 flann::Matrix
<Scalar
> distancesF(
1670 // if maximum distance was set then use radius search
1672 index_
->radiusSearch(queryPts
, indicesF
, distancesF
, maxDist_
, searchParams_
);
1674 index_
->knnSearch(queryPts
, indicesF
, distancesF
, knn
, searchParams_
);
1676 // make result matrices compatible to API
1677 #pragma omp parallel for num_threads(searchParams_.cores)
1678 for(Index i
= 0; i
< indices
.cols(); ++i
)
1681 for(Index j
= 0; j
< indices
.rows(); ++j
)
1683 if(indices(j
, i
) == -1)
1689 distances(j
, i
) = -1;
1697 return dataPoints_
== nullptr ? 0 : dataPoints_
->cols();
1700 Index
dimension() const
1702 return dataPoints_
== nullptr ? 0 : dataPoints_
->rows();
1707 if(index_
!= nullptr)
1714 FlannIndex
&flannIndex()
1720 typedef KDTreeFlann
<double> KDTreeFlannd
;
1721 typedef KDTreeFlann
<float> KDTreeFlannf
;