Commit | Line | Data |
---|---|---|
567a7c38 BA |
1 | EMGLLF_R = function(phiInit,rhoInit,piInit,gamInit,mini,maxi,gamma,lambda,X,Y,tau) |
2 | { | |
b6bb5332 BA |
3 | # Matrix dimensions |
4 | n = dim(X)[1] | |
5 | p = dim(phiInit)[1] | |
6 | m = dim(phiInit)[2] | |
7 | k = dim(phiInit)[3] | |
cbfc356e | 8 | |
b6bb5332 BA |
9 | # Outputs |
10 | phi = phiInit | |
11 | rho = rhoInit | |
12 | pi = piInit | |
13 | llh = -Inf | |
14 | S = array(0, dim=c(p,m,k)) | |
cbfc356e BA |
15 | |
16 | # Algorithm variables | |
b6bb5332 BA |
17 | gam = gamInit |
18 | Gram2 = array(0, dim=c(p,p,k)) | |
19 | ps2 = array(0, dim=c(p,m,k)) | |
b6bb5332 BA |
20 | X2 = array(0, dim=c(n,p,k)) |
21 | Y2 = array(0, dim=c(n,m,k)) | |
22 | EPS = 1e-15 | |
cbfc356e | 23 | |
b6bb5332 | 24 | for (ite in 1:maxi) |
567a7c38 | 25 | { |
cbfc356e | 26 | # Remember last pi,rho,phi values for exit condition in the end of loop |
b6bb5332 BA |
27 | Phi = phi |
28 | Rho = rho | |
29 | Pi = pi | |
567a7c38 | 30 | |
b6bb5332 BA |
31 | # Calcul associé à Y et X |
32 | for (r in 1:k) | |
567a7c38 | 33 | { |
b6bb5332 BA |
34 | for (mm in 1:m) |
35 | Y2[,mm,r] = sqrt(gam[,r]) * Y[,mm] | |
36 | for (i in 1:n) | |
37 | X2[i,,r] = sqrt(gam[i,r]) * X[i,] | |
38 | for (mm in 1:m) | |
39 | ps2[,mm,r] = crossprod(X2[,,r],Y2[,mm,r]) | |
40 | for (j in 1:p) | |
567a7c38 | 41 | { |
b6bb5332 BA |
42 | for (s in 1:p) |
43 | Gram2[j,s,r] = crossprod(X2[,j,r], X2[,s,r]) | |
44 | } | |
45 | } | |
567a7c38 | 46 | |
b6bb5332 BA |
47 | ########## |
48 | #Etape M # | |
49 | ########## | |
cbfc356e | 50 | |
b6bb5332 BA |
51 | # Pour pi |
52 | b = sapply( 1:k, function(r) sum(abs(phi[,,r])) ) | |
53 | gam2 = colSums(gam) | |
54 | a = sum(gam %*% log(pi)) | |
567a7c38 | 55 | |
b6bb5332 BA |
56 | # Tant que les props sont negatives |
57 | kk = 0 | |
58 | pi2AllPositive = FALSE | |
59 | while (!pi2AllPositive) | |
567a7c38 | 60 | { |
b6bb5332 BA |
61 | pi2 = pi + 0.1^kk * ((1/n)*gam2 - pi) |
62 | pi2AllPositive = all(pi2 >= 0) | |
63 | kk = kk+1 | |
64 | } | |
567a7c38 | 65 | |
b6bb5332 BA |
66 | # t(m) la plus grande valeur dans la grille O.1^k tel que ce soit décroissante ou constante |
67 | while( kk < 1000 && -a/n + lambda * sum(pi^gamma * b) < | |
567a7c38 BA |
68 | -sum(gam2 * log(pi2))/n + lambda * sum(pi2^gamma * b) ) |
69 | { | |
b6bb5332 BA |
70 | pi2 = pi + 0.1^kk * (1/n*gam2 - pi) |
71 | kk = kk + 1 | |
72 | } | |
73 | t = 0.1^kk | |
74 | pi = (pi + t*(pi2-pi)) / sum(pi + t*(pi2-pi)) | |
567a7c38 | 75 | |
b6bb5332 BA |
76 | #Pour phi et rho |
77 | for (r in 1:k) | |
567a7c38 | 78 | { |
b6bb5332 | 79 | for (mm in 1:m) |
567a7c38 | 80 | { |
cbfc356e | 81 | ps = 0 |
b6bb5332 BA |
82 | for (i in 1:n) |
83 | ps = ps + Y2[i,mm,r] * sum(X2[i,,r] * phi[,mm,r]) | |
84 | nY2 = sum(Y2[,mm,r]^2) | |
85 | rho[mm,mm,r] = (ps+sqrt(ps^2+4*nY2*gam2[r])) / (2*nY2) | |
567a7c38 | 86 | } |
b6bb5332 | 87 | } |
567a7c38 | 88 | |
b6bb5332 | 89 | for (r in 1:k) |
567a7c38 | 90 | { |
b6bb5332 | 91 | for (j in 1:p) |
567a7c38 | 92 | { |
b6bb5332 | 93 | for (mm in 1:m) |
567a7c38 | 94 | { |
b6bb5332 | 95 | S[j,mm,r] = -rho[mm,mm,r]*ps2[j,mm,r] + sum(phi[-j,mm,r] * Gram2[j,-j,r]) |
567a7c38 | 96 | if (abs(S[j,mm,r]) <= n*lambda*(pi[r]^gamma)) |
b6bb5332 BA |
97 | phi[j,mm,r]=0 |
98 | else if(S[j,mm,r] > n*lambda*(pi[r]^gamma)) | |
99 | phi[j,mm,r] = (n*lambda*(pi[r]^gamma)-S[j,mm,r]) / Gram2[j,j,r] | |
100 | else | |
101 | phi[j,mm,r] = -(n*lambda*(pi[r]^gamma)+S[j,mm,r]) / Gram2[j,j,r] | |
102 | } | |
103 | } | |
104 | } | |
567a7c38 | 105 | |
b6bb5332 BA |
106 | ########## |
107 | #Etape E # | |
108 | ########## | |
567a7c38 | 109 | |
d6d71630 BA |
110 | # Precompute det(rho[,,r]) for r in 1...k |
111 | detRho = sapply(1:k, function(r) det(rho[,,r])) | |
112 | ||
113 | sumLogLLH = 0 | |
b6bb5332 | 114 | for (i in 1:n) |
567a7c38 | 115 | { |
b6bb5332 | 116 | # Update gam[,] |
cbfc356e | 117 | sumGamI = 0 |
b6bb5332 | 118 | for (r in 1:k) |
567a7c38 | 119 | { |
d6d71630 | 120 | gam[i,r] = pi[r]*exp(-0.5*sum((Y[i,]%*%rho[,,r]-X[i,]%*%phi[,,r])^2))*detRho[r] |
cbfc356e | 121 | sumGamI = sumGamI + gam[i,r] |
b6bb5332 | 122 | } |
d6d71630 BA |
123 | sumLogLLH = sumLogLLH + log(sumGamI) - log((2*base::pi)^(m/2)) |
124 | if (sumGamI > EPS) #else: gam[i,] is already ~=0 | |
b6bb5332 BA |
125 | gam[i,] = gam[i,] / sumGamI |
126 | } | |
567a7c38 | 127 | |
b6bb5332 | 128 | sumPen = sum(pi^gamma * b) |
cbfc356e | 129 | last_llh = llh |
d6d71630 | 130 | llh = -sumLogLLH/n + lambda*sumPen |
b6bb5332 BA |
131 | dist = ifelse( ite == 1, llh, (llh-last_llh) / (1+abs(llh)) ) |
132 | Dist1 = max( (abs(phi-Phi)) / (1+abs(phi)) ) | |
133 | Dist2 = max( (abs(rho-Rho)) / (1+abs(rho)) ) | |
134 | Dist3 = max( (abs(pi-Pi)) / (1+abs(Pi)) ) | |
135 | dist2 = max(Dist1,Dist2,Dist3) | |
567a7c38 | 136 | |
d6d71630 | 137 | if (ite >= mini && (dist >= tau || dist2 >= sqrt(tau))) |
cbfc356e | 138 | break |
b6bb5332 | 139 | } |
b9b0b72a | 140 | |
b6bb5332 BA |
141 | affec = apply(gam, 1, which.max) |
142 | list( "phi"=phi, "rho"=rho, "pi"=pi, "llh"=llh, "S"=S, "affec"=affec ) | |
567a7c38 | 143 | } |