prepare structure for R package
[valse.git] / src / sources / EMGrank.c
CommitLineData
1d3c1faa
BA
1#include "EMGrank.h"
2#include <gsl/gsl_linalg.h>
3
4// Compute pseudo-inverse of a square matrix
5static Real* pinv(const Real* matrix, mwSize dim)
6{
7 gsl_matrix* U = gsl_matrix_alloc(dim,dim);
8 gsl_matrix* V = gsl_matrix_alloc(dim,dim);
9 gsl_vector* S = gsl_vector_alloc(dim);
10 gsl_vector* work = gsl_vector_alloc(dim);
11 Real EPS = 1e-10; //threshold for singular value "== 0"
12
13 //copy matrix into U
14 for (mwSize i=0; i<dim*dim; i++)
15 U->data[i] = matrix[i];
16
17 //U,S,V = SVD of matrix
18 gsl_linalg_SV_decomp(U, V, S, work);
19 gsl_vector_free(work);
20
21 // Obtain pseudo-inverse by V*S^{-1}*t(U)
22 Real* inverse = (Real*)malloc(dim*dim*sizeof(Real));
23 for (mwSize i=0; i<dim; i++)
24 {
25 for (mwSize ii=0; ii<dim; ii++)
26 {
27 Real dotProduct = 0.0;
28 for (mwSize j=0; j<dim; j++)
29 dotProduct += V->data[i*dim+j] * (S->data[j] > EPS ? 1.0/S->data[j] : 0.0) * U->data[ii*dim+j];
30 inverse[i*dim+ii] = dotProduct;
31 }
32 }
33
34 gsl_matrix_free(U);
35 gsl_matrix_free(V);
36 gsl_vector_free(S);
37 return inverse;
38}
39
40// TODO: comment EMGrank purpose
41void EMGrank(
42 // IN parameters
43 const Real* Pi, // parametre de proportion
44 const Real* Rho, // parametre initial de variance renormalisé
45 Int mini, // nombre minimal d'itérations dans l'algorithme EM
46 Int maxi, // nombre maximal d'itérations dans l'algorithme EM
47 const Real* X, // régresseurs
48 const Real* Y, // réponse
49 Real tau, // seuil pour accepter la convergence
50 const Int* rank, // vecteur des rangs possibles
51 // OUT parameters
52 Real* phi, // parametre de moyenne renormalisé, calculé par l'EM
53 Real* LLF, // log vraisemblance associé à cet échantillon, pour les valeurs estimées des paramètres
54 // additional size parameters
55 mwSize n, // taille de l'echantillon
56 mwSize p, // nombre de covariables
57 mwSize m, // taille de Y (multivarié)
58 mwSize k) // nombre de composantes
59{
60 // Allocations, initializations
61 Real* Phi = (Real*)calloc(p*m*k,sizeof(Real));
62 Real* hatBetaR = (Real*)malloc(p*m*sizeof(Real));
63 int signum;
64 Real invN = 1.0/n;
65 int deltaPhiBufferSize = 20;
66 Real* deltaPhi = (Real*)malloc(deltaPhiBufferSize*sizeof(Real));
67 mwSize ite = 0;
68 Real sumDeltaPhi = 0.0;
69 Real* YiRhoR = (Real*)malloc(m*sizeof(Real));
70 Real* XiPhiR = (Real*)malloc(m*sizeof(Real));
71 Real* Xr = (Real*)malloc(n*p*sizeof(Real));
72 Real* Yr = (Real*)malloc(n*m*sizeof(Real));
73 Real* tXrXr = (Real*)malloc(p*p*sizeof(Real));
74 Real* tXrYr = (Real*)malloc(p*m*sizeof(Real));
75 gsl_matrix* matrixM = gsl_matrix_alloc(p, m);
76 gsl_matrix* matrixE = gsl_matrix_alloc(m, m);
77 gsl_permutation* permutation = gsl_permutation_alloc(m);
78 gsl_matrix* V = gsl_matrix_alloc(m,m);
79 gsl_vector* S = gsl_vector_alloc(m);
80 gsl_vector* work = gsl_vector_alloc(m);
81
82 //Initialize class memberships (all elements in class 0; TODO: randomize ?)
83 Int* Z = (Int*)calloc(n, sizeof(Int));
84
85 //Initialize phi to zero, because some M loops might exit before phi affectation
86 for (mwSize i=0; i<p*m*k; i++)
87 phi[i] = 0.0;
88
89 while (ite<mini || (ite<maxi && sumDeltaPhi>tau))
90 {
91 /////////////
92 // Etape M //
93 /////////////
94
95 //M step: Mise à jour de Beta (et donc phi)
96 for (mwSize r=0; r<k; r++)
97 {
98 //Compute Xr = X(Z==r,:) and Yr = Y(Z==r,:)
99 mwSize cardClustR=0;
100 for (mwSize i=0; i<n; i++)
101 {
102 if (Z[i] == r)
103 {
104 for (mwSize j=0; j<p; j++)
105 Xr[cardClustR*p+j] = X[i*p+j];
106 for (mwSize j=0; j<m; j++)
107 Yr[cardClustR*m+j] = Y[i*m+j];
108 cardClustR++;
109 }
110 }
111 if (cardClustR == 0)
112 continue;
113
114 //Compute tXrXr = t(Xr) * Xr
115 for (mwSize j=0; j<p; j++)
116 {
117 for (mwSize jj=0; jj<p; jj++)
118 {
119 Real dotProduct = 0.0;
120 for (mwSize u=0; u<cardClustR; u++)
121 dotProduct += Xr[u*p+j] * Xr[u*p+jj];
122 tXrXr[j*p+jj] = dotProduct;
123 }
124 }
125
126 //Get pseudo inverse = (t(Xr)*Xr)^{-1}
127 Real* invTXrXr = pinv(tXrXr, p);
128
129 // Compute tXrYr = t(Xr) * Yr
130 for (mwSize j=0; j<p; j++)
131 {
132 for (mwSize jj=0; jj<m; jj++)
133 {
134 Real dotProduct = 0.0;
135 for (mwSize u=0; u<cardClustR; u++)
136 dotProduct += Xr[u*p+j] * Yr[u*m+jj];
137 tXrYr[j*m+jj] = dotProduct;
138 }
139 }
140
141 //Fill matrixM with inverse * tXrYr = (t(Xr)*Xr)^{-1} * t(Xr) * Yr
142 for (mwSize j=0; j<p; j++)
143 {
144 for (mwSize jj=0; jj<m; jj++)
145 {
146 Real dotProduct = 0.0;
147 for (mwSize u=0; u<p; u++)
148 dotProduct += invTXrXr[j*p+u] * tXrYr[u*m+jj];
149 matrixM->data[j*m+jj] = dotProduct;
150 }
151 }
152 free(invTXrXr);
153
154 //U,S,V = SVD of (t(Xr)Xr)^{-1} * t(Xr) * Yr
155 gsl_linalg_SV_decomp(matrixM, V, S, work);
156
157 //Set m-rank(r) singular values to zero, and recompose
158 //best rank(r) approximation of the initial product
159 for (mwSize j=rank[r]; j<m; j++)
160 S->data[j] = 0.0;
161
162 //[intermediate step] Compute hatBetaR = U * S * t(V)
163 Real* U = matrixM->data;
164 for (mwSize j=0; j<p; j++)
165 {
166 for (mwSize jj=0; jj<m; jj++)
167 {
168 Real dotProduct = 0.0;
169 for (mwSize u=0; u<m; u++)
170 dotProduct += U[j*m+u] * S->data[u] * V->data[jj*m+u];
171 hatBetaR[j*m+jj] = dotProduct;
172 }
173 }
174
175 //Compute phi(:,:,r) = hatBetaR * Rho(:,:,r)
176 for (mwSize j=0; j<p; j++)
177 {
178 for (mwSize jj=0; jj<m; jj++)
179 {
180 Real dotProduct=0.0;
181 for (mwSize u=0; u<m; u++)
182 dotProduct += hatBetaR[j*m+u] * Rho[u*m*k+jj*k+r];
183 phi[j*m*k+jj*k+r] = dotProduct;
184 }
185 }
186 }
187
188 /////////////
189 // Etape E //
190 /////////////
191
192 Real sumLogLLF2 = 0.0;
193 for (mwSize i=0; i<n; i++)
194 {
195 Real sumLLF1 = 0.0;
196 Real maxLogGamIR = -INFINITY;
197 for (mwSize r=0; r<k; r++)
198 {
199 //Compute
200 //Gam(i,r) = Pi(r) * det(Rho(:,:,r)) * exp( -1/2 * (Y(i,:)*Rho(:,:,r) - X(i,:)...
201 // *phi(:,:,r)) * transpose( Y(i,:)*Rho(:,:,r) - X(i,:)*phi(:,:,r) ) );
202 //split in several sub-steps
203
204 //compute det(Rho(:,:,r)) [TODO: avoid re-computations]
205 for (mwSize j=0; j<m; j++)
206 {
207 for (mwSize jj=0; jj<m; jj++)
208 matrixE->data[j*m+jj] = Rho[j*m*k+jj*k+r];
209 }
210 gsl_linalg_LU_decomp(matrixE, permutation, &signum);
211 Real detRhoR = gsl_linalg_LU_det(matrixE, signum);
212
213 //compute Y(i,:)*Rho(:,:,r)
214 for (mwSize j=0; j<m; j++)
215 {
216 YiRhoR[j] = 0.0;
217 for (mwSize u=0; u<m; u++)
218 YiRhoR[j] += Y[i*m+u] * Rho[u*m*k+j*k+r];
219 }
220
221 //compute X(i,:)*phi(:,:,r)
222 for (mwSize j=0; j<m; j++)
223 {
224 XiPhiR[j] = 0.0;
225 for (mwSize u=0; u<p; u++)
226 XiPhiR[j] += X[i*p+u] * phi[u*m*k+j*k+r];
227 }
228
229 //compute dotProduct < Y(:,i)*rho(:,:,r)-X(i,:)*phi(:,:,r) . Y(:,i)*rho(:,:,r)-X(i,:)*phi(:,:,r) >
230 Real dotProduct = 0.0;
231 for (mwSize u=0; u<m; u++)
232 dotProduct += (YiRhoR[u]-XiPhiR[u]) * (YiRhoR[u]-XiPhiR[u]);
233 Real logGamIR = log(Pi[r]) + log(detRhoR) - 0.5*dotProduct;
234
235 //Z(i) = index of max (gam(i,:))
236 if (logGamIR > maxLogGamIR)
237 {
238 Z[i] = r;
239 maxLogGamIR = logGamIR;
240 }
241 sumLLF1 += exp(logGamIR) / pow(2*M_PI,m/2.0);
242 }
243
244 sumLogLLF2 += log(sumLLF1);
245 }
246
247 // Assign output variable LLF
248 *LLF = -invN * sumLogLLF2;
249
250 //newDeltaPhi = max(max((abs(phi-Phi))./(1+abs(phi))));
251 Real newDeltaPhi = 0.0;
252 for (mwSize j=0; j<p; j++)
253 {
254 for (mwSize jj=0; jj<m; jj++)
255 {
256 for (mwSize r=0; r<k; r++)
257 {
258 Real tmpDist = fabs(phi[j*m*k+jj*k+r]-Phi[j*m*k+jj*k+r])
259 / (1.0+fabs(phi[j*m*k+jj*k+r]));
260 if (tmpDist > newDeltaPhi)
261 newDeltaPhi = tmpDist;
262 }
263 }
264 }
265
266 //update distance parameter to check algorithm convergence (delta(phi, Phi))
267 //TODO: deltaPhi should be a linked list for perf.
268 if (ite < deltaPhiBufferSize)
269 deltaPhi[ite] = newDeltaPhi;
270 else
271 {
272 sumDeltaPhi -= deltaPhi[0];
273 for (int u=0; u<deltaPhiBufferSize-1; u++)
274 deltaPhi[u] = deltaPhi[u+1];
275 deltaPhi[deltaPhiBufferSize-1] = newDeltaPhi;
276 }
277 sumDeltaPhi += newDeltaPhi;
278
279 // update other local variables
280 for (mwSize j=0; j<m; j++)
281 {
282 for (mwSize jj=0; jj<p; jj++)
283 {
284 for (mwSize r=0; r<k; r++)
285 Phi[j*m*k+jj*k+r] = phi[j*m*k+jj*k+r];
286 }
287 }
288 ite++;
289 }
290
291 //free memory
292 free(hatBetaR);
293 free(deltaPhi);
294 free(Phi);
295 gsl_matrix_free(matrixE);
296 gsl_matrix_free(matrixM);
297 gsl_permutation_free(permutation);
298 gsl_vector_free(work);
299 gsl_matrix_free(V);
300 gsl_vector_free(S);
301 free(XiPhiR);
302 free(YiRhoR);
303 free(Xr);
304 free(Yr);
305 free(tXrXr);
306 free(tXrYr);
307 free(Z);
308}