Commit | Line | Data |
---|---|---|
4fed76cc BA |
1 | #' EMGLLF |
2 | #' | |
3 | #' Description de EMGLLF | |
4 | #' | |
c280fe59 BA |
5 | #' @param phiInit Parametre initial de moyenne renormalisé |
6 | #' @param rhoInit Parametre initial de variance renormalisé | |
7 | #' @param piInit Parametre initial des proportions | |
8 | #' @param gamInit Paramètre initial des probabilités a posteriori de chaque échantillon | |
9 | #' @param mini Nombre minimal d'itérations dans l'algorithme EM | |
10 | #' @param maxi Nombre maximal d'itérations dans l'algorithme EM | |
11 | #' @param gamma Puissance des proportions dans la pénalisation pour un Lasso adaptatif | |
12 | #' @param lambda Valeur du paramètre de régularisation du Lasso | |
13 | #' @param X Régresseurs | |
14 | #' @param Y Réponse | |
15 | #' @param tau Seuil pour accepter la convergence | |
4fed76cc | 16 | #' |
c280fe59 BA |
17 | #' @return A list ... phi,rho,pi,LLF,S,affec: |
18 | #' phi : parametre de moyenne renormalisé, calculé par l'EM | |
19 | #' rho : parametre de variance renormalisé, calculé par l'EM | |
20 | #' pi : parametre des proportions renormalisé, calculé par l'EM | |
21 | #' LLF : log vraisemblance associée à cet échantillon, pour les valeurs estimées des paramètres | |
22 | #' S : ... affec : ... | |
4fed76cc | 23 | #' |
4fed76cc BA |
24 | #' @export |
25 | EMGLLF <- function(phiInit, rhoInit, piInit, gamInit, | |
aa480ac1 | 26 | mini, maxi, gamma, lambda, X, Y, tau, fast=TRUE) |
4fed76cc | 27 | { |
aa480ac1 BA |
28 | if (!fast) |
29 | { | |
30 | # Function in R | |
a3105972 | 31 | return (.EMGLLF_R(phiInit,rhoInit,piInit,gamInit,mini,maxi,gamma,lambda,X,Y,tau)) |
aa480ac1 | 32 | } |
567a7c38 | 33 | |
aa480ac1 | 34 | # Function in C |
c280fe59 BA |
35 | n = nrow(X) #nombre d'echantillons |
36 | p = ncol(X) #nombre de covariables | |
37 | m = ncol(Y) #taille de Y (multivarié) | |
38 | k = length(piInit) #nombre de composantes dans le mélange | |
39 | .Call("EMGLLF", | |
40 | phiInit, rhoInit, piInit, gamInit, mini, maxi, gamma, lambda, X, Y, tau, | |
41 | phi=double(p*m*k), rho=double(m*m*k), pi=double(k), LLF=double(maxi), | |
42 | S=double(p*m*k), affec=integer(n), | |
43 | n, p, m, k, | |
44 | PACKAGE="valse") | |
4fed76cc | 45 | } |
aa480ac1 BA |
46 | |
47 | # R version - slow but easy to read | |
a3105972 | 48 | .EMGLLF_R = function(phiInit,rhoInit,piInit,gamInit,mini,maxi,gamma,lambda,X,Y,tau) |
aa480ac1 BA |
49 | { |
50 | # Matrix dimensions | |
51 | n = dim(X)[1] | |
52 | p = dim(phiInit)[1] | |
53 | m = dim(phiInit)[2] | |
54 | k = dim(phiInit)[3] | |
55 | ||
56 | # Outputs | |
57 | phi = phiInit | |
58 | rho = rhoInit | |
59 | pi = piInit | |
60 | llh = -Inf | |
61 | S = array(0, dim=c(p,m,k)) | |
62 | ||
63 | # Algorithm variables | |
64 | gam = gamInit | |
65 | Gram2 = array(0, dim=c(p,p,k)) | |
66 | ps2 = array(0, dim=c(p,m,k)) | |
67 | X2 = array(0, dim=c(n,p,k)) | |
68 | Y2 = array(0, dim=c(n,m,k)) | |
69 | EPS = 1e-15 | |
70 | ||
71 | for (ite in 1:maxi) | |
72 | { | |
73 | # Remember last pi,rho,phi values for exit condition in the end of loop | |
74 | Phi = phi | |
75 | Rho = rho | |
76 | Pi = pi | |
77 | ||
78 | # Calcul associé à Y et X | |
79 | for (r in 1:k) | |
80 | { | |
81 | for (mm in 1:m) | |
82 | Y2[,mm,r] = sqrt(gam[,r]) * Y[,mm] | |
83 | for (i in 1:n) | |
84 | X2[i,,r] = sqrt(gam[i,r]) * X[i,] | |
85 | for (mm in 1:m) | |
86 | ps2[,mm,r] = crossprod(X2[,,r],Y2[,mm,r]) | |
87 | for (j in 1:p) | |
88 | { | |
89 | for (s in 1:p) | |
90 | Gram2[j,s,r] = crossprod(X2[,j,r], X2[,s,r]) | |
91 | } | |
92 | } | |
93 | ||
94 | ########## | |
95 | #Etape M # | |
96 | ########## | |
97 | ||
98 | # Pour pi | |
99 | b = sapply( 1:k, function(r) sum(abs(phi[,,r])) ) | |
100 | gam2 = colSums(gam) | |
101 | a = sum(gam %*% log(pi)) | |
102 | ||
103 | # Tant que les props sont negatives | |
104 | kk = 0 | |
105 | pi2AllPositive = FALSE | |
106 | while (!pi2AllPositive) | |
107 | { | |
108 | pi2 = pi + 0.1^kk * ((1/n)*gam2 - pi) | |
109 | pi2AllPositive = all(pi2 >= 0) | |
110 | kk = kk+1 | |
111 | } | |
112 | ||
113 | # t(m) la plus grande valeur dans la grille O.1^k tel que ce soit décroissante ou constante | |
114 | while( kk < 1000 && -a/n + lambda * sum(pi^gamma * b) < | |
115 | -sum(gam2 * log(pi2))/n + lambda * sum(pi2^gamma * b) ) | |
116 | { | |
117 | pi2 = pi + 0.1^kk * (1/n*gam2 - pi) | |
118 | kk = kk + 1 | |
119 | } | |
120 | t = 0.1^kk | |
121 | pi = (pi + t*(pi2-pi)) / sum(pi + t*(pi2-pi)) | |
122 | ||
123 | #Pour phi et rho | |
124 | for (r in 1:k) | |
125 | { | |
126 | for (mm in 1:m) | |
127 | { | |
128 | ps = 0 | |
129 | for (i in 1:n) | |
130 | ps = ps + Y2[i,mm,r] * sum(X2[i,,r] * phi[,mm,r]) | |
131 | nY2 = sum(Y2[,mm,r]^2) | |
132 | rho[mm,mm,r] = (ps+sqrt(ps^2+4*nY2*gam2[r])) / (2*nY2) | |
133 | } | |
134 | } | |
135 | ||
136 | for (r in 1:k) | |
137 | { | |
138 | for (j in 1:p) | |
139 | { | |
140 | for (mm in 1:m) | |
141 | { | |
142 | S[j,mm,r] = -rho[mm,mm,r]*ps2[j,mm,r] + sum(phi[-j,mm,r] * Gram2[j,-j,r]) | |
143 | if (abs(S[j,mm,r]) <= n*lambda*(pi[r]^gamma)) | |
144 | phi[j,mm,r]=0 | |
145 | else if(S[j,mm,r] > n*lambda*(pi[r]^gamma)) | |
146 | phi[j,mm,r] = (n*lambda*(pi[r]^gamma)-S[j,mm,r]) / Gram2[j,j,r] | |
147 | else | |
148 | phi[j,mm,r] = -(n*lambda*(pi[r]^gamma)+S[j,mm,r]) / Gram2[j,j,r] | |
149 | } | |
150 | } | |
151 | } | |
152 | ||
153 | ########## | |
154 | #Etape E # | |
155 | ########## | |
156 | ||
157 | # Precompute det(rho[,,r]) for r in 1...k | |
158 | detRho = sapply(1:k, function(r) det(rho[,,r])) | |
159 | ||
160 | sumLogLLH = 0 | |
161 | for (i in 1:n) | |
162 | { | |
163 | # Update gam[,] | |
164 | sumGamI = 0 | |
165 | for (r in 1:k) | |
166 | { | |
167 | gam[i,r] = pi[r]*exp(-0.5*sum((Y[i,]%*%rho[,,r]-X[i,]%*%phi[,,r])^2))*detRho[r] | |
168 | sumGamI = sumGamI + gam[i,r] | |
169 | } | |
170 | sumLogLLH = sumLogLLH + log(sumGamI) - log((2*base::pi)^(m/2)) | |
171 | if (sumGamI > EPS) #else: gam[i,] is already ~=0 | |
172 | gam[i,] = gam[i,] / sumGamI | |
173 | } | |
174 | ||
175 | sumPen = sum(pi^gamma * b) | |
176 | last_llh = llh | |
177 | llh = -sumLogLLH/n + lambda*sumPen | |
178 | dist = ifelse( ite == 1, llh, (llh-last_llh) / (1+abs(llh)) ) | |
179 | Dist1 = max( (abs(phi-Phi)) / (1+abs(phi)) ) | |
180 | Dist2 = max( (abs(rho-Rho)) / (1+abs(rho)) ) | |
181 | Dist3 = max( (abs(pi-Pi)) / (1+abs(Pi)) ) | |
182 | dist2 = max(Dist1,Dist2,Dist3) | |
183 | ||
184 | if (ite >= mini && (dist >= tau || dist2 >= sqrt(tau))) | |
185 | break | |
186 | } | |
187 | ||
188 | affec = apply(gam, 1, which.max) | |
189 | list( "phi"=phi, "rho"=rho, "pi"=pi, "llh"=llh, "S"=S, "affec"=affec ) | |
190 | } |