42419acd24ab2a8c3e524e0c5b635298c01629a0
3 #include <gsl/gsl_linalg.h>
5 // TODO: don't recompute indexes every time......
8 const double* phiInit
, // parametre initial de moyenne renormalisé
9 const double* rhoInit
, // parametre initial de variance renormalisé
10 const double* piInit
, // parametre initial des proportions
11 const double* gamInit
, // paramètre initial des probabilités a posteriori de chaque échantillon
12 int mini
, // nombre minimal d'itérations dans l'algorithme EM
13 int maxi
, // nombre maximal d'itérations dans l'algorithme EM
14 double gamma
, // puissance des proportions dans la pénalisation pour un Lasso adaptatif
15 double lambda
, // valeur du paramètre de régularisation du Lasso
16 const double* X
, // régresseurs
17 const double* Y
, // réponse
18 double tau
, // seuil pour accepter la convergence
19 // OUT parameters (all pointers, to be modified)
20 double* phi
, // parametre de moyenne renormalisé, calculé par l'EM
21 double* rho
, // parametre de variance renormalisé, calculé par l'EM
22 double* pi
, // parametre des proportions renormalisé, calculé par l'EM
23 double* LLF
, // log vraisemblance associée à cet échantillon, pour les valeurs estimées des paramètres
25 // additional size parameters
26 int n
, // nombre d'echantillons
27 int p
, // nombre de covariables
28 int m
, // taille de Y (multivarié)
29 int k
) // nombre de composantes dans le mélange
32 copyArray(phiInit
, phi
, p
*m
*k
);
33 copyArray(rhoInit
, rho
, m
*m
*k
);
34 copyArray(piInit
, pi
, k
);
36 //S is already allocated, and doesn't need to be 'zeroed'
38 //Other local variables
39 //NOTE: variables order is always [maxi],n,p,m,k
40 double* gam
= (double*)malloc(n
*k
*sizeof(double));
41 copyArray(gamInit
, gam
, n
*k
);
42 double* b
= (double*)malloc(k
*sizeof(double));
43 double* Phi
= (double*)malloc(p
*m
*k
*sizeof(double));
44 double* Rho
= (double*)malloc(m
*m
*k
*sizeof(double));
45 double* Pi
= (double*)malloc(k
*sizeof(double));
46 double* gam2
= (double*)malloc(k
*sizeof(double));
47 double* pi2
= (double*)malloc(k
*sizeof(double));
48 double* Gram2
= (double*)malloc(p
*p
*k
*sizeof(double));
49 double* ps
= (double*)malloc(m
*k
*sizeof(double));
50 double* nY2
= (double*)malloc(m
*k
*sizeof(double));
51 double* ps1
= (double*)malloc(n
*m
*k
*sizeof(double));
52 double* ps2
= (double*)malloc(p
*m
*k
*sizeof(double));
53 double* nY21
= (double*)malloc(n
*m
*k
*sizeof(double));
54 double* Gam
= (double*)malloc(n
*k
*sizeof(double));
55 double* X2
= (double*)malloc(n
*p
*k
*sizeof(double));
56 double* Y2
= (double*)malloc(n
*m
*k
*sizeof(double));
57 gsl_matrix
* matrix
= gsl_matrix_alloc(m
, m
);
58 gsl_permutation
* permutation
= gsl_permutation_alloc(m
);
59 double* YiRhoR
= (double*)malloc(m
*sizeof(double));
60 double* XiPhiR
= (double*)malloc(m
*sizeof(double));
65 double* dotProducts
= (double*)malloc(k
*sizeof(double));
67 while (ite
< mini
|| (ite
< maxi
&& (dist
>= tau
|| dist2
>= sqrt(tau
))))
69 copyArray(phi
, Phi
, p
*m
*k
);
70 copyArray(rho
, Rho
, m
*m
*k
);
73 // Calculs associés a Y et X
74 for (int r
=0; r
<k
; r
++)
76 for (int mm
=0; mm
<m
; mm
++)
78 //Y2(:,mm,r)=sqrt(gam(:,r)).*transpose(Y(mm,:));
79 for (int u
=0; u
<n
; u
++)
80 Y2
[ai(u
,mm
,r
,n
,m
,k
)] = sqrt(gam
[mi(u
,r
,n
,k
)]) * Y
[mi(u
,mm
,m
,n
)];
82 for (int i
=0; i
<n
; i
++)
84 //X2(i,:,r)=X(i,:).*sqrt(gam(i,r));
85 for (int u
=0; u
<p
; u
++)
86 X2
[ai(i
,u
,r
,n
,m
,k
)] = sqrt(gam
[mi(i
,r
,n
,k
)]) * X
[mi(i
,u
,n
,p
)];
88 for (int mm
=0; mm
<m
; mm
++)
90 //ps2(:,mm,r)=transpose(X2(:,:,r))*Y2(:,mm,r);
91 for (int u
=0; u
<p
; u
++)
93 double dotProduct
= 0.;
94 for (int v
=0; v
<n
; v
++)
95 dotProduct
+= X2
[ai(v
,u
,r
,n
,m
,k
)] * Y2
[ai(v
,mm
,r
,n
,m
,k
)];
96 ps2
[ai(u
,mm
,r
,n
,m
,k
)] = dotProduct
;
99 for (int j
=0; j
<p
; j
++)
101 for (int s
=0; s
<p
; s
++)
103 //Gram2(j,s,r)=transpose(X2(:,j,r))*(X2(:,s,r));
104 double dotProduct
= 0.;
105 for (int u
=0; u
<n
; u
++)
106 dotProduct
+= X2
[ai(u
,j
,r
,n
,p
,k
)] * X2
[ai(u
,s
,r
,n
,p
,k
)];
107 Gram2
[ai(j
,s
,r
,p
,p
,k
)] = dotProduct
;
117 for (int r
=0; r
<k
; r
++)
119 //b(r) = sum(sum(abs(phi(:,:,r))));
120 double sumAbsPhi
= 0.;
121 for (int u
=0; u
<p
; u
++)
122 for (int v
=0; v
<m
; v
++)
123 sumAbsPhi
+= fabs(phi
[ai(u
,v
,r
,p
,m
,k
)]);
127 for (int u
=0; u
<k
; u
++)
129 double sumOnColumn
= 0.;
130 for (int v
=0; v
<n
; v
++)
131 sumOnColumn
+= gam
[mi(v
,u
,n
,k
)];
132 gam2
[u
] = sumOnColumn
;
134 //a=sum(gam*transpose(log(pi)));
136 for (int u
=0; u
<n
; u
++)
138 double dotProduct
= 0.;
139 for (int v
=0; v
<k
; v
++)
140 dotProduct
+= gam
[mi(u
,v
,n
,k
)] * log(pi
[v
]);
144 //tant que les proportions sont negatives
146 int pi2AllPositive
= 0;
148 while (!pi2AllPositive
)
150 //pi2(:)=pi(:)+0.1^kk*(1/n*gam2(:)-pi(:));
151 for (int r
=0; r
<k
; r
++)
152 pi2
[r
] = pi
[r
] + pow(0.1,kk
) * (invN
*gam2
[r
] - pi
[r
]);
154 for (int r
=0; r
<k
; r
++)
165 //t(m) la plus grande valeur dans la grille O.1^k tel que ce soit décroissante ou constante
167 double piPowGammaDotB
= 0.;
168 for (int v
=0; v
<k
; v
++)
169 piPowGammaDotB
+= pow(pi
[v
],gamma
) * b
[v
];
171 double pi2PowGammaDotB
= 0.;
172 for (int v
=0; v
<k
; v
++)
173 pi2PowGammaDotB
+= pow(pi2
[v
],gamma
) * b
[v
];
174 //transpose(gam2)*log(pi2)
175 double prodGam2logPi2
= 0.;
176 for (int v
=0; v
<k
; v
++)
177 prodGam2logPi2
+= gam2
[v
] * log(pi2
[v
]);
178 while (-invN
*a
+ lambda
*piPowGammaDotB
< -invN
*prodGam2logPi2
+ lambda
*pi2PowGammaDotB
181 //pi2=pi+0.1^kk*(1/n*gam2-pi);
182 for (int v
=0; v
<k
; v
++)
183 pi2
[v
] = pi
[v
] + pow(0.1,kk
) * (invN
*gam2
[v
] - pi
[v
]);
184 //pi2 was updated, so we recompute pi2PowGammaDotB and prodGam2logPi2
185 pi2PowGammaDotB
= 0.;
186 for (int v
=0; v
<k
; v
++)
187 pi2PowGammaDotB
+= pow(pi2
[v
],gamma
) * b
[v
];
189 for (int v
=0; v
<k
; v
++)
190 prodGam2logPi2
+= gam2
[v
] * log(pi2
[v
]);
193 double t
= pow(0.1,kk
);
195 double sumPiPlusTbyDiff
= 0.;
196 for (int v
=0; v
<k
; v
++)
197 sumPiPlusTbyDiff
+= (pi
[v
] + t
*(pi2
[v
] - pi
[v
]));
198 //pi=(pi+t*(pi2-pi))/sum(pi+t*(pi2-pi));
199 for (int v
=0; v
<k
; v
++)
200 pi
[v
] = (pi
[v
] + t
*(pi2
[v
] - pi
[v
])) / sumPiPlusTbyDiff
;
203 for (int r
=0; r
<k
; r
++)
205 for (int mm
=0; mm
<m
; mm
++)
207 for (int i
=0; i
<n
; i
++)
209 //< X2(i,:,r) , phi(:,mm,r) >
210 double dotProduct
= 0.0;
211 for (int u
=0; u
<p
; u
++)
212 dotProduct
+= X2
[ai(i
,u
,r
,n
,p
,k
)] * phi
[ai(u
,mm
,r
,n
,m
,k
)];
213 //ps1(i,mm,r)=Y2(i,mm,r)*dot(X2(i,:,r),phi(:,mm,r));
214 ps1
[ai(i
,mm
,r
,n
,m
,k
)] = Y2
[ai(i
,mm
,r
,n
,m
,k
)] * dotProduct
;
215 nY21
[ai(i
,mm
,r
,n
,m
,k
)] = Y2
[ai(i
,mm
,r
,n
,m
,k
)] * Y2
[ai(i
,mm
,r
,n
,m
,k
)];
217 //ps(mm,r)=sum(ps1(:,mm,r));
219 for (int u
=0; u
<n
; u
++)
220 sumPs1
+= ps1
[ai(u
,mm
,r
,n
,m
,k
)];
221 ps
[mi(mm
,r
,m
,k
)] = sumPs1
;
222 //nY2(mm,r)=sum(nY21(:,mm,r));
223 double sumNy21
= 0.0;
224 for (int u
=0; u
<n
; u
++)
225 sumNy21
+= nY21
[ai(u
,mm
,r
,n
,m
,k
)];
226 nY2
[mi(mm
,r
,m
,k
)] = sumNy21
;
227 //rho(mm,mm,r)=((ps(mm,r)+sqrt(ps(mm,r)^2+4*nY2(mm,r)*(gam2(r))))/(2*nY2(mm,r)));
228 rho
[ai(mm
,mm
,k
,m
,m
,k
)] = ( ps
[mi(mm
,r
,m
,k
)] + sqrt( ps
[mi(mm
,r
,m
,k
)]*ps
[mi(mm
,r
,m
,k
)]
229 + 4*nY2
[mi(mm
,r
,m
,k
)] * (gam2
[r
]) ) ) / (2*nY2
[mi(mm
,r
,m
,k
)]);
232 for (int r
=0; r
<k
; r
++)
234 for (int j
=0; j
<p
; j
++)
236 for (int mm
=0; mm
<m
; mm
++)
238 //sum(phi(1:j-1,mm,r).*transpose(Gram2(j,1:j-1,r)))+sum(phi(j+1:p,mm,r)
239 // .*transpose(Gram2(j,j+1:p,r)))
240 double dotPhiGram2
= 0.0;
241 for (int u
=0; u
<j
; u
++)
242 dotPhiGram2
+= phi
[ai(u
,mm
,r
,p
,m
,k
)] * Gram2
[ai(j
,u
,r
,p
,p
,k
)];
243 for (int u
=j
+1; u
<p
; u
++)
244 dotPhiGram2
+= phi
[ai(u
,mm
,r
,p
,m
,k
)] * Gram2
[ai(j
,u
,r
,p
,p
,k
)];
245 //S(j,r,mm)=-rho(mm,mm,r)*ps2(j,mm,r)+sum(phi(1:j-1,mm,r).*transpose(Gram2(j,1:j-1,r)))
246 // +sum(phi(j+1:p,mm,r).*transpose(Gram2(j,j+1:p,r)));
247 S
[ai(j
,mm
,r
,p
,m
,k
)] = -rho
[ai(mm
,mm
,r
,m
,m
,k
)] * ps2
[ai(j
,mm
,r
,p
,m
,k
)] + dotPhiGram2
;
248 if (fabs(S
[ai(j
,mm
,r
,p
,m
,k
)]) <= n
*lambda
*pow(pi
[r
],gamma
))
249 phi
[ai(j
,mm
,r
,p
,m
,k
)] = 0;
250 else if (S
[ai(j
,mm
,r
,p
,m
,k
)] > n
*lambda
*pow(pi
[r
],gamma
))
251 phi
[ai(j
,mm
,r
,p
,m
,k
)] = (n
*lambda
*pow(pi
[r
],gamma
) - S
[ai(j
,mm
,r
,p
,m
,k
)])
252 / Gram2
[ai(j
,j
,r
,p
,p
,k
)];
254 phi
[ai(j
,mm
,r
,p
,m
,k
)] = -(n
*lambda
*pow(pi
[r
],gamma
) + S
[ai(j
,mm
,r
,p
,m
,k
)])
255 / Gram2
[ai(j
,j
,r
,p
,p
,k
)];
265 double sumLogLLF2
= 0.0;
266 for (int i
=0; i
<n
; i
++)
268 double sumLLF1
= 0.0;
269 double sumGamI
= 0.0;
270 double minDotProduct
= INFINITY
;
272 for (int r
=0; r
<k
; r
++)
275 //Gam(i,r) = Pi(r) * det(Rho(:,:,r)) * exp( -1/2 * (Y(i,:)*Rho(:,:,r) - X(i,:)...
276 // *phi(:,:,r)) * transpose( Y(i,:)*Rho(:,:,r) - X(i,:)*phi(:,:,r) ) );
277 //split in several sub-steps
279 //compute Y(i,:)*rho(:,:,r)
280 for (int u
=0; u
<m
; u
++)
283 for (int v
=0; v
<m
; v
++)
284 YiRhoR
[u
] += Y
[mi(i
,v
,n
,m
)] * rho
[ai(v
,u
,r
,m
,m
,k
)];
287 //compute X(i,:)*phi(:,:,r)
288 for (int u
=0; u
<m
; u
++)
291 for (int v
=0; v
<p
; v
++)
292 XiPhiR
[u
] += X
[mi(i
,v
,n
,p
)] * phi
[ai(v
,u
,r
,p
,m
,k
)];
296 // < Y(:,i)*rho(:,:,r)-X(i,:)*phi(:,:,r) . Y(:,i)*rho(:,:,r)-X(i,:)*phi(:,:,r) >
297 dotProducts
[r
] = 0.0;
298 for (int u
=0; u
<m
; u
++)
299 dotProducts
[r
] += (YiRhoR
[u
]-XiPhiR
[u
]) * (YiRhoR
[u
]-XiPhiR
[u
]);
300 if (dotProducts
[r
] < minDotProduct
)
301 minDotProduct
= dotProducts
[r
];
303 double shift
= 0.5*minDotProduct
;
304 for (int r
=0; r
<k
; r
++)
306 //compute det(rho(:,:,r)) [TODO: avoid re-computations]
307 for (int u
=0; u
<m
; u
++)
309 for (int v
=0; v
<m
; v
++)
310 matrix
->data
[u
*m
+v
] = rho
[ai(u
,v
,r
,m
,m
,k
)];
312 gsl_linalg_LU_decomp(matrix
, permutation
, &signum
);
313 double detRhoR
= gsl_linalg_LU_det(matrix
, signum
);
315 Gam
[mi(i
,r
,n
,k
)] = pi
[r
] * detRhoR
* exp(-0.5*dotProducts
[r
] + shift
);
316 sumLLF1
+= Gam
[mi(i
,r
,n
,k
)] / pow(2*M_PI
,m
/2.0);
317 sumGamI
+= Gam
[mi(i
,r
,n
,k
)];
319 sumLogLLF2
+= log(sumLLF1
);
320 for (int r
=0; r
<k
; r
++)
322 //gam(i,r)=Gam(i,r)/sum(Gam(i,:));
323 gam
[mi(i
,r
,n
,k
)] = sumGamI
> EPS
324 ? Gam
[mi(i
,r
,n
,k
)] / sumGamI
331 for (int r
=0; r
<k
; r
++)
332 sumPen
+= pow(pi
[r
],gamma
) * b
[r
];
333 //LLF(ite)=-1/n*sum(log(LLF2(ite,:)))+lambda*sum(pen(ite,:));
334 LLF
[ite
] = -invN
* sumLogLLF2
+ lambda
* sumPen
;
338 dist
= (LLF
[ite
] - LLF
[ite
-1]) / (1.0 + fabs(LLF
[ite
]));
340 //Dist1=max(max((abs(phi-Phi))./(1+abs(phi))));
342 for (int u
=0; u
<p
; u
++)
344 for (int v
=0; v
<m
; v
++)
346 for (int w
=0; w
<k
; w
++)
348 double tmpDist
= fabs(phi
[ai(u
,v
,w
,p
,m
,k
)]-Phi
[ai(u
,v
,w
,p
,m
,k
)])
349 / (1.0+fabs(phi
[ai(u
,v
,w
,p
,m
,k
)]));
355 //Dist2=max(max((abs(rho-Rho))./(1+abs(rho))));
357 for (int u
=0; u
<m
; u
++)
359 for (int v
=0; v
<m
; v
++)
361 for (int w
=0; w
<k
; w
++)
363 double tmpDist
= fabs(rho
[ai(u
,v
,w
,m
,m
,k
)]-Rho
[ai(u
,v
,w
,m
,m
,k
)])
364 / (1.0+fabs(rho
[ai(u
,v
,w
,m
,m
,k
)]));
370 //Dist3=max(max((abs(pi-Pi))./(1+abs(Pi))));
372 for (int u
=0; u
<n
; u
++)
374 for (int v
=0; v
<k
; v
++)
376 double tmpDist
= fabs(pi
[v
]-Pi
[v
]) / (1.0+fabs(pi
[v
]));
381 //dist2=max([max(Dist1),max(Dist2),max(Dist3)]);
404 gsl_matrix_free(matrix
);
405 gsl_permutation_free(permutation
);