update
[morpheus.git] / pkg / tests / testthat / test-optimParams.R
CommitLineData
2b3a6af5 1naive_f <- function(link, M1,M2,M3, p,β,b)
cbd88fe5 2{
2b3a6af5
BA
3 d <- length(M1)
4 K <- length(p)
6dd5c2ac 5 λ <- sqrt(colSums(β^2))
cbd88fe5 6
6dd5c2ac 7 # Compute β x2,3 (self) tensorial products
2b3a6af5
BA
8 β2 <- array(0, dim=c(d,d,K))
9 β3 <- array(0, dim=c(d,d,d,K))
6dd5c2ac
BA
10 for (k in 1:K)
11 {
12 for (i in 1:d)
13 {
14 for (j in 1:d)
15 {
16 β2[i,j,k] = β[i,k]*β[j,k]
17 for (l in 1:d)
18 β3[i,j,l,k] = β[i,k]*β[j,k]*β[l,k]
19 }
20 }
21 }
cbd88fe5 22
2b3a6af5 23 res <- 0
6dd5c2ac
BA
24 for (i in 1:d)
25 {
2b3a6af5 26 term <- 0
6dd5c2ac 27 for (k in 1:K)
2b3a6af5
BA
28 term <- term + p[k]*.G(link,1,λ[k],b[k])*β[i,k]
29 res <- res + (term - M1[i])^2
6dd5c2ac
BA
30 for (j in 1:d)
31 {
2b3a6af5 32 term <- 0
6dd5c2ac 33 for (k in 1:K)
2b3a6af5
BA
34 term <- term + p[k]*.G(link,2,λ[k],b[k])*β2[i,j,k]
35 res <- res + (term - M2[i,j])^2
6dd5c2ac
BA
36 for (l in 1:d)
37 {
2b3a6af5 38 term <- 0
6dd5c2ac 39 for (k in 1:K)
2b3a6af5
BA
40 term <- term + p[k]*.G(link,3,λ[k],b[k])*β3[i,j,l,k]
41 res <- res + (term - M3[i,j,l])^2
6dd5c2ac
BA
42 }
43 }
44 }
45 res
cbd88fe5
BA
46}
47
ab35f610
BA
48# TODO: understand why delta is so large (should be 10^-6 10^-7 ...)
49test_that("naive computation provides the same result as vectorized computations",
50{
51 h <- 1e-7 #for finite-difference tests
52 n <- 10
53 for (dK in list( c(2,2), c(5,3)))
54 {
55 d <- dK[1]
56 K <- dK[2]
57
58 M1 <- runif(d, -1, 1)
59 M2 <- matrix(runif(d^2, -1, 1), ncol=d)
60 M3 <- array(runif(d^3, -1, 1), dim=c(d,d,d))
61
62 for (link in c("logit","probit"))
63 {
64 # X and Y are unused here (W not re-computed)
65 op <- optimParams(X=matrix(runif(n*d),ncol=d), Y=rbinom(n,1,.5),
66 K, link, M=list(M1,M2,M3))
67 op$W <- diag(d + d^2 + d^3)
68
69 for (var in seq_len((2+d)*K-1))
70 {
71 p <- runif(K, 0, 1)
72 p <- p / sum(p)
73 β <- matrix(runif(d*K,-5,5),ncol=K)
74 b <- runif(K, -5, 5)
75 x <- c(p[1:(K-1)],as.double(β),b)
76
77 # Test functions values (TODO: 1 is way too high)
78 expect_equal( op$f(x)[1], naive_f(link,M1,M2,M3, p,β,b), tolerance=1 )
79
80 # Test finite differences ~= gradient values
81 dir_h <- rep(0, (2+d)*K-1)
82 dir_h[var] = h
83 expect_equal( op$grad_f(x)[var], ((op$f(x+dir_h) - op$f(x)) / h)[1], tolerance=0.5 )
84 }
85 }
86 }
87})
2b3a6af5
BA
88
89test_that("W computed in C and in R are the same",
cbd88fe5 90{
2b3a6af5 91 tol <- 1e-8
ab35f610
BA
92 n <- 10
93 for (dK in list( c(2,2))) #, c(5,3)))
6dd5c2ac 94 {
2b3a6af5
BA
95 d <- dK[1]
96 K <- dK[2]
97 link <- ifelse(d==2, "logit", "probit")
98 θ <- list(
99 p=rep(1/K,K),
100 β=matrix(runif(d*K),ncol=K),
101 b=rep(0,K))
102 io <- generateSampleIO(n, θ$p, θ$β, θ$b, link)
103 X <- io$X
104 Y <- io$Y
105 dd <- d + d^2 + d^3
106 p <- θ$p
107 β <- θ$β
108 λ <- sqrt(colSums(β^2))
109 b <- θ$b
110 β2 <- apply(β, 2, function(col) col %o% col)
111 β3 <- apply(β, 2, function(col) col %o% col %o% col)
112 M <- c(
113 β %*% (p * .G(link,1,λ,b)),
114 β2 %*% (p * .G(link,2,λ,b)),
115 β3 %*% (p * .G(link,3,λ,b)))
116 Id <- as.double(diag(d))
117 E <- diag(d)
118 v1 <- Y * X
119 v2 <- Y * t( apply(X, 1, function(Xi) Xi %o% Xi - Id) )
120 v3 <- Y * t( apply(X, 1, function(Xi) { return (Xi %o% Xi %o% Xi
121 - Reduce('+', lapply(1:d, function(j)
122 as.double(Xi %o% E[j,] %o% E[j,])), rep(0, d*d*d))
123 - Reduce('+', lapply(1:d, function(j)
124 as.double(E[j,] %o% Xi %o% E[j,])), rep(0, d*d*d))
125 - Reduce('+', lapply(1:d, function(j)
126 as.double(E[j,] %o% E[j,] %o% Xi)), rep(0, d*d*d))) } ) )
127 Omega1 <- matrix(0, nrow=dd, ncol=dd)
128 for (i in 1:n)
6dd5c2ac 129 {
2b3a6af5
BA
130 gi <- t(as.matrix(c(v1[i,], v2[i,], v3[i,]) - M))
131 Omega1 <- Omega1 + t(gi) %*% gi / n
6dd5c2ac 132 }
2b3a6af5
BA
133 W <- matrix(0, nrow=dd, ncol=dd)
134 Omega2 <- matrix( .C("Compute_Omega",
135 X=as.double(X), Y=as.integer(Y), M=as.double(M),
ab35f610 136 pnc=as.integer(1), pn=as.integer(n), pd=as.integer(d),
2b3a6af5
BA
137 W=as.double(W), PACKAGE="morpheus")$W, nrow=dd, ncol=dd )
138 rg <- range(Omega1 - Omega2)
ab35f610 139 expect_equal(rg[1], rg[2], tolerance=tol)
6dd5c2ac 140 }
cbd88fe5 141})