fix memory leaks on EMGLLF, test OK for EMGrank
[valse.git] / src / sources / EMGrank.c
CommitLineData
8e92c49c 1#include <stdlib.h>
1d3c1faa 2#include <gsl/gsl_linalg.h>
8e92c49c 3#include "utils.h"
1d3c1faa
BA
4
5// Compute pseudo-inverse of a square matrix
9ff729fb 6static Real* pinv(const Real* matrix, int dim)
1d3c1faa
BA
7{
8 gsl_matrix* U = gsl_matrix_alloc(dim,dim);
9 gsl_matrix* V = gsl_matrix_alloc(dim,dim);
10 gsl_vector* S = gsl_vector_alloc(dim);
11 gsl_vector* work = gsl_vector_alloc(dim);
9ff729fb 12 Real EPS = 1e-10; //threshold for singular value "== 0"
1d3c1faa
BA
13
14 //copy matrix into U
552b00e2
BA
15 copyArray(matrix, U->data, dim*dim);
16
1d3c1faa
BA
17 //U,S,V = SVD of matrix
18 gsl_linalg_SV_decomp(U, V, S, work);
19 gsl_vector_free(work);
552b00e2 20
1d3c1faa 21 // Obtain pseudo-inverse by V*S^{-1}*t(U)
9ff729fb 22 Real* inverse = (Real*)malloc(dim*dim*sizeof(Real));
552b00e2 23 for (int i=0; i<dim; i++)
1d3c1faa 24 {
552b00e2 25 for (int ii=0; ii<dim; ii++)
1d3c1faa 26 {
9ff729fb 27 Real dotProduct = 0.0;
552b00e2 28 for (int j=0; j<dim; j++)
1d3c1faa
BA
29 dotProduct += V->data[i*dim+j] * (S->data[j] > EPS ? 1.0/S->data[j] : 0.0) * U->data[ii*dim+j];
30 inverse[i*dim+ii] = dotProduct;
31 }
32 }
33
34 gsl_matrix_free(U);
35 gsl_matrix_free(V);
36 gsl_vector_free(S);
37 return inverse;
38}
39
40// TODO: comment EMGrank purpose
09ab3c16 41void EMGrank_core(
1d3c1faa 42 // IN parameters
9ff729fb
BA
43 const Real* Pi, // parametre de proportion
44 const Real* Rho, // parametre initial de variance renormalisé
552b00e2
BA
45 int mini, // nombre minimal d'itérations dans l'algorithme EM
46 int maxi, // nombre maximal d'itérations dans l'algorithme EM
9ff729fb
BA
47 const Real* X, // régresseurs
48 const Real* Y, // réponse
49 Real tau, // seuil pour accepter la convergence
552b00e2 50 const int* rank, // vecteur des rangs possibles
1d3c1faa 51 // OUT parameters
9ff729fb
BA
52 Real* phi, // parametre de moyenne renormalisé, calculé par l'EM
53 Real* LLF, // log vraisemblance associé à cet échantillon, pour les valeurs estimées des paramètres
1d3c1faa 54 // additional size parameters
552b00e2
BA
55 int n, // taille de l'echantillon
56 int p, // nombre de covariables
57 int m, // taille de Y (multivarié)
58 int k) // nombre de composantes
1d3c1faa
BA
59{
60 // Allocations, initializations
9ff729fb
BA
61 Real* Phi = (Real*)calloc(p*m*k,sizeof(Real));
62 Real* hatBetaR = (Real*)malloc(p*m*sizeof(Real));
1d3c1faa 63 int signum;
9ff729fb 64 Real invN = 1.0/n;
1d3c1faa 65 int deltaPhiBufferSize = 20;
9ff729fb 66 Real* deltaPhi = (Real*)malloc(deltaPhiBufferSize*sizeof(Real));
552b00e2 67 int ite = 0;
9ff729fb
BA
68 Real sumDeltaPhi = 0.0;
69 Real* YiRhoR = (Real*)malloc(m*sizeof(Real));
70 Real* XiPhiR = (Real*)malloc(m*sizeof(Real));
71 Real* Xr = (Real*)malloc(n*p*sizeof(Real));
72 Real* Yr = (Real*)malloc(n*m*sizeof(Real));
73 Real* tXrXr = (Real*)malloc(p*p*sizeof(Real));
74 Real* tXrYr = (Real*)malloc(p*m*sizeof(Real));
552b00e2 75 gsl_matrix* matrixM = gsl_matrix_alloc(p, m);
1d3c1faa
BA
76 gsl_matrix* matrixE = gsl_matrix_alloc(m, m);
77 gsl_permutation* permutation = gsl_permutation_alloc(m);
78 gsl_matrix* V = gsl_matrix_alloc(m,m);
79 gsl_vector* S = gsl_vector_alloc(m);
80 gsl_vector* work = gsl_vector_alloc(m);
81
82 //Initialize class memberships (all elements in class 0; TODO: randomize ?)
552b00e2 83 int* Z = (int*)calloc(n, sizeof(int));
1d3c1faa
BA
84
85 //Initialize phi to zero, because some M loops might exit before phi affectation
8e92c49c 86 zeroArray(phi, p*m*k);
1d3c1faa
BA
87
88 while (ite<mini || (ite<maxi && sumDeltaPhi>tau))
552b00e2 89 {
1d3c1faa
BA
90 /////////////
91 // Etape M //
92 /////////////
93
94 //M step: Mise à jour de Beta (et donc phi)
552b00e2 95 for (int r=0; r<k; r++)
1d3c1faa
BA
96 {
97 //Compute Xr = X(Z==r,:) and Yr = Y(Z==r,:)
552b00e2
BA
98 int cardClustR=0;
99 for (int i=0; i<n; i++)
1d3c1faa
BA
100 {
101 if (Z[i] == r)
102 {
552b00e2
BA
103 for (int j=0; j<p; j++)
104 Xr[mi(cardClustR,j,n,p)] = X[mi(i,j,n,p)];
105 for (int j=0; j<m; j++)
106 Yr[mi(cardClustR,j,n,m)] = Y[mi(i,j,n,m)];
1d3c1faa
BA
107 cardClustR++;
108 }
109 }
552b00e2 110 if (cardClustR == 0)
1d3c1faa
BA
111 continue;
112
113 //Compute tXrXr = t(Xr) * Xr
552b00e2 114 for (int j=0; j<p; j++)
1d3c1faa 115 {
552b00e2 116 for (int jj=0; jj<p; jj++)
1d3c1faa 117 {
9ff729fb 118 Real dotProduct = 0.0;
552b00e2
BA
119 for (int u=0; u<cardClustR; u++)
120 dotProduct += Xr[mi(u,j,n,p)] * Xr[mi(u,jj,n,p)];
121 tXrXr[mi(j,jj,p,p)] = dotProduct;
1d3c1faa
BA
122 }
123 }
124
125 //Get pseudo inverse = (t(Xr)*Xr)^{-1}
9ff729fb 126 Real* invTXrXr = pinv(tXrXr, p);
1d3c1faa
BA
127
128 // Compute tXrYr = t(Xr) * Yr
552b00e2 129 for (int j=0; j<p; j++)
1d3c1faa 130 {
552b00e2 131 for (int jj=0; jj<m; jj++)
1d3c1faa 132 {
9ff729fb 133 Real dotProduct = 0.0;
552b00e2 134 for (int u=0; u<cardClustR; u++)
c3bc4705 135 dotProduct += Xr[mi(u,j,n,p)] * Yr[mi(u,jj,n,m)];
552b00e2 136 tXrYr[mi(j,jj,p,m)] = dotProduct;
1d3c1faa
BA
137 }
138 }
139
140 //Fill matrixM with inverse * tXrYr = (t(Xr)*Xr)^{-1} * t(Xr) * Yr
552b00e2 141 for (int j=0; j<p; j++)
1d3c1faa 142 {
552b00e2 143 for (int jj=0; jj<m; jj++)
1d3c1faa 144 {
9ff729fb 145 Real dotProduct = 0.0;
552b00e2
BA
146 for (int u=0; u<p; u++)
147 dotProduct += invTXrXr[mi(j,u,p,p)] * tXrYr[mi(u,jj,p,m)];
148 matrixM->data[j*m+jj] = dotProduct;
1d3c1faa
BA
149 }
150 }
151 free(invTXrXr);
552b00e2 152
1d3c1faa
BA
153 //U,S,V = SVD of (t(Xr)Xr)^{-1} * t(Xr) * Yr
154 gsl_linalg_SV_decomp(matrixM, V, S, work);
552b00e2
BA
155
156 //Set m-rank(r) singular values to zero, and recompose
1d3c1faa 157 //best rank(r) approximation of the initial product
552b00e2 158 for (int j=rank[r]; j<m; j++)
1d3c1faa
BA
159 S->data[j] = 0.0;
160
161 //[intermediate step] Compute hatBetaR = U * S * t(V)
9ff729fb 162 double* U = matrixM->data; //GSL require double precision
552b00e2 163 for (int j=0; j<p; j++)
1d3c1faa 164 {
552b00e2 165 for (int jj=0; jj<m; jj++)
1d3c1faa 166 {
9ff729fb 167 Real dotProduct = 0.0;
552b00e2 168 for (int u=0; u<m; u++)
1d3c1faa 169 dotProduct += U[j*m+u] * S->data[u] * V->data[jj*m+u];
552b00e2 170 hatBetaR[mi(j,jj,p,m)] = dotProduct;
1d3c1faa
BA
171 }
172 }
552b00e2 173
1d3c1faa 174 //Compute phi(:,:,r) = hatBetaR * Rho(:,:,r)
552b00e2 175 for (int j=0; j<p; j++)
1d3c1faa 176 {
552b00e2 177 for (int jj=0; jj<m; jj++)
1d3c1faa 178 {
9ff729fb 179 Real dotProduct=0.0;
552b00e2
BA
180 for (int u=0; u<m; u++)
181 dotProduct += hatBetaR[mi(j,u,p,m)] * Rho[ai(u,jj,r,m,m,k)];
182 phi[ai(j,jj,r,p,m,k)] = dotProduct;
1d3c1faa 183 }
552b00e2 184 }
1d3c1faa
BA
185 }
186
187 /////////////
188 // Etape E //
189 /////////////
190
9ff729fb 191 Real sumLogLLF2 = 0.0;
552b00e2 192 for (int i=0; i<n; i++)
1d3c1faa 193 {
9ff729fb
BA
194 Real sumLLF1 = 0.0;
195 Real maxLogGamIR = -INFINITY;
552b00e2 196 for (int r=0; r<k; r++)
1d3c1faa
BA
197 {
198 //Compute
199 //Gam(i,r) = Pi(r) * det(Rho(:,:,r)) * exp( -1/2 * (Y(i,:)*Rho(:,:,r) - X(i,:)...
552b00e2 200 //*phi(:,:,r)) * transpose( Y(i,:)*Rho(:,:,r) - X(i,:)*phi(:,:,r) ) );
1d3c1faa
BA
201 //split in several sub-steps
202
203 //compute det(Rho(:,:,r)) [TODO: avoid re-computations]
552b00e2 204 for (int j=0; j<m; j++)
1d3c1faa 205 {
552b00e2
BA
206 for (int jj=0; jj<m; jj++)
207 matrixE->data[j*m+jj] = Rho[ai(j,jj,r,m,m,k)];
1d3c1faa
BA
208 }
209 gsl_linalg_LU_decomp(matrixE, permutation, &signum);
9ff729fb 210 Real detRhoR = gsl_linalg_LU_det(matrixE, signum);
552b00e2 211
1d3c1faa 212 //compute Y(i,:)*Rho(:,:,r)
552b00e2 213 for (int j=0; j<m; j++)
1d3c1faa
BA
214 {
215 YiRhoR[j] = 0.0;
552b00e2
BA
216 for (int u=0; u<m; u++)
217 YiRhoR[j] += Y[mi(i,u,n,m)] * Rho[ai(u,j,r,m,m,k)];
1d3c1faa 218 }
552b00e2 219
1d3c1faa 220 //compute X(i,:)*phi(:,:,r)
552b00e2 221 for (int j=0; j<m; j++)
1d3c1faa
BA
222 {
223 XiPhiR[j] = 0.0;
552b00e2
BA
224 for (int u=0; u<p; u++)
225 XiPhiR[j] += X[mi(i,u,n,p)] * phi[ai(u,j,r,p,m,k)];
1d3c1faa 226 }
552b00e2 227
1d3c1faa 228 //compute dotProduct < Y(:,i)*rho(:,:,r)-X(i,:)*phi(:,:,r) . Y(:,i)*rho(:,:,r)-X(i,:)*phi(:,:,r) >
9ff729fb 229 Real dotProduct = 0.0;
552b00e2 230 for (int u=0; u<m; u++)
1d3c1faa 231 dotProduct += (YiRhoR[u]-XiPhiR[u]) * (YiRhoR[u]-XiPhiR[u]);
9ff729fb 232 Real logGamIR = log(Pi[r]) + log(detRhoR) - 0.5*dotProduct;
552b00e2 233
1d3c1faa
BA
234 //Z(i) = index of max (gam(i,:))
235 if (logGamIR > maxLogGamIR)
236 {
237 Z[i] = r;
238 maxLogGamIR = logGamIR;
239 }
240 sumLLF1 += exp(logGamIR) / pow(2*M_PI,m/2.0);
241 }
552b00e2 242
1d3c1faa
BA
243 sumLogLLF2 += log(sumLLF1);
244 }
552b00e2 245
1d3c1faa
BA
246 // Assign output variable LLF
247 *LLF = -invN * sumLogLLF2;
248
249 //newDeltaPhi = max(max((abs(phi-Phi))./(1+abs(phi))));
9ff729fb 250 Real newDeltaPhi = 0.0;
552b00e2 251 for (int j=0; j<p; j++)
1d3c1faa 252 {
552b00e2 253 for (int jj=0; jj<m; jj++)
1d3c1faa 254 {
552b00e2 255 for (int r=0; r<k; r++)
1d3c1faa 256 {
9ff729fb 257 Real tmpDist = fabs(phi[ai(j,jj,r,p,m,k)]-Phi[ai(j,jj,r,p,m,k)])
552b00e2 258 / (1.0+fabs(phi[ai(j,jj,r,p,m,k)]));
1d3c1faa
BA
259 if (tmpDist > newDeltaPhi)
260 newDeltaPhi = tmpDist;
261 }
262 }
263 }
264
265 //update distance parameter to check algorithm convergence (delta(phi, Phi))
266 //TODO: deltaPhi should be a linked list for perf.
267 if (ite < deltaPhiBufferSize)
268 deltaPhi[ite] = newDeltaPhi;
269 else
270 {
271 sumDeltaPhi -= deltaPhi[0];
272 for (int u=0; u<deltaPhiBufferSize-1; u++)
273 deltaPhi[u] = deltaPhi[u+1];
274 deltaPhi[deltaPhiBufferSize-1] = newDeltaPhi;
275 }
276 sumDeltaPhi += newDeltaPhi;
277
278 // update other local variables
552b00e2 279 for (int j=0; j<m; j++)
1d3c1faa 280 {
552b00e2 281 for (int jj=0; jj<p; jj++)
1d3c1faa 282 {
552b00e2 283 for (int r=0; r<k; r++)
c3bc4705 284 Phi[ai(jj,j,r,p,m,k)] = phi[ai(jj,j,r,p,m,k)];
1d3c1faa
BA
285 }
286 }
552b00e2 287 ite++;
1d3c1faa 288 }
552b00e2 289
1d3c1faa
BA
290 //free memory
291 free(hatBetaR);
292 free(deltaPhi);
293 free(Phi);
294 gsl_matrix_free(matrixE);
295 gsl_matrix_free(matrixM);
296 gsl_permutation_free(permutation);
297 gsl_vector_free(work);
298 gsl_matrix_free(V);
299 gsl_vector_free(S);
300 free(XiPhiR);
301 free(YiRhoR);
302 free(Xr);
303 free(Yr);
304 free(tXrXr);
305 free(tXrYr);
552b00e2 306 free(Z);
1d3c1faa 307}