Adjustments + bugs fixing
[morpheus.git] / pkg / R / computeMu.R
CommitLineData
cbd88fe5
BA
1#' Compute μ
2#'
3#' Estimate the normalized columns μ of the β matrix parameter in a mixture of
4#' logistic regressions models, with a spectral method described in the package vignette.
5#'
ab35f610
BA
6#' @name computeMu
7#'
cbd88fe5
BA
8#' @param X Matrix of input data (size nxd)
9#' @param Y Vector of binary outputs (size n)
10#' @param optargs List of optional argument:
11#' \itemize{
12#' \item 'jd_method', joint diagonalization method from the package jointDiag:
13#' 'uwedge' (default) or 'jedi'.
14#' \item 'jd_nvects', number of random vectors for joint-diagonalization
15#' (or 0 for p=d, canonical basis by default)
16#' \item 'M', moments of order 1,2,3: will be computed if not provided.
d294ece1 17#' \item 'K', number of populations (estimated with rank of M2 if not given)
cbd88fe5
BA
18#' }
19#'
20#' @return The estimated normalized parameters as columns of a matrix μ of size dxK
21#'
22#' @seealso \code{multiRun} to estimate statistics based on μ,
23#' and \code{generateSampleIO} for I/O random generation.
24#'
25#' @examples
ab35f610
BA
26#' io <- generateSampleIO(10000, 1/2, matrix(c(1,0,0,1),ncol=2), c(0,0), "probit")
27#' μ <- computeMu(io$X, io$Y, list(K=2)) #or just X and Y for estimated K
2b3a6af5 28#'
cbd88fe5 29#' @export
ab35f610 30computeMu <- function(X, Y, optargs=list())
cbd88fe5 31{
6dd5c2ac
BA
32 if (!is.matrix(X) || !is.numeric(X) || any(is.na(X)))
33 stop("X: real matrix, no NA")
ab35f610
BA
34 n <- nrow(X)
35 d <- ncol(X)
6dd5c2ac
BA
36 if (!is.numeric(Y) || length(Y)!=n || any(Y!=0 & Y!=1))
37 stop("Y: vector of 0 and 1, size nrow(X), no NA")
38 if (!is.list(optargs))
39 stop("optargs: list")
cbd88fe5 40
6dd5c2ac 41 # Step 0: Obtain the empirically estimated moments tensor, estimate also K
ab35f610
BA
42 M <- if (is.null(optargs$M)) computeMoments(X,Y) else optargs$M
43 K <- optargs$K
6dd5c2ac
BA
44 if (is.null(K))
45 {
46 # TODO: improve this basic heuristic
ab35f610 47 Σ <- svd(M[[2]])$d
6dd5c2ac
BA
48 large_ratio <- ( abs(Σ[-d] / Σ[-1]) > 3 )
49 K <- if (any(large_ratio)) max(2, which.min(large_ratio)) else d
50 }
4b2f17bb
BA
51 else if (K > d)
52 stop("K: integer >= 2, <= d")
cbd88fe5 53
6dd5c2ac 54 # Step 1: generate a family of d matrices to joint-diagonalize to increase robustness
ab35f610
BA
55 d <- ncol(X)
56 fixed_design <- FALSE
57 jd_nvects <- ifelse(!is.null(optargs$jd_nvects), optargs$jd_nvects, 0)
6dd5c2ac
BA
58 if (jd_nvects == 0)
59 {
ab35f610
BA
60 jd_nvects <- d
61 fixed_design <- TRUE
6dd5c2ac 62 }
ab35f610 63 M2_t <- array(dim=c(d,d,jd_nvects))
6dd5c2ac
BA
64 for (i in seq_len(jd_nvects))
65 {
ab35f610
BA
66 rho <- if (fixed_design) c(rep(0,i-1),1,rep(0,d-i)) else normalize( rnorm(d) )
67 M2_t[,,i] <- .T_I_I_w(M[[3]],rho)
6dd5c2ac 68 }
cbd88fe5 69
6dd5c2ac 70 # Step 2: obtain factors u_i (and their inverse) from the joint diagonalisation of M2_t
ab35f610
BA
71 jd_method <- ifelse(!is.null(optargs$jd_method), optargs$jd_method, "uwedge")
72 V <-
6dd5c2ac 73 if (jd_nvects > 1) {
0f5fbd13 74 # NOTE: increasing itermax does not help to converge, thus we suppress warnings
6dd5c2ac 75 suppressWarnings({jd = jointDiag::ajd(M2_t, method=jd_method)})
6dd5c2ac
BA
76 if (jd_method=="uwedge") jd$B else MASS::ginv(jd$A)
77 }
78 else
79 eigen(M2_t[,,1])$vectors
cbd88fe5 80
6dd5c2ac 81 # Step 3: obtain final factors from joint diagonalisation of T(I,I,u_i)
ab35f610 82 M2_t <- array(dim=c(d,d,K))
6dd5c2ac 83 for (i in seq_len(K))
ab35f610 84 M2_t[,,i] <- .T_I_I_w(M[[3]],V[,i])
6dd5c2ac 85 suppressWarnings({jd = jointDiag::ajd(M2_t, method=jd_method)})
ab35f610
BA
86 U <- if (jd_method=="uwedge") MASS::ginv(jd$B) else jd$A
87 μ <- normalize(U[,1:K])
cbd88fe5 88
6dd5c2ac
BA
89 # M1 also writes M1 = sum_k coeff_k * μ_k, where coeff_k >= 0
90 # ==> search decomposition of vector M1 onto the (truncated) basis μ (of size dxK)
91 # This is a linear system μ %*% C = M1 with C of size K ==> C = psinv(μ) %*% M1
ab35f610
BA
92 C <- MASS::ginv(μ) %*% M[[1]]
93 μ[,C < 0] <- - μ[,C < 0]
6dd5c2ac 94 μ
cbd88fe5 95}