2 #include <gsl/gsl_linalg.h>
4 // TODO: comment on EMGLLF purpose
7 const Real
* phiInit
, // parametre initial de moyenne renormalisé
8 const Real
* rhoInit
, // parametre initial de variance renormalisé
9 const Real
* piInit
, // parametre initial des proportions
10 const Real
* gamInit
, // paramètre initial des probabilités a posteriori de chaque échantillon
11 Int mini
, // nombre minimal d'itérations dans l'algorithme EM
12 Int maxi
, // nombre maximal d'itérations dans l'algorithme EM
13 Real gamma
, // valeur de gamma : puissance des proportions dans la pénalisation pour un Lasso adaptatif
14 Real lambda
, // valeur du paramètre de régularisation du Lasso
15 const Real
* X
, // régresseurs
16 const Real
* Y
, // réponse
17 Real tau
, // seuil pour accepter la convergence
18 // OUT parameters (all pointers, to be modified)
19 Real
* phi
, // parametre de moyenne renormalisé, calculé par l'EM
20 Real
* rho
, // parametre de variance renormalisé, calculé par l'EM
21 Real
* pi
, // parametre des proportions renormalisé, calculé par l'EM
22 Real
* LLF
, // log vraisemblance associé à cet échantillon, pour les valeurs estimées des paramètres
24 // additional size parameters
25 mwSize n
, // nombre d'echantillons
26 mwSize p
, // nombre de covariables
27 mwSize m
, // taille de Y (multivarié)
28 mwSize k
) // nombre de composantes dans le mélange
31 copyArray(phiInit
, phi
, p
*m
*k
);
32 copyArray(rhoInit
, rho
, m
*m
*k
);
33 copyArray(piInit
, pi
, k
);
35 //S is already allocated, and doesn't need to be 'zeroed'
37 //Other local variables
38 //NOTE: variables order is always [maxi],n,p,m,k
39 Real
* gam
= (Real
*)malloc(n
*k
*sizeof(Real
));
40 copyArray(gamInit
, gam
, n
*k
);
41 Real
* b
= (Real
*)malloc(k
*sizeof(Real
));
42 Real
* Phi
= (Real
*)malloc(p
*m
*k
*sizeof(Real
));
43 Real
* Rho
= (Real
*)malloc(m
*m
*k
*sizeof(Real
));
44 Real
* Pi
= (Real
*)malloc(k
*sizeof(Real
));
45 Real
* gam2
= (Real
*)malloc(k
*sizeof(Real
));
46 Real
* pi2
= (Real
*)malloc(k
*sizeof(Real
));
47 Real
* Gram2
= (Real
*)malloc(p
*p
*k
*sizeof(Real
));
48 Real
* ps
= (Real
*)malloc(m
*k
*sizeof(Real
));
49 Real
* nY2
= (Real
*)malloc(m
*k
*sizeof(Real
));
50 Real
* ps1
= (Real
*)malloc(n
*m
*k
*sizeof(Real
));
51 Real
* ps2
= (Real
*)malloc(p
*m
*k
*sizeof(Real
));
52 Real
* nY21
= (Real
*)malloc(n
*m
*k
*sizeof(Real
));
53 Real
* Gam
= (Real
*)malloc(n
*k
*sizeof(Real
));
54 Real
* X2
= (Real
*)malloc(n
*p
*k
*sizeof(Real
));
55 Real
* Y2
= (Real
*)malloc(n
*m
*k
*sizeof(Real
));
56 gsl_matrix
* matrix
= gsl_matrix_alloc(m
, m
);
57 gsl_permutation
* permutation
= gsl_permutation_alloc(m
);
58 Real
* YiRhoR
= (Real
*)malloc(m
*sizeof(Real
));
59 Real
* XiPhiR
= (Real
*)malloc(m
*sizeof(Real
));
64 Real
* dotProducts
= (Real
*)malloc(k
*sizeof(Real
));
66 while (ite
< mini
|| (ite
< maxi
&& (dist
>= tau
|| dist2
>= sqrt(tau
))))
68 copyArray(phi
, Phi
, p
*m
*k
);
69 copyArray(rho
, Rho
, m
*m
*k
);
72 // Calculs associes a Y et X
73 for (mwSize r
=0; r
<k
; r
++)
75 for (mwSize mm
=0; mm
<m
; mm
++)
77 //Y2(:,mm,r)=sqrt(gam(:,r)).*transpose(Y(mm,:));
78 for (mwSize u
=0; u
<n
; u
++)
79 Y2
[u
*m
*k
+mm
*k
+r
] = sqrt(gam
[u
*k
+r
]) * Y
[u
*m
+mm
];
81 for (mwSize i
=0; i
<n
; i
++)
83 //X2(i,:,r)=X(i,:).*sqrt(gam(i,r));
84 for (mwSize u
=0; u
<p
; u
++)
85 X2
[i
*p
*k
+u
*k
+r
] = sqrt(gam
[i
*k
+r
]) * X
[i
*p
+u
];
87 for (mwSize mm
=0; mm
<m
; mm
++)
89 //ps2(:,mm,r)=transpose(X2(:,:,r))*Y2(:,mm,r);
90 for (mwSize u
=0; u
<p
; u
++)
92 Real dotProduct
= 0.0;
93 for (mwSize v
=0; v
<n
; v
++)
94 dotProduct
+= X2
[v
*p
*k
+u
*k
+r
] * Y2
[v
*m
*k
+mm
*k
+r
];
95 ps2
[u
*m
*k
+mm
*k
+r
] = dotProduct
;
98 for (mwSize j
=0; j
<p
; j
++)
100 for (mwSize s
=0; s
<p
; s
++)
102 //Gram2(j,s,r)=transpose(X2(:,j,r))*(X2(:,s,r));
103 Real dotProduct
= 0.0;
104 for (mwSize u
=0; u
<n
; u
++)
105 dotProduct
+= X2
[u
*p
*k
+j
*k
+r
] * X2
[u
*p
*k
+s
*k
+r
];
106 Gram2
[j
*p
*k
+s
*k
+r
] = dotProduct
;
116 for (mwSize r
=0; r
<k
; r
++)
118 //b(r) = sum(sum(abs(phi(:,:,r))));
119 Real sumAbsPhi
= 0.0;
120 for (mwSize u
=0; u
<p
; u
++)
121 for (mwSize v
=0; v
<m
; v
++)
122 sumAbsPhi
+= fabs(phi
[u
*m
*k
+v
*k
+r
]);
126 for (mwSize u
=0; u
<k
; u
++)
128 Real sumOnColumn
= 0.0;
129 for (mwSize v
=0; v
<n
; v
++)
130 sumOnColumn
+= gam
[v
*k
+u
];
131 gam2
[u
] = sumOnColumn
;
133 //a=sum(gam*transpose(log(pi)));
135 for (mwSize u
=0; u
<n
; u
++)
137 Real dotProduct
= 0.0;
138 for (mwSize v
=0; v
<k
; v
++)
139 dotProduct
+= gam
[u
*k
+v
] * log(pi
[v
]);
143 //tant que les proportions sont negatives
145 int pi2AllPositive
= 0;
147 while (!pi2AllPositive
)
149 //pi2(:)=pi(:)+0.1^kk*(1/n*gam2(:)-pi(:));
150 for (mwSize r
=0; r
<k
; r
++)
151 pi2
[r
] = pi
[r
] + pow(0.1,kk
) * (invN
*gam2
[r
] - pi
[r
]);
153 for (mwSize r
=0; r
<k
; r
++)
164 //t(m) la plus grande valeur dans la grille O.1^k tel que ce soit décroissante ou constante
166 Real piPowGammaDotB
= 0.0;
167 for (mwSize v
=0; v
<k
; v
++)
168 piPowGammaDotB
+= pow(pi
[v
],gamma
) * b
[v
];
170 Real pi2PowGammaDotB
= 0.0;
171 for (mwSize v
=0; v
<k
; v
++)
172 pi2PowGammaDotB
+= pow(pi2
[v
],gamma
) * b
[v
];
173 //transpose(gam2)*log(pi2)
174 Real prodGam2logPi2
= 0.0;
175 for (mwSize v
=0; v
<k
; v
++)
176 prodGam2logPi2
+= gam2
[v
] * log(pi2
[v
]);
177 while (-invN
*a
+ lambda
*piPowGammaDotB
< -invN
*prodGam2logPi2
+ lambda
*pi2PowGammaDotB
&& kk
<1000)
179 //pi2=pi+0.1^kk*(1/n*gam2-pi);
180 for (mwSize v
=0; v
<k
; v
++)
181 pi2
[v
] = pi
[v
] + pow(0.1,kk
) * (invN
*gam2
[v
] - pi
[v
]);
182 //pi2 was updated, so we recompute pi2PowGammaDotB and prodGam2logPi2
183 pi2PowGammaDotB
= 0.0;
184 for (mwSize v
=0; v
<k
; v
++)
185 pi2PowGammaDotB
+= pow(pi2
[v
],gamma
) * b
[v
];
186 prodGam2logPi2
= 0.0;
187 for (mwSize v
=0; v
<k
; v
++)
188 prodGam2logPi2
+= gam2
[v
] * log(pi2
[v
]);
191 Real t
= pow(0.1,kk
);
193 Real sumPiPlusTbyDiff
= 0.0;
194 for (mwSize v
=0; v
<k
; v
++)
195 sumPiPlusTbyDiff
+= (pi
[v
] + t
*(pi2
[v
] - pi
[v
]));
196 //pi=(pi+t*(pi2-pi))/sum(pi+t*(pi2-pi));
197 for (mwSize v
=0; v
<k
; v
++)
198 pi
[v
] = (pi
[v
] + t
*(pi2
[v
] - pi
[v
])) / sumPiPlusTbyDiff
;
201 for (mwSize r
=0; r
<k
; r
++)
203 for (mwSize mm
=0; mm
<m
; mm
++)
205 for (mwSize i
=0; i
<n
; i
++)
207 //< X2(i,:,r) , phi(:,mm,r) >
208 Real dotProduct
= 0.0;
209 for (mwSize u
=0; u
<p
; u
++)
210 dotProduct
+= X2
[i
*p
*k
+u
*k
+r
] * phi
[u
*m
*k
+mm
*k
+r
];
211 //ps1(i,mm,r)=Y2(i,mm,r)*dot(X2(i,:,r),phi(:,mm,r));
212 ps1
[i
*m
*k
+mm
*k
+r
] = Y2
[i
*m
*k
+mm
*k
+r
] * dotProduct
;
213 nY21
[i
*m
*k
+mm
*k
+r
] = Y2
[i
*m
*k
+mm
*k
+r
] * Y2
[i
*m
*k
+mm
*k
+r
];
215 //ps(mm,r)=sum(ps1(:,mm,r));
217 for (mwSize u
=0; u
<n
; u
++)
218 sumPs1
+= ps1
[u
*m
*k
+mm
*k
+r
];
220 //nY2(mm,r)=sum(nY21(:,mm,r));
222 for (mwSize u
=0; u
<n
; u
++)
223 sumNy21
+= nY21
[u
*m
*k
+mm
*k
+r
];
224 nY2
[mm
*k
+r
] = sumNy21
;
225 //rho(mm,mm,r)=((ps(mm,r)+sqrt(ps(mm,r)^2+4*nY2(mm,r)*(gam2(r))))/(2*nY2(mm,r)));
226 rho
[mm
*m
*k
+mm
*k
+r
] = ( ps
[mm
*k
+r
] + sqrt( ps
[mm
*k
+r
]*ps
[mm
*k
+r
]
227 + 4*nY2
[mm
*k
+r
] * (gam2
[r
]) ) ) / (2*nY2
[mm
*k
+r
]);
230 for (mwSize r
=0; r
<k
; r
++)
232 for (mwSize j
=0; j
<p
; j
++)
234 for (mwSize mm
=0; mm
<m
; mm
++)
236 //sum(phi(1:j-1,mm,r).*transpose(Gram2(j,1:j-1,r)))+sum(phi(j+1:p,mm,r).*transpose(Gram2(j,j+1:p,r)))
237 Real dotPhiGram2
= 0.0;
238 for (mwSize u
=0; u
<j
; u
++)
239 dotPhiGram2
+= phi
[u
*m
*k
+mm
*k
+r
] * Gram2
[j
*p
*k
+u
*k
+r
];
240 for (mwSize u
=j
+1; u
<p
; u
++)
241 dotPhiGram2
+= phi
[u
*m
*k
+mm
*k
+r
] * Gram2
[j
*p
*k
+u
*k
+r
];
242 //S(j,r,mm)=-rho(mm,mm,r)*ps2(j,mm,r)+sum(phi(1:j-1,mm,r).*transpose(Gram2(j,1:j-1,r)))
243 // +sum(phi(j+1:p,mm,r).*transpose(Gram2(j,j+1:p,r)));
244 S
[j
*m
*k
+mm
*k
+r
] = -rho
[mm
*m
*k
+mm
*k
+r
] * ps2
[j
*m
*k
+mm
*k
+r
] + dotPhiGram2
;
245 if (fabs(S
[j
*m
*k
+mm
*k
+r
]) <= n
*lambda
*pow(pi
[r
],gamma
))
246 phi
[j
*m
*k
+mm
*k
+r
] = 0;
247 else if (S
[j
*m
*k
+mm
*k
+r
] > n
*lambda
*pow(pi
[r
],gamma
))
248 phi
[j
*m
*k
+mm
*k
+r
] = (n
*lambda
*pow(pi
[r
],gamma
) - S
[j
*m
*k
+mm
*k
+r
])
249 / Gram2
[j
*p
*k
+j
*k
+r
];
251 phi
[j
*m
*k
+mm
*k
+r
] = -(n
*lambda
*pow(pi
[r
],gamma
) + S
[j
*m
*k
+mm
*k
+r
])
252 / Gram2
[j
*p
*k
+j
*k
+r
];
262 Real sumLogLLF2
= 0.0;
263 for (mwSize i
=0; i
<n
; i
++)
267 Real minDotProduct
= INFINITY
;
269 for (mwSize r
=0; r
<k
; r
++)
272 //Gam(i,r) = Pi(r) * det(Rho(:,:,r)) * exp( -1/2 * (Y(i,:)*Rho(:,:,r) - X(i,:)...
273 // *phi(:,:,r)) * transpose( Y(i,:)*Rho(:,:,r) - X(i,:)*phi(:,:,r) ) );
274 //split in several sub-steps
276 //compute Y(i,:)*rho(:,:,r)
277 for (mwSize u
=0; u
<m
; u
++)
280 for (mwSize v
=0; v
<m
; v
++)
281 YiRhoR
[u
] += Y
[i
*m
+v
] * rho
[v
*m
*k
+u
*k
+r
];
284 //compute X(i,:)*phi(:,:,r)
285 for (mwSize u
=0; u
<m
; u
++)
288 for (mwSize v
=0; v
<p
; v
++)
289 XiPhiR
[u
] += X
[i
*p
+v
] * phi
[v
*m
*k
+u
*k
+r
];
292 // compute dotProduct < Y(:,i)*rho(:,:,r)-X(i,:)*phi(:,:,r) . Y(:,i)*rho(:,:,r)-X(i,:)*phi(:,:,r) >
293 dotProducts
[r
] = 0.0;
294 for (mwSize u
=0; u
<m
; u
++)
295 dotProducts
[r
] += (YiRhoR
[u
]-XiPhiR
[u
]) * (YiRhoR
[u
]-XiPhiR
[u
]);
296 if (dotProducts
[r
] < minDotProduct
)
297 minDotProduct
= dotProducts
[r
];
299 Real shift
= 0.5*minDotProduct
;
300 for (mwSize r
=0; r
<k
; r
++)
302 //compute det(rho(:,:,r)) [TODO: avoid re-computations]
303 for (mwSize u
=0; u
<m
; u
++)
305 for (mwSize v
=0; v
<m
; v
++)
306 matrix
->data
[u
*m
+v
] = rho
[u
*m
*k
+v
*k
+r
];
308 gsl_linalg_LU_decomp(matrix
, permutation
, &signum
);
309 Real detRhoR
= gsl_linalg_LU_det(matrix
, signum
);
311 Gam
[i
*k
+r
] = pi
[r
] * detRhoR
* exp(-0.5*dotProducts
[r
] + shift
);
312 sumLLF1
+= Gam
[i
*k
+r
] / pow(2*M_PI
,m
/2.0);
313 sumGamI
+= Gam
[i
*k
+r
];
315 sumLogLLF2
+= log(sumLLF1
);
316 for (mwSize r
=0; r
<k
; r
++)
318 //gam(i,r)=Gam(i,r)/sum(Gam(i,:));
319 gam
[i
*k
+r
] = sumGamI
> EPS
320 ? Gam
[i
*k
+r
] / sumGamI
327 for (mwSize r
=0; r
<k
; r
++)
328 sumPen
+= pow(pi
[r
],gamma
) * b
[r
];
329 //LLF(ite)=-1/n*sum(log(LLF2(ite,:)))+lambda*sum(pen(ite,:));
330 LLF
[ite
] = -invN
* sumLogLLF2
+ lambda
* sumPen
;
334 dist
= (LLF
[ite
] - LLF
[ite
-1]) / (1.0 + fabs(LLF
[ite
]));
336 //Dist1=max(max((abs(phi-Phi))./(1+abs(phi))));
338 for (mwSize u
=0; u
<p
; u
++)
340 for (mwSize v
=0; v
<m
; v
++)
342 for (mwSize w
=0; w
<k
; w
++)
344 Real tmpDist
= fabs(phi
[u
*m
*k
+v
*k
+w
]-Phi
[u
*m
*k
+v
*k
+w
])
345 / (1.0+fabs(phi
[u
*m
*k
+v
*k
+w
]));
351 //Dist2=max(max((abs(rho-Rho))./(1+abs(rho))));
353 for (mwSize u
=0; u
<m
; u
++)
355 for (mwSize v
=0; v
<m
; v
++)
357 for (mwSize w
=0; w
<k
; w
++)
359 Real tmpDist
= fabs(rho
[u
*m
*k
+v
*k
+w
]-Rho
[u
*m
*k
+v
*k
+w
])
360 / (1.0+fabs(rho
[u
*m
*k
+v
*k
+w
]));
366 //Dist3=max(max((abs(pi-Pi))./(1+abs(Pi))));
368 for (mwSize u
=0; u
<n
; u
++)
370 for (mwSize v
=0; v
<k
; v
++)
372 Real tmpDist
= fabs(pi
[v
]-Pi
[v
]) / (1.0+fabs(pi
[v
]));
377 //dist2=max([max(Dist1),max(Dist2),max(Dist3)]);
400 gsl_matrix_free(matrix
);
401 gsl_permutation_free(permutation
);