Commit | Line | Data |
---|---|---|
cbd88fe5 BA |
1 | #' Compute μ |
2 | #' | |
3 | #' Estimate the normalized columns μ of the β matrix parameter in a mixture of | |
4 | #' logistic regressions models, with a spectral method described in the package vignette. | |
5 | #' | |
6 | #' @param X Matrix of input data (size nxd) | |
7 | #' @param Y Vector of binary outputs (size n) | |
8 | #' @param optargs List of optional argument: | |
9 | #' \itemize{ | |
10 | #' \item 'jd_method', joint diagonalization method from the package jointDiag: | |
11 | #' 'uwedge' (default) or 'jedi'. | |
12 | #' \item 'jd_nvects', number of random vectors for joint-diagonalization | |
13 | #' (or 0 for p=d, canonical basis by default) | |
14 | #' \item 'M', moments of order 1,2,3: will be computed if not provided. | |
d294ece1 | 15 | #' \item 'K', number of populations (estimated with rank of M2 if not given) |
cbd88fe5 BA |
16 | #' } |
17 | #' | |
18 | #' @return The estimated normalized parameters as columns of a matrix μ of size dxK | |
19 | #' | |
20 | #' @seealso \code{multiRun} to estimate statistics based on μ, | |
21 | #' and \code{generateSampleIO} for I/O random generation. | |
22 | #' | |
23 | #' @examples | |
24 | #' io = generateSampleIO(10000, 1/2, matrix(c(1,0,0,1),ncol=2), c(0,0), "probit") | |
25 | #' μ = computeMu(io$X, io$Y, list(K=2)) #or just X and Y for estimated K | |
2b3a6af5 | 26 | #' |
cbd88fe5 BA |
27 | #' @export |
28 | computeMu = function(X, Y, optargs=list()) | |
29 | { | |
6dd5c2ac BA |
30 | if (!is.matrix(X) || !is.numeric(X) || any(is.na(X))) |
31 | stop("X: real matrix, no NA") | |
32 | n = nrow(X) | |
33 | d = ncol(X) | |
34 | if (!is.numeric(Y) || length(Y)!=n || any(Y!=0 & Y!=1)) | |
35 | stop("Y: vector of 0 and 1, size nrow(X), no NA") | |
36 | if (!is.list(optargs)) | |
37 | stop("optargs: list") | |
cbd88fe5 | 38 | |
6dd5c2ac BA |
39 | # Step 0: Obtain the empirically estimated moments tensor, estimate also K |
40 | M = if (is.null(optargs$M)) computeMoments(X,Y) else optargs$M | |
41 | K = optargs$K | |
42 | if (is.null(K)) | |
43 | { | |
44 | # TODO: improve this basic heuristic | |
45 | Σ = svd(M[[2]])$d | |
46 | large_ratio <- ( abs(Σ[-d] / Σ[-1]) > 3 ) | |
47 | K <- if (any(large_ratio)) max(2, which.min(large_ratio)) else d | |
48 | } | |
4b2f17bb BA |
49 | else if (K > d) |
50 | stop("K: integer >= 2, <= d") | |
cbd88fe5 | 51 | |
6dd5c2ac BA |
52 | # Step 1: generate a family of d matrices to joint-diagonalize to increase robustness |
53 | d = ncol(X) | |
54 | fixed_design = FALSE | |
55 | jd_nvects = ifelse(!is.null(optargs$jd_nvects), optargs$jd_nvects, 0) | |
56 | if (jd_nvects == 0) | |
57 | { | |
58 | jd_nvects = d | |
59 | fixed_design = TRUE | |
60 | } | |
61 | M2_t = array(dim=c(d,d,jd_nvects)) | |
62 | for (i in seq_len(jd_nvects)) | |
63 | { | |
64 | rho = if (fixed_design) c(rep(0,i-1),1,rep(0,d-i)) else normalize( rnorm(d) ) | |
65 | M2_t[,,i] = .T_I_I_w(M[[3]],rho) | |
66 | } | |
cbd88fe5 | 67 | |
6dd5c2ac BA |
68 | # Step 2: obtain factors u_i (and their inverse) from the joint diagonalisation of M2_t |
69 | jd_method = ifelse(!is.null(optargs$jd_method), optargs$jd_method, "uwedge") | |
70 | V = | |
71 | if (jd_nvects > 1) { | |
0f5fbd13 | 72 | # NOTE: increasing itermax does not help to converge, thus we suppress warnings |
6dd5c2ac | 73 | suppressWarnings({jd = jointDiag::ajd(M2_t, method=jd_method)}) |
6dd5c2ac BA |
74 | if (jd_method=="uwedge") jd$B else MASS::ginv(jd$A) |
75 | } | |
76 | else | |
77 | eigen(M2_t[,,1])$vectors | |
cbd88fe5 | 78 | |
6dd5c2ac BA |
79 | # Step 3: obtain final factors from joint diagonalisation of T(I,I,u_i) |
80 | M2_t = array(dim=c(d,d,K)) | |
81 | for (i in seq_len(K)) | |
82 | M2_t[,,i] = .T_I_I_w(M[[3]],V[,i]) | |
83 | suppressWarnings({jd = jointDiag::ajd(M2_t, method=jd_method)}) | |
6dd5c2ac BA |
84 | U = if (jd_method=="uwedge") MASS::ginv(jd$B) else jd$A |
85 | μ = normalize(U[,1:K]) | |
cbd88fe5 | 86 | |
6dd5c2ac BA |
87 | # M1 also writes M1 = sum_k coeff_k * μ_k, where coeff_k >= 0 |
88 | # ==> search decomposition of vector M1 onto the (truncated) basis μ (of size dxK) | |
89 | # This is a linear system μ %*% C = M1 with C of size K ==> C = psinv(μ) %*% M1 | |
90 | C = MASS::ginv(μ) %*% M[[1]] | |
91 | μ[,C < 0] = - μ[,C < 0] | |
92 | μ | |
cbd88fe5 | 93 | } |