bcbfd3c23c66d9dbca2531e33ad4a69605486a5b
2 #include "constructionModelesLassoMLE.h"
3 #include <gsl/gsl_linalg.h>
5 #include "omp_num_threads.h"
7 // TODO: comment on constructionModelesLassoMLE purpose
8 void constructionModelesLassoMLE(
10 const Real
* phiInit
, // parametre initial de moyenne renormalisé
11 const Real
* rhoInit
, // parametre initial de variance renormalisé
12 const Real
* piInit
, // parametre initial des proportions
13 const Real
* gamInit
, // paramètre initial des probabilités a posteriori de chaque échantillon
14 Int mini
, // nombre minimal d'itérations dans l'algorithme EM
15 Int maxi
, // nombre maximal d'itérations dans l'algorithme EM
16 Real gamma
, // valeur de gamma : puissance des proportions dans la pénalisation pour un Lasso adaptatif
17 const Real
* glambda
, // valeur des paramètres de régularisation du Lasso
18 const Real
* X
, // régresseurs
19 const Real
* Y
, // réponse
20 Real seuil
, // seuil pour prendre en compte une variable
21 Real tau
, // seuil pour accepter la convergence
22 const Int
* A1
, // matrice des coefficients des parametres selectionnes
23 const Int
* A2
, // matrice des coefficients des parametres non selectionnes
25 Real
* phi
, // estimateur ainsi calculé par le Lasso
26 Real
* rho
, // estimateur ainsi calculé par le Lasso
27 Real
* pi
, // estimateur ainsi calculé par le Lasso
28 Real
* lvraisemblance
, // estimateur ainsi calculé par le Lasso
29 // additional size parameters
30 mwSize n
, // taille de l'echantillon
31 mwSize p
, // nombre de covariables
32 mwSize m
, // taille de Y (multivarié)
33 mwSize k
, // nombre de composantes
34 mwSize L
) // taille de glambda
36 //preparation: phi = 0
37 for (mwSize u
=0; u
<p
*m
*k
*L
; u
++)
40 //initiate parallel section
42 omp_set_num_threads(OMP_NUM_THREADS
);
43 #pragma omp parallel default(shared) private(lambdaIndex)
45 #pragma omp for schedule(dynamic,CHUNK_SIZE) nowait
46 for (lambdaIndex
=0; lambdaIndex
<L
; lambdaIndex
++)
48 //~ a = A1(:,1,lambdaIndex);
50 Int
* a
= (Int
*)malloc(p
*sizeof(Int
));
52 for (mwSize j
=0; j
<p
; j
++)
54 if (A1
[j
*(m
+1)*L
+0*L
+lambdaIndex
] != 0)
55 a
[lengthA
++] = A1
[j
*(m
+1)*L
+0*L
+lambdaIndex
] - 1;
61 Real
* Xa
= (Real
*)malloc(n
*lengthA
*sizeof(Real
));
62 for (mwSize i
=0; i
<n
; i
++)
64 for (mwSize j
=0; j
<lengthA
; j
++)
65 Xa
[i
*lengthA
+j
] = X
[i
*p
+a
[j
]];
68 //phia = phiInit(a,:,:)
69 Real
* phia
= (Real
*)malloc(lengthA
*m
*k
*sizeof(Real
));
70 for (mwSize j
=0; j
<lengthA
; j
++)
72 for (mwSize mm
=0; mm
<m
; mm
++)
74 for (mwSize r
=0; r
<k
; r
++)
75 phia
[j
*m
*k
+mm
*k
+r
] = phiInit
[a
[j
]*m
*k
+mm
*k
+r
];
79 //[phiLambda,rhoLambda,piLambda,~,~] = EMGLLF(...
80 // phiInit(a,:,:),rhoInit,piInit,gamInit,mini,maxi,gamma,0,X(:,a),Y,tau);
81 Real
* phiLambda
= (Real
*)malloc(lengthA
*m
*k
*sizeof(Real
));
82 Real
* rhoLambda
= (Real
*)malloc(m
*m
*k
*sizeof(Real
));
83 Real
* piLambda
= (Real
*)malloc(k
*sizeof(Real
));
84 Real
* LLF
= (Real
*)malloc((maxi
+1)*sizeof(Real
));
85 Real
* S
= (Real
*)malloc(lengthA
*m
*k
*sizeof(Real
));
86 EMGLLF(phia
,rhoInit
,piInit
,gamInit
,mini
,maxi
,gamma
,0.0,Xa
,Y
,tau
,
87 phiLambda
,rhoLambda
,piLambda
,LLF
,S
,
95 //~ phi(a(j),:,:,lambdaIndex) = phiLambda(j,:,:);
97 for (mwSize j
=0; j
<lengthA
; j
++)
99 for (mwSize mm
=0; mm
<m
; mm
++)
101 for (mwSize r
=0; r
<k
; r
++)
102 phi
[a
[j
]*m
*k
*L
+mm
*k
*L
+r
*L
+lambdaIndex
] = phiLambda
[j
*m
*k
+mm
*k
+r
];
106 //~ rho(:,:,:,lambdaIndex) = rhoLambda;
107 for (mwSize u
=0; u
<m
; u
++)
109 for (mwSize v
=0; v
<m
; v
++)
111 for (mwSize r
=0; r
<k
; r
++)
112 rho
[u
*m
*k
*L
+v
*k
*L
+r
*L
+lambdaIndex
] = rhoLambda
[u
*m
*k
+v
*k
+r
];
116 //~ pi(:,lambdaIndex) = piLambda;
117 for (mwSize r
=0; r
<k
; r
++)
118 pi
[r
*L
+lambdaIndex
] = piLambda
[r
];
121 mwSize dimension
= 0;
122 Int
* b
= (Int
*)malloc(m
*sizeof(Int
));
123 for (mwSize j
=0; j
<p
; j
++)
125 //~ b = A2(j,2:end,lambdaIndex);
128 for (mwSize mm
=0; mm
<m
; mm
++)
130 if (A2
[j
*(m
+1)*L
+(mm
+1)*L
+lambdaIndex
] != 0)
131 b
[lengthB
++] = A2
[j
*(m
+1)*L
+(mm
+1)*L
+lambdaIndex
] - 1;
134 //~ phi(A2(j,1,lambdaIndex),b,:,lambdaIndex) = 0.0;
138 for (mwSize mm
=0; mm
<lengthB
; mm
++)
140 for (mwSize r
=0; r
<k
; r
++)
141 phi
[(A2
[j
*(m
+1)*L
+0*L
+lambdaIndex
]-1)*m
*k
*L
+ b
[mm
]*k
*L
+ r
*L
+ lambdaIndex
] = 0.0;
145 //~ c = A1(j,2:end,lambdaIndex);
147 //~ dimension = dimension + length(c);
148 for (mwSize mm
=0; mm
<m
; mm
++)
150 if (A1
[j
*(m
+1)*L
+(mm
+1)*L
+lambdaIndex
] != 0)
157 Real
* densite
= (Real
*)calloc(L
*n
,sizeof(Real
));
158 Real sumLogDensit
= 0.0;
159 gsl_matrix
* matrix
= gsl_matrix_alloc(m
, m
);
160 gsl_permutation
* permutation
= gsl_permutation_alloc(m
);
161 Real
* YiRhoR
= (Real
*)malloc(m
*sizeof(Real
));
162 Real
* XiPhiR
= (Real
*)malloc(m
*sizeof(Real
));
163 for (mwSize i
=0; i
<n
; i
++)
166 //~ delta = Y(i,:)*rho(:,:,r,lambdaIndex) - (X(i,a)*(phi(a,:,r,lambdaIndex)));
167 //~ densite(i,lambdaIndex) = densite(i,lambdaIndex) +...
168 //~ pi(r,lambdaIndex)*det(rho(:,:,r,lambdaIndex))/(sqrt(2*PI))^m*exp(-dot(delta,delta)/2.0);
170 for (mwSize r
=0; r
<k
; r
++)
172 //compute det(rho(:,:,r,lambdaIndex)) [TODO: avoid re-computations]
173 for (mwSize u
=0; u
<m
; u
++)
175 for (mwSize v
=0; v
<m
; v
++)
176 matrix
->data
[u
*m
+v
] = rho
[u
*m
*k
*L
+v
*k
*L
+r
*L
+lambdaIndex
];
178 gsl_linalg_LU_decomp(matrix
, permutation
, &signum
);
179 Real detRhoR
= gsl_linalg_LU_det(matrix
, signum
);
181 //compute Y(i,:)*rho(:,:,r,lambdaIndex)
182 for (mwSize u
=0; u
<m
; u
++)
185 for (mwSize v
=0; v
<m
; v
++)
186 YiRhoR
[u
] += Y
[i
*m
+v
] * rho
[v
*m
*k
*L
+u
*k
*L
+r
*L
+lambdaIndex
];
189 //compute X(i,a)*phi(a,:,r,lambdaIndex)
190 for (mwSize u
=0; u
<m
; u
++)
193 for (mwSize v
=0; v
<lengthA
; v
++)
194 XiPhiR
[u
] += X
[i
*p
+a
[v
]] * phi
[a
[v
]*m
*k
*L
+u
*k
*L
+r
*L
+lambdaIndex
];
196 // On peut remplacer X par Xa dans ce dernier calcul, mais je ne sais pas si c'est intéressant ...
198 // compute dotProduct < delta . delta >
199 Real dotProduct
= 0.0;
200 for (mwSize u
=0; u
<m
; u
++)
201 dotProduct
+= (YiRhoR
[u
]-XiPhiR
[u
]) * (YiRhoR
[u
]-XiPhiR
[u
]);
203 densite
[lambdaIndex
*n
+i
] += (pi
[r
*L
+lambdaIndex
]*detRhoR
/pow(sqrt(2.0*M_PI
),m
))*exp(-dotProduct
/2.0);
205 sumLogDensit
+= log(densite
[lambdaIndex
*n
+i
]);
207 lvraisemblance
[lambdaIndex
*2+0] = sumLogDensit
;
208 lvraisemblance
[lambdaIndex
*2+1] = (dimension
+m
+1)*k
-1;
214 gsl_matrix_free(matrix
);
215 gsl_permutation_free(permutation
);