Skip to content

Commit

Permalink
Various omega-related updates by working with omega_inv directly rath…
Browse files Browse the repository at this point in the history
…er than omega:

a) avoiding solve() for diagonal matrices.
b) pre-computing v1=v+1.
c) replaced update_omega() with update_omega_inv(), since riwish was basically solve(rwish(v, solve(V)) and only omega_inv rather than omega was ever used.
d) removed argument "b" from update_omega_inv() and used more efficient matrix multiplication/inversion.
e) removed dependency on MCMCpack completely by writing own (faster!) rWish() function.
  • Loading branch information
Keefe-Murphy committed Feb 27, 2022
1 parent 4843483 commit 1660969
Show file tree
Hide file tree
Showing 5 changed files with 26 additions and 33 deletions.
2 changes: 1 addition & 1 deletion cspbart/DESCRIPTION
Original file line number Diff line number Diff line change
Expand Up @@ -9,5 +9,5 @@ Maintainer: Estevão B. Prado <estevao.prado@hotmail.com>
License: GPL (>= 2)
Encoding: UTF-8
LazyData: true
Imports: statmod, stats, MCMCpack, truncnorm, lme4, dbarts
Imports: statmod, stats, truncnorm, lme4, dbarts
RoxygenNote: 7.1.1
2 changes: 0 additions & 2 deletions cspbart/NAMESPACE
Original file line number Diff line number Diff line change
Expand Up @@ -7,8 +7,6 @@ export(cspbart)
export(predict_cspbart)
export(sspbart)
export(var_used_trees)
importFrom(MCMCpack,'rdirichlet')
importFrom(MCMCpack,'riwish')
importFrom(dbarts,'makeModelMatrixFromDataFrame')
importFrom(lme4,'lFormula')
importFrom(stats,'as.formula')
Expand Down
20 changes: 7 additions & 13 deletions cspbart/R/cspbart.R
Original file line number Diff line number Diff line change
@@ -1,5 +1,4 @@
#' @export
#' @importFrom MCMCpack 'rdirichlet' 'riwish'
#' @importFrom stats 'rgamma' 'runif' 'dnorm' 'sd' 'rnorm' 'pnorm' 'tapply' 'as.formula' 'terms'
#' @importFrom truncnorm 'rtruncnorm'
#' @importFrom lme4 'lFormula'
Expand Down Expand Up @@ -80,11 +79,10 @@ cspbart = function(formula,
p2 = ncol(x2)
p = p1 + p2
s = rep(1/p2, p2)
Omega = diag(p1)
Omega_inv = solve(Omega)
Omega_inv = diag(p1)
b = rep(0, p1)
V = diag(p1)
v = p1
v1 = p1 + 1
beta_hat = rep(0, p1)
current_partial_residuals = y_scale

Expand Down Expand Up @@ -125,8 +123,7 @@ cspbart = function(formula,
yhat_linear = x1%*%beta_hat

# Update covariance matrix of the linear predictor
Omega = update_omega(beta_hat, b, V, v)
Omega_inv = solve(Omega)
Omega_inv = update_omega_inv(beta_hat, V, v1)

# Start looping through trees
for (j in seq_len(ntrees)) {
Expand Down Expand Up @@ -236,7 +233,6 @@ cspbart = function(formula,


#' @export
#' @importFrom MCMCpack 'rdirichlet' 'riwish'
#' @importFrom stats 'rgamma' 'runif' 'dnorm' 'sd' 'rnorm' 'pnorm' 'tapply' 'as.formula' 'model.matrix'
#' @importFrom truncnorm 'rtruncnorm'
#' @importFrom lme4 'lFormula'
Expand Down Expand Up @@ -299,11 +295,10 @@ cl_cspbart = function(formula,
p2 = ncol(x2)
p = p1 + p2
s = rep(1/p2, p2)
Omega = diag(p1)
Omega_inv = solve(Omega)
Omega_inv = diag(p1)
b = rep(0, p1)
V = diag(p1)
v = p1
v1 = p1 + 1
beta_hat = rep(0, p1)
z = ifelse(y == 0, -3, 3)

Expand Down Expand Up @@ -343,9 +338,8 @@ cl_cspbart = function(formula,
yhat_linear = x1%*%beta_hat

# Update covariance matrix of the linear predictor
Omega = update_omega(beta_hat, b, V, v)
Omega_inv = solve(Omega)

Omega_inv = update_omega_inv(beta_hat, V, v1)

# Start looping through trees
for (j in seq_len(ntrees)) {

Expand Down
14 changes: 11 additions & 3 deletions cspbart/R/parameters_quantities_tree.R
Original file line number Diff line number Diff line change
Expand Up @@ -170,9 +170,17 @@ update_beta <- function(y, x, sigma2, omega_inv) { # (1) sample from MVN(Q^{-1}b
}
}

update_omega_inv <- function(beta_hat, V, v1) {
rWish(v1, chol2inv(chol(tcrossprod(beta_hat) + V)))
}

update_omega <- function(beta_hat, b, V, v){
out = riwish(v + 1, (beta_hat - b)%*%t(beta_hat - b) + V)
return(out)
rWish <- function(v, S) {
p <- nrow(S)
CC <- chol(S)
Z <- diag(sqrt(rchisq(p, v:(v - p + 1L))), p)
if(p > 1) {
pseq <- seq_len(p - 1L)
Z[rep(p * pseq, pseq) + unlist(lapply(pseq, seq))] <- rnorm(p * (p - 1)/2)
}
crossprod(Z %*% CC)
}
21 changes: 7 additions & 14 deletions cspbart/R/sspbart.R
Original file line number Diff line number Diff line change
@@ -1,5 +1,4 @@
#' @export
#' @importFrom MCMCpack 'rdirichlet' 'riwish'
#' @importFrom stats 'rgamma' 'runif' 'dnorm' 'sd' 'rnorm' 'pnorm' 'tapply' 'as.formula' 'terms'
#' @importFrom truncnorm 'rtruncnorm'
#' @importFrom lme4 'lFormula'
Expand Down Expand Up @@ -82,12 +81,10 @@ sspbart = function(formula,
p = p1 + p2
s = rep(1/p2, p2)
sigma2_b = 10000
Omega = sigma2_b*diag(p1)

Omega_inv = solve(Omega)
Omega_inv = diag(1/sigma2_b, p1)
b = rep(0, p1)
V = diag(p1)
v = p1
v1 = p1 + 1
beta_hat = rep(0, p1)
current_partial_residuals = y_scale

Expand Down Expand Up @@ -128,9 +125,8 @@ sspbart = function(formula,
yhat_linear = x1%*%beta_hat

# Update covariance matrix of the linear predictor
# Omega = update_omega(beta_hat, b, V, v)
# Omega_inv = solve(Omega)

# Omega_inv = update_omega_inv(beta_hat, V, v1)

# Start looping through trees
for (j in seq_len(ntrees)) {

Expand Down Expand Up @@ -239,7 +235,6 @@ sspbart = function(formula,


#' @export
#' @importFrom MCMCpack 'rdirichlet' 'riwish'
#' @importFrom stats 'rgamma' 'runif' 'dnorm' 'sd' 'rnorm' 'pnorm' 'tapply' 'as.formula' 'model.matrix'
#' @importFrom truncnorm 'rtruncnorm'
#' @importFrom lme4 'lFormula'
Expand Down Expand Up @@ -304,11 +299,10 @@ cl_sspbart = function(formula,
p = p1 + p2
s = rep(1/p2, p2)
sigma2_b = 10000
Omega = sigma2_b*diag(p1)
Omega_inv = solve(Omega)
Omega_inv = diag(1/sigma2_b, p1)
b = rep(0, p1)
V = diag(p1)
v = p1
v = p1 + 1
beta_hat = rep(0, p1)
z = ifelse(y == 0, -3, 3)

Expand Down Expand Up @@ -348,8 +342,7 @@ cl_sspbart = function(formula,
yhat_linear = x1%*%beta_hat

# Update covariance matrix of the linear predictor
# Omega = update_omega(beta_hat, b, V, v)
# Omega_inv = solve(Omega)
# Omega_inv = update_omega(beta_hat, V, v1)

# Start looping through trees
for (j in seq_len(ntrees)) {
Expand Down

0 comments on commit 1660969

Please sign in to comment.