This vignette is a guide to policy_learn() and some of
the associated S3 methods. The purpose of policy_learn is
to specify a policy learning algorithm and estimate an optimal policy.
For details on the methodology, see the associated paper (Nordland and Holst 2023).
We consider a fixed two-stage problem as a general setup and simulate
data using sim_two_stage() and create a
policy_data object using policy_data():
d <- sim_two_stage(n = 2e3, seed = 1)
pd <- policy_data(d,
action = c("A_1", "A_2"),
baseline = c("B", "BB"),
covariates = list(L = c("L_1", "L_2"),
C = c("C_1", "C_2")),
utility = c("U_1", "U_2", "U_3"))
pd
#> Policy data with n = 2000 observations and maximal K = 2 stages.
#>
#> action
#> stage 0 1 n
#> 1 1017 983 2000
#> 2 819 1181 2000
#>
#> Baseline covariates: B, BB
#> State covariates: L, C
#> Average utility: 0.84policy_learn() specify a policy learning algorithm via
the type argument: Q-learning (ql), doubly
robust Q-learning (drql), doubly robust blip learning
(blip), policy tree learning (ptl), and
outcome weighted learning (owl).
Because each policy learning type has varying control arguments,
these are passed as a list using the control argument. To
help the user set the required control arguments and to provide
documentation, each type has a helper function
control_type() which sets the default control arguments and
overwrite values if supplied by the user.
As an example we specify a doubly robust blip learner:
pl_blip <- policy_learn(
type = "blip",
control = control_blip(
blip_models = q_glm(formula = ~ BB + L + C)
)
)For details on the implementation, see Algorithm 3 in (Nordland and Holst 2023). The only required
control argument for blip learning is a model input. The
blip_models argument expects a q_model. In
this case we input a simple linear model as implemented in
q_glm.
The output of policy_learn() is again a function:
pl_blip
#> Policy learner with arguments:
#> policy_data, g_models=NULL, g_functions=NULL,
#> g_full_history=FALSE, q_models, q_full_history=FALSEIn order to apply the policy learner we need to input a
policy_data object and nuisance models
g_models and q_models for computing the doubly
robust score.
Like policy_eval() is it possible to cross-fit the
doubly robust score used as input to the policy model. The number of
folds for the cross-fitting procedure is provided via the L
argument. As default, the cross-fitted nuisance models are not saved.
The cross-fitted nuisance models can be saved via the
save_cross_fit_models argument:
pl_blip_cross <- policy_learn(
type = "blip",
control = control_blip(
blip_models = q_glm(formula = ~ BB + L + C)
),
L = 2,
save_cross_fit_models = TRUE
)
po_blip_cross <- pl_blip_cross(
pd,
g_models = list(g_glm(), g_glm()),
q_models = list(q_glm(), q_glm())
)From a user perspective, nothing has changed. However, the policy object now contains each of the cross-fitted nuisance models:
po_blip_cross$g_functions_cf
#> $`1`
#> $stage_1
#> $model
#>
#> Call: NULL
#>
#> Coefficients:
#> (Intercept) L C B BBgroup2 BBgroup3
#> -0.18321 0.15191 0.90737 -0.03865 0.18927 0.15088
#>
#> Degrees of Freedom: 999 Total (i.e. Null); 994 Residual
#> Null Deviance: 1384
#> Residual Deviance: 1086 AIC: 1098
#>
#>
#> $stage_2
#> $model
#>
#> Call: NULL
#>
#> Coefficients:
#> (Intercept) L C B BBgroup2 BBgroup3
#> 0.24410 0.13150 0.99426 -0.02289 -0.41777 -0.17383
#>
#> Degrees of Freedom: 999 Total (i.e. Null); 994 Residual
#> Null Deviance: 1349
#> Residual Deviance: 1082 AIC: 1094
#>
#>
#> attr(,"full_history")
#> [1] FALSE
#>
#> $`2`
#> $stage_1
#> $model
#>
#> Call: NULL
#>
#> Coefficients:
#> (Intercept) L C B BBgroup2 BBgroup3
#> 0.113952 -0.240397 1.142507 -0.094362 -0.009235 -0.101783
#>
#> Degrees of Freedom: 999 Total (i.e. Null); 994 Residual
#> Null Deviance: 1386
#> Residual Deviance: 1065 AIC: 1077
#>
#>
#> $stage_2
#> $model
#>
#> Call: NULL
#>
#> Coefficients:
#> (Intercept) L C B BBgroup2 BBgroup3
#> 0.15426 0.01307 0.96485 -0.08554 -0.33532 -0.12597
#>
#> Degrees of Freedom: 999 Total (i.e. Null); 994 Residual
#> Null Deviance: 1357
#> Residual Deviance: 1102 AIC: 1114
#>
#>
#> attr(,"full_history")
#> [1] FALSERealistic policy learning is implemented for types ql,
drql, blip and ptl (for a binary
action set). The alpha argument sets the probability
threshold for defining the realistic action set. For implementation
details, see Algorithm 5 in (Nordland and Holst
2023). Here we set a 5% restriction:
pl_blip_alpha <- policy_learn(
type = "blip",
control = control_blip(
blip_models = q_glm(formula = ~ BB + L + C)
),
alpha = 0.05,
L = 2
)
po_blip_alpha <- pl_blip_alpha(
pd,
g_models = list(g_glm(), g_glm()),
q_models = list(q_glm(), q_glm())
)The policy object now lists the alpha level as well as
the g-model used to define the realistic action set:
po_blip_alpha$g_functions
#> $stage_1
#> $model
#>
#> Call: NULL
#>
#> Coefficients:
#> (Intercept) L C B BBgroup2 BBgroup3
#> -0.03295 -0.05107 1.02271 -0.06478 0.09582 0.02370
#>
#> Degrees of Freedom: 1999 Total (i.e. Null); 1994 Residual
#> Null Deviance: 2772
#> Residual Deviance: 2161 AIC: 2173
#>
#>
#> $stage_2
#> $model
#>
#> Call: NULL
#>
#> Coefficients:
#> (Intercept) L C B BBgroup2 BBgroup3
#> 0.19814 0.07355 0.97991 -0.05280 -0.37163 -0.14598
#>
#> Degrees of Freedom: 1999 Total (i.e. Null); 1994 Residual
#> Null Deviance: 2707
#> Residual Deviance: 2186 AIC: 2198
#>
#>
#> attr(,"full_history")
#> [1] FALSEget_policy_functions()A policy function is great for evaluating a given policy
or even implementing or simulating from a single-stage policy. However,
the function is not useful for implementing or simulating from a learned
multi-stage policy. To access the policy function for each stage we use
get_policy_functions(). In this case we get the second
stage policy function:
The stage specific policy requires a data.table with
named columns as input and returns a character vector with the
recommended actions:
get_policy()Applying the policy learner returns a policy_object
containing all of the components needed to specify the learned policy.
In this the only component of the policy is a model for the blip
function:
po_blip$blip_functions$stage_1$blip_model
#> $model
#>
#> Call: NULL
#>
#> Coefficients:
#> (Intercept) BBgroup2 BBgroup3 L C
#> 0.4076 0.2585 0.2231 0.1765 0.8624
#>
#> Degrees of Freedom: 1999 Total (i.e. Null); 1995 Residual
#> Null Deviance: 56820
#> Residual Deviance: 53220 AIC: 12250
#>
#> attr(,"class")
#> [1] "q_glm"To access and apply the policy itself use get_policy(),
which behaves as a policy meaning that we can apply to any
(suitable) policy_data object to get the policy
actions:
sessionInfo()
#> R version 4.4.1 (2024-06-14)
#> Platform: aarch64-apple-darwin23.5.0
#> Running under: macOS Sonoma 14.6.1
#>
#> Matrix products: default
#> BLAS: /Users/oano/.asdf/installs/R/4.4.1/lib/R/lib/libRblas.dylib
#> LAPACK: /Users/oano/.asdf/installs/R/4.4.1/lib/R/lib/libRlapack.dylib; LAPACK version 3.12.0
#>
#> locale:
#> [1] en_US.UTF-8/en_US.UTF-8/en_US.UTF-8/C/en_US.UTF-8/en_US.UTF-8
#>
#> time zone: Europe/Copenhagen
#> tzcode source: internal
#>
#> attached base packages:
#> [1] splines stats graphics grDevices utils datasets methods
#> [8] base
#>
#> other attached packages:
#> [1] ggplot2_3.5.1 data.table_1.15.4 polle_1.5
#> [4] SuperLearner_2.0-29 gam_1.22-4 foreach_1.5.2
#> [7] nnls_1.5
#>
#> loaded via a namespace (and not attached):
#> [1] sass_0.4.9 utf8_1.2.4 future_1.33.2
#> [4] lattice_0.22-6 listenv_0.9.1 digest_0.6.36
#> [7] magrittr_2.0.3 evaluate_0.24.0 grid_4.4.1
#> [10] iterators_1.0.14 mvtnorm_1.2-5 policytree_1.2.3
#> [13] fastmap_1.2.0 jsonlite_1.8.8 Matrix_1.7-0
#> [16] survival_3.6-4 fansi_1.0.6 scales_1.3.0
#> [19] numDeriv_2016.8-1.1 codetools_0.2-20 jquerylib_0.1.4
#> [22] lava_1.8.0 cli_3.6.3 rlang_1.1.4
#> [25] mets_1.3.4 parallelly_1.37.1 future.apply_1.11.2
#> [28] munsell_0.5.1 withr_3.0.0 cachem_1.1.0
#> [31] yaml_2.3.8 tools_4.4.1 parallel_4.4.1
#> [34] colorspace_2.1-0 ranger_0.16.0 globals_0.16.3
#> [37] vctrs_0.6.5 R6_2.5.1 lifecycle_1.0.4
#> [40] pkgconfig_2.0.3 timereg_2.0.5 progressr_0.14.0
#> [43] bslib_0.7.0 pillar_1.9.0 gtable_0.3.5
#> [46] Rcpp_1.0.13 glue_1.7.0 xfun_0.45
#> [49] tibble_3.2.1 highr_0.11 knitr_1.47
#> [52] farver_2.1.2 htmltools_0.5.8.1 rmarkdown_2.27
#> [55] labeling_0.4.3 compiler_4.4.1