3 #include <gsl/gsl_linalg.h>
5 // TODO: don't recompute indexes every time......
8 const Real
* phiInit
, // parametre initial de moyenne renormalisé
9 const Real
* rhoInit
, // parametre initial de variance renormalisé
10 const Real
* piInit
, // parametre initial des proportions
11 const Real
* gamInit
, // paramètre initial des probabilités a posteriori de chaque échantillon
12 int mini
, // nombre minimal d'itérations dans l'algorithme EM
13 int maxi
, // nombre maximal d'itérations dans l'algorithme EM
14 Real gamma
, // puissance des proportions dans la pénalisation pour un Lasso adaptatif
15 Real lambda
, // valeur du paramètre de régularisation du Lasso
16 const Real
* X
, // régresseurs
17 const Real
* Y
, // réponse
18 Real tau
, // seuil pour accepter la convergence
19 // OUT parameters (all pointers, to be modified)
20 Real
* phi
, // parametre de moyenne renormalisé, calculé par l'EM
21 Real
* rho
, // parametre de variance renormalisé, calculé par l'EM
22 Real
* pi
, // parametre des proportions renormalisé, calculé par l'EM
23 Real
* LLF
, // log vraisemblance associée à cet échantillon, pour les valeurs estimées des paramètres
25 // additional size parameters
26 int n
, // nombre d'echantillons
27 int p
, // nombre de covariables
28 int m
, // taille de Y (multivarié)
29 int k
) // nombre de composantes dans le mélange
32 copyArray(phiInit
, phi
, p
*m
*k
);
33 copyArray(rhoInit
, rho
, m
*m
*k
);
34 copyArray(piInit
, pi
, k
);
36 //S is already allocated, and doesn't need to be 'zeroed'
38 //Other local variables
39 Real
* gam
= (Real
*)malloc(n
*k
*sizeof(Real
));
40 copyArray(gamInit
, gam
, n
*k
);
41 Real
* b
= (Real
*)malloc(k
*sizeof(Real
));
42 Real
* Phi
= (Real
*)malloc(p
*m
*k
*sizeof(Real
));
43 Real
* Rho
= (Real
*)malloc(m
*m
*k
*sizeof(Real
));
44 Real
* Pi
= (Real
*)malloc(k
*sizeof(Real
));
45 Real
* gam2
= (Real
*)malloc(k
*sizeof(Real
));
46 Real
* pi2
= (Real
*)malloc(k
*sizeof(Real
));
47 Real
* Gram2
= (Real
*)malloc(p
*p
*k
*sizeof(Real
));
48 Real
* ps
= (Real
*)malloc(m
*k
*sizeof(Real
));
49 Real
* nY2
= (Real
*)malloc(m
*k
*sizeof(Real
));
50 Real
* ps1
= (Real
*)malloc(n
*m
*k
*sizeof(Real
));
51 Real
* ps2
= (Real
*)malloc(p
*m
*k
*sizeof(Real
));
52 Real
* nY21
= (Real
*)malloc(n
*m
*k
*sizeof(Real
));
53 Real
* Gam
= (Real
*)malloc(n
*k
*sizeof(Real
));
54 Real
* X2
= (Real
*)malloc(n
*p
*k
*sizeof(Real
));
55 Real
* Y2
= (Real
*)malloc(n
*m
*k
*sizeof(Real
));
56 Real
* sqNorm2
= (Real
*)malloc(k
*sizeof(Real
));
57 gsl_matrix
* matrix
= gsl_matrix_alloc(m
, m
);
58 gsl_permutation
* permutation
= gsl_permutation_alloc(m
);
59 Real
* YiRhoR
= (Real
*)malloc(m
*sizeof(Real
));
60 Real
* XiPhiR
= (Real
*)malloc(m
*sizeof(Real
));
64 const Real EPS
= 1e-15;
65 const Real gaussConstM
= pow(2.*M_PI
,m
/2.);
67 while (ite
< mini
|| (ite
< maxi
&& (dist
>= tau
|| dist2
>= sqrt(tau
))))
69 copyArray(phi
, Phi
, p
*m
*k
);
70 copyArray(rho
, Rho
, m
*m
*k
);
73 // Calculs associés a Y et X
74 for (int r
=0; r
<k
; r
++)
76 for (int mm
=0; mm
<m
; mm
++)
78 //Y2[,mm,r] = sqrt(gam[,r]) * Y[,mm]
79 for (int u
=0; u
<n
; u
++)
80 Y2
[ai(u
,mm
,r
,n
,m
,k
)] = sqrt(gam
[mi(u
,r
,n
,k
)]) * Y
[mi(u
,mm
,m
,n
)];
82 for (int i
=0; i
<n
; i
++)
84 //X2[i,,r] = sqrt(gam[i,r]) * X[i,]
85 for (int u
=0; u
<p
; u
++)
86 X2
[ai(i
,u
,r
,n
,p
,k
)] = sqrt(gam
[mi(i
,r
,n
,k
)]) * X
[mi(i
,u
,n
,p
)];
88 for (int mm
=0; mm
<m
; mm
++)
90 //ps2[,mm,r] = crossprod(X2[,,r],Y2[,mm,r])
91 for (int u
=0; u
<p
; u
++)
94 for (int v
=0; v
<n
; v
++)
95 dotProduct
+= X2
[ai(v
,u
,r
,n
,p
,k
)] * Y2
[ai(v
,mm
,r
,n
,m
,k
)];
96 ps2
[ai(u
,mm
,r
,p
,m
,k
)] = dotProduct
;
99 for (int j
=0; j
<p
; j
++)
101 for (int s
=0; s
<p
; s
++)
103 //Gram2[j,s,r] = crossprod(X2[,j,r], X2[,s,r])
104 Real dotProduct
= 0.;
105 for (int u
=0; u
<n
; u
++)
106 dotProduct
+= X2
[ai(u
,j
,r
,n
,p
,k
)] * X2
[ai(u
,s
,r
,n
,p
,k
)];
107 Gram2
[ai(j
,s
,r
,p
,p
,k
)] = dotProduct
;
117 for (int r
=0; r
<k
; r
++)
119 //b[r] = sum(abs(phi[,,r]))
121 for (int u
=0; u
<p
; u
++)
122 for (int v
=0; v
<m
; v
++)
123 sumAbsPhi
+= fabs(phi
[ai(u
,v
,r
,p
,m
,k
)]);
126 //gam2 = colSums(gam)
127 for (int u
=0; u
<k
; u
++)
129 Real sumOnColumn
= 0.;
130 for (int v
=0; v
<n
; v
++)
131 sumOnColumn
+= gam
[mi(v
,u
,n
,k
)];
132 gam2
[u
] = sumOnColumn
;
134 //a = sum(gam %*% log(pi))
136 for (int u
=0; u
<n
; u
++)
138 Real dotProduct
= 0.;
139 for (int v
=0; v
<k
; v
++)
140 dotProduct
+= gam
[mi(u
,v
,n
,k
)] * log(pi
[v
]);
144 //tant que les proportions sont negatives
146 int pi2AllPositive
= 0;
148 while (!pi2AllPositive
)
150 //pi2 = pi + 0.1^kk * ((1/n)*gam2 - pi)
151 Real pow_01_kk
= pow(0.1,kk
);
152 for (int r
=0; r
<k
; r
++)
153 pi2
[r
] = pi
[r
] + pow_01_kk
* (invN
*gam2
[r
] - pi
[r
]);
154 //pi2AllPositive = all(pi2 >= 0)
156 for (int r
=0; r
<k
; r
++)
168 Real piPowGammaDotB
= 0.;
169 for (int v
=0; v
<k
; v
++)
170 piPowGammaDotB
+= pow(pi
[v
],gamma
) * b
[v
];
172 Real pi2PowGammaDotB
= 0.;
173 for (int v
=0; v
<k
; v
++)
174 pi2PowGammaDotB
+= pow(pi2
[v
],gamma
) * b
[v
];
175 //transpose(gam2)*log(pi2)
176 Real prodGam2logPi2
= 0.;
177 for (int v
=0; v
<k
; v
++)
178 prodGam2logPi2
+= gam2
[v
] * log(pi2
[v
]);
179 //t(m) la plus grande valeur dans la grille O.1^k tel que ce soit décroissante ou constante
180 while (-invN
*a
+ lambda
*piPowGammaDotB
< -invN
*prodGam2logPi2
+ lambda
*pi2PowGammaDotB
183 Real pow_01_kk
= pow(0.1,kk
);
184 //pi2 = pi + 0.1^kk * (1/n*gam2 - pi)
185 for (int v
=0; v
<k
; v
++)
186 pi2
[v
] = pi
[v
] + pow_01_kk
* (invN
*gam2
[v
] - pi
[v
]);
187 //pi2 was updated, so we recompute pi2PowGammaDotB and prodGam2logPi2
188 pi2PowGammaDotB
= 0.;
189 for (int v
=0; v
<k
; v
++)
190 pi2PowGammaDotB
+= pow(pi2
[v
],gamma
) * b
[v
];
192 for (int v
=0; v
<k
; v
++)
193 prodGam2logPi2
+= gam2
[v
] * log(pi2
[v
]);
196 Real t
= pow(0.1,kk
);
197 //sum(pi + t*(pi2-pi))
198 Real sumPiPlusTbyDiff
= 0.;
199 for (int v
=0; v
<k
; v
++)
200 sumPiPlusTbyDiff
+= (pi
[v
] + t
*(pi2
[v
] - pi
[v
]));
201 //pi = (pi + t*(pi2-pi)) / sum(pi + t*(pi2-pi))
202 for (int v
=0; v
<k
; v
++)
203 pi
[v
] = (pi
[v
] + t
*(pi2
[v
] - pi
[v
])) / sumPiPlusTbyDiff
;
206 for (int r
=0; r
<k
; r
++)
208 for (int mm
=0; mm
<m
; mm
++)
210 for (int i
=0; i
<n
; i
++)
212 //< X2(i,:,r) , phi(:,mm,r) >
213 Real dotProduct
= 0.;
214 for (int u
=0; u
<p
; u
++)
215 dotProduct
+= X2
[ai(i
,u
,r
,n
,p
,k
)] * phi
[ai(u
,mm
,r
,p
,m
,k
)];
216 //ps1[i,mm,r] = Y2[i,mm,r] * sum(X2[i,,r] * phi[,mm,r])
217 ps1
[ai(i
,mm
,r
,n
,m
,k
)] = Y2
[ai(i
,mm
,r
,n
,m
,k
)] * dotProduct
;
218 nY21
[ai(i
,mm
,r
,n
,m
,k
)] = Y2
[ai(i
,mm
,r
,n
,m
,k
)] * Y2
[ai(i
,mm
,r
,n
,m
,k
)];
220 //ps[mm,r] = sum(ps1[,mm,r])
222 for (int u
=0; u
<n
; u
++)
223 sumPs1
+= ps1
[ai(u
,mm
,r
,n
,m
,k
)];
224 ps
[mi(mm
,r
,m
,k
)] = sumPs1
;
225 //nY2[mm,r] = sum(nY21[,mm,r])
227 for (int u
=0; u
<n
; u
++)
228 sumNy21
+= nY21
[ai(u
,mm
,r
,n
,m
,k
)];
229 nY2
[mi(mm
,r
,m
,k
)] = sumNy21
;
230 //rho[mm,mm,r] = (ps[mm,r]+sqrt(ps[mm,r]^2+4*nY2[mm,r]*(gam2[r]))) / (2*nY2[mm,r])
231 rho
[ai(mm
,mm
,r
,m
,m
,k
)] = ( ps
[mi(mm
,r
,m
,k
)] + sqrt( ps
[mi(mm
,r
,m
,k
)]*ps
[mi(mm
,r
,m
,k
)]
232 + 4*nY2
[mi(mm
,r
,m
,k
)] * gam2
[r
] ) ) / (2*nY2
[mi(mm
,r
,m
,k
)]);
235 for (int r
=0; r
<k
; r
++)
237 for (int j
=0; j
<p
; j
++)
239 for (int mm
=0; mm
<m
; mm
++)
241 //sum(phi[1:(j-1),mm,r] * Gram2[j,1:(j-1),r])
242 Real dotPhiGram2
= 0.0;
243 for (int u
=0; u
<j
; u
++)
244 dotPhiGram2
+= phi
[ai(u
,mm
,r
,p
,m
,k
)] * Gram2
[ai(j
,u
,r
,p
,p
,k
)];
245 //sum(phi[(j+1):p,mm,r] * Gram2[j,(j+1):p,r])
246 for (int u
=j
+1; u
<p
; u
++)
247 dotPhiGram2
+= phi
[ai(u
,mm
,r
,p
,m
,k
)] * Gram2
[ai(j
,u
,r
,p
,p
,k
)];
248 //S[j,mm,r] = -rho[mm,mm,r]*ps2[j,mm,r] +
249 // (if(j>1) sum(phi[1:(j-1),mm,r] * Gram2[j,1:(j-1),r]) else 0) +
250 // (if(j<p) sum(phi[(j+1):p,mm,r] * Gram2[j,(j+1):p,r]) else 0)
251 S
[ai(j
,mm
,r
,p
,m
,k
)] = -rho
[ai(mm
,mm
,r
,m
,m
,k
)] * ps2
[ai(j
,mm
,r
,p
,m
,k
)] + dotPhiGram2
;
252 Real pow_pir_gamma
= pow(pi
[r
],gamma
);
253 if (fabs(S
[ai(j
,mm
,r
,p
,m
,k
)]) <= n
*lambda
*pow_pir_gamma
)
254 phi
[ai(j
,mm
,r
,p
,m
,k
)] = 0;
255 else if (S
[ai(j
,mm
,r
,p
,m
,k
)] > n
*lambda
*pow_pir_gamma
)
257 phi
[ai(j
,mm
,r
,p
,m
,k
)] = (n
*lambda
*pow_pir_gamma
- S
[ai(j
,mm
,r
,p
,m
,k
)])
258 / Gram2
[ai(j
,j
,r
,p
,p
,k
)];
262 phi
[ai(j
,mm
,r
,p
,m
,k
)] = -(n
*lambda
*pow_pir_gamma
+ S
[ai(j
,mm
,r
,p
,m
,k
)])
263 / Gram2
[ai(j
,j
,r
,p
,p
,k
)];
274 Real sumLogLLF2
= 0.0;
275 for (int i
=0; i
<n
; i
++)
277 Real minSqNorm2
= INFINITY
;
279 for (int r
=0; r
<k
; r
++)
281 //compute Y[i,]%*%rho[,,r]
282 for (int u
=0; u
<m
; u
++)
285 for (int v
=0; v
<m
; v
++)
286 YiRhoR
[u
] += Y
[mi(i
,v
,n
,m
)] * rho
[ai(v
,u
,r
,m
,m
,k
)];
289 //compute X(i,:)*phi(:,:,r)
290 for (int u
=0; u
<m
; u
++)
293 for (int v
=0; v
<p
; v
++)
294 XiPhiR
[u
] += X
[mi(i
,v
,n
,p
)] * phi
[ai(v
,u
,r
,p
,m
,k
)];
297 //compute sq norm || Y(:,i)*rho(:,:,r)-X(i,:)*phi(:,:,r) ||_2^2
299 for (int u
=0; u
<m
; u
++)
300 sqNorm2
[r
] += (YiRhoR
[u
]-XiPhiR
[u
]) * (YiRhoR
[u
]-XiPhiR
[u
]);
301 if (sqNorm2
[r
] < minSqNorm2
)
302 minSqNorm2
= sqNorm2
[r
];
304 Real shift
= 0.5*minSqNorm2
;
308 for (int r
=0; r
<k
; r
++)
310 //compute det(rho[,,r]) [TODO: avoid re-computations]
311 for (int u
=0; u
<m
; u
++)
313 for (int v
=0; v
<m
; v
++)
314 matrix
->data
[u
*m
+v
] = rho
[ai(u
,v
,r
,m
,m
,k
)];
316 gsl_linalg_LU_decomp(matrix
, permutation
, &signum
);
317 Real detRhoR
= gsl_linalg_LU_det(matrix
, signum
);
319 //FIXME: det(rho[,,r]) too small(?!). See EMGLLF.R
320 Gam
[mi(i
,r
,n
,k
)] = pi
[r
] * exp(-0.5*sqNorm2
[r
] + shift
) ; //* detRhoR;
321 sumLLF1
+= Gam
[mi(i
,r
,n
,k
)] / gaussConstM
;
322 sumGamI
+= Gam
[mi(i
,r
,n
,k
)];
324 sumLogLLF2
+= log(sumLLF1
);
325 for (int r
=0; r
<k
; r
++)
327 //gam[i,] = Gam[i,] / sumGamI
328 gam
[mi(i
,r
,n
,k
)] = sumGamI
> EPS
? Gam
[mi(i
,r
,n
,k
)] / sumGamI
: 0.;
332 //sumPen = sum(pi^gamma * b)
334 for (int r
=0; r
<k
; r
++)
335 sumPen
+= pow(pi
[r
],gamma
) * b
[r
];
336 //LLF[ite] = -sumLogLLF2/n + lambda*sumPen
337 LLF
[ite
] = -invN
* sumLogLLF2
+ lambda
* sumPen
;
338 dist
= ite
==0 ? LLF
[ite
] : (LLF
[ite
] - LLF
[ite
-1]) / (1.0 + fabs(LLF
[ite
]));
340 //Dist1 = max( abs(phi-Phi) / (1+abs(phi)) )
342 for (int u
=0; u
<p
; u
++)
344 for (int v
=0; v
<m
; v
++)
346 for (int w
=0; w
<k
; w
++)
348 Real tmpDist
= fabs(phi
[ai(u
,v
,w
,p
,m
,k
)]-Phi
[ai(u
,v
,w
,p
,m
,k
)])
349 / (1.0+fabs(phi
[ai(u
,v
,w
,p
,m
,k
)]));
355 //Dist2 = max( (abs(rho-Rho)) / (1+abs(rho)) )
357 for (int u
=0; u
<m
; u
++)
359 for (int v
=0; v
<m
; v
++)
361 for (int w
=0; w
<k
; w
++)
363 Real tmpDist
= fabs(rho
[ai(u
,v
,w
,m
,m
,k
)]-Rho
[ai(u
,v
,w
,m
,m
,k
)])
364 / (1.0+fabs(rho
[ai(u
,v
,w
,m
,m
,k
)]));
370 //Dist3 = max( (abs(pi-Pi)) / (1+abs(Pi)))
372 for (int u
=0; u
<n
; u
++)
374 for (int v
=0; v
<k
; v
++)
376 Real tmpDist
= fabs(pi
[v
]-Pi
[v
]) / (1.0+fabs(pi
[v
]));
381 //dist2=max([max(Dist1),max(Dist2),max(Dist3)]);
404 gsl_matrix_free(matrix
);
405 gsl_permutation_free(permutation
);