Nonparametric data-driven approach to discovering heterogeneous subgroups in a selection-on-observables framework. The approach constructs a sequence of groupings, one for each level of granularity. Groupings are nested and feature an optimality property. For each grouping, we obtain point estimation and standard errors for the group average treatment effects (GATEs). Additionally, we assess whether systematic heterogeneity is found by testing the hypotheses that the differences in the GATEs across all pairs of groups are zero. Finally, we investigate the driving mechanisms of effect heterogeneity by computing the average characteristics of units in each group.
build_aggtree(
Y,
D,
X,
honest_frac = 0.5,
method = "aipw",
scores = NULL,
cates = NULL,
is_honest = NULL,
...
)
inference_aggtree(object, n_groups, boot_ci = FALSE, boot_R = 2000)
Outcome vector.
Treatment vector.
Covariate matrix (no intercept).
Fraction of observations to be allocated to honest sample.
Either "raw"
or "aipw"
, controls how node predictions are computed.
Optional, vector of scores to be used in computing node predictions. Useful to save computational time if scores have already been estimated. Ignored if method == "raw"
.
Optional, estimated CATEs. If not provided by the user, CATEs are estimated internally via a causal_forest
.
Logical vector denoting which observations belong to the honest sample. Required only if the cates
argument is used.
Further arguments from rpart.control
.
An aggTrees
object.
Number of desired groups.
Logical, whether to compute bootstrap confidence intervals.
Number of bootstrap replications. Ignored if boot_ci == FALSE
.
build_aggtree
returns an aggTrees
object.
inference_aggtree
returns an aggTrees.inference
object, which in turn contains the aggTrees
object used
in the call.
Aggregation trees are a three-step procedure. First, the conditional average treatment effects (CATEs) are estimated using any
estimator. Second, a tree is grown to approximate the CATEs. Third, the tree is pruned to derive a nested sequence of optimal
groupings, one for each granularity level. For each level of granularity, we can obtain point estimation and inference about
the GATEs.
To implement this methodology, the user can rely on two core functions that handle the various steps.
build_aggtree
constructs the sequence of groupings (i.e., the tree) and estimate the GATEs in each node. The
GATEs can be estimated in several ways. This is controlled by the method
argument. If method == "raw"
, we
compute the difference in mean outcomes between treated and control observations in each node. This is an unbiased estimator
in randomized experiment. If method == "aipw"
, we construct doubly-robust scores and average them in each node. This
is unbiased also in observational studies. Honest regression forests and 5-fold cross fitting are used to estimate the
propensity score and the conditional mean function of the outcome (unless the user specifies the argument scores
).
The user can provide a vector of the estimated CATEs via the cates
argument. If so, the user needs to specify a logical
vector to denote which observations belong to the honest sample. If honesty is not desired, is_honest
must be a
vector of FALSE
s. If no vector of CATEs is provided, these are estimated internally via a
causal_forest
.
inference_aggtree
takes as input an aggTrees
object constructed by build_aggtree
. Then, for
the desired granularity level, chosen via the n_groups
argument, it provides point estimation and standard errors for
the GATEs. Additionally, it performs some hypothesis testing to assess whether we find systematic heterogeneity and computes
the average characteristics of the units in each group to investigate the driving mechanisms.
GATEs and their standard errors are obtained by fitting an appropriate linear model. If method == "raw"
, we estimate
via OLS the following:
$$Y_i = \sum_{l = 1}^{|T|} L_{i, l} \gamma_l + \sum_{l = 1}^{|T|} L_{i, l} D_i \beta_l + \epsilon_i$$
with L_{i, l}
a dummy variable equal to one if the i-th unit falls in the l-th group, and |T| the
number of groups. If the treatment is randomly assigned, one can show that the betas identify the GATE of
each group. However, this is not true in observational studies due to selection into treatment. In this case, the user is
expected to use method == "aipw"
when calling build_aggtree
. In this case,
inference_aggtree
uses the scores in the following regression:
$$score_i = \sum_{l = 1}^{|T|} L_{i, l} \beta_l + \epsilon_i$$
This way, betas again identify the GATEs.
Regardless of method
, standard errors are estimated via the Eicker-Huber-White estimator.
If boot_ci == TRUE
, the routine also computes asymmetric bias-corrected and accelerated 95% confidence intervals using 2000 bootstrap
samples. Particularly useful when the honest sample is small-ish.
inference_aggtree
uses the standard errors obtained by fitting the linear models above to test the hypotheses
that the GATEs are different across all pairs of leaves. Here, we adjust p-values to account for multiple hypotheses testing
using Holm's procedure.
inference_aggtree
regresses each covariate on a set of dummies denoting group membership. This way, we get the
average characteristics of units in each leaf, together with a standard error. Leaves are ordered in increasing order of their
predictions (from most negative to most positive). Standard errors are estimated via the Eicker-Huber-White estimator.
Regardless of the chosen method
, both functions estimate the GATEs, the linear models, and the average characteristics
of units in each group using only observations in the honest sample. If the honest sample is empty (this happens because the
user either sets honest_frac = 0
or passes a vector of FALSE
s as is_honest
when calling
build_aggtree
), the same data used to construct the tree are used to estimate the above quantities. This is
fine for prediction but invalidates inference.
Di Francesco, R. (2022). Aggregation Trees. CEIS Research Paper, 546. doi:10.2139/ssrn.4304256 .
## Generate data.
set.seed(1986)
n <- 1000
k <- 3
X <- matrix(rnorm(n * k), ncol = k)
colnames(X) <- paste0("x", seq_len(k))
D <- rbinom(n, size = 1, prob = 0.5)
mu0 <- 0.5 * X[, 1]
mu1 <- 0.5 * X[, 1] + X[, 2]
Y <- mu0 + D * (mu1 - mu0) + rnorm(n)
## Construct sequence of groupings. CATEs estimated internally.
groupings <- build_aggtree(Y, D, X, method = "aipw")
## Alternatively, we can estimate the CATEs and pass them.
splits <- sample_split(length(Y), training_frac = 0.5)
training_idx <- splits$training_idx
honest_idx <- splits$honest_idx
Y_tr <- Y[training_idx]
D_tr <- D[training_idx]
X_tr <- X[training_idx, ]
Y_hon <- Y[honest_idx]
D_hon <- D[honest_idx]
X_hon <- X[honest_idx, ]
library(grf)
forest <- causal_forest(X_tr, Y_tr, D_tr) # Use training sample.
cates <- predict(forest, X)$predictions
groupings <- build_aggtree(Y, D, X, method = "aipw", cates = cates,
is_honest = 1:length(Y) %in% honest_idx)
## We have compatibility with generic S3-methods.
summary(groupings)
#> Honest estimates: TRUE
#> Call:
#> rpart::rpart(formula = cates ~ ., data = data.frame(cates = cates[training_idx],
#> X_tr), method = "anova", model = TRUE, control = rpart::rpart.control(...))
#> n= 500
#>
#> CP nsplit rel error xerror xstd
#> 1 0.83687878 0 1.00000000 1.00778279 0.030441594
#> 2 0.09040027 1 0.16312122 0.16560696 0.007644829
#> 3 0.03818566 2 0.07272095 0.07501281 0.005894329
#> 4 0.01000000 3 0.03453529 0.03662509 0.002002911
#>
#> Variable importance
#> x2 x3 x1
#> 92 4 4
#>
#> Node number 1: 500 observations, complexity param=0.8368788
#> mean=-0.06297775, MSE=1.030046
#> left son=2 (230 obs) right son=3 (270 obs)
#> Primary splits:
#> x2 < -0.07518198 to the left, improve=0.83687880, (0 missing)
#> x1 < 0.3506824 to the left, improve=0.02048703, (0 missing)
#> x3 < -0.9775529 to the left, improve=0.01361843, (0 missing)
#> Surrogate splits:
#> x3 < -0.983981 to the left, agree=0.560, adj=0.043, (0 split)
#> x1 < -0.207075 to the left, agree=0.558, adj=0.039, (0 split)
#>
#> Node number 2: 230 observations, complexity param=0.03818566
#> mean=-0.7844117, MSE=0.1203484
#> left son=4 (169 obs) right son=5 (61 obs)
#> Primary splits:
#> x2 < -0.4126471 to the left, improve=0.71049120, (0 missing)
#> x1 < 0.3089535 to the left, improve=0.09489957, (0 missing)
#> x3 < -0.61261 to the left, improve=0.01026319, (0 missing)
#> Surrogate splits:
#> x1 < -1.889246 to the right, agree=0.743, adj=0.033, (0 split)
#> x3 < 2.095613 to the left, agree=0.739, adj=0.016, (0 split)
#>
#> Node number 3: 270 observations, complexity param=0.09040027
#> mean=0.5819405, MSE=0.2086335
#> left son=6 (122 obs) right son=7 (148 obs)
#> Primary splits:
#> x2 < 0.5478927 to the left, improve=0.82651090, (0 missing)
#> x3 < 0.4544582 to the left, improve=0.03984684, (0 missing)
#> x1 < 1.302729 to the left, improve=0.01919420, (0 missing)
#> Surrogate splits:
#> x3 < -1.130523 to the left, agree=0.593, adj=0.098, (0 split)
#> x1 < 0.09411584 to the left, agree=0.570, adj=0.049, (0 split)
#>
#> Node number 4: 169 observations
#> mean=-1.050757, MSE=0.02599234
#>
#> Node number 5: 61 observations
#> mean=0.1132702, MSE=0.05935966
#>
#> Node number 6: 122 observations
#> mean=-0.08319338, MSE=0.03747296
#>
#> Node number 7: 148 observations
#> mean=1.189237, MSE=0.03514272
#>
print(groupings)
#> Honest estimates: TRUE
#> n= 500
#>
#> node), split, n, deviance, yval
#> * denotes terminal node
#>
#> 1) root 500 515.023000 -0.06297775
#> 2) x2< -0.07518198 230 27.680130 -0.78441170
#> 4) x2< -0.4126471 169 4.392705 -1.05075700 *
#> 5) x2>=-0.4126471 61 3.620939 0.11327020 *
#> 3) x2>=-0.07518198 270 56.331040 0.58194050
#> 6) x2< 0.5478927 122 4.571701 -0.08319338 *
#> 7) x2>=0.5478927 148 5.201123 1.18923700 *
plot(groupings) # Try also setting 'sequence = TRUE'.
## To predict, do the following.
tree <- subtree(groupings$tree, cv = TRUE) # Select by cross-validation.
head(predict(tree, data.frame(X)))
#> 1 2 3 4 5 6
#> 0.1132702 -1.0507568 1.1892366 1.1892366 -1.0507568 -1.0507568
## Inference with 4 groups.
results <- inference_aggtree(groupings, n_groups = 4)
summary(results$model) # Coefficient of leafk is GATE in k-th leaf.
#>
#> Call:
#> estimatr::lm_robust(formula = scores ~ 0 + leaf, data = data.frame(scores = scores,
#> leaf = leaves), se_type = "HC1")
#>
#> Standard error type: HC1
#>
#> Coefficients:
#> Estimate Std. Error t value Pr(>|t|) CI Lower CI Upper DF
#> leaf1 -1.05076 0.1539 -6.8257 2.562e-11 -1.3532 -0.7483 496
#> leaf2 -0.08319 0.1873 -0.4442 6.571e-01 -0.4512 0.2848 496
#> leaf3 0.11327 0.2800 0.4046 6.860e-01 -0.4368 0.6634 496
#> leaf4 1.18924 0.1994 5.9643 4.673e-09 0.7975 1.5810 496
#>
#> Multiple R-squared: 0.1469 , Adjusted R-squared: 0.14
#> F-statistic: 20.63 on 4 and 496 DF, p-value: 9.648e-16
results$gates_diff_pairs$gates_diff # GATEs differences.
#> leaf1 leaf2 leaf3 leaf4
#> leaf1 NA NA NA NA
#> leaf2 0.9675635 NA NA NA
#> leaf3 1.1640270 0.1964636 NA NA
#> leaf4 2.2399934 1.2724300 1.075966 NA
results$gates_diff_pairs$holm_pvalues # leaves 1-2 not statistically different.
#> [,1] [,2] [,3] [,4]
#> [1,] NA NA NA NA
#> [2,] 3.030658e-04 NA NA NA
#> [3,] 8.930757e-04 5.600035e-01 NA NA
#> [4,] 6.695237e-17 2.118317e-05 0.003698325 NA
## LATEX.
print(results, table = "diff")
#> \begingroup
#> \setlength{\tabcolsep}{8pt}
#> \renewcommand{\arraystretch}{1.2}
#> \begin{table}[b!]
#> \centering
#> \begin{adjustbox}{width = 1\textwidth}
#> \begin{tabular}{@{\extracolsep{5pt}}l c c c c}
#> \\[-1.8ex]\hline
#> \hline \\[-1.8ex]
#>
#> & \textit{Leaf 1} & \textit{Leaf 2} & \textit{Leaf 3} & \textit{Leaf 4} \\
#> \addlinespace[2pt]
#> \hline \\[-1.8ex]
#>
#> \multirow{3}{*}{GATEs} & -1.051 & -0.083 & 0.113 & 1.189 \\
#> & [-1.353, -0.749] & [-0.450, 0.284] & [-0.436, 0.662] & [ 0.799, 1.579] \\
#> & \{NA, NA\} & \{NA, NA\} & \{NA, NA\} & \{NA, NA\} \\
#>
#> \addlinespace[2pt]
#> \hline \\[-1.8ex]
#>
#> \textit{Leaf 1} & NA & NA & NA & NA \\
#> & (NA) & (NA) & (NA) & (NA) \\
#> \textit{Leaf 2} & 0.97 & NA & NA & NA \\
#> & (0.000) & ( NA) & ( NA) & (NA) \\
#> \textit{Leaf 3} & 1.16 & 0.20 & NA & NA \\
#> & (0.001) & (0.560) & ( NA) & (NA) \\
#> \textit{Leaf 4} & 2.24 & 1.27 & 1.08 & NA \\
#> & (0.000) & (0.000) & (0.004) & (NA) \\
#>
#> \addlinespace[3pt]
#> \\[-1.8ex]\hline
#> \hline \\[-1.8ex]
#> \end{tabular}
#> \end{adjustbox}
#> \caption{Point estimates and $95\%$ confidence intervals for the GATEs based on asymptotic normality (in square brackets) and on the percentiles of the bootstrap distribution (in curly braces). Leaves are sorted in increasing order of the GATEs. Additionally, the GATE differences across all pairs of leaves are displayed. $p$-values testing the null hypothesis that a single difference is zero are adjusted using Holm's procedure and reported in parenthesis under each point estimate.}
#> \label{table:differences.gates}
#> \end{table}
#> \endgroup
#>
print(results, table = "avg_char")
#> \begingroup
#> \setlength{\tabcolsep}{8pt}
#> \renewcommand{\arraystretch}{1.1}
#> \begin{table}[b!]
#> \centering
#> \begin{adjustbox}{width = 1\textwidth}
#> \begin{tabular}{@{\extracolsep{5pt}}l c c c c c c c c }
#> \\[-1.8ex]\hline
#> \hline \\[-1.8ex]
#> & \multicolumn{2}{c}{\textit{Leaf 1}} & \multicolumn{2}{c}{\textit{Leaf 2}} & \multicolumn{2}{c}{\textit{Leaf 3}} & \multicolumn{2}{c}{\textit{Leaf 4}} \\\cmidrule{2-3} \cmidrule{4-5} \cmidrule{6-7} \cmidrule{8-9}
#> & Mean & (S.E.) & Mean & (S.E.) & Mean & (S.E.) & Mean & (S.E.) \\
#> \addlinespace[2pt]
#> \hline \\[-1.8ex]
#>
#> \texttt{x1} & 0.025 & (0.072) & 0.035 & (0.096) & -0.179 & (0.133) & 0.086 & (0.082) \\
#> \texttt{x2} & -1.131 & (0.042) & 0.209 & (0.016) & -0.237 & (0.014) & 1.101 & (0.040) \\
#> \texttt{x3} & -0.009 & (0.071) & 0.088 & (0.092) & -0.066 & (0.130) & -0.047 & (0.092) \\
#>
#> \addlinespace[3pt]
#> \\[-1.8ex]\hline
#> \hline \\[-1.8ex]
#> \end{tabular}
#> \end{adjustbox}
#> \caption{Average characteristics of units in each leaf, obtained by regressing each covariate on a set of dummies denoting leaf membership using only the honest sample. Standard errors are estimated via the Eicker-Huber-White estimator. Leaves are sorted in increasing order of the GATEs.}
#> \label{table:average.characteristics.leaves}
#> \end{table}
#> \endgroup
#>