4 #include <gsl/gsl_linalg.h>
6 // TODO: don't recompute indexes ai(...) and mi(...) when possible
9 const Real
* phiInit
, // parametre initial de moyenne renormalisé
10 const Real
* rhoInit
, // parametre initial de variance renormalisé
11 const Real
* piInit
, // parametre initial des proportions
12 const Real
* gamInit
, // paramètre initial des probabilités a posteriori de chaque échantillon
13 int mini
, // nombre minimal d'itérations dans l'algorithme EM
14 int maxi
, // nombre maximal d'itérations dans l'algorithme EM
15 Real gamma
, // puissance des proportions dans la pénalisation pour un Lasso adaptatif
16 Real lambda
, // valeur du paramètre de régularisation du Lasso
17 const Real
* X
, // régresseurs
18 const Real
* Y
, // réponse
19 Real tau
, // seuil pour accepter la convergence
20 // OUT parameters (all pointers, to be modified)
21 Real
* phi
, // parametre de moyenne renormalisé, calculé par l'EM
22 Real
* rho
, // parametre de variance renormalisé, calculé par l'EM
23 Real
* pi
, // parametre des proportions renormalisé, calculé par l'EM
24 Real
* llh
, // (derniere) log vraisemblance associée à cet échantillon,
25 // pour les valeurs estimées des paramètres
28 // additional size parameters
29 int n
, // nombre d'echantillons
30 int p
, // nombre de covariables
31 int m
, // taille de Y (multivarié)
32 int k
) // nombre de composantes dans le mélange
35 copyArray(phiInit
, phi
, p
*m
*k
);
36 copyArray(rhoInit
, rho
, m
*m
*k
);
37 copyArray(piInit
, pi
, k
);
38 //S is already allocated, and doesn't need to be 'zeroed'
40 //Other local variables: same as in R
41 Real
* gam
= (Real
*)malloc(n
*k
*sizeof(Real
));
42 copyArray(gamInit
, gam
, n
*k
);
43 Real
* Gram2
= (Real
*)malloc(p
*p
*k
*sizeof(Real
));
44 Real
* ps2
= (Real
*)malloc(p
*m
*k
*sizeof(Real
));
45 Real
* b
= (Real
*)malloc(k
*sizeof(Real
));
46 Real
* X2
= (Real
*)malloc(n
*p
*k
*sizeof(Real
));
47 Real
* Y2
= (Real
*)malloc(n
*m
*k
*sizeof(Real
));
49 Real
* pi2
= (Real
*)malloc(k
*sizeof(Real
));
50 const Real EPS
= 1e-15;
51 // Additional (not at this place, in R file)
52 Real
* gam2
= (Real
*)malloc(k
*sizeof(Real
));
53 Real
* sqNorm2
= (Real
*)malloc(k
*sizeof(Real
));
54 Real
* detRho
= (Real
*)malloc(k
*sizeof(Real
));
55 gsl_matrix
* matrix
= gsl_matrix_alloc(m
, m
);
56 gsl_permutation
* permutation
= gsl_permutation_alloc(m
);
57 Real
* YiRhoR
= (Real
*)malloc(m
*sizeof(Real
));
58 Real
* XiPhiR
= (Real
*)malloc(m
*sizeof(Real
));
59 const Real gaussConstM
= pow(2.*M_PI
,m
/2.);
60 Real
* Phi
= (Real
*)malloc(p
*m
*k
*sizeof(Real
));
61 Real
* Rho
= (Real
*)malloc(m
*m
*k
*sizeof(Real
));
62 Real
* Pi
= (Real
*)malloc(k
*sizeof(Real
));
64 for (int ite
=0; ite
<maxi
; ite
++)
66 copyArray(phi
, Phi
, p
*m
*k
);
67 copyArray(rho
, Rho
, m
*m
*k
);
70 // Calculs associés a Y et X
71 for (int r
=0; r
<k
; r
++)
73 for (int mm
=0; mm
<m
; mm
++)
75 //Y2[,mm,r] = sqrt(gam[,r]) * Y[,mm]
76 for (int u
=0; u
<n
; u
++)
77 Y2
[ai(u
,mm
,r
,n
,m
,k
)] = sqrt(gam
[mi(u
,r
,n
,k
)]) * Y
[mi(u
,mm
,n
,m
)];
79 for (int i
=0; i
<n
; i
++)
81 //X2[i,,r] = sqrt(gam[i,r]) * X[i,]
82 for (int u
=0; u
<p
; u
++)
83 X2
[ai(i
,u
,r
,n
,p
,k
)] = sqrt(gam
[mi(i
,r
,n
,k
)]) * X
[mi(i
,u
,n
,p
)];
85 for (int mm
=0; mm
<m
; mm
++)
87 //ps2[,mm,r] = crossprod(X2[,,r],Y2[,mm,r])
88 for (int u
=0; u
<p
; u
++)
91 for (int v
=0; v
<n
; v
++)
92 dotProduct
+= X2
[ai(v
,u
,r
,n
,p
,k
)] * Y2
[ai(v
,mm
,r
,n
,m
,k
)];
93 ps2
[ai(u
,mm
,r
,p
,m
,k
)] = dotProduct
;
96 for (int j
=0; j
<p
; j
++)
98 for (int s
=0; s
<p
; s
++)
100 //Gram2[j,s,r] = crossprod(X2[,j,r], X2[,s,r])
101 Real dotProduct
= 0.;
102 for (int u
=0; u
<n
; u
++)
103 dotProduct
+= X2
[ai(u
,j
,r
,n
,p
,k
)] * X2
[ai(u
,s
,r
,n
,p
,k
)];
104 Gram2
[ai(j
,s
,r
,p
,p
,k
)] = dotProduct
;
114 for (int r
=0; r
<k
; r
++)
116 //b[r] = sum(abs(phi[,,r]))
118 for (int u
=0; u
<p
; u
++)
119 for (int v
=0; v
<m
; v
++)
120 sumAbsPhi
+= fabs(phi
[ai(u
,v
,r
,p
,m
,k
)]);
123 //gam2 = colSums(gam)
124 for (int u
=0; u
<k
; u
++)
126 Real sumOnColumn
= 0.;
127 for (int v
=0; v
<n
; v
++)
128 sumOnColumn
+= gam
[mi(v
,u
,n
,k
)];
129 gam2
[u
] = sumOnColumn
;
131 //a = sum(gam %*% log(pi))
133 for (int u
=0; u
<n
; u
++)
135 Real dotProduct
= 0.;
136 for (int v
=0; v
<k
; v
++)
137 dotProduct
+= gam
[mi(u
,v
,n
,k
)] * log(pi
[v
]);
141 //tant que les proportions sont negatives
145 while (!pi2AllPositive
)
147 //pi2 = pi + 0.1^kk * ((1/n)*gam2 - pi)
148 Real pow_01_kk
= pow(0.1,kk
);
149 for (int r
=0; r
<k
; r
++)
150 pi2
[r
] = pi
[r
] + pow_01_kk
* (invN
*gam2
[r
] - pi
[r
]);
151 //pi2AllPositive = all(pi2 >= 0)
153 for (int r
=0; r
<k
; r
++)
165 Real piPowGammaDotB
= 0.;
166 for (int v
=0; v
<k
; v
++)
167 piPowGammaDotB
+= pow(pi
[v
],gamma
) * b
[v
];
169 Real pi2PowGammaDotB
= 0.;
170 for (int v
=0; v
<k
; v
++)
171 pi2PowGammaDotB
+= pow(pi2
[v
],gamma
) * b
[v
];
172 //sum(gam2 * log(pi2))
173 Real gam2DotLogPi2
= 0.;
174 for (int v
=0; v
<k
; v
++)
175 gam2DotLogPi2
+= gam2
[v
] * log(pi2
[v
]);
177 //t(m) la plus grande valeur dans la grille O.1^k tel que ce soit décroissante ou constante
178 while (-invN
*a
+ lambda
*piPowGammaDotB
< -invN
*gam2DotLogPi2
+ lambda
*pi2PowGammaDotB
181 Real pow_01_kk
= pow(0.1,kk
);
182 //pi2 = pi + 0.1^kk * (1/n*gam2 - pi)
183 for (int v
=0; v
<k
; v
++)
184 pi2
[v
] = pi
[v
] + pow_01_kk
* (invN
*gam2
[v
] - pi
[v
]);
185 //pi2 was updated, so we recompute pi2PowGammaDotB and gam2DotLogPi2
186 pi2PowGammaDotB
= 0.;
187 for (int v
=0; v
<k
; v
++)
188 pi2PowGammaDotB
+= pow(pi2
[v
],gamma
) * b
[v
];
190 for (int v
=0; v
<k
; v
++)
191 gam2DotLogPi2
+= gam2
[v
] * log(pi2
[v
]);
194 Real t
= pow(0.1,kk
);
195 //sum(pi + t*(pi2-pi))
196 Real sumPiPlusTbyDiff
= 0.;
197 for (int v
=0; v
<k
; v
++)
198 sumPiPlusTbyDiff
+= (pi
[v
] + t
*(pi2
[v
] - pi
[v
]));
199 //pi = (pi + t*(pi2-pi)) / sum(pi + t*(pi2-pi))
200 for (int v
=0; v
<k
; v
++)
201 pi
[v
] = (pi
[v
] + t
*(pi2
[v
] - pi
[v
])) / sumPiPlusTbyDiff
;
204 for (int r
=0; r
<k
; r
++)
206 for (int mm
=0; mm
<m
; mm
++)
210 // Compute ps, and nY2 = sum(Y2[,mm,r]^2)
211 for (int i
=0; i
<n
; i
++)
213 //< X2[i,,r] , phi[,mm,r] >
214 Real dotProduct
= 0.;
215 for (int u
=0; u
<p
; u
++)
216 dotProduct
+= X2
[ai(i
,u
,r
,n
,p
,k
)] * phi
[ai(u
,mm
,r
,p
,m
,k
)];
217 //ps = ps + Y2[i,mm,r] * sum(X2[i,,r] * phi[,mm,r])
218 ps
+= Y2
[ai(i
,mm
,r
,n
,m
,k
)] * dotProduct
;
219 nY2
+= Y2
[ai(i
,mm
,r
,n
,m
,k
)] * Y2
[ai(i
,mm
,r
,n
,m
,k
)];
221 //rho[mm,mm,r] = (ps+sqrt(ps^2+4*nY2*gam2[r])) / (2*nY2)
222 rho
[ai(mm
,mm
,r
,m
,m
,k
)] = (ps
+ sqrt(ps
*ps
+ 4*nY2
* gam2
[r
])) / (2*nY2
);
226 for (int r
=0; r
<k
; r
++)
228 for (int j
=0; j
<p
; j
++)
230 for (int mm
=0; mm
<m
; mm
++)
232 //sum(phi[-j,mm,r] * Gram2[j,-j,r])
233 Real phiDotGram2
= 0.;
234 for (int u
=0; u
<p
; u
++)
237 phiDotGram2
+= phi
[ai(u
,mm
,r
,p
,m
,k
)] * Gram2
[ai(j
,u
,r
,p
,p
,k
)];
239 //S[j,mm,r] = -rho[mm,mm,r]*ps2[j,mm,r] + sum(phi[-j,mm,r] * Gram2[j,-j,r])
240 S
[ai(j
,mm
,r
,p
,m
,k
)] = -rho
[ai(mm
,mm
,r
,m
,m
,k
)] * ps2
[ai(j
,mm
,r
,p
,m
,k
)]
242 Real pirPowGamma
= pow(pi
[r
],gamma
);
243 if (fabs(S
[ai(j
,mm
,r
,p
,m
,k
)]) <= n
*lambda
*pirPowGamma
)
244 phi
[ai(j
,mm
,r
,p
,m
,k
)] = 0.;
245 else if (S
[ai(j
,mm
,r
,p
,m
,k
)] > n
*lambda
*pirPowGamma
)
247 phi
[ai(j
,mm
,r
,p
,m
,k
)] = (n
*lambda
*pirPowGamma
- S
[ai(j
,mm
,r
,p
,m
,k
)])
248 / Gram2
[ai(j
,j
,r
,p
,p
,k
)];
252 phi
[ai(j
,mm
,r
,p
,m
,k
)] = -(n
*lambda
*pirPowGamma
+ S
[ai(j
,mm
,r
,p
,m
,k
)])
253 / Gram2
[ai(j
,j
,r
,p
,p
,k
)];
263 // Precompute det(rho[,,r]) for r in 1...k
265 for (int r
=0; r
<k
; r
++)
267 for (int u
=0; u
<m
; u
++)
269 for (int v
=0; v
<m
; v
++)
270 matrix
->data
[u
*m
+v
] = rho
[ai(u
,v
,r
,m
,m
,k
)];
272 gsl_linalg_LU_decomp(matrix
, permutation
, &signum
);
273 detRho
[r
] = gsl_linalg_LU_det(matrix
, signum
);
277 for (int i
=0; i
<n
; i
++)
279 for (int r
=0; r
<k
; r
++)
281 //compute Y[i,]%*%rho[,,r]
282 for (int u
=0; u
<m
; u
++)
285 for (int v
=0; v
<m
; v
++)
286 YiRhoR
[u
] += Y
[mi(i
,v
,n
,m
)] * rho
[ai(v
,u
,r
,m
,m
,k
)];
289 //compute X[i,]%*%phi[,,r]
290 for (int u
=0; u
<m
; u
++)
293 for (int v
=0; v
<p
; v
++)
294 XiPhiR
[u
] += X
[mi(i
,v
,n
,p
)] * phi
[ai(v
,u
,r
,p
,m
,k
)];
297 //compute sq norm || Y(:,i)*rho(:,:,r)-X(i,:)*phi(:,:,r) ||_2^2
299 for (int u
=0; u
<m
; u
++)
300 sqNorm2
[r
] += (YiRhoR
[u
]-XiPhiR
[u
]) * (YiRhoR
[u
]-XiPhiR
[u
]);
304 for (int r
=0; r
<k
; r
++)
306 gam
[mi(i
,r
,n
,k
)] = pi
[r
] * exp(-.5*sqNorm2
[r
]) * detRho
[r
];
307 sumGamI
+= gam
[mi(i
,r
,n
,k
)];
310 sumLogLLH
+= log(sumGamI
) - log(gaussConstM
);
311 if (sumGamI
> EPS
) //else: gam[i,] is already ~=0
313 for (int r
=0; r
<k
; r
++)
314 gam
[mi(i
,r
,n
,k
)] /= sumGamI
;
318 //sumPen = sum(pi^gamma * b)
320 for (int r
=0; r
<k
; r
++)
321 sumPen
+= pow(pi
[r
],gamma
) * b
[r
];
322 Real last_llh
= *llh
;
323 //llh = -sumLogLLH/n + lambda*sumPen
324 *llh
= -invN
* sumLogLLH
+ lambda
* sumPen
;
325 Real dist
= ite
==0 ? *llh
: (*llh
- last_llh
) / (1. + fabs(*llh
));
327 //Dist1 = max( abs(phi-Phi) / (1+abs(phi)) )
329 for (int u
=0; u
<p
; u
++)
331 for (int v
=0; v
<m
; v
++)
333 for (int w
=0; w
<k
; w
++)
335 Real tmpDist
= fabs(phi
[ai(u
,v
,w
,p
,m
,k
)]-Phi
[ai(u
,v
,w
,p
,m
,k
)])
336 / (1.+fabs(phi
[ai(u
,v
,w
,p
,m
,k
)]));
342 //Dist2 = max( (abs(rho-Rho)) / (1+abs(rho)) )
344 for (int u
=0; u
<m
; u
++)
346 for (int v
=0; v
<m
; v
++)
348 for (int w
=0; w
<k
; w
++)
350 Real tmpDist
= fabs(rho
[ai(u
,v
,w
,m
,m
,k
)]-Rho
[ai(u
,v
,w
,m
,m
,k
)])
351 / (1.+fabs(rho
[ai(u
,v
,w
,m
,m
,k
)]));
357 //Dist3 = max( (abs(pi-Pi)) / (1+abs(Pi)))
359 for (int u
=0; u
<n
; u
++)
361 for (int v
=0; v
<k
; v
++)
363 Real tmpDist
= fabs(pi
[v
]-Pi
[v
]) / (1.+fabs(pi
[v
]));
368 //dist2=max([max(Dist1),max(Dist2),max(Dist3)]);
375 if (ite
>= mini
&& (dist
>= tau
|| dist2
>= sqrt(tau
)))
379 //affec = apply(gam, 1, which.max)
380 for (int i
=0; i
<n
; i
++)
384 for (int j
=0; j
<k
; j
++)
386 if (gam
[mi(i
,j
,n
,k
)] > rowMax
)
388 affec
[i
] = j
+1; //R indices start at 1
389 rowMax
= gam
[mi(i
,j
,n
,k
)];
402 gsl_matrix_free(matrix
);
403 gsl_permutation_free(permutation
);