X-Git-Url: https://git.auder.net/?p=valse.git;a=blobdiff_plain;f=src%2Ftest%2Fgenerate_test_data%2FEMGLLF.R;h=272eb6f60dc86e8dc226f36d5f219f7410a4bf6a;hp=7100f293b518dd43947267c5d0db9dc8b3399524;hb=f227455a1604906b255ef366d64c10a93e796983;hpb=d4304982ed571d445017bb7daa031dd9fb453b41 diff --git a/src/test/generate_test_data/EMGLLF.R b/src/test/generate_test_data/EMGLLF.R index 7100f29..272eb6f 100644 --- a/src/test/generate_test_data/EMGLLF.R +++ b/src/test/generate_test_data/EMGLLF.R @@ -27,7 +27,6 @@ EMGLLF = function(phiInit,rhoInit,piInit,gamInit,mini,maxi,gamma,lambda,X,Y,tau) ps = matrix(0, m,k) nY2 = matrix(0, m,k) ps1 = array(0, dim=c(n,m,k)) - nY21 = array(0, dim=c(n,m,k)) Gam = matrix(0, n,k) EPS = 1E-15 @@ -58,8 +57,8 @@ EMGLLF = function(phiInit,rhoInit,piInit,gamInit,mini,maxi,gamma,lambda,X,Y,tau) ########## #pour pi - for (r in 1:k) - b[r] = sum(abs(phi[,,r])) + for (r in 1:k){ + b[r] = sum(abs(phi[,,r]))} gam2 = colSums(gam) a = sum(gam %*% log(pi)) @@ -92,12 +91,9 @@ EMGLLF = function(phiInit,rhoInit,piInit,gamInit,mini,maxi,gamma,lambda,X,Y,tau) for (i in 1:n) { ps1[i,mm,r] = Y2[i,mm,r] * sum(X2[i,,r] * phi[,mm,r]) - nY21[i,mm,r] = Y2[i,mm,r]^2 } ps[mm,r] = sum(ps1[,mm,r]) - nY2[mm,r] = sum(nY21[,mm,r]) - -#TODO: debug rho computation + nY2[mm,r] = sum(Y2[,mm,r]^2) rho[mm,mm,r] = (ps[mm,r]+sqrt(ps[mm,r]^2+4*nY2[mm,r]*gam2[r])) / (2*nY2[mm,r]) } } @@ -107,9 +103,9 @@ EMGLLF = function(phiInit,rhoInit,piInit,gamInit,mini,maxi,gamma,lambda,X,Y,tau) { for (mm in 1:m) { - S[j,mm,r] = -rho[mm,mm,r]*ps2[j,mm,r] + - (if(j>1) sum(phi[1:(j-1),mm,r] * Gram2[j,1:(j-1),r]) else 0) + - (if(j1) sum(phi[1:(j-1),mm,r] * Gram2[j,1:(j-1),r]) else 0) + +# (if(j n*lambda*(pi[r]^gamma)) @@ -128,9 +124,8 @@ EMGLLF = function(phiInit,rhoInit,piInit,gamInit,mini,maxi,gamma,lambda,X,Y,tau) { #precompute sq norms to numerically adjust their values sqNorm2 = rep(0,k) - for (r in 1:k) - sqNorm2[r] = sum( (Y[i,]%*%rho[,,r]-X[i,]%*%phi[,,r])^2 ) - shift = 0.5*min(sqNorm2) + for (r in 1:k){ + sqNorm2[r] = sum( (Y[i,]%*%rho[,,r]-X[i,]%*%phi[,,r])^2 )} #compute Gam(:,:) using shift determined above sumLLF1 = 0.0; @@ -138,7 +133,7 @@ EMGLLF = function(phiInit,rhoInit,piInit,gamInit,mini,maxi,gamma,lambda,X,Y,tau) { #FIXME: numerical problems, because 0 < det(Rho[,,r] < EPS; what to do ?! # consequence: error in while() at line 77 - Gam[i,r] = pi[r] * exp(-0.5*sqNorm2[r] + shift) #* det(rho[,,r]) + Gam[i,r] = pi[r] * exp(-0.5*sqNorm2[r])* det(rho[,,r]) sumLLF1 = sumLLF1 + Gam[i,r] / (2*base::pi)^(m/2) } sumLogLLF2 = sumLogLLF2 + log(sumLLF1) @@ -161,6 +156,7 @@ EMGLLF = function(phiInit,rhoInit,piInit,gamInit,mini,maxi,gamma,lambda,X,Y,tau) ite = ite+1 } - - return(list("phi"=phi, "rho"=rho, "pi"=pi, "LLF"=LLF, "S"=S)) + + affec = apply(gam, 1,which.max) + return(list("phi"=phi, "rho"=rho, "pi"=pi, "LLF"=LLF, "S"=S, "affec" = affec )) }