Prediction method for class ocf
.
# S3 method for class 'ocf'
predict(object, data = NULL, type = "response", ...)
An ocf
object.
Data set of class data.frame
. It must contain at least the same covariates used to train the forests. If data
is NULL
, then object$full_data
is used.
Type of prediction. Either "response"
or "terminalNodes"
.
Further arguments passed to or from other methods.
Desired predictions.
If type == "response"
, the routine returns the predicted conditional class probabilities and the predicted class
labels. If forests are honest, the predicted probabilities are honest.
If type == "terminalNodes"
, the IDs of the terminal node in each tree for each observation in data
are returned.
Di Francesco, R. (2025). Ordered Correlation Forest. Econometric Reviews, 1–17. https://doi.org/10.1080/07474938.2024.2429596.
## Generate synthetic data.
set.seed(1986)
data <- generate_ordered_data(100)
sample <- data$sample
Y <- sample$Y
X <- sample[, -1]
## Training-test split.
train_idx <- sample(seq_len(length(Y)), floor(length(Y) * 0.5))
Y_tr <- Y[train_idx]
X_tr <- X[train_idx, ]
Y_test <- Y[-train_idx]
X_test <- X[-train_idx, ]
## Fit ocf on training sample.
forests <- ocf(Y_tr, X_tr)
## Predict on test sample.
predictions <- predict(forests, X_test)
head(predictions$probabilities)
#> P(Y=1) P(Y=2) P(Y=3)
#> [1,] 0.3543743 0.5375517 0.10807401
#> [2,] 0.4571570 0.4057421 0.13710091
#> [3,] 0.1173914 0.4803692 0.40223943
#> [4,] 0.6342519 0.3132168 0.05253126
#> [5,] 0.3482501 0.3873043 0.26444560
#> [6,] 0.6284292 0.2999011 0.07166965
predictions$classification
#> [1] 2 1 2 1 2 1 3 1 2 1 3 1 1 3 1 1 2 3 3 1 3 3 3 3 2 1 1 3 1 2 1 2 1 3 3 3 1 2
#> [39] 1 2 1 3 3 3 2 3 3 1 1 1
## Get terminal nodes.
predictions <- predict(forests, X_test, type = "terminalNodes")
predictions$forest.1[1:10, 1:20] # Rows are observations, columns are forests.
#> [,1] [,2] [,3] [,4] [,5] [,6] [,7] [,8] [,9] [,10] [,11] [,12] [,13]
#> [1,] 8 6 6 5 7 5 3 10 3 11 5 5 9
#> [2,] 9 9 5 5 6 3 4 9 3 11 5 5 3
#> [3,] 10 6 6 6 6 6 7 10 6 12 7 9 6
#> [4,] 10 3 7 3 3 3 5 7 9 7 3 3 7
#> [5,] 9 9 5 5 6 3 4 6 4 10 5 5 3
#> [6,] 9 4 7 3 4 5 3 8 3 7 3 3 9
#> [7,] 6 10 5 5 6 10 7 6 6 10 8 10 7
#> [8,] 6 4 8 4 4 5 3 8 4 7 3 3 6
#> [9,] 6 6 8 4 4 6 7 4 6 8 4 10 6
#> [10,] 4 4 7 4 4 5 3 8 3 4 3 3 6
#> [,14] [,15] [,16] [,17] [,18] [,19] [,20]
#> [1,] 3 3 6 3 4 5 3
#> [2,] 5 3 7 3 6 10 3
#> [3,] 7 6 6 9 2 6 5
#> [4,] 9 5 7 5 2 12 1
#> [5,] 5 3 7 3 6 10 3
#> [6,] 7 3 5 5 4 5 3
#> [7,] 10 2 4 7 2 10 5
#> [8,] 9 2 5 5 4 5 3
#> [9,] 9 2 6 10 2 5 5
#> [10,] 5 2 5 5 4 5 3