Commit | Line | Data |
---|---|---|
2279a641 BA |
1 | #' constructionModelesLassoMLE |
2 | #' | |
5965d116 | 3 | #' Construct a collection of models with the Lasso-MLE procedure. |
4 | #' | |
43d76c49 | 5 | #' @param phiInit an initialization for phi, get by initSmallEM.R |
6 | #' @param rhoInit an initialization for rho, get by initSmallEM.R | |
7 | #' @param piInit an initialization for pi, get by initSmallEM.R | |
8 | #' @param gamInit an initialization for gam, get by initSmallEM.R | |
9 | #' @param mini integer, minimum number of iterations in the EM algorithm, by default = 10 | |
10 | #' @param maxi integer, maximum number of iterations in the EM algorithm, by default = 100 | |
11 | #' @param gamma integer for the power in the penaly, by default = 1 | |
12 | #' @param X matrix of covariates (of size n*p) | |
13 | #' @param Y matrix of responses (of size n*m) | |
14 | #' @param eps real, threshold to say the EM algorithm converges, by default = 1e-4 | |
15 | #' @param S output of selectVariables.R | |
16 | #' @param ncores Number of cores, by default = 3 | |
17 | #' @param fast TRUE to use compiled C code, FALSE for R code only | |
18 | #' @param verbose TRUE to show some execution traces | |
19 | #' | |
20 | #' @return a list with several models, defined by phi, rho, pi, llh | |
2279a641 | 21 | #' |
43d76c49 | 22 | #' @export |
23 | constructionModelesLassoMLE = function( phiInit, rhoInit, piInit, gamInit, mini, maxi,gamma, X, Y, | |
24 | eps, S, ncores=3, fast=TRUE, verbose=FALSE) | |
46a2e676 | 25 | { |
08f4604c BA |
26 | if (ncores > 1) |
27 | { | |
b9b0b72a | 28 | cl = parallel::makeCluster(ncores, outfile='') |
08f4604c | 29 | parallel::clusterExport( cl, envir=environment(), |
43d76c49 | 30 | varlist=c("phiInit","rhoInit","gamInit","mini","maxi","gamma","X","Y","eps", |
31 | "S","ncores","fast","verbose") ) | |
08f4604c BA |
32 | } |
33 | ||
34 | # Individual model computation | |
35 | computeAtLambda <- function(lambda) | |
36 | { | |
37 | if (ncores > 1) | |
38 | require("valse") #nodes start with an empty environment | |
39 | ||
40 | if (verbose) | |
41 | print(paste("Computations for lambda=",lambda)) | |
42 | ||
43 | n = dim(X)[1] | |
44 | p = dim(phiInit)[1] | |
45 | m = dim(phiInit)[2] | |
46 | k = dim(phiInit)[3] | |
08f4604c BA |
47 | sel.lambda = S[[lambda]]$selected |
48 | # col.sel = which(colSums(sel.lambda)!=0) #if boolean matrix | |
49 | col.sel <- which( sapply(sel.lambda,length) > 0 ) #if list of selected vars | |
08f4604c BA |
50 | if (length(col.sel) == 0) |
51 | return (NULL) | |
52 | ||
53 | # lambda == 0 because we compute the EMV: no penalization here | |
54 | res = EMGLLF(phiInit[col.sel,,],rhoInit,piInit,gamInit,mini,maxi,gamma,0, | |
43d76c49 | 55 | X[,col.sel], Y, eps, fast) |
08f4604c BA |
56 | |
57 | # Eval dimension from the result + selected | |
58 | phiLambda2 = res$phi | |
59 | rhoLambda = res$rho | |
60 | piLambda = res$pi | |
61 | phiLambda = array(0, dim = c(p,m,k)) | |
62 | for (j in seq_along(col.sel)) | |
fb6e49cb | 63 | phiLambda[col.sel[j],sel.lambda[[j]],] = phiLambda2[j,sel.lambda[[j]],] |
08f4604c BA |
64 | dimension = length(unlist(sel.lambda)) |
65 | ||
66 | # Computation of the loglikelihood | |
67 | densite = vector("double",n) | |
68 | for (r in 1:k) | |
69 | { | |
fb6e49cb | 70 | if (length(col.sel)==1){ |
71 | delta = (Y%*%rhoLambda[,,r] - (X[, col.sel]%*%t(phiLambda[col.sel,,r]))) | |
72 | } else delta = (Y%*%rhoLambda[,,r] - (X[, col.sel]%*%phiLambda[col.sel,,r])) | |
08f4604c | 73 | densite = densite + piLambda[r] * |
bb11d873 | 74 | det(rhoLambda[,,r])/(sqrt(2*base::pi))^m * exp(-diag(tcrossprod(delta))/2.0) |
08f4604c | 75 | } |
bb11d873 | 76 | llhLambda = c( sum(log(densite)), (dimension+m+1)*k-1 ) |
08f4604c BA |
77 | list("phi"= phiLambda, "rho"= rhoLambda, "pi"= piLambda, "llh" = llhLambda) |
78 | } | |
79 | ||
80 | # For each lambda, computation of the parameters | |
81 | out = | |
82 | if (ncores > 1) | |
83 | parLapply(cl, 1:length(S), computeAtLambda) | |
b9b0b72a BA |
84 | else |
85 | lapply(1:length(S), computeAtLambda) | |
08f4604c BA |
86 | |
87 | if (ncores > 1) | |
88 | parallel::stopCluster(cl) | |
89 | ||
90 | out | |
c3bc4705 | 91 | } |