Fix typo
[nngd.git] / src / knncpp.h
CommitLineData
762721a5
BA
1/* knncpp.h
2 *
3 * Author: Fabian Meyer
4 * Created On: 22 Aug 2021
5 * License: MIT
6 */
7
8#ifndef KNNCPP_H_
9#define KNNCPP_H_
10
11#include <Eigen/Geometry>
12#include <vector>
13#include <map>
14#include <set>
15
16#ifdef KNNCPP_FLANN
17
18#include <flann/flann.hpp>
19
20#endif
21
22namespace knncpp
23{
24 /********************************************************
25 * Matrix Definitions
26 *******************************************************/
27
28 typedef typename Eigen::MatrixXd::Index Index;
29
30 typedef Eigen::Matrix<Index, Eigen::Dynamic, 1> Vectori;
31 typedef Eigen::Matrix<Index, 2, 1> Vector2i;
32 typedef Eigen::Matrix<Index, 3, 1> Vector3i;
33 typedef Eigen::Matrix<Index, 4, 1> Vector4i;
34 typedef Eigen::Matrix<Index, 5, 1> Vector5i;
35
36 typedef Eigen::Matrix<Index, Eigen::Dynamic, Eigen::Dynamic> Matrixi;
37 typedef Eigen::Matrix<Index, 2, 2> Matrix2i;
38 typedef Eigen::Matrix<Index, 3, 3> Matrix3i;
39 typedef Eigen::Matrix<Index, 4, 4> Matrix4i;
40 typedef Eigen::Matrix<Index, 5, 5> Matrix5i;
41
42 typedef Eigen::Matrix<float, Eigen::Dynamic, Eigen::Dynamic> Matrixf;
43 typedef Eigen::Matrix<double, Eigen::Dynamic, Eigen::Dynamic> Matrixd;
44
45 /********************************************************
46 * Distance Functors
47 *******************************************************/
48
49 /** Manhatten distance functor.
50 * This the same as the L1 minkowski distance but more efficient.
51 * @see EuclideanDistance, ChebyshevDistance, MinkowskiDistance */
52 template <typename Scalar>
53 struct ManhattenDistance
54 {
55 /** Compute the unrooted distance between two vectors.
56 * @param lhs vector on left hand side
57 * @param rhs vector on right hand side */
58 template<typename DerivedA, typename DerivedB>
59 Scalar operator()(const Eigen::MatrixBase<DerivedA> &lhs,
60 const Eigen::MatrixBase<DerivedB> &rhs) const
61 {
62 static_assert(
63 std::is_same<typename Eigen::MatrixBase<DerivedA>::Scalar,Scalar>::value,
64 "distance scalar and input matrix A must have same type");
65 static_assert(
66 std::is_same<typename Eigen::MatrixBase<DerivedB>::Scalar, Scalar>::value,
67 "distance scalar and input matrix B must have same type");
68
69 return (lhs - rhs).cwiseAbs().sum();
70 }
71
72 /** Compute the unrooted distance between two scalars.
73 * @param lhs scalar on left hand side
74 * @param rhs scalar on right hand side */
75 Scalar operator()(const Scalar lhs,
76 const Scalar rhs) const
77 {
78 return std::abs(lhs - rhs);
79 }
80
81 /** Compute the root of a unrooted distance value.
82 * @param value unrooted distance value */
83 Scalar operator()(const Scalar val) const
84 {
85 return val;
86 }
87 };
88
89 /** Euclidean distance functor.
90 * This the same as the L2 minkowski distance but more efficient.
91 * @see ManhattenDistance, ChebyshevDistance, MinkowskiDistance */
92 template <typename Scalar>
93 struct EuclideanDistance
94 {
95 /** Compute the unrooted distance between two vectors.
96 * @param lhs vector on left hand side
97 * @param rhs vector on right hand side */
98 template<typename DerivedA, typename DerivedB>
99 Scalar operator()(const Eigen::MatrixBase<DerivedA> &lhs,
100 const Eigen::MatrixBase<DerivedB> &rhs) const
101 {
102 static_assert(
103 std::is_same<typename Eigen::MatrixBase<DerivedA>::Scalar,Scalar>::value,
104 "distance scalar and input matrix A must have same type");
105 static_assert(
106 std::is_same<typename Eigen::MatrixBase<DerivedB>::Scalar, Scalar>::value,
107 "distance scalar and input matrix B must have same type");
108
109 return (lhs - rhs).cwiseAbs2().sum();
110 }
111
112 /** Compute the unrooted distance between two scalars.
113 * @param lhs scalar on left hand side
114 * @param rhs scalar on right hand side */
115 Scalar operator()(const Scalar lhs,
116 const Scalar rhs) const
117 {
118 Scalar diff = lhs - rhs;
119 return diff * diff;
120 }
121
122 /** Compute the root of a unrooted distance value.
123 * @param value unrooted distance value */
124 Scalar operator()(const Scalar val) const
125 {
126 return std::sqrt(val);
127 }
128 };
129
130 /** General minkowski distance functor.
131 * The infinite version is only available through the chebyshev distance.
132 * @see ManhattenDistance, EuclideanDistance, ChebyshevDistance */
133 template <typename Scalar, int P>
134 struct MinkowskiDistance
135 {
136 struct Pow
137 {
138 Scalar operator()(const Scalar val) const
139 {
140 Scalar result = 1;
141 for(int i = 0; i < P; ++i)
142 result *= val;
143 return result;
144 }
145 };
146
147 /** Compute the unrooted distance between two vectors.
148 * @param lhs vector on left hand side
149 * @param rhs vector on right hand side */
150 template<typename DerivedA, typename DerivedB>
151 Scalar operator()(const Eigen::MatrixBase<DerivedA> &lhs,
152 const Eigen::MatrixBase<DerivedB> &rhs) const
153 {
154 static_assert(
155 std::is_same<typename Eigen::MatrixBase<DerivedA>::Scalar,Scalar>::value,
156 "distance scalar and input matrix A must have same type");
157 static_assert(
158 std::is_same<typename Eigen::MatrixBase<DerivedB>::Scalar, Scalar>::value,
159 "distance scalar and input matrix B must have same type");
160
161 return (lhs - rhs).cwiseAbs().unaryExpr(MinkowskiDistance::Pow()).sum();
162 }
163
164 /** Compute the unrooted distance between two scalars.
165 * @param lhs scalar on left hand side
166 * @param rhs scalar on right hand side */
167 Scalar operator()(const Scalar lhs,
168 const Scalar rhs) const
169 {
170 return std::pow(std::abs(lhs - rhs), P);;
171 }
172
173 /** Compute the root of a unrooted distance value.
174 * @param value unrooted distance value */
175 Scalar operator()(const Scalar val) const
176 {
177 return std::pow(val, 1 / static_cast<Scalar>(P));
178 }
179 };
180
181 /** Chebyshev distance functor.
182 * This distance is the same as infinity minkowski distance.
183 * @see ManhattenDistance, EuclideanDistance, MinkowskiDistance */
184 template<typename Scalar>
185 struct ChebyshevDistance
186 {
187 /** Compute the unrooted distance between two vectors.
188 * @param lhs vector on left hand side
189 * @param rhs vector on right hand side */
190 template<typename DerivedA, typename DerivedB>
191 Scalar operator()(const Eigen::MatrixBase<DerivedA> &lhs,
192 const Eigen::MatrixBase<DerivedB> &rhs) const
193 {
194 static_assert(
195 std::is_same<typename Eigen::MatrixBase<DerivedA>::Scalar,Scalar>::value,
196 "distance scalar and input matrix A must have same type");
197 static_assert(
198 std::is_same<typename Eigen::MatrixBase<DerivedB>::Scalar, Scalar>::value,
199 "distance scalar and input matrix B must have same type");
200
201 return (lhs - rhs).cwiseAbs().maxCoeff();
202 }
203
204 /** Compute the unrooted distance between two scalars.
205 * @param lhs scalar on left hand side
206 * @param rhs scalar on right hand side */
207 Scalar operator()(const Scalar lhs,
208 const Scalar rhs) const
209 {
210 return std::abs(lhs - rhs);
211 }
212
213 /** Compute the root of a unrooted distance value.
214 * @param value unrooted distance value */
215 Scalar operator()(const Scalar val) const
216 {
217 return val;
218 }
219 };
220
221 /** Hamming distance functor.
222 * The distance vectors have to be of integral type and should hold the
223 * information vectors as bitmasks.
224 * Performs a XOR operation on the vectors and counts the number of set
225 * ones. */
226 template<typename Scalar>
227 struct HammingDistance
228 {
229 static_assert(std::is_integral<Scalar>::value,
230 "HammingDistance requires integral Scalar type");
231
232 struct XOR
233 {
234 Scalar operator()(const Scalar lhs, const Scalar rhs) const
235 {
236 return lhs ^ rhs;
237 }
238 };
239
240 struct BitCount
241 {
242 Scalar operator()(const Scalar lhs) const
243 {
244 Scalar copy = lhs;
245 Scalar result = 0;
246 while(copy != static_cast<Scalar>(0))
247 {
248 ++result;
249 copy &= (copy - 1);
250 }
251
252 return result;
253 }
254 };
255
256 /** Compute the unrooted distance between two vectors.
257 * @param lhs vector on left hand side
258 * @param rhs vector on right hand side */
259 template<typename DerivedA, typename DerivedB>
260 Scalar operator()(const Eigen::MatrixBase<DerivedA> &lhs,
261 const Eigen::MatrixBase<DerivedB> &rhs) const
262 {
263 static_assert(
264 std::is_same<typename Eigen::MatrixBase<DerivedA>::Scalar,Scalar>::value,
265 "distance scalar and input matrix A must have same type");
266 static_assert(
267 std::is_same<typename Eigen::MatrixBase<DerivedB>::Scalar, Scalar>::value,
268 "distance scalar and input matrix B must have same type");
269
270 return lhs.
271 binaryExpr(rhs, XOR()).
272 unaryExpr(BitCount()).
273 sum();
274 }
275
276 /** Compute the unrooted distance between two scalars.
277 * @param lhs scalar on left hand side
278 * @param rhs scalar on right hand side */
279 Scalar operator()(const Scalar lhs,
280 const Scalar rhs) const
281 {
282 BitCount cnt;
283 XOR xOr;
284 return cnt(xOr(lhs, rhs));
285 }
286
287 /** Compute the root of a unrooted distance value.
288 * @param value unrooted distance value */
289 Scalar operator()(const Scalar value) const
290 {
291 return value;
292 }
293 };
294
295 /** Efficient heap structure to query nearest neighbours. */
296 template<typename Scalar>
297 class QueryHeap
298 {
299 private:
300 Index *indices_ = nullptr;
301 Scalar *distances_ = nullptr;
302 size_t maxSize_ = 0;
303 size_t size_ = 0;
304 public:
305 /** Creates a query heap with the given index and distance memory regions. */
306 QueryHeap(Index *indices, Scalar *distances, const size_t maxSize)
307 : indices_(indices), distances_(distances), maxSize_(maxSize)
308 { }
309
310 /** Pushes a new query data set into the heap with the given
311 * index and distance.
312 * The index identifies the point for which the given distance
313 * was computed.
314 * @param idx index / ID of the query point
315 * @param dist distance that was computed for the query point*/
316 void push(const Index idx, const Scalar dist)
317 {
318 assert(!full());
319
320 // add new value at the end
321 indices_[size_] = idx;
322 distances_[size_] = dist;
323 ++size_;
324
325 // upheap
326 size_t k = size_ - 1;
327 size_t tmp = (k - 1) / 2;
328 while(k > 0 && distances_[tmp] < dist)
329 {
330 distances_[k] = distances_[tmp];
331 indices_[k] = indices_[tmp];
332 k = tmp;
333 tmp = (k - 1) / 2;
334 }
335 distances_[k] = dist;
336 indices_[k] = idx;
337 }
338
339 /** Removes the element at the front of the heap and restores
340 * the heap order. */
341 void pop()
342 {
343 assert(!empty());
344
345 // replace first element with last
346 --size_;
347 distances_[0] = distances_[size_];
348 indices_[0] = indices_[size_];
349
350 // downheap
351 size_t k = 0;
352 size_t j;
353 Scalar dist = distances_[0];
354 Index idx = indices_[0];
355 while(2 * k + 1 < size_)
356 {
357 j = 2 * k + 1;
358 if(j + 1 < size_ && distances_[j+1] > distances_[j])
359 ++j;
360 // j references now greatest child
361 if(dist >= distances_[j])
362 break;
363 distances_[k] = distances_[j];
364 indices_[k] = indices_[j];
365 k = j;
366 }
367 distances_[k] = dist;
368 indices_[k] = idx;
369 }
370
371 /** Returns the distance of the element in front of the heap. */
372 Scalar front() const
373 {
374 assert(!empty());
375 return distances_[0];
376 }
377
378 /** Determines if this query heap is full.
379 * The heap is considered full if its number of elements
380 * has reached its max size.
381 * @return true if the heap is full, else false */
382 bool full() const
383 {
384 return size_ >= maxSize_;
385 }
386
387 /** Determines if this query heap is empty.
388 * @return true if the heap contains no elements, else false */
389 bool empty() const
390 {
391 return size_ == 0;
392 }
393
394 /** Returns the number of elements within the query heap.
395 * @return number of elements in the heap */
396 size_t size() const
397 {
398 return size_;
399 }
400
401 /** Clears the query heap. */
402 void clear()
403 {
404 size_ = 0;
405 }
406
407 /** Sorts the elements within the heap according to
408 * their distance. */
409 void sort()
410 {
411 size_t cnt = size_;
412 for(size_t i = 0; i < cnt; ++i)
413 {
414 Index idx = indices_[0];
415 Scalar dist = distances_[0];
416 pop();
417 indices_[cnt - i - 1] = idx;
418 distances_[cnt - i - 1] = dist;
419 }
420 }
421 };
422
423 /** Class for performing brute force knn search. */
424 template<typename Scalar,
425 typename Distance=EuclideanDistance<Scalar>>
426 class BruteForce
427 {
428 public:
429 typedef Eigen::Matrix<Scalar, Eigen::Dynamic, Eigen::Dynamic> Matrix;
430 typedef Eigen::Matrix<Scalar, Eigen::Dynamic, 1> Vector;
431 typedef knncpp::Matrixi Matrixi;
432 private:
433 Distance distance_ = Distance();
434 Matrix dataCopy_ = Matrix();
435 const Matrix *data_ = nullptr;
436
437 bool sorted_ = true;
438 bool takeRoot_ = true;
439 Index threads_ = 1;
440 Scalar maxDist_ = 0;
441
442 public:
443
444 BruteForce() = default;
445
446 /** Constructs a brute force instance with the given data.
447 * @param data NxM matrix, M points of dimension N
448 * @param copy if true copies the data, otherwise assumes static data */
449 BruteForce(const Matrix &data, const bool copy = false)
450 : BruteForce()
451 {
452 setData(data, copy);
453 }
454
455 /** Set if the points returned by the queries should be sorted
456 * according to their distance to the query points.
457 * @param sorted sort query results */
458 void setSorted(const bool sorted)
459 {
460 sorted_ = sorted;
461 }
462
463 /** Set if the distances after the query should be rooted or not.
464 * Taking the root of the distances increases query time, but the
465 * function will return true distances instead of their powered
466 * versions.
467 * @param takeRoot set true if root should be taken else false */
468 void setTakeRoot(const bool takeRoot)
469 {
470 takeRoot_ = takeRoot;
471 }
472
473 /** Set the amount of threads that should be used for querying.
474 * OpenMP has to be enabled for this to work.
475 * @param threads amount of threads, 0 for optimal choice */
476 void setThreads(const unsigned int threads)
477 {
478 threads_ = threads;
479 }
480
481 /** Set the maximum distance for querying the tree.
482 * The search will be pruned if the maximum distance is set to any
483 * positive number.
484 * @param maxDist maximum distance, <= 0 for no limit */
485 void setMaxDistance(const Scalar maxDist)
486 {
487 maxDist_ = maxDist;
488 }
489
490 /** Set the data points used for this tree.
491 * This does not build the tree.
492 * @param data NxM matrix, M points of dimension N
493 * @param copy if true data is copied, assumes static data otherwise */
494 void setData(const Matrix &data, const bool copy = false)
495 {
496 if(copy)
497 {
498 dataCopy_ = data;
499 data_ = &dataCopy_;
500 }
501 else
502 {
503 data_ = &data;
504 }
505 }
506
507 void setDistance(const Distance &distance)
508 {
509 distance_ = distance;
510 }
511
512 void build()
513 { }
514
515 template<typename Derived>
516 void query(const Eigen::MatrixBase<Derived> &queryPoints,
517 const size_t knn,
518 Matrixi &indices,
519 Matrix &distances) const
520 {
521 if(data_ == nullptr)
522 throw std::runtime_error("cannot query BruteForce: data not set");
523 if(data_->size() == 0)
524 throw std::runtime_error("cannot query BruteForce: data is empty");
525 if(queryPoints.rows() != dimension())
526 throw std::runtime_error("cannot query BruteForce: data and query descriptors do not have same dimension");
527
528 const Matrix &dataPoints = *data_;
529
530 indices.setConstant(knn, queryPoints.cols(), -1);
531 distances.setConstant(knn, queryPoints.cols(), -1);
532
533 #pragma omp parallel for num_threads(threads_)
534 for(Index i = 0; i < queryPoints.cols(); ++i)
535 {
536 Index *idxPoint = &indices.data()[i * knn];
537 Scalar *distPoint = &distances.data()[i * knn];
538
539 QueryHeap<Scalar> heap(idxPoint, distPoint, knn);
540
541 for(Index j = 0; j < dataPoints.cols(); ++j)
542 {
543 Scalar dist = distance_(queryPoints.col(i), dataPoints.col(j));
544
545 // check if point is in range if max distance was set
546 bool isInRange = maxDist_ <= 0 || dist <= maxDist_;
547 // check if this node was an improvement if heap is already full
548 bool isImprovement = !heap.full() ||
549 dist < heap.front();
550 if(isInRange && isImprovement)
551 {
552 if(heap.full())
553 heap.pop();
554 heap.push(j, dist);
555 }
556 }
557
558 if(sorted_)
559 heap.sort();
560
561 if(takeRoot_)
562 {
563 for(size_t j = 0; j < knn; ++j)
564 {
565 if(idxPoint[j] < 0)
566 break;
567 distPoint[j] = distance_(distPoint[j]);
568 }
569 }
570 }
571 }
572
573 /** Returns the amount of data points stored in the search index.
574 * @return number of data points */
575 Index size() const
576 {
577 return data_ == nullptr ? 0 : data_->cols();
578 }
579
580 /** Returns the dimension of the data points in the search index.
581 * @return dimension of data points */
582 Index dimension() const
583 {
584 return data_ == nullptr ? 0 : data_->rows();
585 }
586 };
587
588 // template<typename Scalar>
589 // struct MeanMidpointRule
590 // {
591 // typedef Eigen::Matrix<Scalar, Eigen::Dynamic, Eigen::Dynamic> Matrix;
592 // typedef knncpp::Matrixi Matrixi;
593
594 // void operator(const Matrix &data, const Matrixi &indices, Index split)
595 // };
596
597 /** Class for performing k nearest neighbour searches with minkowski distances.
598 * This kdtree only works reliably with the minkowski distance and its
599 * special cases like manhatten or euclidean distance.
600 * @see ManhattenDistance, EuclideanDistance, ChebyshevDistance, MinkowskiDistance*/
601 template<typename _Scalar, int _Dimension, typename _Distance>
602 class KDTreeMinkowski
603 {
604 public:
605 typedef _Scalar Scalar;
606 typedef _Distance Distance;
607 typedef Eigen::Matrix<Scalar, Eigen::Dynamic, Eigen::Dynamic> Matrix;
608 typedef Eigen::Matrix<Scalar, _Dimension, Eigen::Dynamic> DataMatrix;
609 typedef Eigen::Matrix<Scalar, _Dimension, 1> DataVector;
610 typedef knncpp::Matrixi Matrixi;
611 private:
612 typedef Eigen::Matrix<Scalar, 2, 1> Bounds;
613 typedef Eigen::Matrix<Scalar, 2, _Dimension> BoundingBox;
614
615 /** Struct representing a node in the KDTree.
616 * It can be either a inner node or a leaf node. */
617 struct Node
618 {
619 /** Indices of data points in this leaf node. */
620 Index startIdx = 0;
621 Index length = 0;
622
623 /** Left child of this inner node. */
624 Index left = -1;
625 /** Right child of this inner node. */
626 Index right = -1;
627 /** Axis of the axis aligned splitting hyper plane. */
628 Index splitaxis = -1;
629 /** Translation of the axis aligned splitting hyper plane. */
630 Scalar splitpoint = 0;
631 /** Lower end of the splitpoint range */
632 Scalar splitlower = 0;
633 /** Upper end of the splitpoint range */
634 Scalar splitupper = 0;
635
636
637 Node() = default;
638
639 /** Constructor for leaf nodes */
640 Node(const Index startIdx, const Index length)
641 : startIdx(startIdx), length(length)
642 { }
643
644 /** Constructor for inner nodes */
645 Node(const Index splitaxis, const Scalar splitpoint,
646 const Index left, const Index right)
647 : left(left), right(right),
648 splitaxis(splitaxis), splitpoint(splitpoint)
649 { }
650
651 bool isLeaf() const
652 {
653 return !hasLeft() && !hasRight();
654 }
655
656 bool isInner() const
657 {
658 return hasLeft() && hasRight();
659 }
660
661 bool hasLeft() const
662 {
663 return left >= 0;
664 }
665
666 bool hasRight() const
667 {
668 return right >= 0;
669 }
670 };
671
672 DataMatrix dataCopy_ = DataMatrix();
673 const DataMatrix *data_ = nullptr;
674 std::vector<Index> indices_ = std::vector<Index>();
675 std::vector<Node> nodes_ = std::vector<Node>();
676
677 Index bucketSize_ = 16;
678 bool sorted_ = true;
679 bool compact_ = false;
680 bool balanced_ = false;
681 bool takeRoot_ = true;
682 Index threads_ = 0;
683 Scalar maxDist_ = 0;
684
685 Distance distance_ = Distance();
686
687 BoundingBox bbox_ = BoundingBox();
688
689 Index buildLeafNode(const Index startIdx,
690 const Index length,
691 BoundingBox &bbox)
692 {
693 nodes_.push_back(Node(startIdx, length));
694 calculateBoundingBox(startIdx, length, bbox);
695 return static_cast<Index>(nodes_.size() - 1);
696 }
697
698 /** Finds the minimum and maximum values of each dimension (row) in the
699 * data matrix. Only respects the columns specified by the index
700 * vector.
701 * @param startIdx starting index within indices data structure to search for bounding box
702 * @param length length of the block of indices*/
703 void calculateBoundingBox(const Index startIdx,
704 const Index length,
705 BoundingBox &bbox) const
706 {
707 assert(length > 0);
708 assert(startIdx >= 0);
709 assert(static_cast<size_t>(startIdx + length) <= indices_.size());
710 assert(data_->rows() == bbox.cols());
711
712 const DataMatrix &data = *data_;
713
714 // initialize bounds of the bounding box
715 Index first = indices_[startIdx];
716 for(Index i = 0; i < bbox.cols(); ++i)
717 {
718 bbox(0, i) = data(i, first);
719 bbox(1, i) = data(i, first);
720 }
721
722 // search for min / max values in data
723 for(Index i = 1; i < length; ++i)
724 {
725 // retrieve data index
726 Index col = indices_[startIdx + i];
727 assert(col >= 0 && col < data.cols());
728
729 // check min and max for each dimension individually
730 for(Index j = 0; j < data.rows(); ++j)
731 {
732 bbox(0, j) = std::min(bbox(0, j), data(j, col));
733 bbox(1, j) = std::max(bbox(1, j), data(j, col));
734 }
735 }
736 }
737
738 /** Calculates the bounds (min / max values) for the given dimension and block of data. */
739 void calculateBounds(const Index startIdx,
740 const Index length,
741 const Index dim,
742 Bounds &bounds) const
743 {
744 assert(length > 0);
745 assert(startIdx >= 0);
746 assert(static_cast<size_t>(startIdx + length) <= indices_.size());
747
748 const DataMatrix &data = *data_;
749
750 bounds(0) = data(dim, indices_[startIdx]);
751 bounds(1) = data(dim, indices_[startIdx]);
752
753 for(Index i = 1; i < length; ++i)
754 {
755 Index col = indices_[startIdx + i];
756 assert(col >= 0 && col < data.cols());
757
758 bounds(0) = std::min(bounds(0), data(dim, col));
759 bounds(1) = std::max(bounds(1), data(dim, col));
760 }
761 }
762
763 void calculateSplittingMidpoint(const Index startIdx,
764 const Index length,
765 const BoundingBox &bbox,
766 Index &splitaxis,
767 Scalar &splitpoint,
768 Index &splitoffset)
769 {
770 const DataMatrix &data = *data_;
771
772 // search for axis with longest distance
773 splitaxis = 0;
774 Scalar splitsize = static_cast<Scalar>(0);
775 for(Index i = 0; i < data.rows(); ++i)
776 {
777 Scalar diff = bbox(1, i) - bbox(0, i);
778 if(diff > splitsize)
779 {
780 splitaxis = i;
781 splitsize = diff;
782 }
783 }
784
785 // calculate the bounds in this axis and update our data
786 // accordingly
787 Bounds bounds;
788 calculateBounds(startIdx, length, splitaxis, bounds);
789 splitsize = bounds(1) - bounds(0);
790
791 const Index origSplitaxis = splitaxis;
792 for(Index i = 0; i < data.rows(); ++i)
793 {
794 // skip the dimension of the previously found splitaxis
795 if(i == origSplitaxis)
796 continue;
797 Scalar diff = bbox(1, i) - bbox(0, i);
798 // check if the split for this dimension would be potentially larger
799 if(diff > splitsize)
800 {
801 Bounds newBounds;
802 // update the bounds to their actual current value
803 calculateBounds(startIdx, length, splitaxis, newBounds);
804 diff = newBounds(1) - newBounds(0);
805 if(diff > splitsize)
806 {
807 splitaxis = i;
808 splitsize = diff;
809 bounds = newBounds;
810 }
811 }
812 }
813
814 // use the sliding midpoint rule
815 splitpoint = (bounds(0) + bounds(1)) / static_cast<Scalar>(2);
816
817 Index leftIdx = startIdx;
818 Index rightIdx = startIdx + length - 1;
819
820 // first loop checks left < splitpoint and right >= splitpoint
821 while(leftIdx <= rightIdx)
822 {
823 // increment left as long as left has not reached right and
824 // the value of the left element is less than the splitpoint
825 while(leftIdx <= rightIdx && data(splitaxis, indices_[leftIdx]) < splitpoint)
826 ++leftIdx;
827
828 // decrement right as long as left has not reached right and
829 // the value of the right element is greater than the splitpoint
830 while(leftIdx <= rightIdx && data(splitaxis, indices_[rightIdx]) >= splitpoint)
831 --rightIdx;
832
833 if(leftIdx <= rightIdx)
834 {
835 std::swap(indices_[leftIdx], indices_[rightIdx]);
836 ++leftIdx;
837 --rightIdx;
838 }
839 }
840
841 // remember this offset from starting index
842 const Index offset1 = leftIdx - startIdx;
843
844 rightIdx = startIdx + length - 1;
845 // second loop checks left <= splitpoint and right > splitpoint
846 while(leftIdx <= rightIdx)
847 {
848 // increment left as long as left has not reached right and
849 // the value of the left element is less than the splitpoint
850 while(leftIdx <= rightIdx && data(splitaxis, indices_[leftIdx]) <= splitpoint)
851 ++leftIdx;
852
853 // decrement right as long as left has not reached right and
854 // the value of the right element is greater than the splitpoint
855 while(leftIdx <= rightIdx && data(splitaxis, indices_[rightIdx]) > splitpoint)
856 --rightIdx;
857
858 if(leftIdx <= rightIdx)
859 {
860 std::swap(indices_[leftIdx], indices_[rightIdx]);
861 ++leftIdx;
862 --rightIdx;
863 }
864 }
865
866 // remember this offset from starting index
867 const Index offset2 = leftIdx - startIdx;
868
869 const Index halfLength = length / static_cast<Index>(2);
870
871 // find a separation of points such that is best balanced
872 // offset1 denotes separation where equal points are all on the right
873 // offset2 denots separation where equal points are all on the left
874 if (offset1 > halfLength)
875 splitoffset = offset1;
876 else if (offset2 < halfLength)
877 splitoffset = offset2;
878 // when we get here offset1 < halflength and offset2 > halflength
879 // so simply split the equal elements in the middle
880 else
881 splitoffset = halfLength;
882 }
883
884 Index buildInnerNode(const Index startIdx,
885 const Index length,
886 BoundingBox &bbox)
887 {
888 assert(length > 0);
889 assert(startIdx >= 0);
890 assert(static_cast<size_t>(startIdx + length) <= indices_.size());
891 assert(data_->rows() == bbox.cols());
892
893 // create node
894 const Index nodeIdx = nodes_.size();
895 nodes_.push_back(Node());
896
897 Index splitaxis;
898 Index splitoffset;
899 Scalar splitpoint;
900 calculateSplittingMidpoint(startIdx, length, bbox, splitaxis, splitpoint, splitoffset);
901
902 nodes_[nodeIdx].splitaxis = splitaxis;
903 nodes_[nodeIdx].splitpoint = splitpoint;
904
905 const Index leftStart = startIdx;
906 const Index leftLength = splitoffset;
907 const Index rightStart = startIdx + splitoffset;
908 const Index rightLength = length - splitoffset;
909
910 BoundingBox bboxLeft = bbox;
911 BoundingBox bboxRight = bbox;
912
913 // do left build
914 bboxLeft(1, splitaxis) = splitpoint;
915 Index left = buildR(leftStart, leftLength, bboxLeft);
916 nodes_[nodeIdx].left = left;
917
918 // do right build
919 bboxRight(0, splitaxis) = splitpoint;
920 Index right = buildR(rightStart, rightLength, bboxRight);
921 nodes_[nodeIdx].right = right;
922
923 // extract the range of the splitpoint
924 nodes_[nodeIdx].splitlower = bboxLeft(1, splitaxis);
925 nodes_[nodeIdx].splitupper = bboxRight(0, splitaxis);
926
927 // update the bounding box to the values of the new bounding boxes
928 for(Index i = 0; i < bbox.cols(); ++i)
929 {
930 bbox(0, i) = std::min(bboxLeft(0, i), bboxRight(0, i));
931 bbox(1, i) = std::max(bboxLeft(1, i), bboxRight(1, i));
932 }
933
934 return nodeIdx;
935 }
936
937 Index buildR(const Index startIdx,
938 const Index length,
939 BoundingBox &bbox)
940 {
941 // check for base case
942 if(length <= bucketSize_)
943 return buildLeafNode(startIdx, length, bbox);
944 else
945 return buildInnerNode(startIdx, length, bbox);
946 }
947
948 bool isDistanceInRange(const Scalar dist) const
949 {
950 return maxDist_ <= 0 || dist <= maxDist_;
951 }
952
953 bool isDistanceImprovement(const Scalar dist, const QueryHeap<Scalar> &dataHeap) const
954 {
955 return !dataHeap.full() || dist < dataHeap.front();
956 }
957
958 template<typename Derived>
959 void queryLeafNode(const Node &node,
960 const Eigen::MatrixBase<Derived> &queryPoint,
961 QueryHeap<Scalar> &dataHeap) const
962 {
963 assert(node.isLeaf());
964
965 const DataMatrix &data = *data_;
966
967 // go through all points in this leaf node and do brute force search
968 for(Index i = 0; i < node.length; ++i)
969 {
970 const Index idx = node.startIdx + i;
971 assert(idx >= 0 && idx < static_cast<Index>(indices_.size()));
972
973 // retrieve index of the current data point
974 const Index dataIdx = indices_[idx];
975 const Scalar dist = distance_(queryPoint, data.col(dataIdx));
976
977 // check if point is within max distance and if the value would be
978 // an improvement
979 if(isDistanceInRange(dist) && isDistanceImprovement(dist, dataHeap))
980 {
981 if(dataHeap.full())
982 dataHeap.pop();
983 dataHeap.push(dataIdx, dist);
984 }
985 }
986 }
987
988 template<typename Derived>
989 void queryInnerNode(const Node &node,
990 const Eigen::MatrixBase<Derived> &queryPoint,
991 QueryHeap<Scalar> &dataHeap,
992 DataVector &splitdists,
993 const Scalar mindist) const
994 {
995 assert(node.isInner());
996
997 const Index splitaxis = node.splitaxis;
998 const Scalar splitval = queryPoint(splitaxis, 0);
999 Scalar splitdist;
1000 Index firstNode;
1001 Index secondNode;
1002 // check if right or left child should be visited
1003 const bool visitLeft = (splitval - node.splitlower + splitval - node.splitupper) < 0;
1004 if(visitLeft)
1005 {
1006 firstNode = node.left;
1007 secondNode = node.right;
1008 splitdist = distance_(splitval, node.splitupper);
1009 }
1010 else
1011 {
1012 firstNode = node.right;
1013 secondNode = node.left;
1014 splitdist = distance_(splitval, node.splitlower);
1015 }
1016
1017 queryR(nodes_[firstNode], queryPoint, dataHeap, splitdists, mindist);
1018
1019 const Scalar mindistNew = mindist + splitdist - splitdists(splitaxis);
1020
1021 // check if node is in range if max distance was set
1022 // check if this node was an improvement if heap is already full
1023 if(isDistanceInRange(mindistNew) && isDistanceImprovement(mindistNew, dataHeap))
1024 {
1025 const Scalar splitdistOld = splitdists(splitaxis);
1026 splitdists(splitaxis) = splitdist;
1027 queryR(nodes_[secondNode], queryPoint, dataHeap, splitdists, mindistNew);
1028 splitdists(splitaxis) = splitdistOld;
1029 }
1030 }
1031
1032 template<typename Derived>
1033 void queryR(const Node &node,
1034 const Eigen::MatrixBase<Derived> &queryPoint,
1035 QueryHeap<Scalar> &dataHeap,
1036 DataVector &splitdists,
1037 const Scalar mindist) const
1038 {
1039 if(node.isLeaf())
1040 queryLeafNode(node, queryPoint, dataHeap);
1041 else
1042 queryInnerNode(node, queryPoint, dataHeap, splitdists, mindist);
1043 }
1044
1045 /** Recursively computes the depth for the given node. */
1046 Index depthR(const Node &node) const
1047 {
1048 if(node.isLeaf())
1049 return 1;
1050 else
1051 {
1052 Index left = depthR(nodes_[node.left]);
1053 Index right = depthR(nodes_[node.right]);
1054 return std::max(left, right) + 1;
1055 }
1056 }
1057
1058 public:
1059
1060 /** Constructs an empty KDTree. */
1061 KDTreeMinkowski()
1062 { }
1063
1064 /** Constructs KDTree with the given data. This does not build the
1065 * the index of the tree.
1066 * @param data NxM matrix, M points of dimension N
1067 * @param copy if true copies the data, otherwise assumes static data */
1068 KDTreeMinkowski(const DataMatrix &data, const bool copy=false)
1069 {
1070 setData(data, copy);
1071 }
1072
1073 /** Set the maximum amount of data points per leaf in the tree (aka
1074 * bucket size).
1075 * @param bucketSize amount of points per leaf. */
1076 void setBucketSize(const Index bucketSize)
1077 {
1078 bucketSize_ = bucketSize;
1079 }
1080
1081 /** Set if the points returned by the queries should be sorted
1082 * according to their distance to the query points.
1083 * @param sorted sort query results */
1084 void setSorted(const bool sorted)
1085 {
1086 sorted_ = sorted;
1087 }
1088
1089 /** Set if the tree should be built as balanced as possible.
1090 * This increases build time, but decreases search time.
1091 * @param balanced set true to build a balanced tree */
1092 void setBalanced(const bool balanced)
1093 {
1094 balanced_ = balanced;
1095 }
1096
1097 /** Set if the distances after the query should be rooted or not.
1098 * Taking the root of the distances increases query time, but the
1099 * function will return true distances instead of their powered
1100 * versions.
1101 * @param takeRoot set true if root should be taken else false */
1102 void setTakeRoot(const bool takeRoot)
1103 {
1104 takeRoot_ = takeRoot;
1105 }
1106
1107 /** Set if the tree should be built with compact leaf nodes.
1108 * This increases build time, but makes leaf nodes denser (more)
1109 * points. Thus less visits are necessary.
1110 * @param compact set true ti build a tree with compact leafs */
1111 void setCompact(const bool compact)
1112 {
1113 compact_ = compact;
1114 }
1115
1116 /** Set the amount of threads that should be used for building and
1117 * querying the tree.
1118 * OpenMP has to be enabled for this to work.
1119 * @param threads amount of threads, 0 for optimal choice */
1120 void setThreads(const unsigned int threads)
1121 {
1122 threads_ = threads;
1123 }
1124
1125 /** Set the maximum distance for querying the tree.
1126 * The search will be pruned if the maximum distance is set to any
1127 * positive number.
1128 * @param maxDist maximum distance, <= 0 for no limit */
1129 void setMaxDistance(const Scalar maxDist)
1130 {
1131 maxDist_ = maxDist;
1132 }
1133
1134 /** Set the data points used for this tree.
1135 * This does not build the tree.
1136 * @param data NxM matrix, M points of dimension N
1137 * @param copy if true data is copied, assumes static data otherwise */
1138 void setData(const DataMatrix &data, const bool copy = false)
1139 {
1140 clear();
1141 if(copy)
1142 {
1143 dataCopy_ = data;
1144 data_ = &dataCopy_;
1145 }
1146 else
1147 {
1148 data_ = &data;
1149 }
1150 }
1151
1152 void setDistance(const Distance &distance)
1153 {
1154 distance_ = distance;
1155 }
1156
1157 /** Builds the search index of the tree.
1158 * Data has to be set and must be non-empty. */
1159 void build()
1160 {
1161 if(data_ == nullptr)
1162 throw std::runtime_error("cannot build KDTree; data not set");
1163
1164 if(data_->size() == 0)
1165 throw std::runtime_error("cannot build KDTree; data is empty");
1166
1167 clear();
1168 nodes_.reserve((data_->cols() / bucketSize_) + 1);
1169
1170 // initialize indices in simple sequence
1171 indices_.resize(data_->cols());
1172 for(size_t i = 0; i < indices_.size(); ++i)
1173 indices_[i] = i;
1174
1175 bbox_.resize(2, data_->rows());
1176 Index startIdx = 0;
1177 Index length = data_->cols();
1178
1179 calculateBoundingBox(startIdx, length, bbox_);
1180
1181 buildR(startIdx, length, bbox_);
1182 }
1183
1184 /** Queries the tree for the nearest neighbours of the given query
1185 * points.
1186 *
1187 * The tree has to be built before it can be queried.
1188 *
1189 * The query points have to have the same dimension as the data points
1190 * of the tree.
1191 *
1192 * The result matrices will be resized appropriatley.
1193 * Indices and distances will be set to -1 if less than knn neighbours
1194 * were found.
1195 *
1196 * @param queryPoints NxM matrix, M points of dimension N
1197 * @param knn amount of neighbours to be found
1198 * @param indices KNNxM matrix, indices of neighbours in the data set
1199 * @param distances KNNxM matrix, distance between query points and
1200 * neighbours */
1201 template<typename Derived>
1202 void query(const Eigen::MatrixBase<Derived> &queryPoints,
1203 const size_t knn,
1204 Matrixi &indices,
1205 Matrix &distances) const
1206 {
1207 if(nodes_.size() == 0)
1208 throw std::runtime_error("cannot query KDTree; not built yet");
1209
1210 if(queryPoints.rows() != dimension())
1211 throw std::runtime_error("cannot query KDTree; data and query points do not have same dimension");
1212
1213 distances.setConstant(knn, queryPoints.cols(), -1);
1214 indices.setConstant(knn, queryPoints.cols(), -1);
1215
1216 Index *indicesRaw = indices.data();
1217 Scalar *distsRaw = distances.data();
1218
1219 #pragma omp parallel for num_threads(threads_)
1220 for(Index i = 0; i < queryPoints.cols(); ++i)
1221 {
1222
1223 Scalar *distPoint = &distsRaw[i * knn];
1224 Index *idxPoint = &indicesRaw[i * knn];
1225
1226 // create heap to find nearest neighbours
1227 QueryHeap<Scalar> dataHeap(idxPoint, distPoint, knn);
1228
1229 Scalar mindist = static_cast<Scalar>(0);
1230 DataVector splitdists(queryPoints.rows());
1231
1232 for(Index j = 0; j < splitdists.rows(); ++j)
1233 {
1234 const Scalar value = queryPoints(j, i);
1235 const Scalar lower = bbox_(0, j);
1236 const Scalar upper = bbox_(1, j);
1237 if(value < lower)
1238 {
1239 splitdists(j) = distance_(value, lower);
1240 }
1241 else if(value > upper)
1242 {
1243 splitdists(j) = distance_(value, upper);
1244 }
1245 else
1246 {
1247 splitdists(j) = static_cast<Scalar>(0);
1248 }
1249
1250 mindist += splitdists(j);
1251 }
1252
1253 queryR(nodes_[0], queryPoints.col(i), dataHeap, splitdists, mindist);
1254
1255 if(sorted_)
1256 dataHeap.sort();
1257
1258 if(takeRoot_)
1259 {
1260 for(size_t j = 0; j < knn; ++j)
1261 {
1262 if(distPoint[j] < 0)
1263 break;
1264 distPoint[j] = distance_(distPoint[j]);
1265 }
1266 }
1267 }
1268 }
1269
1270 /** Clears the tree. */
1271 void clear()
1272 {
1273 nodes_.clear();
1274 }
1275
1276 /** Returns the amount of data points stored in the search index.
1277 * @return number of data points */
1278 Index size() const
1279 {
1280 return data_ == nullptr ? 0 : data_->cols();
1281 }
1282
1283 /** Returns the dimension of the data points in the search index.
1284 * @return dimension of data points */
1285 Index dimension() const
1286 {
1287 return data_ == nullptr ? 0 : data_->rows();
1288 }
1289
1290 /** Returns the maxximum depth of the tree.
1291 * @return maximum depth of the tree */
1292 Index depth() const
1293 {
1294 return nodes_.size() == 0 ? 0 : depthR(nodes_.front());
1295 }
1296 };
1297
1298 template<typename _Scalar, typename _Distance = EuclideanDistance<_Scalar>> using KDTreeMinkowski2 = KDTreeMinkowski<_Scalar, 2, _Distance>;
1299 template<typename _Scalar, typename _Distance = EuclideanDistance<_Scalar>> using KDTreeMinkowski3 = KDTreeMinkowski<_Scalar, 3, _Distance>;
1300 template<typename _Scalar, typename _Distance = EuclideanDistance<_Scalar>> using KDTreeMinkowski4 = KDTreeMinkowski<_Scalar, 4, _Distance>;
1301 template<typename _Scalar, typename _Distance = EuclideanDistance<_Scalar>> using KDTreeMinkowski5 = KDTreeMinkowski<_Scalar, 5, _Distance>;
1302 template<typename _Scalar, typename _Distance = EuclideanDistance<_Scalar>> using KDTreeMinkowskiX = KDTreeMinkowski<_Scalar, Eigen::Dynamic, _Distance>;
1303
1304 /** Class for performing KNN search in hamming space by multi-index hashing. */
1305 template<typename Scalar>
1306 class MultiIndexHashing
1307 {
1308 public:
1309 static_assert(std::is_integral<Scalar>::value, "MultiIndexHashing Scalar has to be integral");
1310
1311 typedef Eigen::Matrix<Scalar, Eigen::Dynamic, Eigen::Dynamic> Matrix;
1312 typedef Eigen::Matrix<Scalar, Eigen::Dynamic, 1> Vector;
1313 typedef knncpp::Matrixi Matrixi;
1314
1315 private:
1316 HammingDistance<Scalar> distance_;
1317
1318 Matrix dataCopy_;
1319 const Matrix *data_;
1320
1321 bool sorted_;
1322 Scalar maxDist_;
1323 Index substrLen_;
1324 Index threads_;
1325 std::vector<std::map<Scalar, std::vector<Index>>> buckets_;
1326
1327 template<typename Derived>
1328 Scalar extractCode(const Eigen::MatrixBase<Derived> &data,
1329 const Index idx,
1330 const Index offset) const
1331 {
1332 Index leftShift = std::max<Index>(0, static_cast<Index>(sizeof(Scalar)) - offset - substrLen_);
1333 Index rightShift = leftShift + offset;
1334
1335 Scalar code = (data(idx, 0) << (leftShift * 8)) >> (rightShift * 8);
1336
1337 if(static_cast<Index>(sizeof(Scalar)) - offset < substrLen_ && idx + 1 < data.rows())
1338 {
1339 Index shift = 2 * static_cast<Index>(sizeof(Scalar)) - substrLen_ - offset;
1340 code |= data(idx+1, 0) << (shift * 8);
1341 }
1342
1343 return code;
1344 }
1345 public:
1346 MultiIndexHashing()
1347 : distance_(), dataCopy_(), data_(nullptr), sorted_(true),
1348 maxDist_(0), substrLen_(1), threads_(1)
1349 { }
1350
1351 /** Constructs an index with the given data.
1352 * This does not build the the index.
1353 * @param data NxM matrix, M points of dimension N
1354 * @param copy if true copies the data, otherwise assumes static data */
1355 MultiIndexHashing(const Matrix &data, const bool copy=false)
1356 : MultiIndexHashing()
1357 {
1358 setData(data, copy);
1359 }
1360
1361 /** Set the maximum distance for querying the index.
1362 * Note that if no maximum distance is used, this algorithm performs
1363 * basically a brute force search.
1364 * @param maxDist maximum distance, <= 0 for no limit */
1365 void setMaxDistance(const Scalar maxDist)
1366 {
1367 maxDist_ = maxDist;
1368 }
1369
1370 /** Set if the points returned by the queries should be sorted
1371 * according to their distance to the query points.
1372 * @param sorted sort query results */
1373 void setSorted(const bool sorted)
1374 {
1375 sorted_ = sorted;
1376 }
1377
1378 /** Set the amount of threads that should be used for building and
1379 * querying the tree.
1380 * OpenMP has to be enabled for this to work.
1381 * @param threads amount of threads, 0 for optimal choice */
1382 void setThreads(const unsigned int threads)
1383 {
1384 threads_ = threads;
1385 }
1386
1387 /** Set the length of substrings (in bytes) used for multi index hashing.
1388 * @param len lentth of bucket substrings in bytes*/
1389 void setSubstringLength(const Index len)
1390 {
1391 substrLen_ = len;
1392 }
1393
1394 /** Set the data points used for the KNN search.
1395 * @param data NxM matrix, M points of dimension N
1396 * @param copy if true data is copied, assumes static data otherwise */
1397 void setData(const Matrix &data, const bool copy = false)
1398 {
1399 clear();
1400 if(copy)
1401 {
1402 dataCopy_ = data;
1403 data_ = &dataCopy_;
1404 }
1405 else
1406 {
1407 data_ = &data;
1408 }
1409 }
1410
1411 void build()
1412 {
1413 if(data_ == nullptr)
1414 throw std::runtime_error("cannot build MultiIndexHashing; data not set");
1415 if(data_->size() == 0)
1416 throw std::runtime_error("cannot build MultiIndexHashing; data is empty");
1417
1418 const Matrix &data = *data_;
1419 const Index bytesPerVec = data.rows() * static_cast<Index>(sizeof(Scalar));
1420 if(bytesPerVec % substrLen_ != 0)
1421 throw std::runtime_error("cannot build MultiIndexHashing; cannot divide byte count per vector by substring length without remainings");
1422
1423 buckets_.clear();
1424 buckets_.resize(bytesPerVec / substrLen_);
1425
1426 for(size_t i = 0; i < buckets_.size(); ++i)
1427 {
1428 Index start = static_cast<Index>(i) * substrLen_;
1429 Index idx = start / static_cast<Index>(sizeof(Scalar));
1430 Index offset = start % static_cast<Index>(sizeof(Scalar));
1431 std::map<Scalar, std::vector<Index>> &map = buckets_[i];
1432
1433 for(Index c = 0; c < data.cols(); ++c)
1434 {
1435 Scalar code = extractCode(data.col(c), idx, offset);
1436 if(map.find(code) == map.end())
1437 map[code] = std::vector<Index>();
1438 map[code].push_back(c);
1439 }
1440 }
1441 }
1442
1443 template<typename Derived>
1444 void query(const Eigen::MatrixBase<Derived> &queryPoints,
1445 const size_t knn,
1446 Matrixi &indices,
1447 Matrix &distances) const
1448 {
1449 if(buckets_.size() == 0)
1450 throw std::runtime_error("cannot query MultiIndexHashing; not built yet");
1451 if(queryPoints.rows() != dimension())
1452 throw std::runtime_error("cannot query MultiIndexHashing; data and query points do not have same dimension");
1453
1454 const Matrix &data = *data_;
1455
1456 indices.setConstant(knn, queryPoints.cols(), -1);
1457 distances.setConstant(knn, queryPoints.cols(), -1);
1458
1459 Index *indicesRaw = indices.data();
1460 Scalar *distsRaw = distances.data();
1461
1462 Scalar maxDistPart = maxDist_ / buckets_.size();
1463
1464 #pragma omp parallel for num_threads(threads_)
1465 for(Index c = 0; c < queryPoints.cols(); ++c)
1466 {
1467 std::set<Index> candidates;
1468 for(size_t i = 0; i < buckets_.size(); ++i)
1469 {
1470 Index start = static_cast<Index>(i) * substrLen_;
1471 Index idx = start / static_cast<Index>(sizeof(Scalar));
1472 Index offset = start % static_cast<Index>(sizeof(Scalar));
1473 const std::map<Scalar, std::vector<Index>> &map = buckets_[i];
1474
1475 Scalar code = extractCode(queryPoints.col(c), idx, offset);
1476 for(const auto &x: map)
1477 {
1478 Scalar dist = distance_(x.first, code);
1479 if(maxDistPart <= 0 || dist <= maxDistPart)
1480 {
1481 for(size_t j = 0; j < x.second.size(); ++j)
1482 candidates.insert(x.second[j]);
1483 }
1484 }
1485 }
1486
1487 Scalar *distPoint = &distsRaw[c * knn];
1488 Index *idxPoint = &indicesRaw[c * knn];
1489 // create heap to find nearest neighbours
1490 QueryHeap<Scalar> dataHeap(idxPoint, distPoint, knn);
1491
1492 for(Index idx: candidates)
1493 {
1494 Scalar dist = distance_(data.col(idx), queryPoints.col(c));
1495
1496 bool isInRange = maxDist_ <= 0 || dist <= maxDist_;
1497 bool isImprovement = !dataHeap.full() ||
1498 dist < dataHeap.front();
1499 if(isInRange && isImprovement)
1500 {
1501 if(dataHeap.full())
1502 dataHeap.pop();
1503 dataHeap.push(idx, dist);
1504 }
1505 }
1506
1507 if(sorted_)
1508 dataHeap.sort();
1509 }
1510 }
1511
1512 /** Returns the amount of data points stored in the search index.
1513 * @return number of data points */
1514 Index size() const
1515 {
1516 return data_ == nullptr ? 0 : data_->cols();
1517 }
1518
1519 /** Returns the dimension of the data points in the search index.
1520 * @return dimension of data points */
1521 Index dimension() const
1522 {
1523 return data_ == nullptr ? 0 : data_->rows();
1524 }
1525
1526 void clear()
1527 {
1528 data_ = nullptr;
1529 dataCopy_.resize(0, 0);
1530 buckets_.clear();
1531 }
1532
1533 };
1534
1535 #ifdef KNNCPP_FLANN
1536
1537 /** Wrapper class of FLANN kdtrees for the use with Eigen3. */
1538 template<typename Scalar,
1539 typename Distance=flann::L2_Simple<Scalar>>
1540 class KDTreeFlann
1541 {
1542 public:
1543 typedef Eigen::Matrix<Scalar, Eigen::Dynamic, Eigen::Dynamic> Matrix;
1544 typedef Eigen::Matrix<Scalar, Eigen::Dynamic, 1> Vector;
1545 typedef Eigen::Matrix<int, Eigen::Dynamic, Eigen::Dynamic> Matrixi;
1546
1547 private:
1548 typedef flann::Index<Distance> FlannIndex;
1549
1550 Matrix dataCopy_;
1551 Matrix *dataPoints_;
1552
1553 FlannIndex *index_;
1554 flann::SearchParams searchParams_;
1555 flann::IndexParams indexParams_;
1556 Scalar maxDist_;
1557
1558 public:
1559 KDTreeFlann()
1560 : dataCopy_(), dataPoints_(nullptr), index_(nullptr),
1561 searchParams_(32, 0, false),
1562 indexParams_(flann::KDTreeSingleIndexParams(15)),
1563 maxDist_(0)
1564 {
1565 }
1566
1567 KDTreeFlann(Matrix &data, const bool copy = false)
1568 : KDTreeFlann()
1569 {
1570 setData(data, copy);
1571 }
1572
1573 ~KDTreeFlann()
1574 {
1575 clear();
1576 }
1577
1578 void setIndexParams(const flann::IndexParams &params)
1579 {
1580 indexParams_ = params;
1581 }
1582
1583 void setChecks(const int checks)
1584 {
1585 searchParams_.checks = checks;
1586 }
1587
1588 void setSorted(const bool sorted)
1589 {
1590 searchParams_.sorted = sorted;
1591 }
1592
1593 void setThreads(const int threads)
1594 {
1595 searchParams_.cores = threads;
1596 }
1597
1598 void setEpsilon(const float eps)
1599 {
1600 searchParams_.eps = eps;
1601 }
1602
1603 void setMaxDistance(const Scalar dist)
1604 {
1605 maxDist_ = dist;
1606 }
1607
1608 void setData(Matrix &data, const bool copy = false)
1609 {
1610 if(copy)
1611 {
1612 dataCopy_ = data;
1613 dataPoints_ = &dataCopy_;
1614 }
1615 else
1616 {
1617 dataPoints_ = &data;
1618 }
1619
1620 clear();
1621 }
1622
1623 void build()
1624 {
1625 if(dataPoints_ == nullptr)
1626 throw std::runtime_error("cannot build KDTree; data not set");
1627 if(dataPoints_->size() == 0)
1628 throw std::runtime_error("cannot build KDTree; data is empty");
1629
1630 if(index_ != nullptr)
1631 delete index_;
1632
1633 flann::Matrix<Scalar> dataPts(
1634 dataPoints_->data(),
1635 dataPoints_->cols(),
1636 dataPoints_->rows());
1637
1638 index_ = new FlannIndex(dataPts, indexParams_);
1639 index_->buildIndex();
1640 }
1641
1642 void query(Matrix &queryPoints,
1643 const size_t knn,
1644 Matrixi &indices,
1645 Matrix &distances) const
1646 {
1647 if(index_ == nullptr)
1648 throw std::runtime_error("cannot query KDTree; not built yet");
1649 if(dataPoints_->rows() != queryPoints.rows())
1650 throw std::runtime_error("cannot query KDTree; KDTree has different dimension than query data");
1651
1652 // resize result matrices
1653 distances.resize(knn, queryPoints.cols());
1654 indices.resize(knn, queryPoints.cols());
1655
1656 // wrap matrices into flann matrices
1657 flann::Matrix<Scalar> queryPts(
1658 queryPoints.data(),
1659 queryPoints.cols(),
1660 queryPoints.rows());
1661 flann::Matrix<int> indicesF(
1662 indices.data(),
1663 indices.cols(),
1664 indices.rows());
1665 flann::Matrix<Scalar> distancesF(
1666 distances.data(),
1667 distances.cols(),
1668 distances.rows());
1669
1670 // if maximum distance was set then use radius search
1671 if(maxDist_ > 0)
1672 index_->radiusSearch(queryPts, indicesF, distancesF, maxDist_, searchParams_);
1673 else
1674 index_->knnSearch(queryPts, indicesF, distancesF, knn, searchParams_);
1675
1676 // make result matrices compatible to API
1677 #pragma omp parallel for num_threads(searchParams_.cores)
1678 for(Index i = 0; i < indices.cols(); ++i)
1679 {
1680 bool found = false;
1681 for(Index j = 0; j < indices.rows(); ++j)
1682 {
1683 if(indices(j, i) == -1)
1684 found = true;
1685
1686 if(found)
1687 {
1688 indices(j, i) = -1;
1689 distances(j, i) = -1;
1690 }
1691 }
1692 }
1693 }
1694
1695 Index size() const
1696 {
1697 return dataPoints_ == nullptr ? 0 : dataPoints_->cols();
1698 }
1699
1700 Index dimension() const
1701 {
1702 return dataPoints_ == nullptr ? 0 : dataPoints_->rows();
1703 }
1704
1705 void clear()
1706 {
1707 if(index_ != nullptr)
1708 {
1709 delete index_;
1710 index_ = nullptr;
1711 }
1712 }
1713
1714 FlannIndex &flannIndex()
1715 {
1716 return index_;
1717 }
1718 };
1719
1720 typedef KDTreeFlann<double> KDTreeFlannd;
1721 typedef KDTreeFlann<float> KDTreeFlannf;
1722
1723 #endif
1724}
1725
1726#endif