Fix typo
[nngd.git] / src / nng.cpp
... / ...
CommitLineData
1#include <iostream>
2#include <Rcpp.h>
3#include <RcppEigen.h>
4#include "knncpp.h"
5
6using namespace Rcpp;
7
8// [[Rcpp::depends(RcppEigen)]]
9// [[Rcpp::export]]
10List findNeighbors(NumericMatrix data, int k, bool mutual) {
11 int n = data.nrow(),
12 d = data.ncol();
13
14 Eigen::MatrixXd dataPoints(d, n);
15 for (int row = 0; row < n; ++row) {
16 for (int col = 0; col < d; ++col) {
17 dataPoints(col,row) = data(row,col); //dataPoints: by columns
18 }
19 }
20
21 knncpp::KDTreeMinkowskiX<double, knncpp::EuclideanDistance<double>> kdtree(dataPoints);
22 kdtree.setBucketSize(16);
23 kdtree.setSorted(false);
24 kdtree.setTakeRoot(false);
25 kdtree.setMaxDistance(0);
26 kdtree.setThreads(0);
27 kdtree.build();
28
29 knncpp::Matrixi indices;
30 Eigen::MatrixXd distances;
31 // k+1 because i is always a neighbor of i (to discard)
32 kdtree.query(dataPoints, k+1, indices, distances);
33
34 NumericVector res_edges(0);
35 NumericVector res_dists(0);
36 for (int i = 0; i <= k; ++i) {
37 for (int j = 0; j < n; ++j) {
38 if (indices(i,j) == j)
39 continue;
40 bool addRow = false;
41 if (!mutual)
42 addRow = true;
43 else if (mutual && j < indices(i,j)) {
44 int l = 0;
45 for (; l <= k; ++l) {
46 if (indices(l,indices(i,j)) == j)
47 break;
48 }
49 if (l <= k)
50 addRow = true;
51 }
52 if (addRow) {
53 // R indices from 1 to n:
54 res_edges.push_back(j+1);
55 res_edges.push_back(indices(i,j)+1);
56 res_dists.push_back(distances(i,j));
57 }
58 }
59 }
60
61 res_edges.attr("dim") = Dimension(2, res_edges.length() / 2);
62 List L = List::create(Named("edges") = res_edges,
63 Named("euc_dists") = res_dists);
64 return L;
65}