fusedTree is a prediction model that integrates a set of low‑dimensional, established clinical variables with high‑dimensional, noisy omics variables. It fits (generalized) linear regression models in each leaf node of a tree, enabling both interpretability and flexibility in handling complex data structures.
Note: Tree construction must be done externally
(e.g., with the rpart
package in R).
For full methodological details, see the preprint.
# CRAN (when available)
install.packages("fusedTree")
# Development version from GitHub
::install_github("JeroenGoedhart/fusedTree") remotes
We illustrate the model for a continuous response. The simulated data has a nonlinear relationship with clinical variables and a linear relationship with omics variables.
library(fusedTree)
if (!requireNamespace("rpart", quietly = TRUE)) install.packages("rpart")
library(rpart)
set.seed(10)
<- 5 # Number of omics variables
p <- 5 # Number of clinical variables
p_Clin <- 100 # Sample size
N
# Nonlinear function of clinical variables
<- function(z) {
g 15 * sin(pi * z[,1] * z[,2]) +
10 * (z[,3] - 0.5)^2 +
2 * exp(z[,4]) +
2 * z[,5]
}
# Clinical and omics covariates
<- as.data.frame(matrix(runif(N * p_Clin), nrow = N))
Z <- matrix(rnorm(N * p), nrow = N)
X <- c(1, -1, 3, 2, -2)
betas
# Response: nonlinear clinical + linear omics + noise
<- as.vector(g(Z) + X %*% betas + rnorm(N)) Y
Thus, the response is generated by a nonlinear clinical part and a separate linear omics part. Therefore, the omics variables do not vary with the clinical variables. The omics regressions in the different nodes of the tree should therefore benefit from strong fusion.
<- cbind.data.frame(Y, Z)
dat <- rpart(
rp ~ ., data = dat,
Y control = rpart.control(xval = 5, minbucket = 10),
model = TRUE
)# poste-prune the tree
<- rp$cptable[which.min(rp$cptable[, "xerror"]), "CP"]
cp <- prune(rp, cp = cp)
Treefit
plot(Treefit, main = "Clinical-variable Tree")
text(Treefit, use.n = TRUE)
Before fitting the model, it’s useful to understand how fusedTree internally represents the data to enable leaf-specific regression. Each leaf node of the tree gets its own (generalized) linear regression model. To support this, two large design matrices are constructed:
Clinical design matrix
(Clinical
):
A binary intercept indicator matrix of size
N × (# of leaf nodes)
. Each column corresponds to a leaf
node, with entries equal to 1 if an observation falls into that node and
0 otherwise.
Omics design matrix (Omics
):
A matrix of size N × (p × # of leaf nodes)
where
p
is the number of omics variables. For each leaf node, the
corresponding block of columns contains the omics values only
for the observations in that node; entries are 0
elsewhere.
These matrices are created automatically during model fitting, but
you can inspect them yourself using the Dat_Tree()
function:
<- Dat_Tree(Tree = Treefit, X = X, Z = Z, LinVars = FALSE)
Dat_fusedTree
# Clinical design matrix: indicator for node membership
head(Dat_fusedTree$Clinical)
#> N2 N6 N7
#> 1 0 1 0
#> 2 1 0 0
#> 3 0 1 0
#> 4 0 1 0
#> 5 1 0 0
#> 6 1 0 0
# Omics design matrix: omics data distributed across nodes
head(Dat_fusedTree$Omics)
#> x1_N2 x1_N6 x1_N7 x2_N2 x2_N6 x2_N7 x3_N2
#> [1,] 0.0000000 1.0778926 0 0.0000000 -0.886788 0 0.0000000
#> [2,] 0.9317812 0.0000000 0 1.2711460 0.000000 0 -1.5233846
#> [3,] 0.0000000 -1.4607939 0 0.0000000 -1.605085 0 0.0000000
#> [4,] 0.0000000 -0.9060756 0 0.0000000 1.122273 0 0.0000000
#> [5,] -0.6803478 0.0000000 0 2.1584386 0.000000 0 -0.2874329
#> [6,] 1.0631660 0.0000000 0 0.4282466 0.000000 0 -0.4353083
#> x3_N6 x3_N7 x4_N2 x4_N6 x4_N7 x5_N2 x5_N6 x5_N7
#> [1,] 1.1639675 0 0.00000000 -0.3121347 0 0.0000000 -0.8658204 0
#> [2,] 0.0000000 0 -0.69877530 0.0000000 0 0.8254939 0.0000000 0
#> [3,] -2.5183351 0 0.00000000 -2.6438498 0 0.0000000 -0.8001323 0
#> [4,] -0.7075292 0 0.00000000 0.8250224 0 0.0000000 0.9758301 0
#> [5,] 0.0000000 0 0.30692631 0.0000000 0 2.7000755 0.0000000 0
#> [6,] 0.0000000 0 -0.05803946 0.0000000 0 -0.1353896 0.0000000 0
Note: You do not need to create these matrices
manually — this step is handled internally by the
fusedTree()
function. However, visualizing them can help
you better understand how the model applies fusion across leaf-specific
regressions.
Create balanced cross‑validation folds across the leaf nodes. Folds are balanced w.r.t the proportion of observations in the leaf nodes, and w.r.t the outcome for binary and survival data.
set.seed(11)
<- CVfoldsTree(Y = Y, Tree = Treefit, Z = Z, model = "linear")
folds
<- PenOpt(
optPenalties Tree = Treefit,
X = X,
Y = Y,
Z = Z,
model = "linear",
lambdaInit = 10,
alphaInit = 10,
loss = "loglik",
LinVars = FALSE,
folds = folds,
multistart = FALSE # TRUE yields more stable but slower results
)#> Tuning fusedTree with fusion penalty
optPenalties#> lambda alpha
#> 1.490862e-13 3.843843e+12
As seen, the fusion penalty alpha is tuned to a (very) large value as expected. The standard ridge penalty is (very) small because of the low-dimensional simulation setting
<- fusedTree(
fit Tree = Treefit,
X = X,
Y = Y,
Z = Z,
LinVars = FALSE,
model = "linear",
lambda = optPenalties[1],
alpha = optPenalties[2]
)#> Fit fusedTree with fusion penalty
# View results
$Effects # Omics effects per leaf
fit#> N2 N6 N7 x1_N2 x1_N6 x1_N7 x2_N2
#> 9.1434885 11.5976204 18.2735435 0.6656824 0.6656824 0.6656824 -0.9519646
#> x2_N6 x2_N7 x3_N2 x3_N6 x3_N7 x4_N2 x4_N6
#> -0.9519646 -0.9519646 3.1750430 3.1750430 3.1750430 1.7737451 1.7737451
#> x4_N7 x5_N2 x5_N6 x5_N7
#> 1.7737451 -1.9979752 -1.9979752 -1.9979752
plot(fit$Tree) # Underlying tree structure
$Pars # Model parameters
fit#> Model LinVar Alpha Lambda
#> alpha linear FALSE 3.843843e+12 1.490862e-13
Because of the strong fusion penalty, the estimated omics effects across leaf nodes are (nearly) identical. However, some bias remains in the omics effect estimates due to the tree’s limited ability to capture the nonlinear structure in the clinical variables. Since the leaf-node-specific intercepts (representing the clinical contribution) and the omics effects are estimated jointly, bias in the intercepts propagates into the omics coefficients.
# Simulate test set
<- 50
N_test <- as.data.frame(matrix(runif(N_test * p_Clin), nrow = N_test))
Z_test <- matrix(rnorm(N_test * p), nrow = N_test)
X_test <- as.vector(g(Z_test) + X_test %*% betas + rnorm(N_test))
Y_test
# Generate predictions
<- predict(fit, newX = X_test, newY = Y_test, newZ = Z_test)
Preds <- mean((Preds$Resp - Preds$Ypred)^2)
PMSE
PMSE#> [1] 15.03962
Below is a short example showing how to use fusedTree
for binary outcomes. We simulate a binary response using a logistic
model, with omics effects shared across leaf nodes.
# Load package
library(fusedTree)
if (!requireNamespace("rpart", quietly = TRUE)) install.packages("rpart")
# Settings
set.seed(13)
<- 300
N <- 5
p <- 5
p_Clin
# Simulate data
<- as.data.frame(matrix(runif(N * p_Clin), nrow = N)) # clinical variables
Z <- matrix(rnorm(N * p), nrow = N) # omics variables
X <- c(1, -1, 3, 2, -2)
betas <- 15 * sin(pi * Z[,1] * Z[,2]) - 10 * (Z[,3] - 0.5)^2 -
eta 2 * exp(Z[,4]) - 2 * Z[,5] + X %*% betas
<- 1 / (1 + exp(-eta))
prob <- rbinom(N, size = 1, prob = prob)
Y
# Fit tree using only clinical variables
<- data.frame(Y = Y, Z)
dat <- rpart::rpart(Y ~ ., data = dat,
rp control = rpart::rpart.control(xval = 10, minbucket = 10),
method = "class", model = TRUE)
<- rp$cptable[,1][which.min(rp$cptable[,4])]
cp <- rpart::prune(rp, cp = cp)
Treefit plot(Treefit)
We then tune the penalties and fit the fusedTree model:
# Create folds
set.seed(30)
<- CVfoldsTree(Y = Y, Tree = Treefit, Z = Z,
folds model = "logistic", nrepeat = 1)
# Tune hyperparameters
<- PenOpt(Tree = Treefit, X = X, Y = Y, Z = Z,
optPenalties model = "logistic",
lambdaInit = 10, alphaInit = 10,
loss = "loglik",
LinVars = FALSE,
folds = folds,
multistart = TRUE) # slower
#> Tuning fusedTree with fusion penalty
optPenalties#> lambda alpha
#> 0.2211904 141.7153124
# Fit fusedTree
<- fusedTree(Tree = Treefit, X = X, Y = Y, Z = Z,
fit_bin LinVars = FALSE, model = "logistic",
lambda = optPenalties[1],
alpha = optPenalties[2],
verbose = TRUE) # prints progress of IRLS algorithm
#> Fit fusedTree with fusion penalty
#> Iteration 1 log likelihood equals: -101.096
#> Iteration 2 log likelihood equals: -85.313
#> Iteration 3 log likelihood equals: -81.537
#> Iteration 4 log likelihood equals: -81.180
#> Iteration 5 log likelihood equals: -81.176
#> Iteration 6 log likelihood equals: -81.176
#> Iteration 7 log likelihood equals: -81.176
#> IRLS converged at iteration 7
$Effects
fit_bin#> N2 N6 N7 x1_N2 x1_N6 x1_N7 x2_N2
#> -1.5850973 -2.1764195 3.4338211 0.6738147 0.6737848 0.6949362 -0.3938863
#> x2_N6 x2_N7 x3_N2 x3_N6 x3_N7 x4_N2 x4_N6
#> -0.3700743 -0.3992398 1.2956706 1.2868135 1.2815384 1.2211976 1.2232870
#> x4_N7 x5_N2 x5_N6 x5_N7
#> 1.2347340 -1.2410471 -1.2507165 -1.2090639
Finally, we simulate test data and evaluate classification performance:
# Simulate test data
<- 50
N_test <- as.data.frame(matrix(runif(N_test * p_Clin), nrow = N_test))
Z_test <- matrix(rnorm(N_test * p), nrow = N_test)
X_test <- 15 * sin(pi * Z_test[,1] * Z_test[,2]) - 10 * (Z_test[,3] - 0.5)^2 -
eta_test 2 * exp(Z_test[,4]) - 2 * Z_test[,5] + X_test %*% betas
<- 1 / (1 + exp(-eta_test))
prob_test <- rbinom(N_test, size = 1, prob = prob_test)
Y_test
# Predict
<- predict(fit_bin, newX = X_test, newY = Y_test, newZ = Z_test)
Preds
# AUC
if (!requireNamespace("pROC", quietly = TRUE)) install.packages("pROC")
library(pROC)
#> Type 'citation("pROC")' for a citation.
#>
#> Attaching package: 'pROC'
#> The following objects are masked from 'package:stats':
#>
#> cov, smooth, var
<- pROC::auc(Y_test, Preds$Ypred)
auc_result #> Setting levels: control = 0, case = 1
#> Setting direction: controls < cases
auc_result#> Area under the curve: 0.9328
This example demonstrates how to apply fusedTree
to
binary classification problems using logistic regression and prediction
based on the estimated fused model.
fusedTree provides:
See the paper for applications to survival outcomes, and further methodological details.