fix a few weird things
authorBenjamin Auder <benjamin.auder@somewhere>
Fri, 14 Apr 2017 22:29:28 +0000 (00:29 +0200)
committerBenjamin Auder <benjamin.auder@somewhere>
Fri, 14 Apr 2017 22:29:28 +0000 (00:29 +0200)
pkg/R/EMGLLF.R

index 0a279f0..b71a128 100644 (file)
@@ -45,21 +45,13 @@ EMGLLF <- function(phiInit, rhoInit, piInit, gamInit, mini, maxi, gamma, lambda,
 
 # R version - slow but easy to read
 .EMGLLF_R <- function(phiInit, rhoInit, piInit, gamInit, mini, maxi, gamma, lambda, 
-  X2, Y, eps)
+  X, Y, eps)
 {
-  # Matrix dimensions
+  # Matrix dimensions: NOTE: phiInit *must* be an array (even if p==1)
   n <- dim(Y)[1]
-  if (length(dim(phiInit)) == 2) {
-    p <- 1
-    m <- dim(phiInit)[1]
-    k <- dim(phiInit)[2]
-  } else {
-    p <- dim(phiInit)[1]
-    m <- dim(phiInit)[2]
-    k <- dim(phiInit)[3]
-  }
-  X <- matrix(nrow = n, ncol = p)
-  X[1:n, 1:p] <- X2
+  p <- dim(phiInit)[1]
+  m <- dim(phiInit)[2]
+  k <- dim(phiInit)[3]
 
   # Outputs
   phi <- array(NA, dim = c(p, m, k))
@@ -159,22 +151,22 @@ EMGLLF <- function(phiInit, rhoInit, piInit, gamInit, mini, maxi, gamma, lambda,
       }
     }
 
-    ######## E step#
+    ## E step
 
     # Precompute det(rho[,,r]) for r in 1...k
     detRho <- sapply(1:k, function(r) det(rho[, , r]))
-    gam1 <- matrix(0, nrow = n, ncol = k)
     for (i in 1:n)
     {
       # Update gam[,]
       for (r in 1:k)
       {
-        gam1[i, r] <- pi[r] * exp(-0.5
+        gam[i, r] <- pi[r] * exp(-0.5
           * sum((Y[i, ] %*% rho[, , r] - X[i, ] %*% phi[, , r])^2)) * detRho[r]
       }
     }
-    gam <- gam1 / rowSums(gam1)
-    sumLogLLH <- sum(log(rowSums(gam)) - log((2 * base::pi)^(m/2)))
+    norm_fact <- rowSums(gam)
+    gam <- gam / norm_fact
+    sumLogLLH <- sum(log(norm_fact) - log((2 * base::pi)^(m/2)))
     sumPen <- sum(pi^gamma * b)
     last_llh <- llh
     llh <- -sumLogLLH/n + lambda * sumPen
@@ -184,7 +176,7 @@ EMGLLF <- function(phiInit, rhoInit, piInit, gamInit, mini, maxi, gamma, lambda,
     Dist3 <- max((abs(pi - Pi))/(1 + abs(Pi)))
     dist2 <- max(Dist1, Dist2, Dist3)
 
-    if (ite >= mini && (dist >= eps || dist2 >= sqrt(eps))) 
+    if (ite >= mini && (dist >= eps || dist2 >= sqrt(eps)))
       break
   }