Commit | Line | Data |
---|---|---|
cbd88fe5 BA |
1 | #' Compute μ |
2 | #' | |
3 | #' Estimate the normalized columns μ of the β matrix parameter in a mixture of | |
4 | #' logistic regressions models, with a spectral method described in the package vignette. | |
5 | #' | |
6 | #' @param X Matrix of input data (size nxd) | |
7 | #' @param Y Vector of binary outputs (size n) | |
8 | #' @param optargs List of optional argument: | |
9 | #' \itemize{ | |
10 | #' \item 'jd_method', joint diagonalization method from the package jointDiag: | |
11 | #' 'uwedge' (default) or 'jedi'. | |
12 | #' \item 'jd_nvects', number of random vectors for joint-diagonalization | |
13 | #' (or 0 for p=d, canonical basis by default) | |
14 | #' \item 'M', moments of order 1,2,3: will be computed if not provided. | |
15 | #' \item 'K', number of populations (estimated with ranks of M2 if not given) | |
16 | #' } | |
17 | #' | |
18 | #' @return The estimated normalized parameters as columns of a matrix μ of size dxK | |
19 | #' | |
20 | #' @seealso \code{multiRun} to estimate statistics based on μ, | |
21 | #' and \code{generateSampleIO} for I/O random generation. | |
22 | #' | |
23 | #' @examples | |
24 | #' io = generateSampleIO(10000, 1/2, matrix(c(1,0,0,1),ncol=2), c(0,0), "probit") | |
25 | #' μ = computeMu(io$X, io$Y, list(K=2)) #or just X and Y for estimated K | |
26 | #' @export | |
27 | computeMu = function(X, Y, optargs=list()) | |
28 | { | |
29 | if (!is.matrix(X) || !is.numeric(X) || any(is.na(X))) | |
30 | stop("X: real matrix, no NA") | |
31 | n = nrow(X) | |
32 | d = ncol(X) | |
33 | if (!is.numeric(Y) || length(Y)!=n || any(Y!=0 & Y!=1)) | |
34 | stop("Y: vector of 0 and 1, size nrow(X), no NA") | |
35 | if (!is.list(optargs)) | |
36 | stop("optargs: list") | |
37 | ||
38 | # Step 0: Obtain the empirically estimated moments tensor, estimate also K | |
39 | M = if (is.null(optargs$M)) computeMoments(X,Y) else optargs$M | |
40 | K = optargs$K | |
41 | if (is.null(K)) | |
42 | { | |
43 | # TODO: improve this basic heuristic | |
44 | Σ = svd(M[[2]])$d | |
45 | large_ratio <- ( abs(Σ[-d] / Σ[-1]) > 3 ) | |
46 | K <- if (any(large_ratio)) max(2, which.min(large_ratio)) else d | |
47 | } | |
48 | ||
49 | # Step 1: generate a family of d matrices to joint-diagonalize to increase robustness | |
50 | d = ncol(X) | |
51 | fixed_design = FALSE | |
52 | jd_nvects = ifelse(!is.null(optargs$jd_nvects), optargs$jd_nvects, 0) | |
53 | if (jd_nvects == 0) | |
54 | { | |
55 | jd_nvects = d | |
56 | fixed_design = TRUE | |
57 | } | |
58 | M2_t = array(dim=c(d,d,jd_nvects)) | |
59 | for (i in seq_len(jd_nvects)) | |
60 | { | |
61 | rho = if (fixed_design) c(rep(0,i-1),1,rep(0,d-i)) else normalize( rnorm(d) ) | |
62 | M2_t[,,i] = .T_I_I_w(M[[3]],rho) | |
63 | } | |
64 | ||
65 | # Step 2: obtain factors u_i (and their inverse) from the joint diagonalisation of M2_t | |
66 | jd_method = ifelse(!is.null(optargs$jd_method), optargs$jd_method, "uwedge") | |
67 | V = | |
68 | if (jd_nvects > 1) { | |
69 | #NOTE: increasing itermax does not help to converge, thus we suppress warnings | |
70 | suppressWarnings({jd = jointDiag::ajd(M2_t, method=jd_method)}) | |
71 | # if (jd_method=="uwedge") jd$B else solve(jd$A) | |
72 | if (jd_method=="uwedge") jd$B else MASS::ginv(jd$A) | |
73 | } | |
74 | else | |
75 | eigen(M2_t[,,1])$vectors | |
76 | ||
77 | # Step 3: obtain final factors from joint diagonalisation of T(I,I,u_i) | |
78 | M2_t = array(dim=c(d,d,K)) | |
79 | for (i in seq_len(K)) | |
80 | M2_t[,,i] = .T_I_I_w(M[[3]],V[,i]) | |
81 | suppressWarnings({jd = jointDiag::ajd(M2_t, method=jd_method)}) | |
82 | # U = if (jd_method=="uwedge") solve(jd$B) else jd$A | |
83 | U = if (jd_method=="uwedge") MASS::ginv(jd$B) else jd$A | |
84 | μ = normalize(U[,1:K]) | |
85 | ||
86 | # M1 also writes M1 = sum_k coeff_k * μ_k, where coeff_k >= 0 | |
87 | # ==> search decomposition of vector M1 onto the (truncated) basis μ (of size dxK) | |
88 | # This is a linear system μ %*% C = M1 with C of size K ==> C = psinv(μ) %*% M1 | |
89 | C = MASS::ginv(μ) %*% M[[1]] | |
90 | μ[,C < 0] = - μ[,C < 0] | |
91 | μ | |
92 | } |