Validates machine learning predictions of conditional average treatment effects (CATEs) by estimating Centered-Weighted Average Treatment Effects (CWATEs) and their normalized counterparts (NCWATEs).

valiCATE(
  Y,
  D,
  X,
  cates,
  weights = c("AUTOC", "AUC-HVL", "BLP", "QINI"),
  scores = NULL,
  n_folds = 5,
  alpha = 0.05,
  verbose = TRUE
)

Arguments

Y

Outcome vector (validation sample).

D

Treatment indicator vector, binary 0/1 with 1 for treated (validation sample).

X

Covariate matrix or data frame, without intercept (validation sample).

cates

Named list of CATE prediction vectors on the validation sample produced by different models. Each element must be a numeric vector of the same length as Y. CATE models must be estimated using only the training sample for valid inference.

weights

Character vector controlling which weight functions to use. Admitted values are "AUTOC", "AUC-HVL", "BLP", and "QINI".

scores

Optional, pre-computed AIPW pseudo-outcomes. If not provided by the user, scores are estimated internally via cross-fitting. Useful to save computational time if scores have already been estimated.

n_folds

Optional, number of cross-fitting folds for nuisance estimation. Default is 5. Ignored if scores is provided.

alpha

Optional, significance level for confidence intervals and hypothesis tests. Default is 0.05.

verbose

Optional, set to FALSE to prevent the function from printing the progresses.

Value

Object of class valiCATE.

Details

The user must provide observations on the outcomes, the treatment status, and the covariates of units in the validation sample using the first three arguments. The user must also provide CATE predictions on the validation sample as a named list storing predictions produced by different models. Be careful, CATE models must be estimated using only the training sample to achieve valid inference.

Estimation is based on AIPW pseudo-outcomes, which require nuisance function estimates (propensity score, conditional mean of the outcome for treated and control units). By default, nuisance functions are estimated internally via honest regression_forests using K-fold cross-fitting in the validation sample. Alternatively, the user can supply pre-computed AIPW scores via the scores argument; in this case, scores should be cross-fitted in the validation sample to achieve valid inference.

For the CWATE, a one-sided test of H0: theta <= 0 is reported. Rejection signals that the predicted CATEs capture genuine heterogeneity in the correct direction. For the NCWATE, a two-sided test of H0: gamma = 1 is reported. Non-rejection signals that the predicted CATEs recover the true CATEs.

References

  • Di Francesco, R., & Knaus, M. C. (2025). Validating ML Predictions of Heterogeneous Treatment Effects via CWATE.

Author

Riccardo Di Francesco

Examples

## Generate data.
set.seed(1986)

n <- 1000
k <- 2

X <- matrix(rnorm(n * k), ncol = k)
colnames(X) <- paste0("x", seq_len(k))
D <- rbinom(n, size = 1, prob = 0.5)
mu0 <- 0.5 * X[, 1]
mu1 <- 0.5 * X[, 1] + X[, 2]
Y <- mu0 + D * (mu1 - mu0) + rnorm(n, sd = 0.5)

## Split into training and validation samples.
train_idx <- sample(1:n, n / 2)
val_idx <- setdiff(1:n, train_idx)

## Estimate CATEs on the training sample, predict on the validation sample.
library(grf)
cf <- causal_forest(X[train_idx, ], Y[train_idx], D[train_idx])
cates <- predict(cf, X[val_idx, ])$predictions

## Validate using the validation sample.
result <- valiCATE(Y[val_idx], D[val_idx], X[val_idx, ],
                   cates = list("causal_forest" = cates))
#> Estimating nuisance functions via 5 -fold cross-fitting; 
#> CWATE/NCWATE estimation for model: causal_forest ; 
#> Output. 
#> 

## We can also compare multiple models.
cf2 <- causal_forest(X[train_idx, ], Y[train_idx], D[train_idx],
                     num.trees = 500)
cates2 <- predict(cf2, X[val_idx, ])$predictions

result_multi <- valiCATE(Y[val_idx], D[val_idx], X[val_idx, ],
                         cates = list("cf_default" = cates,
                                      "cf_500" = cates2))
