Replaces node predictions of an rpart
object using external data to estimate the group average treatment
effects (GATEs).
estimate_rpart(tree, Y, D, X, method = "aipw", scores = NULL)
An rpart
object.
Outcome vector.
Treatment assignment vector.
Covariate matrix (no intercept).
Either "raw"
or "aipw"
, controls how node predictions are replaced.
Optional, vector of scores to be used in replacing node predictions. Useful to save computational time if scores have already been estimated. Ignored if method == "raw"
.
A tree with node predictions replaced, as an rpart
object, and the scores (if method == "raw"
,
this is NULL
).
If method == "raw"
, estimate_rpart
replaces node predictions with the differences between the sample average
of the observed outcomes of treated units and the sample average of the observed outcomes of control units in each node,
which is an unbiased estimator of the GATEs if the assignment to treatment is randomized.
If method == "aipw"
, estimate_rpart
replaces node predictions with sample averages of doubly-robust
scores in each node. This is a valid estimator of the GATEs in observational studies. Honest regression forests
and 5-fold cross fitting are used to estimate the propensity score and the conditional mean function of the outcome
(unless the user specifies the argument scores
).
estimate_rpart
allows the user to implement "honest" estimation. If observations in y
, D
and X
have not been used to construct the tree
, then the new predictions are honest in the sense of Athey and Imbens (2016).
To get standard errors for the tree's estimates, please use causal_ols_rpart
.
Di Francesco, R. (2022). Aggregation Trees. CEIS Research Paper, 546. doi:10.2139/ssrn.4304256 .
## Generate data.
set.seed(1986)
n <- 1000
k <- 3
X <- matrix(rnorm(n * k), ncol = k)
colnames(X) <- paste0("x", seq_len(k))
D <- rbinom(n, size = 1, prob = 0.5)
mu0 <- 0.5 * X[, 1]
mu1 <- 0.5 * X[, 1] + X[, 2]
Y <- mu0 + D * (mu1 - mu0) + rnorm(n)
## Split the sample.
splits <- sample_split(length(Y), training_frac = 0.5)
training_idx <- splits$training_idx
honest_idx <- splits$honest_idx
Y_tr <- Y[training_idx]
D_tr <- D[training_idx]
X_tr <- X[training_idx, ]
Y_hon <- Y[honest_idx]
D_hon <- D[honest_idx]
X_hon <- X[honest_idx, ]
## Construct a tree using training sample.
library(rpart)
tree <- rpart(Y ~ ., data = data.frame("Y" = Y_tr, X_tr), maxdepth = 2)
## Estimate GATEs in each node (internal and terminal) using honest sample.
new_tree <- estimate_rpart(tree, Y_hon, D_hon, X_hon, method = "raw")
new_tree$tree
#> n= 500
#>
#> node), split, n, deviance, yval
#> * denotes terminal node
#>
#> 1) root 500 848.77870 -0.1196941
#> 2) x2< -0.2670246 208 294.70990 -1.3627760
#> 4) x1< 0.111117 116 149.74660 -1.5080320 *
#> 5) x1>=0.111117 92 108.62540 -1.1109040 *
#> 3) x2>=-0.2670246 292 425.26540 0.5981272
#> 6) x1< 1.135506 231 289.47460 0.5557710 *
#> 7) x1>=1.135506 61 70.06308 1.0030310 *