fix test; EMGLLF.c != EMGLLF.R now...
[valse.git] / pkg / R / EMGLLF_R.R
index 362d0dc..09ae2e3 100644 (file)
@@ -17,7 +17,6 @@ EMGLLF_R = function(phiInit,rhoInit,piInit,gamInit,mini,maxi,gamma,lambda,X,Y,ta
        gam = gamInit
        Gram2 = array(0, dim=c(p,p,k))
        ps2 = array(0, dim=c(p,m,k))
-       b = rep(0, k)
        X2 = array(0, dim=c(n,p,k))
        Y2 = array(0, dim=c(n,m,k))
        EPS = 1e-15
@@ -108,34 +107,34 @@ EMGLLF_R = function(phiInit,rhoInit,piInit,gamInit,mini,maxi,gamma,lambda,X,Y,ta
                #Etape E #
                ##########
 
-               sumLogLLH2 = 0
+               # Precompute det(rho[,,r]) for r in 1...k
+               detRho = sapply(1:k, function(r) det(rho[,,r]))
+
+               sumLogLLH = 0
                for (i in 1:n)
                {
                        # Update gam[,]
-                       sumLLH1 = 0
                        sumGamI = 0
                        for (r in 1:k)
                        {
-                               gam[i,r] = pi[r] * exp(-0.5*sum( (Y[i,]%*%rho[,,r]-X[i,]%*%phi[,,r])^2 ))
-                                       * det(rho[,,r])
-                               sumLLH1 = sumLLH1 + gam[i,r] / (2*base::pi)^(m/2)
+                               gam[i,r] = pi[r]*exp(-0.5*sum((Y[i,]%*%rho[,,r]-X[i,]%*%phi[,,r])^2))*detRho[r]
                                sumGamI = sumGamI + gam[i,r]
                        }
-                       sumLogLLH2 = sumLogLLH2 + log(sumLLH1)
-                       if(sumGamI > EPS) #else: gam[i,] is already ~=0
+                       sumLogLLH = sumLogLLH + log(sumGamI) - log((2*base::pi)^(m/2))
+                       if (sumGamI > EPS) #else: gam[i,] is already ~=0
                                gam[i,] = gam[i,] / sumGamI
                }
 
                sumPen = sum(pi^gamma * b)
                last_llh = llh
-               llh = -sumLogLLH2/n + lambda*sumPen
+               llh = -sumLogLLH/n + lambda*sumPen
                dist = ifelse( ite == 1, llh, (llh-last_llh) / (1+abs(llh)) )
                Dist1 = max( (abs(phi-Phi)) / (1+abs(phi)) )
                Dist2 = max( (abs(rho-Rho)) / (1+abs(rho)) )
                Dist3 = max( (abs(pi-Pi)) / (1+abs(Pi)) )
                dist2 = max(Dist1,Dist2,Dist3)
 
-               if (ite>=mini && (dist>= tau || dist2 >= sqrt(tau)))
+               if (ite >= mini && (dist >= tau || dist2 >= sqrt(tau)))
                        break
        }