Commit | Line | Data |
---|---|---|
4fed76cc BA |
1 | #' EMGLLF |
2 | #' | |
3 | #' Description de EMGLLF | |
4 | #' | |
43d76c49 | 5 | #' @param phiInit an initialization for phi |
6 | #' @param rhoInit an initialization for rho | |
7 | #' @param piInit an initialization for pi | |
8 | #' @param gamInit initialization for the a posteriori probabilities | |
9 | #' @param mini integer, minimum number of iterations in the EM algorithm, by default = 10 | |
10 | #' @param maxi integer, maximum number of iterations in the EM algorithm, by default = 100 | |
11 | #' @param gamma integer for the power in the penaly, by default = 1 | |
12 | #' @param lambda regularization parameter in the Lasso estimation | |
13 | #' @param X matrix of covariates (of size n*p) | |
14 | #' @param Y matrix of responses (of size n*m) | |
15 | #' @param eps real, threshold to say the EM algorithm converges, by default = 1e-4 | |
4fed76cc | 16 | #' |
c280fe59 BA |
17 | #' @return A list ... phi,rho,pi,LLF,S,affec: |
18 | #' phi : parametre de moyenne renormalisé, calculé par l'EM | |
19 | #' rho : parametre de variance renormalisé, calculé par l'EM | |
20 | #' pi : parametre des proportions renormalisé, calculé par l'EM | |
21 | #' LLF : log vraisemblance associée à cet échantillon, pour les valeurs estimées des paramètres | |
22 | #' S : ... affec : ... | |
4fed76cc | 23 | #' |
4fed76cc BA |
24 | #' @export |
25 | EMGLLF <- function(phiInit, rhoInit, piInit, gamInit, | |
43d76c49 | 26 | mini, maxi, gamma, lambda, X, Y, eps, fast=TRUE) |
4fed76cc | 27 | { |
fb6e49cb | 28 | if (!fast) |
29 | { | |
30 | # Function in R | |
7ac88d64 | 31 | return (.EMGLLF_R(phiInit,rhoInit,piInit,gamInit,mini,maxi,gamma,lambda,X,Y,eps)) |
fb6e49cb | 32 | } |
33 | ||
34 | # Function in C | |
35 | n = nrow(X) #nombre d'echantillons | |
36 | p = ncol(X) #nombre de covariables | |
37 | m = ncol(Y) #taille de Y (multivarié) | |
38 | k = length(piInit) #nombre de composantes dans le mélange | |
39 | .Call("EMGLLF", | |
7ac88d64 | 40 | phiInit, rhoInit, piInit, gamInit, mini, maxi, gamma, lambda, X, Y, eps, |
fb6e49cb | 41 | phi=double(p*m*k), rho=double(m*m*k), pi=double(k), LLF=double(maxi), |
42 | S=double(p*m*k), affec=integer(n), | |
43 | n, p, m, k, | |
44 | PACKAGE="valse") | |
4fed76cc | 45 | } |
aa480ac1 BA |
46 | |
47 | # R version - slow but easy to read | |
7ac88d64 | 48 | .EMGLLF_R = function(phiInit,rhoInit,piInit,gamInit,mini,maxi,gamma,lambda,X2,Y,eps) |
aa480ac1 | 49 | { |
fb6e49cb | 50 | # Matrix dimensions |
51 | n = dim(Y)[1] | |
52 | if (length(dim(phiInit)) == 2){ | |
53 | p = 1 | |
54 | m = dim(phiInit)[1] | |
55 | k = dim(phiInit)[2] | |
56 | } else { | |
57 | p = dim(phiInit)[1] | |
58 | m = dim(phiInit)[2] | |
59 | k = dim(phiInit)[3] | |
60 | } | |
61 | X = matrix(nrow = n, ncol = p) | |
62 | X[1:n,1:p] = X2 | |
63 | # Outputs | |
64 | phi = array(NA, dim = c(p,m,k)) | |
65 | phi[1:p,,] = phiInit | |
66 | rho = rhoInit | |
67 | pi = piInit | |
68 | llh = -Inf | |
69 | S = array(0, dim=c(p,m,k)) | |
70 | ||
71 | # Algorithm variables | |
72 | gam = gamInit | |
73 | Gram2 = array(0, dim=c(p,p,k)) | |
74 | ps2 = array(0, dim=c(p,m,k)) | |
75 | X2 = array(0, dim=c(n,p,k)) | |
76 | Y2 = array(0, dim=c(n,m,k)) | |
77 | EPS = 1e-15 | |
78 | ||
79 | for (ite in 1:maxi) | |
80 | { | |
81 | # Remember last pi,rho,phi values for exit condition in the end of loop | |
82 | Phi = phi | |
83 | Rho = rho | |
84 | Pi = pi | |
85 | ||
86 | # Computations associated to X and Y | |
87 | for (r in 1:k) | |
88 | { | |
89 | for (mm in 1:m) | |
90 | Y2[,mm,r] = sqrt(gam[,r]) * Y[,mm] | |
91 | for (i in 1:n) | |
92 | X2[i,,r] = sqrt(gam[i,r]) * X[i,] | |
93 | for (mm in 1:m) | |
94 | ps2[,mm,r] = crossprod(X2[,,r],Y2[,mm,r]) | |
95 | for (j in 1:p) | |
96 | { | |
97 | for (s in 1:p) | |
98 | Gram2[j,s,r] = crossprod(X2[,j,r], X2[,s,r]) | |
99 | } | |
100 | } | |
101 | ||
102 | ######### | |
103 | #M step # | |
104 | ######### | |
105 | ||
106 | # For pi | |
107 | b = sapply( 1:k, function(r) sum(abs(phi[,,r])) ) | |
108 | gam2 = colSums(gam) | |
109 | a = sum(gam %*% log(pi)) | |
110 | ||
111 | # While the proportions are nonpositive | |
112 | kk = 0 | |
113 | pi2AllPositive = FALSE | |
114 | while (!pi2AllPositive) | |
115 | { | |
116 | pi2 = pi + 0.1^kk * ((1/n)*gam2 - pi) | |
117 | pi2AllPositive = all(pi2 >= 0) | |
118 | kk = kk+1 | |
119 | } | |
120 | ||
121 | # t(m) is the largest value in the grid O.1^k such that it is nonincreasing | |
122 | while( kk < 1000 && -a/n + lambda * sum(pi^gamma * b) < | |
123 | -sum(gam2 * log(pi2))/n + lambda * sum(pi2^gamma * b) ) | |
124 | { | |
125 | pi2 = pi + 0.1^kk * (1/n*gam2 - pi) | |
126 | kk = kk + 1 | |
127 | } | |
128 | t = 0.1^kk | |
129 | pi = (pi + t*(pi2-pi)) / sum(pi + t*(pi2-pi)) | |
130 | ||
131 | #For phi and rho | |
132 | for (r in 1:k) | |
133 | { | |
134 | for (mm in 1:m) | |
135 | { | |
136 | ps = 0 | |
137 | for (i in 1:n) | |
138 | ps = ps + Y2[i,mm,r] * sum(X2[i,,r] * phi[,mm,r]) | |
139 | nY2 = sum(Y2[,mm,r]^2) | |
140 | rho[mm,mm,r] = (ps+sqrt(ps^2+4*nY2*gam2[r])) / (2*nY2) | |
141 | } | |
142 | } | |
143 | ||
144 | for (r in 1:k) | |
145 | { | |
146 | for (j in 1:p) | |
147 | { | |
148 | for (mm in 1:m) | |
149 | { | |
150 | S[j,mm,r] = -rho[mm,mm,r]*ps2[j,mm,r] + sum(phi[-j,mm,r] * Gram2[j,-j,r]) | |
151 | if (abs(S[j,mm,r]) <= n*lambda*(pi[r]^gamma)) | |
152 | phi[j,mm,r]=0 | |
153 | else if(S[j,mm,r] > n*lambda*(pi[r]^gamma)) | |
154 | phi[j,mm,r] = (n*lambda*(pi[r]^gamma)-S[j,mm,r]) / Gram2[j,j,r] | |
155 | else | |
156 | phi[j,mm,r] = -(n*lambda*(pi[r]^gamma)+S[j,mm,r]) / Gram2[j,j,r] | |
157 | } | |
158 | } | |
159 | } | |
160 | ||
161 | ######## | |
162 | #E step# | |
163 | ######## | |
164 | ||
165 | # Precompute det(rho[,,r]) for r in 1...k | |
166 | detRho = sapply(1:k, function(r) det(rho[,,r])) | |
167 | gam1 = matrix(0, nrow = n, ncol = k) | |
168 | for (i in 1:n) | |
169 | { | |
170 | # Update gam[,] | |
171 | for (r in 1:k) | |
172 | { | |
173 | gam1[i,r] = pi[r]*exp(-0.5*sum((Y[i,]%*%rho[,,r]-X[i,]%*%phi[,,r])^2))*detRho[r] | |
174 | } | |
175 | } | |
176 | gam = gam1 / rowSums(gam1) | |
177 | sumLogLLH = sum(log(rowSums(gam)) - log((2*base::pi)^(m/2))) | |
178 | sumPen = sum(pi^gamma * b) | |
179 | last_llh = llh | |
180 | llh = -sumLogLLH/n + lambda*sumPen | |
181 | dist = ifelse( ite == 1, llh, (llh-last_llh) / (1+abs(llh)) ) | |
182 | Dist1 = max( (abs(phi-Phi)) / (1+abs(phi)) ) | |
183 | Dist2 = max( (abs(rho-Rho)) / (1+abs(rho)) ) | |
184 | Dist3 = max( (abs(pi-Pi)) / (1+abs(Pi)) ) | |
185 | dist2 = max(Dist1,Dist2,Dist3) | |
186 | ||
7ac88d64 | 187 | if (ite >= mini && (dist >= eps || dist2 >= sqrt(eps))) |
fb6e49cb | 188 | break |
189 | } | |
190 | ||
191 | list( "phi"=phi, "rho"=rho, "pi"=pi, "llh"=llh, "S"=S) | |
aa480ac1 | 192 | } |