Remove OpenMP dependency (and unused variable)
[morpheus.git] / pkg / src / functions.c
CommitLineData
cbd88fe5 1#include <stdlib.h>
f47183de 2//#include <omp.h>
cbd88fe5 3
d08fef42 4// Index matrix (by columns)
9ac8ecc0 5#define mi(i, j, d1, d2) (j*d1 + i)
cbd88fe5 6
d08fef42 7// Index 3-tensor (by columns, matrices ordered by last dim)
9ac8ecc0 8#define ti(i, j, k, d1, d2, d3) (k*d1*d2 + j*d1 + i)
cbd88fe5 9
d08fef42 10// Empirical cross-moment of order 2 between X size nxd and Y size n
cbd88fe5
BA
11void Moments_M2(double* X, double* Y, int* pn, int* pd, double* M2)
12{
6dd5c2ac
BA
13 int n=*pn, d=*pd;
14 //double* M2 = (double*)calloc(d*d,sizeof(double));
cbd88fe5 15
6dd5c2ac
BA
16 // M2 = E[Y*X^*2] - E[Y*e^*2] = E[Y (X^*2 - I)]
17 for (int j=0; j<d; j++)
18 {
19 for (int i=0; i<n; i++)
20 {
21 M2[mi(j,j,d,d)] -= Y[i] / n;
22 for (int k=0; k<d; k++)
23 M2[mi(j,k,d,d)] += Y[i] * X[mi(i,j,n,d)]*X[mi(i,k,n,d)] / n;
24 }
25 }
cbd88fe5
BA
26}
27
d08fef42 28// Empirical cross-moment of order 3 between X size nxd and Y size n
cbd88fe5
BA
29void Moments_M3(double* X, double* Y, int* pn, int* pd, double* M3)
30{
6dd5c2ac
BA
31 int n=*pn, d=*pd;
32 //double* M3 = (double*)calloc(d*d*d,sizeof(double));
cbd88fe5 33
6dd5c2ac
BA
34 // M3 = E[Y*X^*3] - E[Y*e*X*e] - E[Y*e*e*X] - E[Y*X*e*e]
35 for (int j=0; j<d; j++)
36 {
37 for (int k=0; k<d; k++)
38 {
39 for (int i=0; i<n; i++)
40 {
41 double tensor_elt = Y[i]*X[mi(i,k,n,d)] / n;
42 M3[ti(j,k,j,d,d,d)] -= tensor_elt;
43 M3[ti(j,j,k,d,d,d)] -= tensor_elt;
44 M3[ti(k,j,j,d,d,d)] -= tensor_elt;
45 for (int o=0; o<d; o++)
46 M3[ti(j,k,o,d,d,d)] += Y[i] * X[mi(i,j,n,d)]*X[mi(i,k,n,d)]*X[mi(i,o,n,d)] / n;
47 }
48 }
49 }
cbd88fe5 50}
4263503b 51
d08fef42
BA
52// W = 1/N sum( t(g(Zi,theta)) g(Zi,theta) )
53// with g(Zi, theta) = i-th contribution to all moments (size dim) - real moments
5af71d43 54void Compute_Omega(double* X, int* Y, double* M, int* pnc, int* pn, int* pd, double* W)
4263503b 55{
f47183de 56 int n=*pn, d=*pd; //,nc=*pnc
b389a46a 57 int dim = d + d*d + d*d*d;
bbdcfe44 58 //double* W = (double*)malloc(dim*dim*sizeof(double));
4bf8494d
BA
59
60 // (Re)Initialize W:
61 for (int j=0; j<dim; j++)
62 {
63 for (int k=0; k<dim; k++)
64 W[j*dim+k] = 0.0;
65 }
4bf8494d 66 double* g = (double*)malloc(dim*sizeof(double));
a27d53c3
BA
67 // TODO: stabilize this (for now, random result)
68// omp_set_num_threads(nc >= 1 ? nc : omp_get_num_procs());
69// #pragma omp parallel for
7737c2fa
BA
70 for (int i=0; i<n; i++)
71 {
bbdcfe44 72 // g == gi:
d08fef42 73 for (int j=0; j<d; j++)
19d893c4 74 g[j] = Y[i] * X[mi(i,j,n,d)] - M[j];
d08fef42
BA
75 for (int j=d; j<d+(d*d); j++)
76 {
77 int idx1 = (j-d) % d; //num row
78 int idx2 = ((j-d) - idx1) / d; //num col
79 g[j] = 0.0;
19d893c4
BA
80 if (idx1 == idx2)
81 g[j] -= Y[i];
82 g[j] += Y[i] * X[mi(i,idx1,n,d)]*X[mi(i,idx2,n,d)] - M[j];
d08fef42
BA
83 }
84 for (int j=d+d*d; j<dim; j++)
85 {
86 int idx1 = (j-d-d*d) % d; //num row
87 int idx2 = ((j-d-d*d - idx1) / d) %d; //num col
88 int idx3 = (((j-d-d*d - idx1) / d) - idx2) / d; //num "depth"
89 g[j] = 0.0;
19d893c4
BA
90 if (idx1 == idx2)
91 g[j] -= Y[i] * X[mi(i,idx3,n,d)];
92 if (idx1 == idx3)
93 g[j] -= Y[i] * X[mi(i,idx2,n,d)];
94 if (idx2 == idx3)
95 g[j] -= Y[i] * X[mi(i,idx1,n,d)];
96 g[j] += Y[i] * X[mi(i,idx1,n,d)]*X[mi(i,idx2,n,d)]*X[mi(i,idx3,n,d)] - M[j];
d08fef42
BA
97 }
98 // Add 1/n t(gi) %*% gi to W
9fdd3e5f 99 for (int j=0; j<dim; j++)
d08fef42 100 {
5af71d43
BA
101 // This final nested loop is very costly. Some basic optimisations:
102 double gj = g[j];
103 int baseIdx = j * dim;
f47183de 104// #pragma GCC unroll 32
9fdd3e5f 105 for (int k=j; k>=0; k--)
5af71d43 106 W[baseIdx+k] += gj * g[k];
d08fef42 107 }
7737c2fa 108 }
5af71d43
BA
109 // Normalize W: x 1/n
110 for (int j=0; j<dim; j++)
111 {
ab35f610 112 for (int k=j; k<dim; k++)
5af71d43
BA
113 W[mi(j,k,dim,dim)] /= n;
114 }
ab35f610 115 // Symmetrize W: W[k,j] = W[j,k] for k > j
9fdd3e5f
BA
116 for (int j=0; j<dim; j++)
117 {
de1a19fd 118 for (int k=j+1; k<dim; k++)
ab35f610 119 W[mi(k,j,dim,dim)] = W[mi(j,k,dim,dim)];
9fdd3e5f 120 }
d08fef42 121 free(g);
4263503b 122}