Commit | Line | Data |
---|---|---|
762721a5 BA |
1 | /* knncpp.h |
2 | * | |
3 | * Author: Fabian Meyer | |
4 | * Created On: 22 Aug 2021 | |
5 | * License: MIT | |
6 | */ | |
7 | ||
8 | #ifndef KNNCPP_H_ | |
9 | #define KNNCPP_H_ | |
10 | ||
11 | #include <Eigen/Geometry> | |
12 | #include <vector> | |
13 | #include <map> | |
14 | #include <set> | |
15 | ||
16 | #ifdef KNNCPP_FLANN | |
17 | ||
18 | #include <flann/flann.hpp> | |
19 | ||
20 | #endif | |
21 | ||
22 | namespace knncpp | |
23 | { | |
24 | /******************************************************** | |
25 | * Matrix Definitions | |
26 | *******************************************************/ | |
27 | ||
28 | typedef typename Eigen::MatrixXd::Index Index; | |
29 | ||
30 | typedef Eigen::Matrix<Index, Eigen::Dynamic, 1> Vectori; | |
31 | typedef Eigen::Matrix<Index, 2, 1> Vector2i; | |
32 | typedef Eigen::Matrix<Index, 3, 1> Vector3i; | |
33 | typedef Eigen::Matrix<Index, 4, 1> Vector4i; | |
34 | typedef Eigen::Matrix<Index, 5, 1> Vector5i; | |
35 | ||
36 | typedef Eigen::Matrix<Index, Eigen::Dynamic, Eigen::Dynamic> Matrixi; | |
37 | typedef Eigen::Matrix<Index, 2, 2> Matrix2i; | |
38 | typedef Eigen::Matrix<Index, 3, 3> Matrix3i; | |
39 | typedef Eigen::Matrix<Index, 4, 4> Matrix4i; | |
40 | typedef Eigen::Matrix<Index, 5, 5> Matrix5i; | |
41 | ||
42 | typedef Eigen::Matrix<float, Eigen::Dynamic, Eigen::Dynamic> Matrixf; | |
43 | typedef Eigen::Matrix<double, Eigen::Dynamic, Eigen::Dynamic> Matrixd; | |
44 | ||
45 | /******************************************************** | |
46 | * Distance Functors | |
47 | *******************************************************/ | |
48 | ||
49 | /** Manhatten distance functor. | |
50 | * This the same as the L1 minkowski distance but more efficient. | |
51 | * @see EuclideanDistance, ChebyshevDistance, MinkowskiDistance */ | |
52 | template <typename Scalar> | |
53 | struct ManhattenDistance | |
54 | { | |
55 | /** Compute the unrooted distance between two vectors. | |
56 | * @param lhs vector on left hand side | |
57 | * @param rhs vector on right hand side */ | |
58 | template<typename DerivedA, typename DerivedB> | |
59 | Scalar operator()(const Eigen::MatrixBase<DerivedA> &lhs, | |
60 | const Eigen::MatrixBase<DerivedB> &rhs) const | |
61 | { | |
62 | static_assert( | |
63 | std::is_same<typename Eigen::MatrixBase<DerivedA>::Scalar,Scalar>::value, | |
64 | "distance scalar and input matrix A must have same type"); | |
65 | static_assert( | |
66 | std::is_same<typename Eigen::MatrixBase<DerivedB>::Scalar, Scalar>::value, | |
67 | "distance scalar and input matrix B must have same type"); | |
68 | ||
69 | return (lhs - rhs).cwiseAbs().sum(); | |
70 | } | |
71 | ||
72 | /** Compute the unrooted distance between two scalars. | |
73 | * @param lhs scalar on left hand side | |
74 | * @param rhs scalar on right hand side */ | |
75 | Scalar operator()(const Scalar lhs, | |
76 | const Scalar rhs) const | |
77 | { | |
78 | return std::abs(lhs - rhs); | |
79 | } | |
80 | ||
81 | /** Compute the root of a unrooted distance value. | |
82 | * @param value unrooted distance value */ | |
83 | Scalar operator()(const Scalar val) const | |
84 | { | |
85 | return val; | |
86 | } | |
87 | }; | |
88 | ||
89 | /** Euclidean distance functor. | |
90 | * This the same as the L2 minkowski distance but more efficient. | |
91 | * @see ManhattenDistance, ChebyshevDistance, MinkowskiDistance */ | |
92 | template <typename Scalar> | |
93 | struct EuclideanDistance | |
94 | { | |
95 | /** Compute the unrooted distance between two vectors. | |
96 | * @param lhs vector on left hand side | |
97 | * @param rhs vector on right hand side */ | |
98 | template<typename DerivedA, typename DerivedB> | |
99 | Scalar operator()(const Eigen::MatrixBase<DerivedA> &lhs, | |
100 | const Eigen::MatrixBase<DerivedB> &rhs) const | |
101 | { | |
102 | static_assert( | |
103 | std::is_same<typename Eigen::MatrixBase<DerivedA>::Scalar,Scalar>::value, | |
104 | "distance scalar and input matrix A must have same type"); | |
105 | static_assert( | |
106 | std::is_same<typename Eigen::MatrixBase<DerivedB>::Scalar, Scalar>::value, | |
107 | "distance scalar and input matrix B must have same type"); | |
108 | ||
109 | return (lhs - rhs).cwiseAbs2().sum(); | |
110 | } | |
111 | ||
112 | /** Compute the unrooted distance between two scalars. | |
113 | * @param lhs scalar on left hand side | |
114 | * @param rhs scalar on right hand side */ | |
115 | Scalar operator()(const Scalar lhs, | |
116 | const Scalar rhs) const | |
117 | { | |
118 | Scalar diff = lhs - rhs; | |
119 | return diff * diff; | |
120 | } | |
121 | ||
122 | /** Compute the root of a unrooted distance value. | |
123 | * @param value unrooted distance value */ | |
124 | Scalar operator()(const Scalar val) const | |
125 | { | |
126 | return std::sqrt(val); | |
127 | } | |
128 | }; | |
129 | ||
130 | /** General minkowski distance functor. | |
131 | * The infinite version is only available through the chebyshev distance. | |
132 | * @see ManhattenDistance, EuclideanDistance, ChebyshevDistance */ | |
133 | template <typename Scalar, int P> | |
134 | struct MinkowskiDistance | |
135 | { | |
136 | struct Pow | |
137 | { | |
138 | Scalar operator()(const Scalar val) const | |
139 | { | |
140 | Scalar result = 1; | |
141 | for(int i = 0; i < P; ++i) | |
142 | result *= val; | |
143 | return result; | |
144 | } | |
145 | }; | |
146 | ||
147 | /** Compute the unrooted distance between two vectors. | |
148 | * @param lhs vector on left hand side | |
149 | * @param rhs vector on right hand side */ | |
150 | template<typename DerivedA, typename DerivedB> | |
151 | Scalar operator()(const Eigen::MatrixBase<DerivedA> &lhs, | |
152 | const Eigen::MatrixBase<DerivedB> &rhs) const | |
153 | { | |
154 | static_assert( | |
155 | std::is_same<typename Eigen::MatrixBase<DerivedA>::Scalar,Scalar>::value, | |
156 | "distance scalar and input matrix A must have same type"); | |
157 | static_assert( | |
158 | std::is_same<typename Eigen::MatrixBase<DerivedB>::Scalar, Scalar>::value, | |
159 | "distance scalar and input matrix B must have same type"); | |
160 | ||
161 | return (lhs - rhs).cwiseAbs().unaryExpr(MinkowskiDistance::Pow()).sum(); | |
162 | } | |
163 | ||
164 | /** Compute the unrooted distance between two scalars. | |
165 | * @param lhs scalar on left hand side | |
166 | * @param rhs scalar on right hand side */ | |
167 | Scalar operator()(const Scalar lhs, | |
168 | const Scalar rhs) const | |
169 | { | |
170 | return std::pow(std::abs(lhs - rhs), P);; | |
171 | } | |
172 | ||
173 | /** Compute the root of a unrooted distance value. | |
174 | * @param value unrooted distance value */ | |
175 | Scalar operator()(const Scalar val) const | |
176 | { | |
177 | return std::pow(val, 1 / static_cast<Scalar>(P)); | |
178 | } | |
179 | }; | |
180 | ||
181 | /** Chebyshev distance functor. | |
182 | * This distance is the same as infinity minkowski distance. | |
183 | * @see ManhattenDistance, EuclideanDistance, MinkowskiDistance */ | |
184 | template<typename Scalar> | |
185 | struct ChebyshevDistance | |
186 | { | |
187 | /** Compute the unrooted distance between two vectors. | |
188 | * @param lhs vector on left hand side | |
189 | * @param rhs vector on right hand side */ | |
190 | template<typename DerivedA, typename DerivedB> | |
191 | Scalar operator()(const Eigen::MatrixBase<DerivedA> &lhs, | |
192 | const Eigen::MatrixBase<DerivedB> &rhs) const | |
193 | { | |
194 | static_assert( | |
195 | std::is_same<typename Eigen::MatrixBase<DerivedA>::Scalar,Scalar>::value, | |
196 | "distance scalar and input matrix A must have same type"); | |
197 | static_assert( | |
198 | std::is_same<typename Eigen::MatrixBase<DerivedB>::Scalar, Scalar>::value, | |
199 | "distance scalar and input matrix B must have same type"); | |
200 | ||
201 | return (lhs - rhs).cwiseAbs().maxCoeff(); | |
202 | } | |
203 | ||
204 | /** Compute the unrooted distance between two scalars. | |
205 | * @param lhs scalar on left hand side | |
206 | * @param rhs scalar on right hand side */ | |
207 | Scalar operator()(const Scalar lhs, | |
208 | const Scalar rhs) const | |
209 | { | |
210 | return std::abs(lhs - rhs); | |
211 | } | |
212 | ||
213 | /** Compute the root of a unrooted distance value. | |
214 | * @param value unrooted distance value */ | |
215 | Scalar operator()(const Scalar val) const | |
216 | { | |
217 | return val; | |
218 | } | |
219 | }; | |
220 | ||
221 | /** Hamming distance functor. | |
222 | * The distance vectors have to be of integral type and should hold the | |
223 | * information vectors as bitmasks. | |
224 | * Performs a XOR operation on the vectors and counts the number of set | |
225 | * ones. */ | |
226 | template<typename Scalar> | |
227 | struct HammingDistance | |
228 | { | |
229 | static_assert(std::is_integral<Scalar>::value, | |
230 | "HammingDistance requires integral Scalar type"); | |
231 | ||
232 | struct XOR | |
233 | { | |
234 | Scalar operator()(const Scalar lhs, const Scalar rhs) const | |
235 | { | |
236 | return lhs ^ rhs; | |
237 | } | |
238 | }; | |
239 | ||
240 | struct BitCount | |
241 | { | |
242 | Scalar operator()(const Scalar lhs) const | |
243 | { | |
244 | Scalar copy = lhs; | |
245 | Scalar result = 0; | |
246 | while(copy != static_cast<Scalar>(0)) | |
247 | { | |
248 | ++result; | |
249 | copy &= (copy - 1); | |
250 | } | |
251 | ||
252 | return result; | |
253 | } | |
254 | }; | |
255 | ||
256 | /** Compute the unrooted distance between two vectors. | |
257 | * @param lhs vector on left hand side | |
258 | * @param rhs vector on right hand side */ | |
259 | template<typename DerivedA, typename DerivedB> | |
260 | Scalar operator()(const Eigen::MatrixBase<DerivedA> &lhs, | |
261 | const Eigen::MatrixBase<DerivedB> &rhs) const | |
262 | { | |
263 | static_assert( | |
264 | std::is_same<typename Eigen::MatrixBase<DerivedA>::Scalar,Scalar>::value, | |
265 | "distance scalar and input matrix A must have same type"); | |
266 | static_assert( | |
267 | std::is_same<typename Eigen::MatrixBase<DerivedB>::Scalar, Scalar>::value, | |
268 | "distance scalar and input matrix B must have same type"); | |
269 | ||
270 | return lhs. | |
271 | binaryExpr(rhs, XOR()). | |
272 | unaryExpr(BitCount()). | |
273 | sum(); | |
274 | } | |
275 | ||
276 | /** Compute the unrooted distance between two scalars. | |
277 | * @param lhs scalar on left hand side | |
278 | * @param rhs scalar on right hand side */ | |
279 | Scalar operator()(const Scalar lhs, | |
280 | const Scalar rhs) const | |
281 | { | |
282 | BitCount cnt; | |
283 | XOR xOr; | |
284 | return cnt(xOr(lhs, rhs)); | |
285 | } | |
286 | ||
287 | /** Compute the root of a unrooted distance value. | |
288 | * @param value unrooted distance value */ | |
289 | Scalar operator()(const Scalar value) const | |
290 | { | |
291 | return value; | |
292 | } | |
293 | }; | |
294 | ||
295 | /** Efficient heap structure to query nearest neighbours. */ | |
296 | template<typename Scalar> | |
297 | class QueryHeap | |
298 | { | |
299 | private: | |
300 | Index *indices_ = nullptr; | |
301 | Scalar *distances_ = nullptr; | |
302 | size_t maxSize_ = 0; | |
303 | size_t size_ = 0; | |
304 | public: | |
305 | /** Creates a query heap with the given index and distance memory regions. */ | |
306 | QueryHeap(Index *indices, Scalar *distances, const size_t maxSize) | |
307 | : indices_(indices), distances_(distances), maxSize_(maxSize) | |
308 | { } | |
309 | ||
310 | /** Pushes a new query data set into the heap with the given | |
311 | * index and distance. | |
312 | * The index identifies the point for which the given distance | |
313 | * was computed. | |
314 | * @param idx index / ID of the query point | |
315 | * @param dist distance that was computed for the query point*/ | |
316 | void push(const Index idx, const Scalar dist) | |
317 | { | |
318 | assert(!full()); | |
319 | ||
320 | // add new value at the end | |
321 | indices_[size_] = idx; | |
322 | distances_[size_] = dist; | |
323 | ++size_; | |
324 | ||
325 | // upheap | |
326 | size_t k = size_ - 1; | |
327 | size_t tmp = (k - 1) / 2; | |
328 | while(k > 0 && distances_[tmp] < dist) | |
329 | { | |
330 | distances_[k] = distances_[tmp]; | |
331 | indices_[k] = indices_[tmp]; | |
332 | k = tmp; | |
333 | tmp = (k - 1) / 2; | |
334 | } | |
335 | distances_[k] = dist; | |
336 | indices_[k] = idx; | |
337 | } | |
338 | ||
339 | /** Removes the element at the front of the heap and restores | |
340 | * the heap order. */ | |
341 | void pop() | |
342 | { | |
343 | assert(!empty()); | |
344 | ||
345 | // replace first element with last | |
346 | --size_; | |
347 | distances_[0] = distances_[size_]; | |
348 | indices_[0] = indices_[size_]; | |
349 | ||
350 | // downheap | |
351 | size_t k = 0; | |
352 | size_t j; | |
353 | Scalar dist = distances_[0]; | |
354 | Index idx = indices_[0]; | |
355 | while(2 * k + 1 < size_) | |
356 | { | |
357 | j = 2 * k + 1; | |
358 | if(j + 1 < size_ && distances_[j+1] > distances_[j]) | |
359 | ++j; | |
360 | // j references now greatest child | |
361 | if(dist >= distances_[j]) | |
362 | break; | |
363 | distances_[k] = distances_[j]; | |
364 | indices_[k] = indices_[j]; | |
365 | k = j; | |
366 | } | |
367 | distances_[k] = dist; | |
368 | indices_[k] = idx; | |
369 | } | |
370 | ||
371 | /** Returns the distance of the element in front of the heap. */ | |
372 | Scalar front() const | |
373 | { | |
374 | assert(!empty()); | |
375 | return distances_[0]; | |
376 | } | |
377 | ||
378 | /** Determines if this query heap is full. | |
379 | * The heap is considered full if its number of elements | |
380 | * has reached its max size. | |
381 | * @return true if the heap is full, else false */ | |
382 | bool full() const | |
383 | { | |
384 | return size_ >= maxSize_; | |
385 | } | |
386 | ||
387 | /** Determines if this query heap is empty. | |
388 | * @return true if the heap contains no elements, else false */ | |
389 | bool empty() const | |
390 | { | |
391 | return size_ == 0; | |
392 | } | |
393 | ||
394 | /** Returns the number of elements within the query heap. | |
395 | * @return number of elements in the heap */ | |
396 | size_t size() const | |
397 | { | |
398 | return size_; | |
399 | } | |
400 | ||
401 | /** Clears the query heap. */ | |
402 | void clear() | |
403 | { | |
404 | size_ = 0; | |
405 | } | |
406 | ||
407 | /** Sorts the elements within the heap according to | |
408 | * their distance. */ | |
409 | void sort() | |
410 | { | |
411 | size_t cnt = size_; | |
412 | for(size_t i = 0; i < cnt; ++i) | |
413 | { | |
414 | Index idx = indices_[0]; | |
415 | Scalar dist = distances_[0]; | |
416 | pop(); | |
417 | indices_[cnt - i - 1] = idx; | |
418 | distances_[cnt - i - 1] = dist; | |
419 | } | |
420 | } | |
421 | }; | |
422 | ||
423 | /** Class for performing brute force knn search. */ | |
424 | template<typename Scalar, | |
425 | typename Distance=EuclideanDistance<Scalar>> | |
426 | class BruteForce | |
427 | { | |
428 | public: | |
429 | typedef Eigen::Matrix<Scalar, Eigen::Dynamic, Eigen::Dynamic> Matrix; | |
430 | typedef Eigen::Matrix<Scalar, Eigen::Dynamic, 1> Vector; | |
431 | typedef knncpp::Matrixi Matrixi; | |
432 | private: | |
433 | Distance distance_ = Distance(); | |
434 | Matrix dataCopy_ = Matrix(); | |
435 | const Matrix *data_ = nullptr; | |
436 | ||
437 | bool sorted_ = true; | |
438 | bool takeRoot_ = true; | |
439 | Index threads_ = 1; | |
440 | Scalar maxDist_ = 0; | |
441 | ||
442 | public: | |
443 | ||
444 | BruteForce() = default; | |
445 | ||
446 | /** Constructs a brute force instance with the given data. | |
447 | * @param data NxM matrix, M points of dimension N | |
448 | * @param copy if true copies the data, otherwise assumes static data */ | |
449 | BruteForce(const Matrix &data, const bool copy = false) | |
450 | : BruteForce() | |
451 | { | |
452 | setData(data, copy); | |
453 | } | |
454 | ||
455 | /** Set if the points returned by the queries should be sorted | |
456 | * according to their distance to the query points. | |
457 | * @param sorted sort query results */ | |
458 | void setSorted(const bool sorted) | |
459 | { | |
460 | sorted_ = sorted; | |
461 | } | |
462 | ||
463 | /** Set if the distances after the query should be rooted or not. | |
464 | * Taking the root of the distances increases query time, but the | |
465 | * function will return true distances instead of their powered | |
466 | * versions. | |
467 | * @param takeRoot set true if root should be taken else false */ | |
468 | void setTakeRoot(const bool takeRoot) | |
469 | { | |
470 | takeRoot_ = takeRoot; | |
471 | } | |
472 | ||
473 | /** Set the amount of threads that should be used for querying. | |
474 | * OpenMP has to be enabled for this to work. | |
475 | * @param threads amount of threads, 0 for optimal choice */ | |
476 | void setThreads(const unsigned int threads) | |
477 | { | |
478 | threads_ = threads; | |
479 | } | |
480 | ||
481 | /** Set the maximum distance for querying the tree. | |
482 | * The search will be pruned if the maximum distance is set to any | |
483 | * positive number. | |
484 | * @param maxDist maximum distance, <= 0 for no limit */ | |
485 | void setMaxDistance(const Scalar maxDist) | |
486 | { | |
487 | maxDist_ = maxDist; | |
488 | } | |
489 | ||
490 | /** Set the data points used for this tree. | |
491 | * This does not build the tree. | |
492 | * @param data NxM matrix, M points of dimension N | |
493 | * @param copy if true data is copied, assumes static data otherwise */ | |
494 | void setData(const Matrix &data, const bool copy = false) | |
495 | { | |
496 | if(copy) | |
497 | { | |
498 | dataCopy_ = data; | |
499 | data_ = &dataCopy_; | |
500 | } | |
501 | else | |
502 | { | |
503 | data_ = &data; | |
504 | } | |
505 | } | |
506 | ||
507 | void setDistance(const Distance &distance) | |
508 | { | |
509 | distance_ = distance; | |
510 | } | |
511 | ||
512 | void build() | |
513 | { } | |
514 | ||
515 | template<typename Derived> | |
516 | void query(const Eigen::MatrixBase<Derived> &queryPoints, | |
517 | const size_t knn, | |
518 | Matrixi &indices, | |
519 | Matrix &distances) const | |
520 | { | |
521 | if(data_ == nullptr) | |
522 | throw std::runtime_error("cannot query BruteForce: data not set"); | |
523 | if(data_->size() == 0) | |
524 | throw std::runtime_error("cannot query BruteForce: data is empty"); | |
525 | if(queryPoints.rows() != dimension()) | |
526 | throw std::runtime_error("cannot query BruteForce: data and query descriptors do not have same dimension"); | |
527 | ||
528 | const Matrix &dataPoints = *data_; | |
529 | ||
530 | indices.setConstant(knn, queryPoints.cols(), -1); | |
531 | distances.setConstant(knn, queryPoints.cols(), -1); | |
532 | ||
533 | #pragma omp parallel for num_threads(threads_) | |
534 | for(Index i = 0; i < queryPoints.cols(); ++i) | |
535 | { | |
536 | Index *idxPoint = &indices.data()[i * knn]; | |
537 | Scalar *distPoint = &distances.data()[i * knn]; | |
538 | ||
539 | QueryHeap<Scalar> heap(idxPoint, distPoint, knn); | |
540 | ||
541 | for(Index j = 0; j < dataPoints.cols(); ++j) | |
542 | { | |
543 | Scalar dist = distance_(queryPoints.col(i), dataPoints.col(j)); | |
544 | ||
545 | // check if point is in range if max distance was set | |
546 | bool isInRange = maxDist_ <= 0 || dist <= maxDist_; | |
547 | // check if this node was an improvement if heap is already full | |
548 | bool isImprovement = !heap.full() || | |
549 | dist < heap.front(); | |
550 | if(isInRange && isImprovement) | |
551 | { | |
552 | if(heap.full()) | |
553 | heap.pop(); | |
554 | heap.push(j, dist); | |
555 | } | |
556 | } | |
557 | ||
558 | if(sorted_) | |
559 | heap.sort(); | |
560 | ||
561 | if(takeRoot_) | |
562 | { | |
563 | for(size_t j = 0; j < knn; ++j) | |
564 | { | |
565 | if(idxPoint[j] < 0) | |
566 | break; | |
567 | distPoint[j] = distance_(distPoint[j]); | |
568 | } | |
569 | } | |
570 | } | |
571 | } | |
572 | ||
573 | /** Returns the amount of data points stored in the search index. | |
574 | * @return number of data points */ | |
575 | Index size() const | |
576 | { | |
577 | return data_ == nullptr ? 0 : data_->cols(); | |
578 | } | |
579 | ||
580 | /** Returns the dimension of the data points in the search index. | |
581 | * @return dimension of data points */ | |
582 | Index dimension() const | |
583 | { | |
584 | return data_ == nullptr ? 0 : data_->rows(); | |
585 | } | |
586 | }; | |
587 | ||
588 | // template<typename Scalar> | |
589 | // struct MeanMidpointRule | |
590 | // { | |
591 | // typedef Eigen::Matrix<Scalar, Eigen::Dynamic, Eigen::Dynamic> Matrix; | |
592 | // typedef knncpp::Matrixi Matrixi; | |
593 | ||
594 | // void operator(const Matrix &data, const Matrixi &indices, Index split) | |
595 | // }; | |
596 | ||
597 | /** Class for performing k nearest neighbour searches with minkowski distances. | |
598 | * This kdtree only works reliably with the minkowski distance and its | |
599 | * special cases like manhatten or euclidean distance. | |
600 | * @see ManhattenDistance, EuclideanDistance, ChebyshevDistance, MinkowskiDistance*/ | |
601 | template<typename _Scalar, int _Dimension, typename _Distance> | |
602 | class KDTreeMinkowski | |
603 | { | |
604 | public: | |
605 | typedef _Scalar Scalar; | |
606 | typedef _Distance Distance; | |
607 | typedef Eigen::Matrix<Scalar, Eigen::Dynamic, Eigen::Dynamic> Matrix; | |
608 | typedef Eigen::Matrix<Scalar, _Dimension, Eigen::Dynamic> DataMatrix; | |
609 | typedef Eigen::Matrix<Scalar, _Dimension, 1> DataVector; | |
610 | typedef knncpp::Matrixi Matrixi; | |
611 | private: | |
612 | typedef Eigen::Matrix<Scalar, 2, 1> Bounds; | |
613 | typedef Eigen::Matrix<Scalar, 2, _Dimension> BoundingBox; | |
614 | ||
615 | /** Struct representing a node in the KDTree. | |
616 | * It can be either a inner node or a leaf node. */ | |
617 | struct Node | |
618 | { | |
619 | /** Indices of data points in this leaf node. */ | |
620 | Index startIdx = 0; | |
621 | Index length = 0; | |
622 | ||
623 | /** Left child of this inner node. */ | |
624 | Index left = -1; | |
625 | /** Right child of this inner node. */ | |
626 | Index right = -1; | |
627 | /** Axis of the axis aligned splitting hyper plane. */ | |
628 | Index splitaxis = -1; | |
629 | /** Translation of the axis aligned splitting hyper plane. */ | |
630 | Scalar splitpoint = 0; | |
631 | /** Lower end of the splitpoint range */ | |
632 | Scalar splitlower = 0; | |
633 | /** Upper end of the splitpoint range */ | |
634 | Scalar splitupper = 0; | |
635 | ||
636 | ||
637 | Node() = default; | |
638 | ||
639 | /** Constructor for leaf nodes */ | |
640 | Node(const Index startIdx, const Index length) | |
641 | : startIdx(startIdx), length(length) | |
642 | { } | |
643 | ||
644 | /** Constructor for inner nodes */ | |
645 | Node(const Index splitaxis, const Scalar splitpoint, | |
646 | const Index left, const Index right) | |
647 | : left(left), right(right), | |
648 | splitaxis(splitaxis), splitpoint(splitpoint) | |
649 | { } | |
650 | ||
651 | bool isLeaf() const | |
652 | { | |
653 | return !hasLeft() && !hasRight(); | |
654 | } | |
655 | ||
656 | bool isInner() const | |
657 | { | |
658 | return hasLeft() && hasRight(); | |
659 | } | |
660 | ||
661 | bool hasLeft() const | |
662 | { | |
663 | return left >= 0; | |
664 | } | |
665 | ||
666 | bool hasRight() const | |
667 | { | |
668 | return right >= 0; | |
669 | } | |
670 | }; | |
671 | ||
672 | DataMatrix dataCopy_ = DataMatrix(); | |
673 | const DataMatrix *data_ = nullptr; | |
674 | std::vector<Index> indices_ = std::vector<Index>(); | |
675 | std::vector<Node> nodes_ = std::vector<Node>(); | |
676 | ||
677 | Index bucketSize_ = 16; | |
678 | bool sorted_ = true; | |
679 | bool compact_ = false; | |
680 | bool balanced_ = false; | |
681 | bool takeRoot_ = true; | |
682 | Index threads_ = 0; | |
683 | Scalar maxDist_ = 0; | |
684 | ||
685 | Distance distance_ = Distance(); | |
686 | ||
687 | BoundingBox bbox_ = BoundingBox(); | |
688 | ||
689 | Index buildLeafNode(const Index startIdx, | |
690 | const Index length, | |
691 | BoundingBox &bbox) | |
692 | { | |
693 | nodes_.push_back(Node(startIdx, length)); | |
694 | calculateBoundingBox(startIdx, length, bbox); | |
695 | return static_cast<Index>(nodes_.size() - 1); | |
696 | } | |
697 | ||
698 | /** Finds the minimum and maximum values of each dimension (row) in the | |
699 | * data matrix. Only respects the columns specified by the index | |
700 | * vector. | |
701 | * @param startIdx starting index within indices data structure to search for bounding box | |
702 | * @param length length of the block of indices*/ | |
703 | void calculateBoundingBox(const Index startIdx, | |
704 | const Index length, | |
705 | BoundingBox &bbox) const | |
706 | { | |
707 | assert(length > 0); | |
708 | assert(startIdx >= 0); | |
709 | assert(static_cast<size_t>(startIdx + length) <= indices_.size()); | |
710 | assert(data_->rows() == bbox.cols()); | |
711 | ||
712 | const DataMatrix &data = *data_; | |
713 | ||
714 | // initialize bounds of the bounding box | |
715 | Index first = indices_[startIdx]; | |
716 | for(Index i = 0; i < bbox.cols(); ++i) | |
717 | { | |
718 | bbox(0, i) = data(i, first); | |
719 | bbox(1, i) = data(i, first); | |
720 | } | |
721 | ||
722 | // search for min / max values in data | |
723 | for(Index i = 1; i < length; ++i) | |
724 | { | |
725 | // retrieve data index | |
726 | Index col = indices_[startIdx + i]; | |
727 | assert(col >= 0 && col < data.cols()); | |
728 | ||
729 | // check min and max for each dimension individually | |
730 | for(Index j = 0; j < data.rows(); ++j) | |
731 | { | |
732 | bbox(0, j) = std::min(bbox(0, j), data(j, col)); | |
733 | bbox(1, j) = std::max(bbox(1, j), data(j, col)); | |
734 | } | |
735 | } | |
736 | } | |
737 | ||
738 | /** Calculates the bounds (min / max values) for the given dimension and block of data. */ | |
739 | void calculateBounds(const Index startIdx, | |
740 | const Index length, | |
741 | const Index dim, | |
742 | Bounds &bounds) const | |
743 | { | |
744 | assert(length > 0); | |
745 | assert(startIdx >= 0); | |
746 | assert(static_cast<size_t>(startIdx + length) <= indices_.size()); | |
747 | ||
748 | const DataMatrix &data = *data_; | |
749 | ||
750 | bounds(0) = data(dim, indices_[startIdx]); | |
751 | bounds(1) = data(dim, indices_[startIdx]); | |
752 | ||
753 | for(Index i = 1; i < length; ++i) | |
754 | { | |
755 | Index col = indices_[startIdx + i]; | |
756 | assert(col >= 0 && col < data.cols()); | |
757 | ||
758 | bounds(0) = std::min(bounds(0), data(dim, col)); | |
759 | bounds(1) = std::max(bounds(1), data(dim, col)); | |
760 | } | |
761 | } | |
762 | ||
763 | void calculateSplittingMidpoint(const Index startIdx, | |
764 | const Index length, | |
765 | const BoundingBox &bbox, | |
766 | Index &splitaxis, | |
767 | Scalar &splitpoint, | |
768 | Index &splitoffset) | |
769 | { | |
770 | const DataMatrix &data = *data_; | |
771 | ||
772 | // search for axis with longest distance | |
773 | splitaxis = 0; | |
774 | Scalar splitsize = static_cast<Scalar>(0); | |
775 | for(Index i = 0; i < data.rows(); ++i) | |
776 | { | |
777 | Scalar diff = bbox(1, i) - bbox(0, i); | |
778 | if(diff > splitsize) | |
779 | { | |
780 | splitaxis = i; | |
781 | splitsize = diff; | |
782 | } | |
783 | } | |
784 | ||
785 | // calculate the bounds in this axis and update our data | |
786 | // accordingly | |
787 | Bounds bounds; | |
788 | calculateBounds(startIdx, length, splitaxis, bounds); | |
789 | splitsize = bounds(1) - bounds(0); | |
790 | ||
791 | const Index origSplitaxis = splitaxis; | |
792 | for(Index i = 0; i < data.rows(); ++i) | |
793 | { | |
794 | // skip the dimension of the previously found splitaxis | |
795 | if(i == origSplitaxis) | |
796 | continue; | |
797 | Scalar diff = bbox(1, i) - bbox(0, i); | |
798 | // check if the split for this dimension would be potentially larger | |
799 | if(diff > splitsize) | |
800 | { | |
801 | Bounds newBounds; | |
802 | // update the bounds to their actual current value | |
803 | calculateBounds(startIdx, length, splitaxis, newBounds); | |
804 | diff = newBounds(1) - newBounds(0); | |
805 | if(diff > splitsize) | |
806 | { | |
807 | splitaxis = i; | |
808 | splitsize = diff; | |
809 | bounds = newBounds; | |
810 | } | |
811 | } | |
812 | } | |
813 | ||
814 | // use the sliding midpoint rule | |
815 | splitpoint = (bounds(0) + bounds(1)) / static_cast<Scalar>(2); | |
816 | ||
817 | Index leftIdx = startIdx; | |
818 | Index rightIdx = startIdx + length - 1; | |
819 | ||
820 | // first loop checks left < splitpoint and right >= splitpoint | |
821 | while(leftIdx <= rightIdx) | |
822 | { | |
823 | // increment left as long as left has not reached right and | |
824 | // the value of the left element is less than the splitpoint | |
825 | while(leftIdx <= rightIdx && data(splitaxis, indices_[leftIdx]) < splitpoint) | |
826 | ++leftIdx; | |
827 | ||
828 | // decrement right as long as left has not reached right and | |
829 | // the value of the right element is greater than the splitpoint | |
830 | while(leftIdx <= rightIdx && data(splitaxis, indices_[rightIdx]) >= splitpoint) | |
831 | --rightIdx; | |
832 | ||
833 | if(leftIdx <= rightIdx) | |
834 | { | |
835 | std::swap(indices_[leftIdx], indices_[rightIdx]); | |
836 | ++leftIdx; | |
837 | --rightIdx; | |
838 | } | |
839 | } | |
840 | ||
841 | // remember this offset from starting index | |
842 | const Index offset1 = leftIdx - startIdx; | |
843 | ||
844 | rightIdx = startIdx + length - 1; | |
845 | // second loop checks left <= splitpoint and right > splitpoint | |
846 | while(leftIdx <= rightIdx) | |
847 | { | |
848 | // increment left as long as left has not reached right and | |
849 | // the value of the left element is less than the splitpoint | |
850 | while(leftIdx <= rightIdx && data(splitaxis, indices_[leftIdx]) <= splitpoint) | |
851 | ++leftIdx; | |
852 | ||
853 | // decrement right as long as left has not reached right and | |
854 | // the value of the right element is greater than the splitpoint | |
855 | while(leftIdx <= rightIdx && data(splitaxis, indices_[rightIdx]) > splitpoint) | |
856 | --rightIdx; | |
857 | ||
858 | if(leftIdx <= rightIdx) | |
859 | { | |
860 | std::swap(indices_[leftIdx], indices_[rightIdx]); | |
861 | ++leftIdx; | |
862 | --rightIdx; | |
863 | } | |
864 | } | |
865 | ||
866 | // remember this offset from starting index | |
867 | const Index offset2 = leftIdx - startIdx; | |
868 | ||
869 | const Index halfLength = length / static_cast<Index>(2); | |
870 | ||
871 | // find a separation of points such that is best balanced | |
872 | // offset1 denotes separation where equal points are all on the right | |
873 | // offset2 denots separation where equal points are all on the left | |
874 | if (offset1 > halfLength) | |
875 | splitoffset = offset1; | |
876 | else if (offset2 < halfLength) | |
877 | splitoffset = offset2; | |
878 | // when we get here offset1 < halflength and offset2 > halflength | |
879 | // so simply split the equal elements in the middle | |
880 | else | |
881 | splitoffset = halfLength; | |
882 | } | |
883 | ||
884 | Index buildInnerNode(const Index startIdx, | |
885 | const Index length, | |
886 | BoundingBox &bbox) | |
887 | { | |
888 | assert(length > 0); | |
889 | assert(startIdx >= 0); | |
890 | assert(static_cast<size_t>(startIdx + length) <= indices_.size()); | |
891 | assert(data_->rows() == bbox.cols()); | |
892 | ||
893 | // create node | |
894 | const Index nodeIdx = nodes_.size(); | |
895 | nodes_.push_back(Node()); | |
896 | ||
897 | Index splitaxis; | |
898 | Index splitoffset; | |
899 | Scalar splitpoint; | |
900 | calculateSplittingMidpoint(startIdx, length, bbox, splitaxis, splitpoint, splitoffset); | |
901 | ||
902 | nodes_[nodeIdx].splitaxis = splitaxis; | |
903 | nodes_[nodeIdx].splitpoint = splitpoint; | |
904 | ||
905 | const Index leftStart = startIdx; | |
906 | const Index leftLength = splitoffset; | |
907 | const Index rightStart = startIdx + splitoffset; | |
908 | const Index rightLength = length - splitoffset; | |
909 | ||
910 | BoundingBox bboxLeft = bbox; | |
911 | BoundingBox bboxRight = bbox; | |
912 | ||
913 | // do left build | |
914 | bboxLeft(1, splitaxis) = splitpoint; | |
915 | Index left = buildR(leftStart, leftLength, bboxLeft); | |
916 | nodes_[nodeIdx].left = left; | |
917 | ||
918 | // do right build | |
919 | bboxRight(0, splitaxis) = splitpoint; | |
920 | Index right = buildR(rightStart, rightLength, bboxRight); | |
921 | nodes_[nodeIdx].right = right; | |
922 | ||
923 | // extract the range of the splitpoint | |
924 | nodes_[nodeIdx].splitlower = bboxLeft(1, splitaxis); | |
925 | nodes_[nodeIdx].splitupper = bboxRight(0, splitaxis); | |
926 | ||
927 | // update the bounding box to the values of the new bounding boxes | |
928 | for(Index i = 0; i < bbox.cols(); ++i) | |
929 | { | |
930 | bbox(0, i) = std::min(bboxLeft(0, i), bboxRight(0, i)); | |
931 | bbox(1, i) = std::max(bboxLeft(1, i), bboxRight(1, i)); | |
932 | } | |
933 | ||
934 | return nodeIdx; | |
935 | } | |
936 | ||
937 | Index buildR(const Index startIdx, | |
938 | const Index length, | |
939 | BoundingBox &bbox) | |
940 | { | |
941 | // check for base case | |
942 | if(length <= bucketSize_) | |
943 | return buildLeafNode(startIdx, length, bbox); | |
944 | else | |
945 | return buildInnerNode(startIdx, length, bbox); | |
946 | } | |
947 | ||
948 | bool isDistanceInRange(const Scalar dist) const | |
949 | { | |
950 | return maxDist_ <= 0 || dist <= maxDist_; | |
951 | } | |
952 | ||
953 | bool isDistanceImprovement(const Scalar dist, const QueryHeap<Scalar> &dataHeap) const | |
954 | { | |
955 | return !dataHeap.full() || dist < dataHeap.front(); | |
956 | } | |
957 | ||
958 | template<typename Derived> | |
959 | void queryLeafNode(const Node &node, | |
960 | const Eigen::MatrixBase<Derived> &queryPoint, | |
961 | QueryHeap<Scalar> &dataHeap) const | |
962 | { | |
963 | assert(node.isLeaf()); | |
964 | ||
965 | const DataMatrix &data = *data_; | |
966 | ||
967 | // go through all points in this leaf node and do brute force search | |
968 | for(Index i = 0; i < node.length; ++i) | |
969 | { | |
970 | const Index idx = node.startIdx + i; | |
971 | assert(idx >= 0 && idx < static_cast<Index>(indices_.size())); | |
972 | ||
973 | // retrieve index of the current data point | |
974 | const Index dataIdx = indices_[idx]; | |
975 | const Scalar dist = distance_(queryPoint, data.col(dataIdx)); | |
976 | ||
977 | // check if point is within max distance and if the value would be | |
978 | // an improvement | |
979 | if(isDistanceInRange(dist) && isDistanceImprovement(dist, dataHeap)) | |
980 | { | |
981 | if(dataHeap.full()) | |
982 | dataHeap.pop(); | |
983 | dataHeap.push(dataIdx, dist); | |
984 | } | |
985 | } | |
986 | } | |
987 | ||
988 | template<typename Derived> | |
989 | void queryInnerNode(const Node &node, | |
990 | const Eigen::MatrixBase<Derived> &queryPoint, | |
991 | QueryHeap<Scalar> &dataHeap, | |
992 | DataVector &splitdists, | |
993 | const Scalar mindist) const | |
994 | { | |
995 | assert(node.isInner()); | |
996 | ||
997 | const Index splitaxis = node.splitaxis; | |
998 | const Scalar splitval = queryPoint(splitaxis, 0); | |
999 | Scalar splitdist; | |
1000 | Index firstNode; | |
1001 | Index secondNode; | |
1002 | // check if right or left child should be visited | |
1003 | const bool visitLeft = (splitval - node.splitlower + splitval - node.splitupper) < 0; | |
1004 | if(visitLeft) | |
1005 | { | |
1006 | firstNode = node.left; | |
1007 | secondNode = node.right; | |
1008 | splitdist = distance_(splitval, node.splitupper); | |
1009 | } | |
1010 | else | |
1011 | { | |
1012 | firstNode = node.right; | |
1013 | secondNode = node.left; | |
1014 | splitdist = distance_(splitval, node.splitlower); | |
1015 | } | |
1016 | ||
1017 | queryR(nodes_[firstNode], queryPoint, dataHeap, splitdists, mindist); | |
1018 | ||
1019 | const Scalar mindistNew = mindist + splitdist - splitdists(splitaxis); | |
1020 | ||
1021 | // check if node is in range if max distance was set | |
1022 | // check if this node was an improvement if heap is already full | |
1023 | if(isDistanceInRange(mindistNew) && isDistanceImprovement(mindistNew, dataHeap)) | |
1024 | { | |
1025 | const Scalar splitdistOld = splitdists(splitaxis); | |
1026 | splitdists(splitaxis) = splitdist; | |
1027 | queryR(nodes_[secondNode], queryPoint, dataHeap, splitdists, mindistNew); | |
1028 | splitdists(splitaxis) = splitdistOld; | |
1029 | } | |
1030 | } | |
1031 | ||
1032 | template<typename Derived> | |
1033 | void queryR(const Node &node, | |
1034 | const Eigen::MatrixBase<Derived> &queryPoint, | |
1035 | QueryHeap<Scalar> &dataHeap, | |
1036 | DataVector &splitdists, | |
1037 | const Scalar mindist) const | |
1038 | { | |
1039 | if(node.isLeaf()) | |
1040 | queryLeafNode(node, queryPoint, dataHeap); | |
1041 | else | |
1042 | queryInnerNode(node, queryPoint, dataHeap, splitdists, mindist); | |
1043 | } | |
1044 | ||
1045 | /** Recursively computes the depth for the given node. */ | |
1046 | Index depthR(const Node &node) const | |
1047 | { | |
1048 | if(node.isLeaf()) | |
1049 | return 1; | |
1050 | else | |
1051 | { | |
1052 | Index left = depthR(nodes_[node.left]); | |
1053 | Index right = depthR(nodes_[node.right]); | |
1054 | return std::max(left, right) + 1; | |
1055 | } | |
1056 | } | |
1057 | ||
1058 | public: | |
1059 | ||
1060 | /** Constructs an empty KDTree. */ | |
1061 | KDTreeMinkowski() | |
1062 | { } | |
1063 | ||
1064 | /** Constructs KDTree with the given data. This does not build the | |
1065 | * the index of the tree. | |
1066 | * @param data NxM matrix, M points of dimension N | |
1067 | * @param copy if true copies the data, otherwise assumes static data */ | |
1068 | KDTreeMinkowski(const DataMatrix &data, const bool copy=false) | |
1069 | { | |
1070 | setData(data, copy); | |
1071 | } | |
1072 | ||
1073 | /** Set the maximum amount of data points per leaf in the tree (aka | |
1074 | * bucket size). | |
1075 | * @param bucketSize amount of points per leaf. */ | |
1076 | void setBucketSize(const Index bucketSize) | |
1077 | { | |
1078 | bucketSize_ = bucketSize; | |
1079 | } | |
1080 | ||
1081 | /** Set if the points returned by the queries should be sorted | |
1082 | * according to their distance to the query points. | |
1083 | * @param sorted sort query results */ | |
1084 | void setSorted(const bool sorted) | |
1085 | { | |
1086 | sorted_ = sorted; | |
1087 | } | |
1088 | ||
1089 | /** Set if the tree should be built as balanced as possible. | |
1090 | * This increases build time, but decreases search time. | |
1091 | * @param balanced set true to build a balanced tree */ | |
1092 | void setBalanced(const bool balanced) | |
1093 | { | |
1094 | balanced_ = balanced; | |
1095 | } | |
1096 | ||
1097 | /** Set if the distances after the query should be rooted or not. | |
1098 | * Taking the root of the distances increases query time, but the | |
1099 | * function will return true distances instead of their powered | |
1100 | * versions. | |
1101 | * @param takeRoot set true if root should be taken else false */ | |
1102 | void setTakeRoot(const bool takeRoot) | |
1103 | { | |
1104 | takeRoot_ = takeRoot; | |
1105 | } | |
1106 | ||
1107 | /** Set if the tree should be built with compact leaf nodes. | |
1108 | * This increases build time, but makes leaf nodes denser (more) | |
1109 | * points. Thus less visits are necessary. | |
1110 | * @param compact set true ti build a tree with compact leafs */ | |
1111 | void setCompact(const bool compact) | |
1112 | { | |
1113 | compact_ = compact; | |
1114 | } | |
1115 | ||
1116 | /** Set the amount of threads that should be used for building and | |
1117 | * querying the tree. | |
1118 | * OpenMP has to be enabled for this to work. | |
1119 | * @param threads amount of threads, 0 for optimal choice */ | |
1120 | void setThreads(const unsigned int threads) | |
1121 | { | |
1122 | threads_ = threads; | |
1123 | } | |
1124 | ||
1125 | /** Set the maximum distance for querying the tree. | |
1126 | * The search will be pruned if the maximum distance is set to any | |
1127 | * positive number. | |
1128 | * @param maxDist maximum distance, <= 0 for no limit */ | |
1129 | void setMaxDistance(const Scalar maxDist) | |
1130 | { | |
1131 | maxDist_ = maxDist; | |
1132 | } | |
1133 | ||
1134 | /** Set the data points used for this tree. | |
1135 | * This does not build the tree. | |
1136 | * @param data NxM matrix, M points of dimension N | |
1137 | * @param copy if true data is copied, assumes static data otherwise */ | |
1138 | void setData(const DataMatrix &data, const bool copy = false) | |
1139 | { | |
1140 | clear(); | |
1141 | if(copy) | |
1142 | { | |
1143 | dataCopy_ = data; | |
1144 | data_ = &dataCopy_; | |
1145 | } | |
1146 | else | |
1147 | { | |
1148 | data_ = &data; | |
1149 | } | |
1150 | } | |
1151 | ||
1152 | void setDistance(const Distance &distance) | |
1153 | { | |
1154 | distance_ = distance; | |
1155 | } | |
1156 | ||
1157 | /** Builds the search index of the tree. | |
1158 | * Data has to be set and must be non-empty. */ | |
1159 | void build() | |
1160 | { | |
1161 | if(data_ == nullptr) | |
1162 | throw std::runtime_error("cannot build KDTree; data not set"); | |
1163 | ||
1164 | if(data_->size() == 0) | |
1165 | throw std::runtime_error("cannot build KDTree; data is empty"); | |
1166 | ||
1167 | clear(); | |
1168 | nodes_.reserve((data_->cols() / bucketSize_) + 1); | |
1169 | ||
1170 | // initialize indices in simple sequence | |
1171 | indices_.resize(data_->cols()); | |
1172 | for(size_t i = 0; i < indices_.size(); ++i) | |
1173 | indices_[i] = i; | |
1174 | ||
1175 | bbox_.resize(2, data_->rows()); | |
1176 | Index startIdx = 0; | |
1177 | Index length = data_->cols(); | |
1178 | ||
1179 | calculateBoundingBox(startIdx, length, bbox_); | |
1180 | ||
1181 | buildR(startIdx, length, bbox_); | |
1182 | } | |
1183 | ||
1184 | /** Queries the tree for the nearest neighbours of the given query | |
1185 | * points. | |
1186 | * | |
1187 | * The tree has to be built before it can be queried. | |
1188 | * | |
1189 | * The query points have to have the same dimension as the data points | |
1190 | * of the tree. | |
1191 | * | |
1192 | * The result matrices will be resized appropriatley. | |
1193 | * Indices and distances will be set to -1 if less than knn neighbours | |
1194 | * were found. | |
1195 | * | |
1196 | * @param queryPoints NxM matrix, M points of dimension N | |
1197 | * @param knn amount of neighbours to be found | |
1198 | * @param indices KNNxM matrix, indices of neighbours in the data set | |
1199 | * @param distances KNNxM matrix, distance between query points and | |
1200 | * neighbours */ | |
1201 | template<typename Derived> | |
1202 | void query(const Eigen::MatrixBase<Derived> &queryPoints, | |
1203 | const size_t knn, | |
1204 | Matrixi &indices, | |
1205 | Matrix &distances) const | |
1206 | { | |
1207 | if(nodes_.size() == 0) | |
1208 | throw std::runtime_error("cannot query KDTree; not built yet"); | |
1209 | ||
1210 | if(queryPoints.rows() != dimension()) | |
1211 | throw std::runtime_error("cannot query KDTree; data and query points do not have same dimension"); | |
1212 | ||
1213 | distances.setConstant(knn, queryPoints.cols(), -1); | |
1214 | indices.setConstant(knn, queryPoints.cols(), -1); | |
1215 | ||
1216 | Index *indicesRaw = indices.data(); | |
1217 | Scalar *distsRaw = distances.data(); | |
1218 | ||
1219 | #pragma omp parallel for num_threads(threads_) | |
1220 | for(Index i = 0; i < queryPoints.cols(); ++i) | |
1221 | { | |
1222 | ||
1223 | Scalar *distPoint = &distsRaw[i * knn]; | |
1224 | Index *idxPoint = &indicesRaw[i * knn]; | |
1225 | ||
1226 | // create heap to find nearest neighbours | |
1227 | QueryHeap<Scalar> dataHeap(idxPoint, distPoint, knn); | |
1228 | ||
1229 | Scalar mindist = static_cast<Scalar>(0); | |
1230 | DataVector splitdists(queryPoints.rows()); | |
1231 | ||
1232 | for(Index j = 0; j < splitdists.rows(); ++j) | |
1233 | { | |
1234 | const Scalar value = queryPoints(j, i); | |
1235 | const Scalar lower = bbox_(0, j); | |
1236 | const Scalar upper = bbox_(1, j); | |
1237 | if(value < lower) | |
1238 | { | |
1239 | splitdists(j) = distance_(value, lower); | |
1240 | } | |
1241 | else if(value > upper) | |
1242 | { | |
1243 | splitdists(j) = distance_(value, upper); | |
1244 | } | |
1245 | else | |
1246 | { | |
1247 | splitdists(j) = static_cast<Scalar>(0); | |
1248 | } | |
1249 | ||
1250 | mindist += splitdists(j); | |
1251 | } | |
1252 | ||
1253 | queryR(nodes_[0], queryPoints.col(i), dataHeap, splitdists, mindist); | |
1254 | ||
1255 | if(sorted_) | |
1256 | dataHeap.sort(); | |
1257 | ||
1258 | if(takeRoot_) | |
1259 | { | |
1260 | for(size_t j = 0; j < knn; ++j) | |
1261 | { | |
1262 | if(distPoint[j] < 0) | |
1263 | break; | |
1264 | distPoint[j] = distance_(distPoint[j]); | |
1265 | } | |
1266 | } | |
1267 | } | |
1268 | } | |
1269 | ||
1270 | /** Clears the tree. */ | |
1271 | void clear() | |
1272 | { | |
1273 | nodes_.clear(); | |
1274 | } | |
1275 | ||
1276 | /** Returns the amount of data points stored in the search index. | |
1277 | * @return number of data points */ | |
1278 | Index size() const | |
1279 | { | |
1280 | return data_ == nullptr ? 0 : data_->cols(); | |
1281 | } | |
1282 | ||
1283 | /** Returns the dimension of the data points in the search index. | |
1284 | * @return dimension of data points */ | |
1285 | Index dimension() const | |
1286 | { | |
1287 | return data_ == nullptr ? 0 : data_->rows(); | |
1288 | } | |
1289 | ||
1290 | /** Returns the maxximum depth of the tree. | |
1291 | * @return maximum depth of the tree */ | |
1292 | Index depth() const | |
1293 | { | |
1294 | return nodes_.size() == 0 ? 0 : depthR(nodes_.front()); | |
1295 | } | |
1296 | }; | |
1297 | ||
1298 | template<typename _Scalar, typename _Distance = EuclideanDistance<_Scalar>> using KDTreeMinkowski2 = KDTreeMinkowski<_Scalar, 2, _Distance>; | |
1299 | template<typename _Scalar, typename _Distance = EuclideanDistance<_Scalar>> using KDTreeMinkowski3 = KDTreeMinkowski<_Scalar, 3, _Distance>; | |
1300 | template<typename _Scalar, typename _Distance = EuclideanDistance<_Scalar>> using KDTreeMinkowski4 = KDTreeMinkowski<_Scalar, 4, _Distance>; | |
1301 | template<typename _Scalar, typename _Distance = EuclideanDistance<_Scalar>> using KDTreeMinkowski5 = KDTreeMinkowski<_Scalar, 5, _Distance>; | |
1302 | template<typename _Scalar, typename _Distance = EuclideanDistance<_Scalar>> using KDTreeMinkowskiX = KDTreeMinkowski<_Scalar, Eigen::Dynamic, _Distance>; | |
1303 | ||
1304 | /** Class for performing KNN search in hamming space by multi-index hashing. */ | |
1305 | template<typename Scalar> | |
1306 | class MultiIndexHashing | |
1307 | { | |
1308 | public: | |
1309 | static_assert(std::is_integral<Scalar>::value, "MultiIndexHashing Scalar has to be integral"); | |
1310 | ||
1311 | typedef Eigen::Matrix<Scalar, Eigen::Dynamic, Eigen::Dynamic> Matrix; | |
1312 | typedef Eigen::Matrix<Scalar, Eigen::Dynamic, 1> Vector; | |
1313 | typedef knncpp::Matrixi Matrixi; | |
1314 | ||
1315 | private: | |
1316 | HammingDistance<Scalar> distance_; | |
1317 | ||
1318 | Matrix dataCopy_; | |
1319 | const Matrix *data_; | |
1320 | ||
1321 | bool sorted_; | |
1322 | Scalar maxDist_; | |
1323 | Index substrLen_; | |
1324 | Index threads_; | |
1325 | std::vector<std::map<Scalar, std::vector<Index>>> buckets_; | |
1326 | ||
1327 | template<typename Derived> | |
1328 | Scalar extractCode(const Eigen::MatrixBase<Derived> &data, | |
1329 | const Index idx, | |
1330 | const Index offset) const | |
1331 | { | |
1332 | Index leftShift = std::max<Index>(0, static_cast<Index>(sizeof(Scalar)) - offset - substrLen_); | |
1333 | Index rightShift = leftShift + offset; | |
1334 | ||
1335 | Scalar code = (data(idx, 0) << (leftShift * 8)) >> (rightShift * 8); | |
1336 | ||
1337 | if(static_cast<Index>(sizeof(Scalar)) - offset < substrLen_ && idx + 1 < data.rows()) | |
1338 | { | |
1339 | Index shift = 2 * static_cast<Index>(sizeof(Scalar)) - substrLen_ - offset; | |
1340 | code |= data(idx+1, 0) << (shift * 8); | |
1341 | } | |
1342 | ||
1343 | return code; | |
1344 | } | |
1345 | public: | |
1346 | MultiIndexHashing() | |
1347 | : distance_(), dataCopy_(), data_(nullptr), sorted_(true), | |
1348 | maxDist_(0), substrLen_(1), threads_(1) | |
1349 | { } | |
1350 | ||
1351 | /** Constructs an index with the given data. | |
1352 | * This does not build the the index. | |
1353 | * @param data NxM matrix, M points of dimension N | |
1354 | * @param copy if true copies the data, otherwise assumes static data */ | |
1355 | MultiIndexHashing(const Matrix &data, const bool copy=false) | |
1356 | : MultiIndexHashing() | |
1357 | { | |
1358 | setData(data, copy); | |
1359 | } | |
1360 | ||
1361 | /** Set the maximum distance for querying the index. | |
1362 | * Note that if no maximum distance is used, this algorithm performs | |
1363 | * basically a brute force search. | |
1364 | * @param maxDist maximum distance, <= 0 for no limit */ | |
1365 | void setMaxDistance(const Scalar maxDist) | |
1366 | { | |
1367 | maxDist_ = maxDist; | |
1368 | } | |
1369 | ||
1370 | /** Set if the points returned by the queries should be sorted | |
1371 | * according to their distance to the query points. | |
1372 | * @param sorted sort query results */ | |
1373 | void setSorted(const bool sorted) | |
1374 | { | |
1375 | sorted_ = sorted; | |
1376 | } | |
1377 | ||
1378 | /** Set the amount of threads that should be used for building and | |
1379 | * querying the tree. | |
1380 | * OpenMP has to be enabled for this to work. | |
1381 | * @param threads amount of threads, 0 for optimal choice */ | |
1382 | void setThreads(const unsigned int threads) | |
1383 | { | |
1384 | threads_ = threads; | |
1385 | } | |
1386 | ||
1387 | /** Set the length of substrings (in bytes) used for multi index hashing. | |
1388 | * @param len lentth of bucket substrings in bytes*/ | |
1389 | void setSubstringLength(const Index len) | |
1390 | { | |
1391 | substrLen_ = len; | |
1392 | } | |
1393 | ||
1394 | /** Set the data points used for the KNN search. | |
1395 | * @param data NxM matrix, M points of dimension N | |
1396 | * @param copy if true data is copied, assumes static data otherwise */ | |
1397 | void setData(const Matrix &data, const bool copy = false) | |
1398 | { | |
1399 | clear(); | |
1400 | if(copy) | |
1401 | { | |
1402 | dataCopy_ = data; | |
1403 | data_ = &dataCopy_; | |
1404 | } | |
1405 | else | |
1406 | { | |
1407 | data_ = &data; | |
1408 | } | |
1409 | } | |
1410 | ||
1411 | void build() | |
1412 | { | |
1413 | if(data_ == nullptr) | |
1414 | throw std::runtime_error("cannot build MultiIndexHashing; data not set"); | |
1415 | if(data_->size() == 0) | |
1416 | throw std::runtime_error("cannot build MultiIndexHashing; data is empty"); | |
1417 | ||
1418 | const Matrix &data = *data_; | |
1419 | const Index bytesPerVec = data.rows() * static_cast<Index>(sizeof(Scalar)); | |
1420 | if(bytesPerVec % substrLen_ != 0) | |
1421 | throw std::runtime_error("cannot build MultiIndexHashing; cannot divide byte count per vector by substring length without remainings"); | |
1422 | ||
1423 | buckets_.clear(); | |
1424 | buckets_.resize(bytesPerVec / substrLen_); | |
1425 | ||
1426 | for(size_t i = 0; i < buckets_.size(); ++i) | |
1427 | { | |
1428 | Index start = static_cast<Index>(i) * substrLen_; | |
1429 | Index idx = start / static_cast<Index>(sizeof(Scalar)); | |
1430 | Index offset = start % static_cast<Index>(sizeof(Scalar)); | |
1431 | std::map<Scalar, std::vector<Index>> &map = buckets_[i]; | |
1432 | ||
1433 | for(Index c = 0; c < data.cols(); ++c) | |
1434 | { | |
1435 | Scalar code = extractCode(data.col(c), idx, offset); | |
1436 | if(map.find(code) == map.end()) | |
1437 | map[code] = std::vector<Index>(); | |
1438 | map[code].push_back(c); | |
1439 | } | |
1440 | } | |
1441 | } | |
1442 | ||
1443 | template<typename Derived> | |
1444 | void query(const Eigen::MatrixBase<Derived> &queryPoints, | |
1445 | const size_t knn, | |
1446 | Matrixi &indices, | |
1447 | Matrix &distances) const | |
1448 | { | |
1449 | if(buckets_.size() == 0) | |
1450 | throw std::runtime_error("cannot query MultiIndexHashing; not built yet"); | |
1451 | if(queryPoints.rows() != dimension()) | |
1452 | throw std::runtime_error("cannot query MultiIndexHashing; data and query points do not have same dimension"); | |
1453 | ||
1454 | const Matrix &data = *data_; | |
1455 | ||
1456 | indices.setConstant(knn, queryPoints.cols(), -1); | |
1457 | distances.setConstant(knn, queryPoints.cols(), -1); | |
1458 | ||
1459 | Index *indicesRaw = indices.data(); | |
1460 | Scalar *distsRaw = distances.data(); | |
1461 | ||
1462 | Scalar maxDistPart = maxDist_ / buckets_.size(); | |
1463 | ||
1464 | #pragma omp parallel for num_threads(threads_) | |
1465 | for(Index c = 0; c < queryPoints.cols(); ++c) | |
1466 | { | |
1467 | std::set<Index> candidates; | |
1468 | for(size_t i = 0; i < buckets_.size(); ++i) | |
1469 | { | |
1470 | Index start = static_cast<Index>(i) * substrLen_; | |
1471 | Index idx = start / static_cast<Index>(sizeof(Scalar)); | |
1472 | Index offset = start % static_cast<Index>(sizeof(Scalar)); | |
1473 | const std::map<Scalar, std::vector<Index>> &map = buckets_[i]; | |
1474 | ||
1475 | Scalar code = extractCode(queryPoints.col(c), idx, offset); | |
1476 | for(const auto &x: map) | |
1477 | { | |
1478 | Scalar dist = distance_(x.first, code); | |
1479 | if(maxDistPart <= 0 || dist <= maxDistPart) | |
1480 | { | |
1481 | for(size_t j = 0; j < x.second.size(); ++j) | |
1482 | candidates.insert(x.second[j]); | |
1483 | } | |
1484 | } | |
1485 | } | |
1486 | ||
1487 | Scalar *distPoint = &distsRaw[c * knn]; | |
1488 | Index *idxPoint = &indicesRaw[c * knn]; | |
1489 | // create heap to find nearest neighbours | |
1490 | QueryHeap<Scalar> dataHeap(idxPoint, distPoint, knn); | |
1491 | ||
1492 | for(Index idx: candidates) | |
1493 | { | |
1494 | Scalar dist = distance_(data.col(idx), queryPoints.col(c)); | |
1495 | ||
1496 | bool isInRange = maxDist_ <= 0 || dist <= maxDist_; | |
1497 | bool isImprovement = !dataHeap.full() || | |
1498 | dist < dataHeap.front(); | |
1499 | if(isInRange && isImprovement) | |
1500 | { | |
1501 | if(dataHeap.full()) | |
1502 | dataHeap.pop(); | |
1503 | dataHeap.push(idx, dist); | |
1504 | } | |
1505 | } | |
1506 | ||
1507 | if(sorted_) | |
1508 | dataHeap.sort(); | |
1509 | } | |
1510 | } | |
1511 | ||
1512 | /** Returns the amount of data points stored in the search index. | |
1513 | * @return number of data points */ | |
1514 | Index size() const | |
1515 | { | |
1516 | return data_ == nullptr ? 0 : data_->cols(); | |
1517 | } | |
1518 | ||
1519 | /** Returns the dimension of the data points in the search index. | |
1520 | * @return dimension of data points */ | |
1521 | Index dimension() const | |
1522 | { | |
1523 | return data_ == nullptr ? 0 : data_->rows(); | |
1524 | } | |
1525 | ||
1526 | void clear() | |
1527 | { | |
1528 | data_ = nullptr; | |
1529 | dataCopy_.resize(0, 0); | |
1530 | buckets_.clear(); | |
1531 | } | |
1532 | ||
1533 | }; | |
1534 | ||
1535 | #ifdef KNNCPP_FLANN | |
1536 | ||
1537 | /** Wrapper class of FLANN kdtrees for the use with Eigen3. */ | |
1538 | template<typename Scalar, | |
1539 | typename Distance=flann::L2_Simple<Scalar>> | |
1540 | class KDTreeFlann | |
1541 | { | |
1542 | public: | |
1543 | typedef Eigen::Matrix<Scalar, Eigen::Dynamic, Eigen::Dynamic> Matrix; | |
1544 | typedef Eigen::Matrix<Scalar, Eigen::Dynamic, 1> Vector; | |
1545 | typedef Eigen::Matrix<int, Eigen::Dynamic, Eigen::Dynamic> Matrixi; | |
1546 | ||
1547 | private: | |
1548 | typedef flann::Index<Distance> FlannIndex; | |
1549 | ||
1550 | Matrix dataCopy_; | |
1551 | Matrix *dataPoints_; | |
1552 | ||
1553 | FlannIndex *index_; | |
1554 | flann::SearchParams searchParams_; | |
1555 | flann::IndexParams indexParams_; | |
1556 | Scalar maxDist_; | |
1557 | ||
1558 | public: | |
1559 | KDTreeFlann() | |
1560 | : dataCopy_(), dataPoints_(nullptr), index_(nullptr), | |
1561 | searchParams_(32, 0, false), | |
1562 | indexParams_(flann::KDTreeSingleIndexParams(15)), | |
1563 | maxDist_(0) | |
1564 | { | |
1565 | } | |
1566 | ||
1567 | KDTreeFlann(Matrix &data, const bool copy = false) | |
1568 | : KDTreeFlann() | |
1569 | { | |
1570 | setData(data, copy); | |
1571 | } | |
1572 | ||
1573 | ~KDTreeFlann() | |
1574 | { | |
1575 | clear(); | |
1576 | } | |
1577 | ||
1578 | void setIndexParams(const flann::IndexParams ¶ms) | |
1579 | { | |
1580 | indexParams_ = params; | |
1581 | } | |
1582 | ||
1583 | void setChecks(const int checks) | |
1584 | { | |
1585 | searchParams_.checks = checks; | |
1586 | } | |
1587 | ||
1588 | void setSorted(const bool sorted) | |
1589 | { | |
1590 | searchParams_.sorted = sorted; | |
1591 | } | |
1592 | ||
1593 | void setThreads(const int threads) | |
1594 | { | |
1595 | searchParams_.cores = threads; | |
1596 | } | |
1597 | ||
1598 | void setEpsilon(const float eps) | |
1599 | { | |
1600 | searchParams_.eps = eps; | |
1601 | } | |
1602 | ||
1603 | void setMaxDistance(const Scalar dist) | |
1604 | { | |
1605 | maxDist_ = dist; | |
1606 | } | |
1607 | ||
1608 | void setData(Matrix &data, const bool copy = false) | |
1609 | { | |
1610 | if(copy) | |
1611 | { | |
1612 | dataCopy_ = data; | |
1613 | dataPoints_ = &dataCopy_; | |
1614 | } | |
1615 | else | |
1616 | { | |
1617 | dataPoints_ = &data; | |
1618 | } | |
1619 | ||
1620 | clear(); | |
1621 | } | |
1622 | ||
1623 | void build() | |
1624 | { | |
1625 | if(dataPoints_ == nullptr) | |
1626 | throw std::runtime_error("cannot build KDTree; data not set"); | |
1627 | if(dataPoints_->size() == 0) | |
1628 | throw std::runtime_error("cannot build KDTree; data is empty"); | |
1629 | ||
1630 | if(index_ != nullptr) | |
1631 | delete index_; | |
1632 | ||
1633 | flann::Matrix<Scalar> dataPts( | |
1634 | dataPoints_->data(), | |
1635 | dataPoints_->cols(), | |
1636 | dataPoints_->rows()); | |
1637 | ||
1638 | index_ = new FlannIndex(dataPts, indexParams_); | |
1639 | index_->buildIndex(); | |
1640 | } | |
1641 | ||
1642 | void query(Matrix &queryPoints, | |
1643 | const size_t knn, | |
1644 | Matrixi &indices, | |
1645 | Matrix &distances) const | |
1646 | { | |
1647 | if(index_ == nullptr) | |
1648 | throw std::runtime_error("cannot query KDTree; not built yet"); | |
1649 | if(dataPoints_->rows() != queryPoints.rows()) | |
1650 | throw std::runtime_error("cannot query KDTree; KDTree has different dimension than query data"); | |
1651 | ||
1652 | // resize result matrices | |
1653 | distances.resize(knn, queryPoints.cols()); | |
1654 | indices.resize(knn, queryPoints.cols()); | |
1655 | ||
1656 | // wrap matrices into flann matrices | |
1657 | flann::Matrix<Scalar> queryPts( | |
1658 | queryPoints.data(), | |
1659 | queryPoints.cols(), | |
1660 | queryPoints.rows()); | |
1661 | flann::Matrix<int> indicesF( | |
1662 | indices.data(), | |
1663 | indices.cols(), | |
1664 | indices.rows()); | |
1665 | flann::Matrix<Scalar> distancesF( | |
1666 | distances.data(), | |
1667 | distances.cols(), | |
1668 | distances.rows()); | |
1669 | ||
1670 | // if maximum distance was set then use radius search | |
1671 | if(maxDist_ > 0) | |
1672 | index_->radiusSearch(queryPts, indicesF, distancesF, maxDist_, searchParams_); | |
1673 | else | |
1674 | index_->knnSearch(queryPts, indicesF, distancesF, knn, searchParams_); | |
1675 | ||
1676 | // make result matrices compatible to API | |
1677 | #pragma omp parallel for num_threads(searchParams_.cores) | |
1678 | for(Index i = 0; i < indices.cols(); ++i) | |
1679 | { | |
1680 | bool found = false; | |
1681 | for(Index j = 0; j < indices.rows(); ++j) | |
1682 | { | |
1683 | if(indices(j, i) == -1) | |
1684 | found = true; | |
1685 | ||
1686 | if(found) | |
1687 | { | |
1688 | indices(j, i) = -1; | |
1689 | distances(j, i) = -1; | |
1690 | } | |
1691 | } | |
1692 | } | |
1693 | } | |
1694 | ||
1695 | Index size() const | |
1696 | { | |
1697 | return dataPoints_ == nullptr ? 0 : dataPoints_->cols(); | |
1698 | } | |
1699 | ||
1700 | Index dimension() const | |
1701 | { | |
1702 | return dataPoints_ == nullptr ? 0 : dataPoints_->rows(); | |
1703 | } | |
1704 | ||
1705 | void clear() | |
1706 | { | |
1707 | if(index_ != nullptr) | |
1708 | { | |
1709 | delete index_; | |
1710 | index_ = nullptr; | |
1711 | } | |
1712 | } | |
1713 | ||
1714 | FlannIndex &flannIndex() | |
1715 | { | |
1716 | return index_; | |
1717 | } | |
1718 | }; | |
1719 | ||
1720 | typedef KDTreeFlann<double> KDTreeFlannd; | |
1721 | typedef KDTreeFlann<float> KDTreeFlannf; | |
1722 | ||
1723 | #endif | |
1724 | } | |
1725 | ||
1726 | #endif |