From: Benjamin Auder Date: Fri, 14 Apr 2017 22:29:28 +0000 (+0200) Subject: fix a few weird things X-Git-Url: https://git.auder.net/%7B%7B%20path%28%27mixstore_store_usecase_upsert%27%2C%20%7B%20pkgid:%20pkg.id%20%7D%29%20%7D%7D?a=commitdiff_plain;h=c8baa02224d4ee13889e964b56b9a67280c1816f;p=valse.git fix a few weird things --- diff --git a/pkg/R/EMGLLF.R b/pkg/R/EMGLLF.R index 0a279f0..b71a128 100644 --- a/pkg/R/EMGLLF.R +++ b/pkg/R/EMGLLF.R @@ -45,21 +45,13 @@ EMGLLF <- function(phiInit, rhoInit, piInit, gamInit, mini, maxi, gamma, lambda, # R version - slow but easy to read .EMGLLF_R <- function(phiInit, rhoInit, piInit, gamInit, mini, maxi, gamma, lambda, - X2, Y, eps) + X, Y, eps) { - # Matrix dimensions + # Matrix dimensions: NOTE: phiInit *must* be an array (even if p==1) n <- dim(Y)[1] - if (length(dim(phiInit)) == 2) { - p <- 1 - m <- dim(phiInit)[1] - k <- dim(phiInit)[2] - } else { - p <- dim(phiInit)[1] - m <- dim(phiInit)[2] - k <- dim(phiInit)[3] - } - X <- matrix(nrow = n, ncol = p) - X[1:n, 1:p] <- X2 + p <- dim(phiInit)[1] + m <- dim(phiInit)[2] + k <- dim(phiInit)[3] # Outputs phi <- array(NA, dim = c(p, m, k)) @@ -159,22 +151,22 @@ EMGLLF <- function(phiInit, rhoInit, piInit, gamInit, mini, maxi, gamma, lambda, } } - ######## E step# + ## E step # Precompute det(rho[,,r]) for r in 1...k detRho <- sapply(1:k, function(r) det(rho[, , r])) - gam1 <- matrix(0, nrow = n, ncol = k) for (i in 1:n) { # Update gam[,] for (r in 1:k) { - gam1[i, r] <- pi[r] * exp(-0.5 + gam[i, r] <- pi[r] * exp(-0.5 * sum((Y[i, ] %*% rho[, , r] - X[i, ] %*% phi[, , r])^2)) * detRho[r] } } - gam <- gam1 / rowSums(gam1) - sumLogLLH <- sum(log(rowSums(gam)) - log((2 * base::pi)^(m/2))) + norm_fact <- rowSums(gam) + gam <- gam / norm_fact + sumLogLLH <- sum(log(norm_fact) - log((2 * base::pi)^(m/2))) sumPen <- sum(pi^gamma * b) last_llh <- llh llh <- -sumLogLLH/n + lambda * sumPen @@ -184,7 +176,7 @@ EMGLLF <- function(phiInit, rhoInit, piInit, gamInit, mini, maxi, gamma, lambda, Dist3 <- max((abs(pi - Pi))/(1 + abs(Pi))) dist2 <- max(Dist1, Dist2, Dist3) - if (ite >= mini && (dist >= eps || dist2 >= sqrt(eps))) + if (ite >= mini && (dist >= eps || dist2 >= sqrt(eps))) break }