Title: | Deep Neural Networks for Survival Analysis Using 'torch' |
Version: | 0.6.0 |
Description: | Provides deep learning models for right-censored survival data using the 'torch' backend. Supports multiple loss functions, including Cox partial likelihood, L2-penalized Cox, time-dependent Cox, and accelerated failure time (AFT) loss. Offers a formula-based interface, built-in support for cross-validation, hyperparameter tuning, survival curve plotting, and evaluation metrics such as the C-index, Brier score, and integrated Brier score. For methodological details, see Kvamme et al. (2019) https://www.jmlr.org/papers/v20/18-424.html. |
License: | MIT + file LICENSE |
Encoding: | UTF-8 |
Depends: | R (≥ 4.1.0) |
Imports: | torch, survival, stats, utils, tibble, dplyr, purrr, tidyr, ggplot2, methods, rsample, cli, glue |
Suggests: | testthat (≥ 3.0.0), knitr, rmarkdown |
RoxygenNote: | 7.3.2 |
Config/testthat/edition: | 3 |
BugReports: | https://github.com/ielbadisy/survdnn/issues |
URL: | https://github.com/ielbadisy/survdnn |
NeedsCompilation: | no |
Packaged: | 2025-07-21 12:44:39 UTC; imad-el-badisy |
Author: | Imad EL BADISY [aut, cre] |
Maintainer: | Imad EL BADISY <elbadisyimad@gmail.com> |
Repository: | CRAN |
Date/Publication: | 2025-07-22 11:11:46 UTC |
Brier Score for Right-Censored Survival Data at a Fixed Time
Description
Computes the Brier score at a fixed time point using inverse probability of censoring weights (IPCW).
Usage
brier(object, pre_sp, t_star)
Arguments
object |
A 'Surv' object with observed time and status. |
pre_sp |
A numeric vector of predicted survival probabilities at 't_star'. |
t_star |
The evaluation time point. |
Value
A single numeric value representing the Brier score.
Examples
library(survival)
data(veteran, package = "survival")
mod <- survdnn(Surv(time, status) ~
age + karno + celltype, data = veteran, epochs = 50, verbose = FALSE)
pred <- predict(mod, newdata = veteran, type = "survival", times = c(30, 90, 180))
y <- model.response(model.frame(mod$formula, veteran))
survdnn::brier(y, pre_sp = pred[["t=90"]], t_star = 90)
Build a Deep Neural Network for Survival Analysis
Description
Constructs a multilayer perceptron (MLP) with batch normalization, activation functions, and dropout. Used internally by [survdnn()] to define the model architecture.
Usage
build_dnn(input_dim, hidden, activation = "relu", output_dim = 1L)
Arguments
input_dim |
Integer. Number of input features. |
Integer vector. Sizes of the hidden layers (e.g., c(32, 16)). | |
activation |
Character. Name of the activation function to use in each layer. Supported options: '"relu"', '"leaky_relu"', '"tanh"', '"sigmoid"', '"gelu"', '"elu"', '"softplus"'. |
output_dim |
Integer. Output layer dimension (default = 1). |
Value
A 'nn_sequential' object representing the network.
Examples
net <- build_dnn(10, hidden = c(64, 32), activation = "relu")
Concordance Index from a Survival Probability Matrix
Description
Computes the time-dependent concordance index (C-index) from a predicted survival matrix at a fixed time point. The risk is computed as '1 - S(t_star)'.
Usage
cindex_survmat(object, predicted, t_star = NULL)
Arguments
object |
A 'Surv' object representing the observed survival data. |
predicted |
A data frame or matrix of predicted survival probabilities. Each column corresponds to a time point (e.g., 't=90', 't=180'). |
t_star |
A numeric time point corresponding to one of the columns in 'predicted'. If 'NULL', the last column is used. |
Value
A single numeric value representing the C-index.
Examples
library(survival)
data(veteran, package = "survival")
mod <- survdnn(Surv(time, status) ~
age + karno + celltype, data = veteran, epochs = 50, verbose = FALSE)
pred <- predict(mod, newdata = veteran, type = "survival", times = c(30, 90, 180))
y <- model.response(model.frame(mod$formula, veteran))
cindex_survmat(y, pred, t_star = 180)
K-Fold Cross-Validation for survdnn Models
Description
Performs cross-validation for a 'survdnn' model using the specified evaluation metrics.
Usage
cv_survdnn(
formula,
data,
times,
metrics = c("cindex", "ibs"),
folds = 5,
.seed = NULL,
...
)
Arguments
formula |
A survival formula, e.g., 'Surv(time, status) ~ x1 + x2'. |
data |
A data frame. |
times |
A numeric vector of evaluation time points. |
metrics |
A character vector: any of '"cindex"', '"brier"', '"ibs"'. |
folds |
Integer. Number of folds to use. |
.seed |
Optional. Set random seed for reproducibility. |
... |
Additional arguments passed to [survdnn()]. |
Value
A tibble containing metric values per fold and (optionally) per time point.
Examples
library(survival)
data(veteran)
cv_survdnn(
Surv(time, status) ~ age + karno + celltype,
data = veteran,
times = c(30, 90, 180),
metrics = "ibs",
folds = 3,
.seed = 42,
hidden = c(16, 8),
epochs = 5
)
Evaluate a survdnn Model Using Survival Metrics
Description
Computes evaluation metrics for a fitted 'survdnn' model at one or more time points. Supported metrics include the concordance index ('"cindex"'), Brier score ('"brier"'), and integrated Brier score ('"ibs"').
Usage
evaluate_survdnn(
model,
metrics = c("cindex", "brier", "ibs"),
times,
newdata = NULL
)
Arguments
model |
A fitted 'survdnn' model object. |
metrics |
A character vector of metric names: '"cindex"', '"brier"', '"ibs"'. |
times |
A numeric vector of evaluation time points. |
newdata |
Optional. A data frame on which to evaluate the model. Defaults to training data. |
Value
A tibble with evaluation results, containing at least 'metric', 'value', and possibly 'time'.
Examples
library(survival)
data(veteran)
mod <- survdnn(Surv(time, status) ~ age + karno + celltype,
data = veteran, epochs = 5, verbose = FALSE)
evaluate_survdnn(mod, metrics = c("cindex", "ibs"), times = c(30, 90, 180))
evaluate_survdnn(mod, metrics = "brier", times = c(30, 90, 180))
Grid Search for survdnn Hyperparameters
Description
Performs grid search over user-specified hyperparameters and evaluates performance on a validation set.
Usage
gridsearch_survdnn(
formula,
train,
valid,
times,
metrics = c("cindex", "ibs"),
param_grid,
.seed = 42
)
Arguments
formula |
A survival formula (e.g., 'Surv(time, status) ~ .') |
train |
Training dataset |
valid |
Validation dataset |
times |
Evaluation time points (numeric vector) |
metrics |
Evaluation metrics (character vector): any of "cindex", "ibs", "brier" |
param_grid |
A named list of hyperparameters ('hidden', 'lr', 'activation', 'epochs', 'loss') |
.seed |
Optional random seed for reproducibility |
Value
A tibble with configurations and their validation metrics
Examples
library(survdnn)
library(survival)
set.seed(123)
# Simulate small dataset
n <- 300
x1 <- rnorm(n); x2 <- rbinom(n, 1, 0.5)
time <- rexp(n, rate = 0.1)
status <- rbinom(n, 1, 0.7)
df <- data.frame(time, status, x1, x2)
# Split into training and validation
idx <- sample(seq_len(n), 0.7 * n)
train <- df[idx, ]
valid <- df[-idx, ]
# Define formula and param grid
formula <- Surv(time, status) ~ x1 + x2
param_grid <- list(
hidden = list(c(16, 8), c(32, 16)),
lr = c(1e-3),
activation = c("relu"),
epochs = c(100),
loss = c("cox", "coxtime")
)
# Run grid search
results <- gridsearch_survdnn(
formula = formula,
train = train,
valid = valid,
times = c(10, 20, 30),
metrics = c("cindex", "ibs"),
param_grid = param_grid
)
# View summary
dplyr::group_by(results, hidden, lr, activation, epochs, loss, metric) |>
dplyr::summarise(mean = mean(value, na.rm = TRUE), .groups = "drop")
Integrated Brier Score (IBS) from a Survival Probability Matrix
Description
Computes the Integrated Brier Score (IBS) over a set of evaluation time points, using trapezoidal integration and IPCW adjustment for right-censoring.
Usage
ibs_survmat(object, sp_matrix, times)
Arguments
object |
A 'Surv' object with observed time and status. |
sp_matrix |
A data frame or matrix of predicted survival probabilities. Each column corresponds to a time point in 'times'. |
times |
A numeric vector of time points. Must match the columns of 'sp_matrix'. |
Value
A single numeric value representing the integrated Brier score.
Examples
set.seed(123)
library(survival)
data(veteran, package = "survival")
idx <- sample(nrow(veteran), 0.7 * nrow(veteran))
train <- veteran[idx, ]; test <- veteran[-idx, ]
mod <- survdnn(Surv(time, status) ~
age + karno + celltype, data = train, epochs = 50, verbose = FALSE)
pred <- predict(mod, newdata = test, times = c(30, 90, 180), type = "survival")
y_test <- model.response(model.frame(mod$formula, test))
ibs_survmat(y_test, sp_matrix = pred, times = c(30, 90, 180))
Plot survdnn Survival Curves using ggplot2
Description
Visualizes survival curves predicted by a fitted 'survdnn' model. Curves can be grouped by a categorical variable in 'newdata' and optionally display only the group-wise means or overlay them.
Usage
## S3 method for class 'survdnn'
plot(
x,
newdata = NULL,
times = 1:365,
group_by = NULL,
plot_mean_only = FALSE,
add_mean = TRUE,
alpha = 0.3,
mean_lwd = 1.3,
mean_lty = 1,
...
)
Arguments
x |
A fitted 'survdnn' model object. |
newdata |
Optional data frame for prediction (defaults to training data). |
times |
A numeric vector of time points at which to compute survival probabilities. |
group_by |
Optional name of a column in 'newdata' used to color and group curves. |
plot_mean_only |
Logical; if 'TRUE', plots only the mean survival curve per group. |
add_mean |
Logical; if 'TRUE', adds mean curves to the individual lines. |
alpha |
Alpha transparency for individual curves (ignored if 'plot_mean_only = TRUE'). |
mean_lwd |
Line width for mean survival curves. |
mean_lty |
Line type for mean survival curves. |
... |
Reserved for future use. |
Value
A 'ggplot' object.
Examples
library(survival)
data(veteran)
set.seed(42)
mod <- survdnn(Surv(time, status) ~ age + karno + celltype, data = veteran,
hidden = c(16, 8), epochs = 100, verbose = FALSE)
plot(mod, group_by = "celltype", times = 1:300)
Predict from a survdnn Model
Description
Generate predictions from a fitted 'survdnn' model for new data. Supports linear predictors, survival probabilities at specified time points, or cumulative risk estimates.
Usage
## S3 method for class 'survdnn'
predict(object, newdata, times = NULL, type = c("survival", "lp", "risk"), ...)
Arguments
object |
An object of class '"survdnn"' returned by [survdnn()]. |
newdata |
A data frame of new observations to predict on. |
times |
Numeric vector of time points at which to compute survival or risk probabilities. Required if 'type = "survival"' or 'type = "risk"'. |
type |
Character string specifying the type of prediction to return:
|
... |
Currently ignored (for future extensions). |
Value
A numeric vector (if 'type = "lp"' or '"risk"'), or a data frame (if 'type = "survival"') with one row per observation and one column per 'times'.
Examples
library(survival)
data(veteran, package = "survival")
# Fit survdnn with Cox loss
mod <- survdnn(Surv(time, status) ~ age + karno + celltype, data = veteran,
loss = "cox", epochs = 50, verbose = FALSE)
# Linear predictor (log-risk)
predict(mod, newdata = veteran, type = "lp")[1:5]
# Survival probabilities at selected times
predict(mod, newdata = veteran, type = "survival", times = c(30, 90, 180))[1:5, ]
# Cumulative risk at 180 days
predict(mod, newdata = veteran, type = "risk", times = 180)[1:5]
Print a survdnn Model
Description
Pretty prints a fitted 'survdnn' model. Displays the formula, network architecture, training configuration, and final training loss.
Usage
## S3 method for class 'survdnn'
print(x, ...)
Arguments
x |
An object of class '"survdnn"', returned by [survdnn()]. |
... |
Ignored (for future compatibility). |
Value
The model object, invisibly.
Examples
library(survival)
data(veteran, package = "survival")
mod <- survdnn(Surv(time, status) ~
age + karno + celltype, data = veteran, epochs = 20, verbose = FALSE)
print(mod)
Summarize Cross-Validation Results from survdnn
Description
Computes mean, standard deviation, and confidence intervals for metrics from cross-validation.
Usage
summarize_cv_survdnn(cv_results, by_time = TRUE, conf_level = 0.95)
Arguments
cv_results |
A tibble returned by [cv_survdnn()]. |
by_time |
Logical. Whether to stratify results by 'time' (if present). |
conf_level |
Confidence level for the intervals (default: 0.95). |
Value
A tibble summarizing mean, sd, and confidence bounds per metric (and per time if applicable).
Examples
library(survival)
data(veteran)
res <- cv_survdnn(
Surv(time, status) ~ age + karno + celltype,
data = veteran,
times = c(30, 90, 180, 270),
metrics = c("cindex", "ibs"),
folds = 3,
.seed = 42,
hidden = c(16, 8),
epochs = 5
)
summarize_cv_survdnn(res)
Summarize survdnn Tuning Results
Description
Aggregates cross-validation results from 'tune_survdnn(return = "all")' by configuration, metric, and optionally by time point.
Usage
summarize_tune_survdnn(tuning_results, by_time = TRUE)
Arguments
tuning_results |
The full tibble returned by 'tune_survdnn(..., return = "all")'. |
by_time |
Logical; whether to group and summarize separately by time points. |
Value
A summarized tibble with mean and standard deviation of performance metrics.
Summarize a Deep Survival Neural Network Model
Description
Provides a structured summary of a fitted 'survdnn' model, including the network architecture, training configuration, and data characteristics. The summary is printed automatically with a styled header and sectioned output using {cli} and base formatting. The object is returned invisibly.
Usage
## S3 method for class 'survdnn'
summary(object, ...)
Arguments
object |
An object of class '"survdnn"' returned by the [survdnn()] function. |
... |
Currently ignored (for future compatibility). |
Value
Invisibly returns an object of class '"summary.survdnn"'.
Examples
set.seed(42)
sim_data <- data.frame(
age = rnorm(100, 60, 10),
sex = factor(sample(c("male", "female"), 100, TRUE)),
trt = factor(sample(c("A", "B"), 100, TRUE)),
time = rexp(100, 0.05),
status = rbinom(100, 1, 0.7)
)
mod <- survdnn(Surv(time, status) ~ age + sex + trt, data = sim_data, epochs = 50, verbose = FALSE)
summary(mod)
Fit a Deep Neural Network for Survival Analysis
Description
Trains a deep neural network (DNN) to model right-censored survival data using one of the predefined loss functions: Cox, AFT, or Coxtime.
Usage
survdnn(
formula,
data,
hidden = c(32L, 16L),
activation = "relu",
lr = 1e-04,
epochs = 300L,
loss = c("cox", "cox_l2", "aft", "coxtime"),
verbose = TRUE
)
Arguments
formula |
A survival formula of the form 'Surv(time, status) ~ predictors'. |
data |
A data frame containing the variables in the model. |
Integer vector. Sizes of the hidden layers (default: c(32, 16)). | |
activation |
Character string specifying the activation function to use in each layer. Supported options: '"relu"', '"leaky_relu"', '"tanh"', '"sigmoid"', '"gelu"', '"elu"', '"softplus"'. |
lr |
Learning rate for the Adam optimizer (default: '1e-4'). |
epochs |
Number of training epochs (default: 300). |
loss |
Character name of the loss function to use. One of '"cox"', '"cox_l2"', '"aft"', or '"coxtime"'. |
verbose |
Logical; whether to print loss progress every 50 epochs (default: TRUE). |
Value
An object of class '"survdnn"' containing:
- model
Trained 'nn_module' object.
- formula
Original survival formula.
- data
Training data used for fitting.
- xnames
Predictor variable names.
- x_center
Column means of predictors.
- x_scale
Column standard deviations of predictors.
- loss_history
Vector of loss values per epoch.
- final_loss
Final training loss.
- loss
Loss function name used ("cox", "aft", etc.).
- activation
Activation function used.
- hidden
Hidden layer sizes.
- lr
Learning rate.
- epochs
Number of training epochs.
Examples
set.seed(123)
df <- data.frame(
time = rexp(100, rate = 0.1),
status = rbinom(100, 1, 0.7),
x1 = rnorm(100),
x2 = rbinom(100, 1, 0.5)
)
mod <- survdnn(Surv(time, status) ~ x1 + x2, data = df, epochs = 5
, loss = "cox", verbose = FALSE)
mod$final_loss
Loss Functions for survdnn Models
Description
These functions define various loss functions used internally by 'survdnn()' for training deep neural networks on right-censored survival data.
Usage
cox_loss(pred, true)
cox_l2_loss(pred, true, lambda = 1e-04)
aft_loss(pred, true)
coxtime_loss(pred, true)
Arguments
pred |
A tensor of predicted values (typically linear predictors or log-times). |
true |
A tensor with two columns: observed time and status (1 = event, 0 = censored). |
lambda |
Regularization parameter for 'cox_l2_loss' (default: '1e-4'). |
Value
A scalar 'torch_tensor' representing the loss value.
Supported Losses
- **Cox partial likelihood loss** ('cox_loss'): Negative partial log-likelihood used in proportional hazards modeling. - **L2-penalized Cox loss** ('cox_l2_loss'): Adds L2 regularization to the Cox loss. - **Accelerated Failure Time (AFT) loss** ('aft_loss'): Mean squared error between predicted and log-transformed event times, applied to uncensored observations only. - **CoxTime loss** ('coxtime_loss'): Implements the partial likelihood loss from Kvamme & Borgan (2019), used in Cox-Time models.
Examples
# Used internally by survdnn()
Tune Hyperparameters for a survdnn Model via Cross-Validation
Description
Performs k-fold cross-validation over a user-defined hyperparameter grid and selects the best configuration according to the specified evaluation metric.
Usage
tune_survdnn(
formula,
data,
times,
metrics = "cindex",
param_grid,
folds = 3,
.seed = 42,
refit = FALSE,
return = c("all", "summary", "best_model")
)
Arguments
formula |
A survival formula, e.g., 'Surv(time, status) ~ x1 + x2'. |
data |
A data frame. |
times |
A numeric vector of evaluation time points. |
metrics |
A character vector of evaluation metrics: "cindex", "brier", or "ibs". Only the first metric is used for model selection. |
param_grid |
A named list defining hyperparameter combinations to evaluate. Required names: 'hidden', 'lr', 'activation', 'epochs', 'loss'. |
folds |
Number of cross-validation folds (default: 3). |
.seed |
Optional seed for reproducibility (default: 42). |
refit |
Logical. If TRUE, refits the best model on the full dataset. |
return |
One of "all", "summary", or "best_model":
|
Value
A tibble or model object depending on the 'return' value.