Estimating and Evaluating the Optimal Subgroup


This vignette showcase how policy_learn() and policy_eval() can be combined to estimate and evaluate the optimal subgroup in the single-stage case. We refer to (Nordland and Holst 2023) for the syntax and methodological context.


From here on we consider the single-stage case with a binary action set {0, 1}. For a given threshold η > 0 we can formulate the optimal subgroup function via the conditional average treatment effect (CATE/blip) as

where B0 is the CATE defined as

The average treatment effect in the optimal subgroup is now defined as

which under consistency, positivity and randomization is identified as

where Z(a, g, Q)(O) is the doubly robust score for treatment a and

Threshold policy learning

In polle the threshold policy dη can be estimated using policy_learn() via the threshold argument, and the average treatment effect in the subgroup can be estimated using policy_eval() setting target = subgroup.

Here we consider an example using simulated data:

par0 <- list(a = 1, b = 0, c = 3)
sim_d <- function(n, par=par0, potential_outcomes = FALSE) {
  W <- runif(n = n, min = -1, max = 1)
  L <- runif(n = n, min = -1, max = 1)
  A <- rbinom(n = n, size = 1, prob = 0.5)
  U1 <- W + L + (par$c*W + par$a*L + par$b) # U^1
  U0 <- W + L # U^0
  U <- A * U1 + (1 - A) * U0 + rnorm(n = n)
  out <- data.table(W = W, L = L, A = A, U = U)
  if (potential_outcomes == TRUE) {
    out$U0 <- U0
    out$U1 <- U1

Note that in this simple case U(1) − U(0) = cW + aL + b.

d <- sim_d(n = 200)
pd <- policy_data(
    action = "A",
    covariates = list("W", "L"),
    utility = "U"

We set a correctly specified policy learner using policy_learn() with type = "blip" and set a threshold of η = 1:

pl1 <- policy_learn(
  type = "blip",
  control = control_blip(blip_models = q_glm(~ W + L)),
  threshold = 1

When then apply the policy learner based on the correctly specified nuisance models. Furthermore, we extract the corresponding policy actions, where dN(Z, L) = 1 identifies the optimal subgroup for η = 1:

po1 <- pl1(
  policy_data = pd,
  g_models = g_glm(~ 1),
  q_models = q_glm(~ A * (W + L))
pf1 <- get_policy(po1)
pa <- pf1(pd)

In the following plot, the black line indicates the boundary for the true optimal subgroup. The dots represent the estimated threshold policy:

Similarly, we can also use type = "ptl" to fit a policy tree with a given threshold for not choosing the reference action (first action in action set in alphabetical order)

get_action_set(pd)[1] # reference action
## [1] "0"
pl1_ptl <- policy_learn(
    type = "ptl",
    control = control_ptl(policy_var = c("W", "L")),
    threshold = 1
## Loading required namespace: policytree
po1_ptl <- pl1_ptl(
  policy_data = pd,
  g_models = g_glm(~ 1),
  q_models = q_glm(~ A * (W + L))
## $stage_1
## $stage_1$threshold_1
## policy_tree object 
## Tree depth:  2 
## Actions:  1: 0 2: 1 
## Variable splits: 
## (1) split_variable: W  split_value: -0.0948583 
##   (2) split_variable: W  split_value: -0.107529 
##     (4) * action: 1 
##     (5) * action: 2 
##   (3) split_variable: W  split_value: 0.197522 
##     (6) * action: 1 
##     (7) * action: 2

Subgroup average treatment effect

The true subgroup average treatment effect is given by:

which we can easily approximate:

approx <- sim_d(n = 1e7, potential_outcomes = TRUE)
(sate <- with(approx, mean((U1 - U0)[(U1 - U0 >= 1)])))
## [1] 2.082982

The subgroup average treatment effect associated with the learned optimal threshold policy can be directly estimated using policy_eval() via the target argument:

(pe <- policy_eval(
  policy_data = pd,
  policy_learn = pl1,
  target = "subgroup"
##                                 Estimate Std.Err  2.5% 97.5%   P-value
## E[U(1)-U(0)|d=1]: d=blip(eta=1)    1.941  0.2614 1.428 2.453 1.136e-13

We can also estimate the subgroup average treatment effect for a set of thresholds at once:

pl_set <- policy_learn(
  type = "blip",
  control = control_blip(blip_models = q_glm(~ W + L)),
  threshold = c(0, 1)

  policy_data = pd,
  g_models = g_glm(~ 1),
  q_models = q_glm(~ A * (W + L)),
  policy_learn = pl_set,
  target = "subgroup"
##                                 Estimate Std.Err  2.5% 97.5%   P-value
## E[U(1)-U(0)|d=1]: d=blip(eta=0)    1.641  0.2161 1.217 2.064 3.118e-14
## E[U(1)-U(0)|d=1]: d=blip(eta=1)    1.935  0.2612 1.423 2.447 1.268e-13


The data adaptive target parameter

is asymptotically normal with influence function

where d is the limiting policy of dN. The fitted influence curve can be extracted using IC():

IC(pe) |> head()
##           [,1]
## [1,]  0.000000
## [2,]  0.000000
## [3,]  0.000000
## [4,] -1.780405
## [5,]  0.000000
## [6,]  9.520253


Nordland, Andreas, and Klaus K. Holst. 2023. “Policy Learning with the Polle Package.”