#> Estimating nuisance functions via 5 -fold cross-fitting; 
#> CWATE/NCWATE estimation for model: cf_default ; 
#> CWATE/NCWATE estimation for model: cf_500 ; 
#> Output. 
#> 

## We have compatibility with generic S3-methods.
summary(result)
#> ====================================================================== 
#> Model: causal_forest 
#> ====================================================================== 
#> 
#> CWATE (H0: theta <= 0, one-sided test)
#> ---------------------------------------------------------------------- 
#>   Weight       Estimate         SE           95% CI    p-value
#> ---------------------------------------------------------------------- 
#>   AUTOC          0.9017     0.0538 [0.7963, 1.0071]     0.0000
#>   AUC-HVL        1.8342     0.0983 [1.6416, 2.0268]     0.0000
#>   BLP            0.8683     0.0606 [0.7496, 0.9871]     0.0000
#>   QINI           0.2952     0.0166 [0.2626, 0.3278]     0.0000
#> 
#> NCWATE (H0: gamma = 1, two-sided test)
#> ---------------------------------------------------------------------- 
#>   Weight       Estimate         SE           95% CI    p-value
#> ---------------------------------------------------------------------- 
#>   AUTOC          1.1646     0.0642 [1.0388, 1.2905]     0.0103
#>   AUC-HVL        1.2517     0.0627 [1.1289, 1.3745]     0.0001
#>   BLP            1.1983     0.0600 [1.0806, 1.3160]     0.0010
#>   QINI           1.2204     0.0612 [1.1005, 1.3404]     0.0003
#> 
summary(result, latex = TRUE)
#> \begingroup
#>     \setlength{\tabcolsep}{8pt}
#>     \renewcommand{\arraystretch}{1.1}
#>     \begin{table}[H]
#>       \centering
#>       \caption{CWATE and NCWATE results.}
#>       \vspace{-0.3cm}
#>       \label{table_cwate_ncwate}
#>       \begin{adjustbox}{width = 1\textwidth}
#>       \begin{tabular}{@{\extracolsep{5pt}} l c c }
#>       \\[-1.8ex]\hline
#>       \hline \\[-1.8ex]
#>       &  \multicolumn{2}{c}{\textit{causal\_forest}}  \\  \cmidrule{2-3} 
#>       &  CWATE & NCWATE  \\
#>       \addlinespace[2pt]
#>       \hline \\[-1.8ex] 
#> 
#>        AUTOC  &  0.9017 & 1.1646  \\
#>       &  [0.796, 1.007] & [1.039, 1.290]  \\
#>       &  (0.0000) & (0.0103)  \\
#>       \addlinespace[3pt]
#>        AUC--HVL  &  1.8342 & 1.2517  \\
#>       &  [1.642, 2.027] & [1.129, 1.375]  \\
#>       &  (0.0000) & (0.0001)  \\
#>       \addlinespace[3pt]
#>        BLP  &  0.8683 & 1.1983  \\
#>       &  [0.750, 0.987] & [1.081, 1.316]  \\
#>       &  (0.0000) & (0.0010)  \\
#>       \addlinespace[3pt]
#>        QINI  &  0.2952 & 1.2204  \\
#>       &  [0.263, 0.328] & [1.101, 1.340]  \\
#>       &  (0.0000) & (0.0003)  \\
#>       \addlinespace[3pt]
#> 
#>       \\[-1.8ex]\hline
#>       \hline \\[-1.8ex]
#>       \end{tabular}
#>       \end{adjustbox}
#>       \begin{minipage}{1\textwidth}
#>       \scriptsize
#>       \renewcommand{\baselineskip}{11pt}
#>       \vspace{0.05cm}
#>       \textit{Notes.} 95\% confidence intervals in brackets, $p$-values in parentheses. CWATE: $\mathcal{H}_0\!: \theta \leq 0$ (one-sided). NCWATE: $\mathcal{H}_0\!: \gamma = 1$ (two-sided).
#>       \end{minipage}
#>     \end{table}
#> \endgroup