| 1 | #' Compute μ |
| 2 | #' |
| 3 | #' Estimate the normalized columns μ of the β matrix parameter in a mixture of |
| 4 | #' logistic regressions models, with a spectral method described in the package vignette. |
| 5 | #' |
| 6 | #' @name computeMu |
| 7 | #' |
| 8 | #' @param X Matrix of input data (size nxd) |
| 9 | #' @param Y Vector of binary outputs (size n) |
| 10 | #' @param optargs List of optional argument: |
| 11 | #' \itemize{ |
| 12 | #' \item 'jd_method', joint diagonalization method from the package jointDiag: |
| 13 | #' 'uwedge' (default) or 'jedi'. |
| 14 | #' \item 'jd_nvects', number of random vectors for joint-diagonalization |
| 15 | #' (or 0 for p=d, canonical basis by default) |
| 16 | #' \item 'M', moments of order 1,2,3: will be computed if not provided. |
| 17 | #' \item 'K', number of populations (estimated with rank of M2 if not given) |
| 18 | #' } |
| 19 | #' |
| 20 | #' @return The estimated normalized parameters as columns of a matrix μ of size dxK |
| 21 | #' |
| 22 | #' @seealso \code{multiRun} to estimate statistics based on μ, |
| 23 | #' and \code{generateSampleIO} for I/O random generation. |
| 24 | #' |
| 25 | #' @examples |
| 26 | #' io <- generateSampleIO(10000, 1/2, matrix(c(1,0,0,1),ncol=2), c(0,0), "probit") |
| 27 | #' μ <- computeMu(io$X, io$Y, list(K=2)) #or just X and Y for estimated K |
| 28 | #' |
| 29 | #' @export |
| 30 | computeMu <- function(X, Y, optargs=list()) |
| 31 | { |
| 32 | if (!is.matrix(X) || !is.numeric(X) || any(is.na(X))) |
| 33 | stop("X: real matrix, no NA") |
| 34 | n <- nrow(X) |
| 35 | d <- ncol(X) |
| 36 | if (!is.numeric(Y) || length(Y)!=n || any(Y!=0 & Y!=1)) |
| 37 | stop("Y: vector of 0 and 1, size nrow(X), no NA") |
| 38 | if (!is.list(optargs)) |
| 39 | stop("optargs: list") |
| 40 | |
| 41 | # Step 0: Obtain the empirically estimated moments tensor, estimate also K |
| 42 | M <- if (is.null(optargs$M)) computeMoments(X,Y) else optargs$M |
| 43 | K <- optargs$K |
| 44 | if (is.null(K)) |
| 45 | { |
| 46 | # TODO: improve this basic heuristic |
| 47 | Σ <- svd(M[[2]])$d |
| 48 | large_ratio <- ( abs(Σ[-d] / Σ[-1]) > 3 ) |
| 49 | K <- if (any(large_ratio)) max(2, which.min(large_ratio)) else d |
| 50 | } |
| 51 | else if (K > d) |
| 52 | stop("K: integer >= 2, <= d") |
| 53 | |
| 54 | # Step 1: generate a family of d matrices to joint-diagonalize to increase robustness |
| 55 | d <- ncol(X) |
| 56 | fixed_design <- FALSE |
| 57 | jd_nvects <- ifelse(!is.null(optargs$jd_nvects), optargs$jd_nvects, 0) |
| 58 | if (jd_nvects == 0) |
| 59 | { |
| 60 | jd_nvects <- d |
| 61 | fixed_design <- TRUE |
| 62 | } |
| 63 | M2_t <- array(dim=c(d,d,jd_nvects)) |
| 64 | for (i in seq_len(jd_nvects)) |
| 65 | { |
| 66 | rho <- if (fixed_design) c(rep(0,i-1),1,rep(0,d-i)) else normalize( rnorm(d) ) |
| 67 | M2_t[,,i] <- .T_I_I_w(M[[3]],rho) |
| 68 | } |
| 69 | |
| 70 | # Step 2: obtain factors u_i (and their inverse) from the joint diagonalisation of M2_t |
| 71 | jd_method <- ifelse(!is.null(optargs$jd_method), optargs$jd_method, "uwedge") |
| 72 | V <- |
| 73 | if (jd_nvects > 1) { |
| 74 | # NOTE: increasing itermax does not help to converge, thus we suppress warnings |
| 75 | suppressWarnings({jd = jointDiag::ajd(M2_t, method=jd_method)}) |
| 76 | if (jd_method=="uwedge") jd$B else MASS::ginv(jd$A) |
| 77 | } |
| 78 | else |
| 79 | eigen(M2_t[,,1])$vectors |
| 80 | |
| 81 | # Step 3: obtain final factors from joint diagonalisation of T(I,I,u_i) |
| 82 | M2_t <- array(dim=c(d,d,K)) |
| 83 | for (i in seq_len(K)) |
| 84 | M2_t[,,i] <- .T_I_I_w(M[[3]],V[,i]) |
| 85 | suppressWarnings({jd = jointDiag::ajd(M2_t, method=jd_method)}) |
| 86 | U <- if (jd_method=="uwedge") MASS::ginv(jd$B) else jd$A |
| 87 | μ <- normalize(U[,1:K]) |
| 88 | |
| 89 | # M1 also writes M1 = sum_k coeff_k * μ_k, where coeff_k >= 0 |
| 90 | # ==> search decomposition of vector M1 onto the (truncated) basis μ (of size dxK) |
| 91 | # This is a linear system μ %*% C = M1 with C of size K ==> C = psinv(μ) %*% M1 |
| 92 | C <- MASS::ginv(μ) %*% M[[1]] |
| 93 | μ[,C < 0] <- - μ[,C < 0] |
| 94 | μ |
| 95 | } |