initial commit
[valse.git] / ProcLassoMLE / EMGLLF.m
... / ...
CommitLineData
1function[phi,rho,pi,LLF,S] = EMGLLF(phiInit,rhoInit,piInit,gamInit,mini,maxi,gamma,lambda,X,Y,tau)\r
2\r
3 %Get matrices dimensions\r
4 PI = 4.0 * atan(1.0);\r
5 n = size(X, 1);\r
6 [p,m,k] = size(phiInit);\r
7 \r
8 %Initialize outputs\r
9 phi = phiInit;\r
10 rho = rhoInit;\r
11 pi = piInit;\r
12 LLF = zeros(maxi,1);\r
13 S = zeros(p,m,k);\r
14 \r
15 %Other local variables\r
16 %NOTE: variables order is always n,p,m,k\r
17 gam = gamInit;\r
18 Gram2 = zeros(p,p,k);\r
19 ps2 = zeros(p,m,k);\r
20 b = zeros(k,1);\r
21 pen = zeros(maxi,k);\r
22 X2 = zeros(n,p,k);\r
23 Y2 = zeros(n,m,k);\r
24 dist = 0;\r
25 dist2 = 0;\r
26 ite = 1;\r
27 pi2 = zeros(k,1);\r
28 ps = zeros(m,k);\r
29 nY2 = zeros(m,k);\r
30 ps1 = zeros(n,m,k);\r
31 nY21 = zeros(n,m,k);\r
32 Gam = zeros(n,k);\r
33 EPS = 1e-15;\r
34\r
35 while ite<=mini || (ite<=maxi && (dist>=tau || dist2>=sqrt(tau)))\r
36 \r
37 Phi = phi;\r
38 Rho = rho;\r
39 Pi = pi;\r
40 \r
41 %Calculs associés à Y et X\r
42 for r=1:k\r
43 for mm=1:m\r
44 Y2(:,mm,r) = sqrt(gam(:,r)) .* Y(:,mm);\r
45 end\r
46 for i=1:n\r
47 X2(i,:,r) = X(i,:) .* sqrt(gam(i,r));\r
48 end\r
49 for mm=1:m\r
50 ps2(:,mm,r) = transpose(X2(:,:,r)) * Y2(:,mm,r);\r
51 end\r
52 for j=1:p\r
53 for s=1:p\r
54 Gram2(j,s,r) = dot(X2(:,j,r), X2(:,s,r));\r
55 end\r
56 end\r
57 end\r
58\r
59 %%%%%%%%%%\r
60 %Etape M %\r
61 %%%%%%%%%%\r
62 \r
63 %Pour pi\r
64 for r=1:k\r
65 b(r) = sum(sum(abs(phi(:,:,r))));\r
66 end\r
67 gam2 = sum(gam,1);\r
68 a = sum(gam*transpose(log(pi)));\r
69 \r
70 %tant que les proportions sont negatives\r
71 kk = 0;\r
72 pi2AllPositive = false;\r
73 while ~pi2AllPositive\r
74 pi2 = pi + 0.1^kk * ((1/n)*gam2 - pi);\r
75 pi2AllPositive = true;\r
76 for r=1:k\r
77 if pi2(r) < 0\r
78 pi2AllPositive = false;\r
79 break;\r
80 end\r
81 end\r
82 kk = kk+1;\r
83 end\r
84 \r
85 %t(m) la plus grande valeur dans la grille O.1^k tel que ce soit\r
86 %décroissante ou constante\r
87 while (-1/n*a+lambda*((pi.^gamma)*b))<(-1/n*gam2*transpose(log(pi2))+lambda.*(pi2.^gamma)*b) && kk<1000\r
88 pi2 = pi+0.1^kk*(1/n*gam2-pi);\r
89 kk = kk+1;\r
90 end\r
91 t = 0.1^(kk);\r
92 pi = (pi+t*(pi2-pi)) / sum(pi+t*(pi2-pi));\r
93\r
94 %Pour phi et rho\r
95 for r=1:k\r
96 for mm=1:m\r
97 for i=1:n\r
98 ps1(i,mm,r) = Y2(i,mm,r) * dot(X2(i,:,r), phi(:,mm,r));\r
99 nY21(i,mm,r) = (Y2(i,mm,r))^2;\r
100 end\r
101 ps(mm,r) = sum(ps1(:,mm,r));\r
102 nY2(mm,r) = sum(nY21(:,mm,r));\r
103 rho(mm,mm,r) = ((ps(mm,r)+sqrt(ps(mm,r)^2+4*nY2(mm,r)*(gam2(r))))/(2*nY2(mm,r)));\r
104 end\r
105 end\r
106 for r=1:k\r
107 for j=1:p\r
108 for mm=1:m \r
109 S(j,mm,r) = -rho(mm,mm,r)*ps2(j,mm,r) + dot(phi(1:j-1,mm,r),Gram2(j,1:j-1,r)')...\r
110 + dot(phi(j+1:p,mm,r),Gram2(j,j+1:p,r)');\r
111 if abs(S(j,mm,r)) <= n*lambda*(pi(r)^gamma)\r
112 phi(j,mm,r)=0;\r
113 else \r
114 if S(j,mm,r)> n*lambda*(pi(r)^gamma)\r
115 phi(j,mm,r)=(n*lambda*(pi(r)^gamma)-S(j,mm,r))/Gram2(j,j,r);\r
116 else\r
117 phi(j,mm,r)=-(n*lambda*(pi(r)^gamma)+S(j,mm,r))/Gram2(j,j,r);\r
118 end\r
119 end\r
120 end\r
121 end\r
122 end\r
123\r
124 %%%%%%%%%%\r
125 %Etape E %\r
126 %%%%%%%%%%\r
127 \r
128 sumLogLLF2 = 0.0;\r
129 for i=1:n\r
130 %precompute dot products to numerically adjust their values\r
131 dotProducts = zeros(k,1);\r
132 for r=1:k\r
133 dotProducts(r)= (Y(i,:)*rho(:,:,r)-X(i,:)*phi(:,:,r)) * transpose(Y(i,:)*rho(:,:,r)-X(i,:)*phi(:,:,r));\r
134 end\r
135 shift = 0.5*min(dotProducts);\r
136 \r
137 %compute Gam(:,:) using shift determined above\r
138 sumLLF1 = 0.0;\r
139 for r=1:k\r
140 Gam(i,r) = pi(r)*det(rho(:,:,r))*exp(-0.5*dotProducts(r) + shift);\r
141 sumLLF1 = sumLLF1 + Gam(i,r)/(2*PI)^(m/2);\r
142 end\r
143 sumLogLLF2 = sumLogLLF2 + log(sumLLF1);\r
144 sumGamI = sum(Gam(i,:));\r
145 if sumGamI > EPS\r
146 gam(i,:) = Gam(i,:) / sumGamI;\r
147 else\r
148 gam(i,:) = zeros(k,1);\r
149 end\r
150 end\r
151 \r
152 sumPen = 0.0;\r
153 for r=1:k\r
154 sumPen = sumPen + pi(r).^gamma .* b(r);\r
155 end\r
156 LLF(ite) = -(1/n)*sumLogLLF2 + lambda*sumPen;\r
157 \r
158 if ite == 1\r
159 dist = LLF(ite);\r
160 else \r
161 dist = (LLF(ite)-LLF(ite-1))/(1+abs(LLF(ite)));\r
162 end\r
163 \r
164 Dist1=max(max(max((abs(phi-Phi))./(1+abs(phi)))));\r
165 Dist2=max(max(max((abs(rho-Rho))./(1+abs(rho)))));\r
166 Dist3=max(max((abs(pi-Pi))./(1+abs(Pi))));\r
167 dist2=max([Dist1,Dist2,Dist3]); \r
168 \r
169 ite=ite+1;\r
170 end\r
171\r
172 pi = transpose(pi);\r
173\r
174end\r