Prediction method for class ocf
.
# S3 method for class 'ocf'
predict(object, data = NULL, type = "response", ...)
An ocf
object.
Data set of class data.frame
. It must contain at least the same covariates used to train the forests. If data
is NULL
, then object$full_data
is used.
Type of prediction. Either "response"
or "terminalNodes"
.
Further arguments passed to or from other methods.
Desired predictions.
If type == "response"
, the routine returns the predicted conditional class probabilities and the predicted class
labels. If forests are honest, the predicted probabilities are honest.
If type == "terminalNodes"
, the IDs of the terminal node in each tree for each observation in data
are returned.
Di Francesco, R. (2023). Ordered Correlation Forest. arXiv preprint arXiv:2309.08755.
## Generate synthetic data.
set.seed(1986)
data <- generate_ordered_data(100)
sample <- data$sample
Y <- sample$Y
X <- sample[, -1]
## Training-test split.
train_idx <- sample(seq_len(length(Y)), floor(length(Y) * 0.5))
Y_tr <- Y[train_idx]
X_tr <- X[train_idx, ]
Y_test <- Y[-train_idx]
X_test <- X[-train_idx, ]
## Fit ocf on training sample.
forests <- ocf(Y_tr, X_tr)
## Predict on test sample.
predictions <- predict(forests, X_test)
head(predictions$probabilities)
#> P(Y=1) P(Y=2) P(Y=3)
#> [1,] 0.4224274 0.4548215 0.12275111
#> [2,] 0.4786262 0.4133636 0.10801015
#> [3,] 0.1446138 0.4064470 0.44893918
#> [4,] 0.6215123 0.3249310 0.05355674
#> [5,] 0.4359897 0.3503095 0.21370084
#> [6,] 0.6224514 0.3216924 0.05585619
predictions$classification
#> [1] 2 1 3 1 1 1 3 1 2 1 3 1 1 3 1 1 2 3 3 1 3 3 3 3 2 1 1 3 1 2 1 2 1 3 3 3 2 3
#> [39] 1 2 1 3 3 3 2 3 3 1 1 1
## Get terminal nodes.
predictions <- predict(forests, X_test, type = "terminalNodes")
predictions$forest.1[1:10, 1:20] # Rows are observations, columns are forests.
#> [,1] [,2] [,3] [,4] [,5] [,6] [,7] [,8] [,9] [,10] [,11] [,12] [,13]
#> [1,] 12 6 6 5 3 5 1 9 7 11 5 5 10
#> [2,] 14 9 5 5 3 3 6 9 7 11 5 5 3
#> [3,] 10 6 6 6 8 6 6 12 6 12 7 9 6
#> [4,] 13 3 7 3 1 3 6 7 12 7 3 3 7
#> [5,] 14 9 5 5 3 3 1 6 7 11 5 5 3
#> [6,] 13 4 7 3 3 5 6 8 7 7 3 3 9
#> [7,] 6 10 5 5 6 12 6 6 6 12 8 12 7
#> [8,] 13 4 8 3 8 5 6 8 7 7 3 3 6
#> [9,] 12 6 8 3 8 6 6 4 6 7 4 11 6
#> [10,] 4 4 7 4 8 5 6 8 7 4 3 3 6
#> [,14] [,15] [,16] [,17] [,18] [,19] [,20]
#> [1,] 3 3 6 3 4 6 3
#> [2,] 5 3 3 4 5 7 3
#> [3,] 7 6 6 9 2 6 5
#> [4,] 9 5 9 5 2 9 1
#> [5,] 5 3 3 4 5 7 3
#> [6,] 7 3 5 5 4 6 3
#> [7,] 10 2 10 7 2 9 5
#> [8,] 9 2 5 5 4 6 3
#> [9,] 9 2 6 10 2 6 5
#> [10,] 5 2 5 5 4 6 3