fix memory leaks on EMGLLF, test OK for EMGrank
[valse.git] / src / test / generate_test_data / helpers / EMGLLF.m
1 function[phi,rho,pi,LLF,S] = EMGLLF(phiInit,rhoInit,piInit,gamInit,mini,maxi,gamma,lambda,X,Y,tau)
2
3 %Get matrices dimensions
4 PI = 4.0 * atan(1.0);
5 n = size(X, 1);
6 [p,m,k] = size(phiInit);
7
8 %Initialize outputs
9 phi = phiInit;
10 rho = rhoInit;
11 pi = piInit;
12 LLF = zeros(maxi,1);
13 S = zeros(p,m,k);
14
15 %Other local variables
16 %NOTE: variables order is always n,p,m,k
17 gam = gamInit;
18 Gram2 = zeros(p,p,k);
19 ps2 = zeros(p,m,k);
20 b = zeros(k,1);
21 pen = zeros(maxi,k);
22 X2 = zeros(n,p,k);
23 Y2 = zeros(n,m,k);
24 dist = 0;
25 dist2 = 0;
26 ite = 1;
27 pi2 = zeros(k,1);
28 ps = zeros(m,k);
29 nY2 = zeros(m,k);
30 ps1 = zeros(n,m,k);
31 nY21 = zeros(n,m,k);
32 Gam = zeros(n,k);
33 EPS = 1e-15;
34
35 while ite<=mini || (ite<=maxi && (dist>=tau || dist2>=sqrt(tau)))
36
37 Phi = phi;
38 Rho = rho;
39 Pi = pi;
40
41 %Calculs associés à Y et X
42 for r=1:k
43 for mm=1:m
44 Y2(:,mm,r) = sqrt(gam(:,r)) .* Y(:,mm);
45 end
46 for i=1:n
47 X2(i,:,r) = X(i,:) .* sqrt(gam(i,r));
48 end
49 for mm=1:m
50 ps2(:,mm,r) = transpose(X2(:,:,r)) * Y2(:,mm,r);
51 end
52 for j=1:p
53 for s=1:p
54 Gram2(j,s,r) = dot(X2(:,j,r), X2(:,s,r));
55 end
56 end
57 end
58
59 %%%%%%%%%%
60 %Etape M %
61 %%%%%%%%%%
62
63 %Pour pi
64 for r=1:k
65 b(r) = sum(sum(abs(phi(:,:,r))));
66 end
67 gam2 = sum(gam,1);
68 a = sum(gam*transpose(log(pi)));
69
70 %tant que les proportions sont negatives
71 kk = 0;
72 pi2AllPositive = false;
73 while ~pi2AllPositive
74 pi2 = pi + 0.1^kk * ((1/n)*gam2 - pi);
75 pi2AllPositive = true;
76 for r=1:k
77 if pi2(r) < 0
78 pi2AllPositive = false;
79 break;
80 end
81 end
82 kk = kk+1;
83 end
84
85 %t(m) la plus grande valeur dans la grille O.1^k tel que ce soit
86 %décroissante ou constante
87 while (-1/n*a+lambda*((pi.^gamma)*b))<(-1/n*gam2*transpose(log(pi2))+lambda.*(pi2.^gamma)*b) && kk<1000
88 pi2 = pi+0.1^kk*(1/n*gam2-pi);
89 kk = kk+1;
90 end
91 t = 0.1^(kk);
92 pi = (pi+t*(pi2-pi)) / sum(pi+t*(pi2-pi));
93
94 %Pour phi et rho
95 for r=1:k
96 for mm=1:m
97 for i=1:n
98 ps1(i,mm,r) = Y2(i,mm,r) * dot(X2(i,:,r), phi(:,mm,r));
99 nY21(i,mm,r) = (Y2(i,mm,r))^2;
100 end
101 ps(mm,r) = sum(ps1(:,mm,r));
102 nY2(mm,r) = sum(nY21(:,mm,r));
103 rho(mm,mm,r) = ((ps(mm,r)+sqrt(ps(mm,r)^2+4*nY2(mm,r)*(gam2(r))))/(2*nY2(mm,r)));
104 end
105 end
106 for r=1:k
107 for j=1:p
108 for mm=1:m
109 S(j,mm,r) = -rho(mm,mm,r)*ps2(j,mm,r) + dot(phi(1:j-1,mm,r),Gram2(j,1:j-1,r)')...
110 + dot(phi(j+1:p,mm,r),Gram2(j,j+1:p,r)');
111 if abs(S(j,mm,r)) <= n*lambda*(pi(r)^gamma)
112 phi(j,mm,r)=0;
113 else
114 if S(j,mm,r)> n*lambda*(pi(r)^gamma)
115 phi(j,mm,r)=(n*lambda*(pi(r)^gamma)-S(j,mm,r))/Gram2(j,j,r);
116 else
117 phi(j,mm,r)=-(n*lambda*(pi(r)^gamma)+S(j,mm,r))/Gram2(j,j,r);
118 end
119 end
120 end
121 end
122 end
123
124 %%%%%%%%%%
125 %Etape E %
126 %%%%%%%%%%
127
128 sumLogLLF2 = 0.0;
129 for i=1:n
130 %precompute dot products to numerically adjust their values
131 dotProducts = zeros(k,1);
132 for r=1:k
133 dotProducts(r)= (Y(i,:)*rho(:,:,r)-X(i,:)*phi(:,:,r)) * transpose(Y(i,:)*rho(:,:,r)-X(i,:)*phi(:,:,r));
134 end
135 shift = 0.5*min(dotProducts);
136
137 %compute Gam(:,:) using shift determined above
138 sumLLF1 = 0.0;
139 for r=1:k
140 Gam(i,r) = pi(r)*det(rho(:,:,r))*exp(-0.5*dotProducts(r) + shift);
141 sumLLF1 = sumLLF1 + Gam(i,r)/(2*PI)^(m/2);
142 end
143 sumLogLLF2 = sumLogLLF2 + log(sumLLF1);
144 sumGamI = sum(Gam(i,:));
145 if sumGamI > EPS
146 gam(i,:) = Gam(i,:) / sumGamI;
147 else
148 gam(i,:) = zeros(k,1);
149 end
150 end
151
152 sumPen = 0.0;
153 for r=1:k
154 sumPen = sumPen + pi(r).^gamma .* b(r);
155 end
156 LLF(ite) = -(1/n)*sumLogLLF2 + lambda*sumPen;
157
158 if ite == 1
159 dist = LLF(ite);
160 else
161 dist = (LLF(ite)-LLF(ite-1))/(1+abs(LLF(ite)));
162 end
163
164 Dist1=max(max(max((abs(phi-Phi))./(1+abs(phi)))));
165 Dist2=max(max(max((abs(rho-Rho))./(1+abs(rho)))));
166 Dist3=max(max((abs(pi-Pi))./(1+abs(Pi))));
167 dist2=max([Dist1,Dist2,Dist3]);
168
169 ite=ite+1;
170 end
171
172 pi = transpose(pi);
173
174 end