Commit | Line | Data |
---|---|---|
762721a5 BA |
1 | #include <iostream> |
2 | #include <Rcpp.h> | |
3 | #include <RcppEigen.h> | |
4 | #include "knncpp.h" | |
5 | ||
6 | using namespace Rcpp; | |
7 | ||
8 | // [[Rcpp::depends(RcppEigen)]] | |
9 | // [[Rcpp::export]] | |
10 | List findNeighbors(NumericMatrix data, int k, bool mutual) { | |
11 | int n = data.nrow(), | |
12 | d = data.ncol(); | |
13 | ||
14 | Eigen::MatrixXd dataPoints(d, n); | |
15 | for (int row = 0; row < n; ++row) { | |
16 | for (int col = 0; col < d; ++col) { | |
17 | dataPoints(col,row) = data(row,col); //dataPoints: by columns | |
18 | } | |
19 | } | |
20 | ||
21 | knncpp::KDTreeMinkowskiX<double, knncpp::EuclideanDistance<double>> kdtree(dataPoints); | |
22 | kdtree.setBucketSize(16); | |
23 | kdtree.setSorted(false); | |
24 | kdtree.setTakeRoot(false); | |
25 | kdtree.setMaxDistance(0); | |
26 | kdtree.setThreads(0); | |
27 | kdtree.build(); | |
28 | ||
29 | knncpp::Matrixi indices; | |
30 | Eigen::MatrixXd distances; | |
31 | // k+1 because i is always a neighbor of i (to discard) | |
32 | kdtree.query(dataPoints, k+1, indices, distances); | |
33 | ||
34 | NumericVector res_edges(0); | |
35 | NumericVector res_dists(0); | |
36 | for (int i = 0; i <= k; ++i) { | |
37 | for (int j = 0; j < n; ++j) { | |
38 | if (indices(i,j) == j) | |
39 | continue; | |
40 | bool addRow = false; | |
41 | if (!mutual) | |
42 | addRow = true; | |
43 | else if (mutual && j < indices(i,j)) { | |
44 | int l = 0; | |
45 | for (; l <= k; ++l) { | |
46 | if (indices(l,indices(i,j)) == j) | |
47 | break; | |
48 | } | |
49 | if (l <= k) | |
50 | addRow = true; | |
51 | } | |
52 | if (addRow) { | |
53 | // R indices from 1 to n: | |
54 | res_edges.push_back(j+1); | |
55 | res_edges.push_back(indices(i,j)+1); | |
56 | res_dists.push_back(distances(i,j)); | |
57 | } | |
58 | } | |
59 | } | |
60 | ||
61 | res_edges.attr("dim") = Dimension(2, res_edges.length() / 2); | |
62 | List L = List::create(Named("edges") = res_edges, | |
63 | Named("euc_dists") = res_dists); | |
64 | return L; | |
65 | } |