| Title: | Pathwise Estimation of Covariate Balancing Propensity Scores |
| Version: | 0.0.1 |
| Description: | Provides pathwise estimation of regularized logistic propensity score models using covariate balancing loss functions rather than maximum likelihood. Regularization paths are fit via the 'adelie' elastic-net solver with a 'glmnet'-like interface and objectives that directly target covariate balance for the ATE and ATT. For details, see Sverdrup & Hastie (2026) <doi:10.48550/arXiv.2602.18577>. |
| License: | MIT + file LICENSE |
| Encoding: | UTF-8 |
| RoxygenNote: | 7.3.3 |
| LinkingTo: | Rcpp, RcppEigen |
| SystemRequirements: | C++17 |
| Imports: | Rcpp, Matrix, methods |
| Suggests: | testthat (≥ 3.0.0), knitr, rmarkdown |
| URL: | https://github.com/erikcs/balnet |
| BugReports: | https://github.com/erikcs/balnet/issues |
| VignetteBuilder: | knitr |
| NeedsCompilation: | yes |
| Packaged: | 2026-03-31 02:29:15 UTC; erikcs |
| Author: | Erik Sverdrup [aut, cre], Trevor Hastie [aut], James Yang [ctb] (adelie core author) |
| Maintainer: | Erik Sverdrup <erik.sverdrup@monash.edu> |
| Repository: | CRAN |
| Date/Publication: | 2026-04-03 09:00:16 UTC |
Pathwise estimation of covariate balancing propensity scores.
Description
Fits regularized logistic regression models using covariate balancing loss functions, targeting the ATE, ATT, or treated/control means.
Usage
balnet(
X,
W,
target = c("ATE", "ATT", "treated", "control"),
sample.weights = NULL,
max.imbalance = NULL,
nlambda = 100L,
lambda.min.ratio = 0.01,
lambda = NULL,
penalty.factor = NULL,
groups = NULL,
alpha = 1,
standardize = TRUE,
tol = 1e-07,
maxit = as.integer(1e+05),
verbose = FALSE,
num.threads = 1L,
...
)
Arguments
X |
A numeric matrix or data frame with pre-treatment covariates. |
W |
Treatment vector (0 = control, 1 = treated). |
target |
The target estimand. Default is "ATE". |
sample.weights |
Optional sample weights. If |
max.imbalance |
Optional upper bound on the standardized covariate imbalance. For lasso penalization
( |
nlambda |
Number of values for |
lambda.min.ratio |
Ratio of smallest to largest lambda. Default is 1e-2. |
lambda |
Optional |
penalty.factor |
Penalty factor per feature. Default is 1 (i.e., each feature receives the same penalty). |
groups |
Optional list of group indices for group penalization. |
alpha |
Elastic net mixing parameter. Default is 1 (lasso), 0 corresponds to ridge. |
standardize |
Whether to standardize the input matrix. Should only be |
tol |
Coordinate descent convergence tolerance. Default is 1e-7. |
maxit |
Maximum number of coordinate descent iterations. Default is 1e5. |
verbose |
Whether to display information during fitting. Default is |
num.threads |
Number of threads to use. Default is 1. |
... |
Additional internal arguments passed to the solver. |
Details
This function aims to find balancing weights \hat\gamma_i, using logistic propensity scores,
that balance covariate means to a target vector, i.e.,
\frac{1}{n} \sum_{i=1}^n \hat\gamma_i X_i = \bar X_{\mathrm{target}}.
With lasso regularization (alpha = 1), imbalance is controlled in the \ell_\infty sense,
allowing absolute slack of at most \lambda per covariate.
For target = "ATE", two logistic models are fit, one per arm, with
\hat\gamma_i^{(1)} = \frac{W_i}{\hat e^{(1)}(X_i)}, \quad
\hat\gamma_i^{(0)} = \frac{1 - W_i}{1 - \hat e^{(0)}(X_i)}, \quad
\bar X_{\mathrm{target}} = \frac{1}{n} \sum_{i=1}^n X_i.
\hat e^{(w)}(X_i) is the fitted propensity score for arm w.
For target = "ATT", weights balance the control means:
\hat\gamma_i = (1 - W_i) \frac{\hat e^{(0)}(X_i)}{1 - \hat e^{(0)}(X_i)}, \quad
\bar X_{\mathrm{target}} = \frac{1}{\sum W_i} \sum_{i=1}^n W_i X_i.
Value
A fit balnet object.
References
Sverdrup, Erik and Trevor Hastie. "balnet: Pathwise Estimation of Covariate Balancing Propensity Scores". arXiv preprint, arXiv:2602.18577, 2026.
Examples
# Simulate data with confounding.
n <- 2000
p <- 10
X <- matrix(rnorm(n * p), n, p)
W <- rbinom(n, 1, 1 / (1.5 + exp(X[, 2] + X[, 3])))
Y <- W + 2 * log(1 + exp(X[, 1] + X[, 2] + X[, 3])) + rnorm(n)
# Fit model targeting the ATE = E[Y(1)] - E[Y(0)].
# Two logistic models are fit: one for treated, one for control.
fit <- balnet(X, W, target = "ATE")
# Print path summary.
print(fit)
# Visualize the path.
plot(fit)
# Plot the standardized covariate imbalance at given lambda.
# Note: lambda = 0 selects the final lambda in the sequence. Scalar values
# are applied to both arms.
plot(fit, lambda = 0)
# Predict propensity scores at end of lambda path.
W.hat <- predict(fit, X, lambda = 0)
# Get balancing weights at end of lambda path.
ipw.weights <- balweights(fit, lambda = 0)
# Estimate ATE using balancing weights.
mean(Y * (ipw.weights$treated - ipw.weights$control))
Low-level fit function for adelie cbps solver.
Description
Low-level fit function for adelie cbps solver.
Usage
balnet.fit(
stan,
y,
weights = NULL,
target_scale = 1,
lambda = NULL,
lmda_path_size = 100L,
min_ratio = 0.01,
penalty = NULL,
groups = NULL,
alpha = 1,
irls_max_iters = as.integer(10000),
irls_tol = 1e-07,
max_iters = as.integer(1e+05),
tol = 1e-07,
newton_max_iters = 1000L,
newton_tol = 1e-12,
screen_rule = c("pivot", "strong"),
max_screen_size = NULL,
max_active_size = NULL,
pivot_subset_ratio = 0.1,
pivot_subset_min = 1L,
pivot_slack_ratio = 1.25,
progress_bar = FALSE,
progress_bar_prefix = "",
n_threads = 1L
)
Arguments
stan |
List containing the standardized feature matrix along with mean and scales. |
y |
The 0/1 outcome. |
weights |
Sample weights. |
target_scale |
Gradient scaling for glm. |
lambda |
Optional |
lmda_path_size |
Number of values for |
min_ratio |
Ratio between smallest and largest value of lambda. Default is 1e-2. |
penalty |
Penalty factor per feature. Default is 1. |
groups |
List of group indices. Default is each variable is a group. |
alpha |
Elastic net mixing parameter. Default is 1 (lasso). 0 is ridge. |
irls_max_iters |
Maximum number of IRLS iterations, default is 1e4. |
irls_tol |
IRLS convergence tolerance, default is 1e-7. |
max_iters |
Maximum total number of coordinate descent iterations at each BASIL step, default is 1e5. |
tol |
Coordinate descent convergence tolerance, default 1e-7. |
newton_max_iters |
Maximum number of iterations for the BCD update, default 1000. |
newton_tol |
Convergence tolerance for the BCD update, default 1e-12. |
screen_rule |
Screen rule, with default |
max_screen_size |
Maximum number of screen groups. |
max_active_size |
Maximum number of active groups. |
pivot_subset_ratio |
Subset ratio of pivot rule. |
pivot_subset_min |
Minimum subset of pivot rule. |
pivot_slack_ratio |
Slack ratio of pivot rule. |
progress_bar |
Progress bar. Default is |
progress_bar_prefix |
Progress bar prefix. Default is none. |
n_threads |
Number of threads, default 1. |
Value
A balnet.fit object.
Extract balancing weights from a balnet object.
Description
Retrieves the estimated balancing weights \hat{\gamma}.
Under unconfoundedness, these correspond to inverse probability weights (IPW)
for standard treatment effect estimands.
Usage
balweights(object, lambda = NULL, ...)
## S3 method for class 'balnet'
balweights(object, lambda = NULL, ...)
## S3 method for class 'cv.balnet'
balweights(object, lambda = "lambda.min", ...)
Arguments
object |
A |
lambda |
Value(s) of the penalty parameter
|
... |
Additional arguments (currently ignored). |
Value
Estimated balancing weights
(for contrast fits, target = "ATE" or "ATT", returns a list with entries for each arm).
Examples
n <- 100
p <- 25
X <- matrix(rnorm(n * p), n, p)
W <- rbinom(n, 1, 1 / (1 + exp(1 - X[, 1])))
# Fit an ATT model.
fit <- balnet(X, W, target = "ATT")
# Extract balancing weights.
wts <- balweights(fit, lambda = 0)
Extract coefficients from a balnet object.
Description
Extract coefficients from a balnet object.
Usage
## S3 method for class 'balnet'
coef(object, lambda = NULL, ...)
Arguments
object |
A |
lambda |
Value(s) of the penalty parameter
|
... |
Additional arguments (currently ignored). |
Value
Estimated logistic coefficients (for dual-arm fits, returns a list with entries for each arm).
Examples
n <- 100
p <- 25
X <- matrix(rnorm(n * p), n, p)
W <- rbinom(n, 1, 1 / (1 + exp(1 - X[, 1])))
# Fit an ATT model.
fit <- balnet(X, W, target = "ATT")
# Extract coefficients.
coefs <- coef(fit)
Extract coefficients from a balnet.fit object.
Description
Extract coefficients from a balnet.fit object.
Usage
## S3 method for class 'balnet.fit'
coef(object, lambda = NULL, ...)
Arguments
object |
A |
lambda |
Value(s) for the penalty parameter. If NULL (default), the
lambda path on which |
... |
Additional arguments (currently ignored). |
Value
Coefficients.
Extract coefficients from a cv.balnet object.
Description
Extract coefficients from a cv.balnet object.
Usage
## S3 method for class 'cv.balnet'
coef(object, lambda = "lambda.min", ...)
Arguments
object |
A |
lambda |
The lambda to use. Defaults to the cross-validated lambda. |
... |
Additional arguments (currently ignored). |
Value
Estimated logistic coefficients (for dual-arm fits, returns a list with entries for each arm).
Examples
n <- 100
p <- 15
X <- matrix(rnorm(n * p), n, p)
W <- rbinom(n, 1, 1 / (1 + exp(1 - X[, 1])))
# Fit an ATT model.
cv.fit <- cv.balnet(X, W, target = "ATT")
# Extract coefficients at cross-validated lambda.
coefs <- coef(cv.fit)
Compute weighted column means and standard deviations.
Description
For X column x, and weight column w, this function computes
\mu = \frac{\sum_{i}^{n} w_i x_i}{\sum_{i}^{n} w_i} and
\sigma^2 = \frac{\sum_{i}^{n} w_i (x_i - \mu)^2}{\sum_{i}^{n} w_i}.
Usage
col_stats(X, weights = NULL, compute_sd = FALSE, n_threads = 1L)
Arguments
X |
A |
weights |
A |
compute_sd |
Whether to return the standard deviation. |
Value
L * p matrices of column stats.
Cross-validation for balnet.
Description
Cross-validation for balnet.
Usage
cv.balnet(
X,
W,
type.measure = c("balance.loss"),
nfolds = 10,
foldid = NULL,
...
)
Arguments
X |
A numeric matrix or data frame with pre-treatment covariates. |
W |
Treatment vector (0: control, 1: treated). |
type.measure |
The loss to minimize for cross-validation. Default is balance loss. |
nfolds |
The number of folds used for cross-validation, default is 10. |
foldid |
An optional |
... |
Arguments for |
Value
A fit cv.balnet object.
Examples
n <- 100
p <- 15
X <- matrix(rnorm(n * p), n, p)
W <- rbinom(n, 1, 1 / (1 + exp(1 - X[, 1])))
# Fit an ATE model.
cv.fit <- cv.balnet(X, W)
# Print CV summary.
print(cv.fit)
# Plot at cross-validated lambda.
plot(cv.fit)
# Predict at cross-validated lambda.
W.hat <- predict(cv.fit, X)
Plot diagnostics for a balnet object.
Description
Shows effective sample size (ESS) and percent bias reduction (PBR; reduction
in mean absolute imbalance) along the regularization path, computed from balancing
weights and normalized to percentages. The right-hand axis maps these values
to the coefficient of variation (CV) of the weights.
Supplying the lambda argument displays the standardized covariate imbalance
(\bar X_{\mathrm{weighted}} - \bar X_{\mathrm{target}}) / \sigma_{\mathrm{target}},
computed using the balancing weights at the specified lambda.
Usage
## S3 method for class 'balnet'
plot(x, lambda = NULL, groups = NULL, max = NULL, ...)
Arguments
x |
A |
lambda |
If NULL (default) diagnostics over the lambda path is shown. Otherwise, covariate balance at provided lambda value is shown (if target = "ATE", lambda can be a 2-vector, arm 0 and arm 1.) |
groups |
Optional named list of contiguous covariate index ranges to
aggregate into a single variable before computing covariate imbalance
(e.g., |
max |
The number of covariates to display in covariate balance plot. Defaults to all covariates. |
... |
Additional arguments. |
Value
Invisibly returns the information underlying the plot.
Examples
n <- 100
p <- 25
X <- matrix(rnorm(n * p), n, p)
W <- rbinom(n, 1, 1 / (1 + exp(1 - X[, 1])))
# Fit an ATT model.
fit <- balnet(X, W, target = "ATT")
# Plot the five covariates with the largest unweighted imbalance
plot(fit, lambda = 0, max = 5)
Plot diagnostics for a cv.balnet object.
Description
Plot diagnostics for a cv.balnet object.
Usage
## S3 method for class 'cv.balnet'
plot(x, lambda = "lambda.min", ...)
Arguments
x |
A |
lambda |
The lambda to use. Defaults to the cross-validated lambda. |
... |
Additional arguments. |
Value
Invisibly returns the information underlying the plot.
Examples
n <- 100
p <- 15
X <- matrix(rnorm(n * p), n, p)
W <- rbinom(n, 1, 1 / (1 + exp(1 - X[, 1])))
# Fit an ATT model.
cv.fit <- cv.balnet(X, W, target = "ATT")
# Plot at cross-validated lambda.
plot(cv.fit)
Predict using a balnet object.
Description
Predict using a balnet object.
Usage
## S3 method for class 'balnet'
predict(object, newdata, lambda = NULL, type = c("response"), ...)
Arguments
object |
A |
newdata |
A numeric matrix. |
lambda |
Value(s) of the penalty parameter
|
type |
The type of predictions. Default is "response" (propensity scores). |
... |
Additional arguments (currently ignored). |
Value
Estimated predictions (for dual-arm fits, returns a list with entries for each arm).
Examples
n <- 100
p <- 25
X <- matrix(rnorm(n * p), n, p)
W <- rbinom(n, 1, 1 / (1 + exp(1 - X[, 1])))
# Fit an ATT model.
fit <- balnet(X, W, target = "ATT")
# Predict propensity scores.
W.hat <- predict(fit, X)
Predict using a balnet.fit object.
Description
Predict using a balnet.fit object.
Usage
## S3 method for class 'balnet.fit'
predict(object, newdata, lambda = NULL, type = c("response", "link"), ...)
Arguments
object |
A balnet.fit object. |
newdata |
A numeric matrix. |
lambda |
Value(s) for the penalty parameter. If NULL (default), the
lambda path on which |
type |
The type of predictions. |
... |
Additional arguments (currently ignored). |
Value
Predictions.
Predict using a cv.balnet object.
Description
Predict using a cv.balnet object.
Usage
## S3 method for class 'cv.balnet'
predict(object, newdata, lambda = "lambda.min", type = c("response"), ...)
Arguments
object |
A |
newdata |
A numeric matrix. |
lambda |
The lambda to use. Defaults to the cross-validated lambda. |
type |
The type of predictions. Default is "response" (propensity scores). |
... |
Additional arguments (currently ignored). |
Value
Estimated predictions (for dual-arm fits, returns a list with entries for each arm).
Examples
n <- 100
p <- 15
X <- matrix(rnorm(n * p), n, p)
W <- rbinom(n, 1, 1 / (1 + exp(1 - X[, 1])))
# Fit an ATT model.
cv.fit <- cv.balnet(X, W, target = "ATT")
# Predict propensity scores at cross-validated lambda.
W.hat <- predict(cv.fit, X)
Print a balnet object.
Description
Print a balnet object.
Usage
## S3 method for class 'balnet'
print(x, digits = max(3L, getOption("digits") - 3L), max = 3, ...)
Arguments
x |
A |
digits |
Number of digits to print. |
max |
Total number of rows to show from the beginning and end of the path |
... |
Additional print arguments. |
Value
Invisibly returns the printed information.
Examples
n <- 100
p <- 25
X <- matrix(rnorm(n * p), n, p)
W <- rbinom(n, 1, 1 / (1 + exp(1 - X[, 1])))
# Fit an ATT model.
fit <- balnet(X, W, target = "ATT")
# Print path summary.
print(fit)
Print a cv.balnet object.
Description
Print a cv.balnet object.
Usage
## S3 method for class 'cv.balnet'
print(x, digits = max(3L, getOption("digits") - 3L), ...)
Arguments
x |
A |
digits |
Number of digits to print. |
... |
Additional print arguments. |
Value
Invisibly returns the printed information.
Examples
n <- 100
p <- 15
X <- matrix(rnorm(n * p), n, p)
W <- rbinom(n, 1, 1 / (1 + exp(1 - X[, 1])))
# Fit an ATT model.
cv.fit <- cv.balnet(X, W, target = "ATT")
# Print CV summary.
print(cv.fit)
Quickly center and scale a standard dense R matrix.
Description
Quickly center and scale a standard dense R matrix.
Usage
standardize(
X,
weights = NULL,
standardize = TRUE,
inplace = FALSE,
n_threads = 1L
)
Arguments
X |
A numeric R matrix. |
weights |
Sample weights. |
standardize |
Whether to standardize. |
inplace |
Whether to overwrite X. |
n_threads |
Number of threads used. |
Value
A list containing the standardized information.