Commit | Line | Data |
---|---|---|
1d3c1faa BA |
1 | #include "EMGrank.h" |
2 | #include <gsl/gsl_linalg.h> | |
3 | ||
4 | // Compute pseudo-inverse of a square matrix | |
5 | static Real* pinv(const Real* matrix, mwSize dim) | |
6 | { | |
7 | gsl_matrix* U = gsl_matrix_alloc(dim,dim); | |
8 | gsl_matrix* V = gsl_matrix_alloc(dim,dim); | |
9 | gsl_vector* S = gsl_vector_alloc(dim); | |
10 | gsl_vector* work = gsl_vector_alloc(dim); | |
11 | Real EPS = 1e-10; //threshold for singular value "== 0" | |
12 | ||
13 | //copy matrix into U | |
14 | for (mwSize i=0; i<dim*dim; i++) | |
15 | U->data[i] = matrix[i]; | |
16 | ||
17 | //U,S,V = SVD of matrix | |
18 | gsl_linalg_SV_decomp(U, V, S, work); | |
19 | gsl_vector_free(work); | |
20 | ||
21 | // Obtain pseudo-inverse by V*S^{-1}*t(U) | |
22 | Real* inverse = (Real*)malloc(dim*dim*sizeof(Real)); | |
23 | for (mwSize i=0; i<dim; i++) | |
24 | { | |
25 | for (mwSize ii=0; ii<dim; ii++) | |
26 | { | |
27 | Real dotProduct = 0.0; | |
28 | for (mwSize j=0; j<dim; j++) | |
29 | dotProduct += V->data[i*dim+j] * (S->data[j] > EPS ? 1.0/S->data[j] : 0.0) * U->data[ii*dim+j]; | |
30 | inverse[i*dim+ii] = dotProduct; | |
31 | } | |
32 | } | |
33 | ||
34 | gsl_matrix_free(U); | |
35 | gsl_matrix_free(V); | |
36 | gsl_vector_free(S); | |
37 | return inverse; | |
38 | } | |
39 | ||
40 | // TODO: comment EMGrank purpose | |
41 | void EMGrank( | |
42 | // IN parameters | |
43 | const Real* Pi, // parametre de proportion | |
44 | const Real* Rho, // parametre initial de variance renormalisé | |
45 | Int mini, // nombre minimal d'itérations dans l'algorithme EM | |
46 | Int maxi, // nombre maximal d'itérations dans l'algorithme EM | |
47 | const Real* X, // régresseurs | |
48 | const Real* Y, // réponse | |
49 | Real tau, // seuil pour accepter la convergence | |
50 | const Int* rank, // vecteur des rangs possibles | |
51 | // OUT parameters | |
52 | Real* phi, // parametre de moyenne renormalisé, calculé par l'EM | |
53 | Real* LLF, // log vraisemblance associé à cet échantillon, pour les valeurs estimées des paramètres | |
54 | // additional size parameters | |
55 | mwSize n, // taille de l'echantillon | |
56 | mwSize p, // nombre de covariables | |
57 | mwSize m, // taille de Y (multivarié) | |
58 | mwSize k) // nombre de composantes | |
59 | { | |
60 | // Allocations, initializations | |
61 | Real* Phi = (Real*)calloc(p*m*k,sizeof(Real)); | |
62 | Real* hatBetaR = (Real*)malloc(p*m*sizeof(Real)); | |
63 | int signum; | |
64 | Real invN = 1.0/n; | |
65 | int deltaPhiBufferSize = 20; | |
66 | Real* deltaPhi = (Real*)malloc(deltaPhiBufferSize*sizeof(Real)); | |
67 | mwSize ite = 0; | |
68 | Real sumDeltaPhi = 0.0; | |
69 | Real* YiRhoR = (Real*)malloc(m*sizeof(Real)); | |
70 | Real* XiPhiR = (Real*)malloc(m*sizeof(Real)); | |
71 | Real* Xr = (Real*)malloc(n*p*sizeof(Real)); | |
72 | Real* Yr = (Real*)malloc(n*m*sizeof(Real)); | |
73 | Real* tXrXr = (Real*)malloc(p*p*sizeof(Real)); | |
74 | Real* tXrYr = (Real*)malloc(p*m*sizeof(Real)); | |
75 | gsl_matrix* matrixM = gsl_matrix_alloc(p, m); | |
76 | gsl_matrix* matrixE = gsl_matrix_alloc(m, m); | |
77 | gsl_permutation* permutation = gsl_permutation_alloc(m); | |
78 | gsl_matrix* V = gsl_matrix_alloc(m,m); | |
79 | gsl_vector* S = gsl_vector_alloc(m); | |
80 | gsl_vector* work = gsl_vector_alloc(m); | |
81 | ||
82 | //Initialize class memberships (all elements in class 0; TODO: randomize ?) | |
83 | Int* Z = (Int*)calloc(n, sizeof(Int)); | |
84 | ||
85 | //Initialize phi to zero, because some M loops might exit before phi affectation | |
86 | for (mwSize i=0; i<p*m*k; i++) | |
87 | phi[i] = 0.0; | |
88 | ||
89 | while (ite<mini || (ite<maxi && sumDeltaPhi>tau)) | |
90 | { | |
91 | ///////////// | |
92 | // Etape M // | |
93 | ///////////// | |
94 | ||
95 | //M step: Mise à jour de Beta (et donc phi) | |
96 | for (mwSize r=0; r<k; r++) | |
97 | { | |
98 | //Compute Xr = X(Z==r,:) and Yr = Y(Z==r,:) | |
99 | mwSize cardClustR=0; | |
100 | for (mwSize i=0; i<n; i++) | |
101 | { | |
102 | if (Z[i] == r) | |
103 | { | |
104 | for (mwSize j=0; j<p; j++) | |
105 | Xr[cardClustR*p+j] = X[i*p+j]; | |
106 | for (mwSize j=0; j<m; j++) | |
107 | Yr[cardClustR*m+j] = Y[i*m+j]; | |
108 | cardClustR++; | |
109 | } | |
110 | } | |
111 | if (cardClustR == 0) | |
112 | continue; | |
113 | ||
114 | //Compute tXrXr = t(Xr) * Xr | |
115 | for (mwSize j=0; j<p; j++) | |
116 | { | |
117 | for (mwSize jj=0; jj<p; jj++) | |
118 | { | |
119 | Real dotProduct = 0.0; | |
120 | for (mwSize u=0; u<cardClustR; u++) | |
121 | dotProduct += Xr[u*p+j] * Xr[u*p+jj]; | |
122 | tXrXr[j*p+jj] = dotProduct; | |
123 | } | |
124 | } | |
125 | ||
126 | //Get pseudo inverse = (t(Xr)*Xr)^{-1} | |
127 | Real* invTXrXr = pinv(tXrXr, p); | |
128 | ||
129 | // Compute tXrYr = t(Xr) * Yr | |
130 | for (mwSize j=0; j<p; j++) | |
131 | { | |
132 | for (mwSize jj=0; jj<m; jj++) | |
133 | { | |
134 | Real dotProduct = 0.0; | |
135 | for (mwSize u=0; u<cardClustR; u++) | |
136 | dotProduct += Xr[u*p+j] * Yr[u*m+jj]; | |
137 | tXrYr[j*m+jj] = dotProduct; | |
138 | } | |
139 | } | |
140 | ||
141 | //Fill matrixM with inverse * tXrYr = (t(Xr)*Xr)^{-1} * t(Xr) * Yr | |
142 | for (mwSize j=0; j<p; j++) | |
143 | { | |
144 | for (mwSize jj=0; jj<m; jj++) | |
145 | { | |
146 | Real dotProduct = 0.0; | |
147 | for (mwSize u=0; u<p; u++) | |
148 | dotProduct += invTXrXr[j*p+u] * tXrYr[u*m+jj]; | |
149 | matrixM->data[j*m+jj] = dotProduct; | |
150 | } | |
151 | } | |
152 | free(invTXrXr); | |
153 | ||
154 | //U,S,V = SVD of (t(Xr)Xr)^{-1} * t(Xr) * Yr | |
155 | gsl_linalg_SV_decomp(matrixM, V, S, work); | |
156 | ||
157 | //Set m-rank(r) singular values to zero, and recompose | |
158 | //best rank(r) approximation of the initial product | |
159 | for (mwSize j=rank[r]; j<m; j++) | |
160 | S->data[j] = 0.0; | |
161 | ||
162 | //[intermediate step] Compute hatBetaR = U * S * t(V) | |
163 | Real* U = matrixM->data; | |
164 | for (mwSize j=0; j<p; j++) | |
165 | { | |
166 | for (mwSize jj=0; jj<m; jj++) | |
167 | { | |
168 | Real dotProduct = 0.0; | |
169 | for (mwSize u=0; u<m; u++) | |
170 | dotProduct += U[j*m+u] * S->data[u] * V->data[jj*m+u]; | |
171 | hatBetaR[j*m+jj] = dotProduct; | |
172 | } | |
173 | } | |
174 | ||
175 | //Compute phi(:,:,r) = hatBetaR * Rho(:,:,r) | |
176 | for (mwSize j=0; j<p; j++) | |
177 | { | |
178 | for (mwSize jj=0; jj<m; jj++) | |
179 | { | |
180 | Real dotProduct=0.0; | |
181 | for (mwSize u=0; u<m; u++) | |
182 | dotProduct += hatBetaR[j*m+u] * Rho[u*m*k+jj*k+r]; | |
183 | phi[j*m*k+jj*k+r] = dotProduct; | |
184 | } | |
185 | } | |
186 | } | |
187 | ||
188 | ///////////// | |
189 | // Etape E // | |
190 | ///////////// | |
191 | ||
192 | Real sumLogLLF2 = 0.0; | |
193 | for (mwSize i=0; i<n; i++) | |
194 | { | |
195 | Real sumLLF1 = 0.0; | |
196 | Real maxLogGamIR = -INFINITY; | |
197 | for (mwSize r=0; r<k; r++) | |
198 | { | |
199 | //Compute | |
200 | //Gam(i,r) = Pi(r) * det(Rho(:,:,r)) * exp( -1/2 * (Y(i,:)*Rho(:,:,r) - X(i,:)... | |
201 | // *phi(:,:,r)) * transpose( Y(i,:)*Rho(:,:,r) - X(i,:)*phi(:,:,r) ) ); | |
202 | //split in several sub-steps | |
203 | ||
204 | //compute det(Rho(:,:,r)) [TODO: avoid re-computations] | |
205 | for (mwSize j=0; j<m; j++) | |
206 | { | |
207 | for (mwSize jj=0; jj<m; jj++) | |
208 | matrixE->data[j*m+jj] = Rho[j*m*k+jj*k+r]; | |
209 | } | |
210 | gsl_linalg_LU_decomp(matrixE, permutation, &signum); | |
211 | Real detRhoR = gsl_linalg_LU_det(matrixE, signum); | |
212 | ||
213 | //compute Y(i,:)*Rho(:,:,r) | |
214 | for (mwSize j=0; j<m; j++) | |
215 | { | |
216 | YiRhoR[j] = 0.0; | |
217 | for (mwSize u=0; u<m; u++) | |
218 | YiRhoR[j] += Y[i*m+u] * Rho[u*m*k+j*k+r]; | |
219 | } | |
220 | ||
221 | //compute X(i,:)*phi(:,:,r) | |
222 | for (mwSize j=0; j<m; j++) | |
223 | { | |
224 | XiPhiR[j] = 0.0; | |
225 | for (mwSize u=0; u<p; u++) | |
226 | XiPhiR[j] += X[i*p+u] * phi[u*m*k+j*k+r]; | |
227 | } | |
228 | ||
229 | //compute dotProduct < Y(:,i)*rho(:,:,r)-X(i,:)*phi(:,:,r) . Y(:,i)*rho(:,:,r)-X(i,:)*phi(:,:,r) > | |
230 | Real dotProduct = 0.0; | |
231 | for (mwSize u=0; u<m; u++) | |
232 | dotProduct += (YiRhoR[u]-XiPhiR[u]) * (YiRhoR[u]-XiPhiR[u]); | |
233 | Real logGamIR = log(Pi[r]) + log(detRhoR) - 0.5*dotProduct; | |
234 | ||
235 | //Z(i) = index of max (gam(i,:)) | |
236 | if (logGamIR > maxLogGamIR) | |
237 | { | |
238 | Z[i] = r; | |
239 | maxLogGamIR = logGamIR; | |
240 | } | |
241 | sumLLF1 += exp(logGamIR) / pow(2*M_PI,m/2.0); | |
242 | } | |
243 | ||
244 | sumLogLLF2 += log(sumLLF1); | |
245 | } | |
246 | ||
247 | // Assign output variable LLF | |
248 | *LLF = -invN * sumLogLLF2; | |
249 | ||
250 | //newDeltaPhi = max(max((abs(phi-Phi))./(1+abs(phi)))); | |
251 | Real newDeltaPhi = 0.0; | |
252 | for (mwSize j=0; j<p; j++) | |
253 | { | |
254 | for (mwSize jj=0; jj<m; jj++) | |
255 | { | |
256 | for (mwSize r=0; r<k; r++) | |
257 | { | |
258 | Real tmpDist = fabs(phi[j*m*k+jj*k+r]-Phi[j*m*k+jj*k+r]) | |
259 | / (1.0+fabs(phi[j*m*k+jj*k+r])); | |
260 | if (tmpDist > newDeltaPhi) | |
261 | newDeltaPhi = tmpDist; | |
262 | } | |
263 | } | |
264 | } | |
265 | ||
266 | //update distance parameter to check algorithm convergence (delta(phi, Phi)) | |
267 | //TODO: deltaPhi should be a linked list for perf. | |
268 | if (ite < deltaPhiBufferSize) | |
269 | deltaPhi[ite] = newDeltaPhi; | |
270 | else | |
271 | { | |
272 | sumDeltaPhi -= deltaPhi[0]; | |
273 | for (int u=0; u<deltaPhiBufferSize-1; u++) | |
274 | deltaPhi[u] = deltaPhi[u+1]; | |
275 | deltaPhi[deltaPhiBufferSize-1] = newDeltaPhi; | |
276 | } | |
277 | sumDeltaPhi += newDeltaPhi; | |
278 | ||
279 | // update other local variables | |
280 | for (mwSize j=0; j<m; j++) | |
281 | { | |
282 | for (mwSize jj=0; jj<p; jj++) | |
283 | { | |
284 | for (mwSize r=0; r<k; r++) | |
285 | Phi[j*m*k+jj*k+r] = phi[j*m*k+jj*k+r]; | |
286 | } | |
287 | } | |
288 | ite++; | |
289 | } | |
290 | ||
291 | //free memory | |
292 | free(hatBetaR); | |
293 | free(deltaPhi); | |
294 | free(Phi); | |
295 | gsl_matrix_free(matrixE); | |
296 | gsl_matrix_free(matrixM); | |
297 | gsl_permutation_free(permutation); | |
298 | gsl_vector_free(work); | |
299 | gsl_matrix_free(V); | |
300 | gsl_vector_free(S); | |
301 | free(XiPhiR); | |
302 | free(YiRhoR); | |
303 | free(Xr); | |
304 | free(Yr); | |
305 | free(tXrXr); | |
306 | free(tXrYr); | |
307 | free(Z); | |
308 | } |