This vignette is a guide to policy_eval()
and some of
the associated S3 methods. The purpose of policy_eval
is to
estimate (evaluate) the value of a user-defined policy or a policy
learning algorithm. For details on the methodology, see the associated
paper (Nordland and Holst 2023).
We consider a fixed two-stage problem as a general setup and simulate
data using sim_two_stage()
and create a
policy_data
object using policy_data()
:
d <- sim_two_stage(n = 2e3, seed = 1)
pd <- policy_data(d,
action = c("A_1", "A_2"),
baseline = c("B", "BB"),
covariates = list(L = c("L_1", "L_2"),
C = c("C_1", "C_2")),
utility = c("U_1", "U_2", "U_3"))
pd
## Policy data with n = 2000 observations and maximal K = 2 stages.
##
## action
## stage 0 1 n
## 1 1017 983 2000
## 2 819 1181 2000
##
## Baseline covariates: B, BB
## State covariates: L, C
## Average utility: 0.84
User-defined policies are created using policy_def()
. In
this case we define a simple static policy always selecting action
'1'
:
As we want to apply the same policy function at both stages we set
reuse = TRUE
.
policy_eval()
implements three types of policy
evaluations: Inverse probability weighting estimation, outcome
regression estimation, and doubly robust (DR) estimation. As doubly
robust estimation is a combination of the two other types, we focus on
this approach. For details on the implementation see Algorithm 1 in
(Nordland and Holst 2023).
## Estimate Std.Err 2.5% 97.5% P-value
## E[U(d)]: d=(A=1) 0.8213 0.1115 0.6027 1.04 1.796e-13
policy_eval()
returns an object of type
policy_eval
which prints like a lava::estimate
object. The policy value estimate and variance are available via
coef()
and vcov()
:
## [1] 0.8213233
## [,1]
## [1,] 0.01244225
policy_eval
objectsThe policy_eval
object behaves like an
lava::estimate
object, which can also be directly accessed
using estimate()
.
estimate
objects makes it easy to work with estimates
with an iid decomposition given by the influence curve/function, see the
estimate
vignette.
The influence curve is available via IC()
:
## [,1]
## [1,] 2.5515875
## [2,] -5.6787782
## [3,] 4.9506000
## [4,] 2.0661524
## [5,] 0.7939672
## [6,] -2.2932160
Merging estimate
objects allow the user to get inference
for transformations of the estimates via the Delta method. Here we get
inference for the average treatment effect, both as a difference and as
a ratio:
p0 <- policy_def(policy_functions = 0, reuse = TRUE, name = "(A=0)")
pe0 <- policy_eval(policy_data = pd,
policy = p0,
type = "dr")
(est <- merge(pe0, pe1))
## Estimate Std.Err 2.5% 97.5% P-value
## E[U(d)]: d=(A=0) -0.06123 0.0881 -0.2339 0.1114 4.871e-01
## ────────────────
## E[U(d)]: d=(A=1) 0.82132 0.1115 0.6027 1.0399 1.796e-13
## Estimate Std.Err 2.5% 97.5% P-value
## ATE-difference 0.8825 0.1338 0.6203 1.145 4.25e-11
## Estimate Std.Err 2.5% 97.5% P-value
## ATE-ratio -13.41 19.6 -51.83 25 0.4937
So far we have relied on the default generalized linear models for
the nuisance g-models and Q-models. As default, a single g-model trained
across all stages using the state/Markov type history, see the
policy_data
vignette. Use get_g_functions()
to
get access to the fitted model:
## $all_stages
## $model
##
## Call: NULL
##
## Coefficients:
## (Intercept) L C B BBgroup2 BBgroup3
## 0.08285 0.03094 0.97993 -0.05753 -0.13970 -0.06122
##
## Degrees of Freedom: 3999 Total (i.e. Null); 3994 Residual
## Null Deviance: 5518
## Residual Deviance: 4356 AIC: 4368
##
##
## attr(,"full_history")
## [1] FALSE
The g-functions can be used as input to a new policy evaluation:
## Estimate Std.Err 2.5% 97.5% P-value
## E[U(d)]: d=(A=0) -0.06123 0.0881 -0.2339 0.1114 0.4871
or we can get the associated predicted values:
## Key: <id, stage>
## id stage g_0 g_1
## <int> <int> <num> <num>
## 1: 1 1 0.15628741 0.84371259
## 2: 1 2 0.08850558 0.91149442
## 3: 2 1 0.92994454 0.07005546
## 4: 2 2 0.92580890 0.07419110
## 5: 3 1 0.11184451 0.88815549
## 6: 3 2 0.08082666 0.91917334
Similarly, we can inspect the Q-functions using
get_q_functions()
:
## $stage_1
## $model
##
## Call: NULL
##
## Coefficients:
## (Intercept) A1 L C B BBgroup2
## 0.232506 0.682422 0.454642 0.039021 -0.070152 -0.184704
## BBgroup3 A1:L A1:C A1:B A1:BBgroup2 A1:BBgroup3
## -0.171734 -0.010746 0.938791 0.003772 0.157200 0.270711
##
## Degrees of Freedom: 1999 Total (i.e. Null); 1988 Residual
## Null Deviance: 7689
## Residual Deviance: 3599 AIC: 6877
##
##
## $stage_2
## $model
##
## Call: NULL
##
## Coefficients:
## (Intercept) A1 L C B BBgroup2
## -0.043324 0.147356 0.002376 -0.042036 0.005331 -0.001128
## BBgroup3 A1:L A1:C A1:B A1:BBgroup2 A1:BBgroup3
## -0.108404 0.024424 0.962591 -0.059177 -0.102084 0.094688
##
## Degrees of Freedom: 1999 Total (i.e. Null); 1988 Residual
## Null Deviance: 3580
## Residual Deviance: 1890 AIC: 5588
##
##
## attr(,"full_history")
## [1] FALSE
Note that a model is trained for each stage. Again, we can predict
from the Q-models using predict()
.
Usually, we want to specify the nuisance models ourselves using the
g_models
and q_models
arguments:
pe1 <- policy_eval(pd,
policy = p1,
g_models = list(
g_sl(formula = ~ BB + L_1, SL.library = c("SL.glm", "SL.ranger")),
g_sl(formula = ~ BB + L_1 + C_2, SL.library = c("SL.glm", "SL.ranger"))
),
g_full_history = TRUE,
q_models = list(
q_glm(formula = ~ A * (B + C_1)), # including action interactions
q_glm(formula = ~ A * (B + C_1 + C_2)) # including action interactions
),
q_full_history = TRUE)
## Loading required namespace: ranger
Here we train a super learner g-model for each stage using the full
available history and a generalized linear model for the Q-models. The
formula
argument is used to construct the model frame
passed to the model for training (and prediction). The valid formula
terms depending on g_full_history
and
q_full_history
are available via
get_history_names()
:
## [1] "L" "C" "B" "BB"
## [1] "L_1" "C_1" "B" "BB"
## [1] "A_1" "L_1" "L_2" "C_1" "C_2" "B" "BB"
Remember that the action variable at the current stage is always
named A
. Some models like glm
require
interactions to be specified via the model frame. Thus, for some models,
it is important to include action interaction terms for the
Q-models.
The value of a learned policy is an important performance measure,
and policy_eval()
allow for direct evaluation of a given
policy learning algorithm. For details, see Algorithm 4 in (Nordland and Holst 2023).
In polle
, policy learning algorithms are specified using
policy_learn()
, see the associated vignette. These
functions can be directly evaluated in policy_eval()
:
## Estimate Std.Err 2.5% 97.5% P-value
## E[U(d)]: d=ql 1.306 0.06641 1.176 1.437 3.783e-86
In the above example we evaluate the policy estimated via Q-learning.
Alternatively, we can first learn the policy and then pass it to
policy_eval()
:
p_ql <- policy_learn(type = "ql")(pd, q_models = q_glm())
policy_eval(pd,
policy = get_policy(p_ql))
## Estimate Std.Err 2.5% 97.5% P-value
## E[U(d)]: d=ql 1.306 0.06641 1.176 1.437 3.783e-86
A key feature of policy_eval()
is that it allows for
easy cross-fitting of the nuisance models as well the learned policy.
Here we specify two-fold cross-fitting via the M
argument:
Specifically, both the nuisance models and the optimal policy are fitted on each training fold. Subsequently, the doubly robust value score is calculated on the validation folds.
The policy_eval
object now consists of a list of
policy_eval
objects associated with each fold:
## [1] 3 4 5 7 8 10
## Estimate Std.Err 2.5% 97.5% P-value
## E[U(d)]: d=ql 1.261 0.09456 1.075 1.446 1.538e-40
In order to save memory, particularly when cross-fitting, it is
possible not to save the nuisance models via the
save_g_functions
and save_q_functions
arguments.
future.apply
It is easy to parallelize the cross-fitting procedure via the
future.apply
package:
## R version 4.4.3 (2025-02-28)
## Platform: x86_64-pc-linux-gnu
## Running under: Ubuntu 24.04.2 LTS
##
## Matrix products: default
## BLAS: /usr/lib/x86_64-linux-gnu/openblas-pthread/libblas.so.3
## LAPACK: /usr/lib/x86_64-linux-gnu/openblas-pthread/libopenblasp-r0.3.26.so; LAPACK version 3.12.0
##
## locale:
## [1] LC_CTYPE=en_US.UTF-8 LC_NUMERIC=C
## [3] LC_TIME=en_US.UTF-8 LC_COLLATE=C
## [5] LC_MONETARY=en_US.UTF-8 LC_MESSAGES=en_US.UTF-8
## [7] LC_PAPER=en_US.UTF-8 LC_NAME=C
## [9] LC_ADDRESS=C LC_TELEPHONE=C
## [11] LC_MEASUREMENT=en_US.UTF-8 LC_IDENTIFICATION=C
##
## time zone: Etc/UTC
## tzcode source: system (glibc)
##
## attached base packages:
## [1] splines stats graphics grDevices utils datasets methods
## [8] base
##
## other attached packages:
## [1] ggplot2_3.5.1 data.table_1.17.0 polle_1.5.1
## [4] SuperLearner_2.0-29 gam_1.22-5 foreach_1.5.2
## [7] nnls_1.6 rmarkdown_2.29
##
## loaded via a namespace (and not attached):
## [1] sass_0.4.9 future_1.34.0 lattice_0.22-6
## [4] listenv_0.9.1 digest_0.6.37 magrittr_2.0.3
## [7] evaluate_1.0.3 grid_4.4.3 iterators_1.0.14
## [10] mvtnorm_1.3-3 policytree_1.2.3 fastmap_1.2.0
## [13] jsonlite_1.9.1 Matrix_1.7-2 survival_3.8-3
## [16] scales_1.3.0 numDeriv_2016.8-1.1 codetools_0.2-20
## [19] jquerylib_0.1.4 lava_1.8.1 cli_3.6.4
## [22] rlang_1.1.5 mets_1.3.5 parallelly_1.42.0
## [25] future.apply_1.11.3 munsell_0.5.1 withr_3.0.2
## [28] cachem_1.1.0 yaml_2.3.10 tools_4.4.3
## [31] parallel_4.4.3 colorspace_2.1-1 ranger_0.17.0
## [34] globals_0.16.3 buildtools_1.0.0 vctrs_0.6.5
## [37] R6_2.6.1 lifecycle_1.0.4 pkgconfig_2.0.3
## [40] timereg_2.0.6 progressr_0.15.1 bslib_0.9.0
## [43] pillar_1.10.1 gtable_0.3.6 Rcpp_1.0.14
## [46] glue_1.8.0 xfun_0.51 tibble_3.2.1
## [49] sys_3.4.3 knitr_1.49 farver_2.1.2
## [52] htmltools_0.5.8.1 maketools_1.3.2 labeling_0.4.3
## [55] compiler_4.4.3