Prediction method for class ocf.

# S3 method for class 'ocf'
predict(object, data = NULL, type = "response", ...)

Arguments

object

An ocf object.

data

Data set of class data.frame. It must contain at least the same covariates used to train the forests. If data is NULL, then object$full_data is used.

type

Type of prediction. Either "response" or "terminalNodes".

...

Further arguments passed to or from other methods.

Value

Desired predictions.

Details

If type == "response", the routine returns the predicted conditional class probabilities and the predicted class labels. If forests are honest, the predicted probabilities are honest.

If type == "terminalNodes", the IDs of the terminal node in each tree for each observation in data are returned.

References

  • Di Francesco, R. (2023). Ordered Correlation Forest. arXiv preprint arXiv:2309.08755.

See also

Author

Riccardo Di Francesco

Examples

## Generate synthetic data.
set.seed(1986)

data <- generate_ordered_data(100)
sample <- data$sample
Y <- sample$Y
X <- sample[, -1]

## Training-test split.
train_idx <- sample(seq_len(length(Y)), floor(length(Y) * 0.5))

Y_tr <- Y[train_idx]
X_tr <- X[train_idx, ]

Y_test <- Y[-train_idx]
X_test <- X[-train_idx, ]

## Fit ocf on training sample.
forests <- ocf(Y_tr, X_tr)

## Predict on test sample.
predictions <- predict(forests, X_test)
head(predictions$probabilities)
#>         P(Y=1)    P(Y=2)     P(Y=3)
#> [1,] 0.4224274 0.4548215 0.12275111
#> [2,] 0.4786262 0.4133636 0.10801015
#> [3,] 0.1446138 0.4064470 0.44893918
#> [4,] 0.6215123 0.3249310 0.05355674
#> [5,] 0.4359897 0.3503095 0.21370084
#> [6,] 0.6224514 0.3216924 0.05585619
predictions$classification
#>  [1] 2 1 3 1 1 1 3 1 2 1 3 1 1 3 1 1 2 3 3 1 3 3 3 3 2 1 1 3 1 2 1 2 1 3 3 3 2 3
#> [39] 1 2 1 3 3 3 2 3 3 1 1 1

## Get terminal nodes.
predictions <- predict(forests, X_test, type = "terminalNodes")
predictions$forest.1[1:10, 1:20] # Rows are observations, columns are forests.
#>       [,1] [,2] [,3] [,4] [,5] [,6] [,7] [,8] [,9] [,10] [,11] [,12] [,13]
#>  [1,]   12    6    6    5    3    5    1    9    7    11     5     5    10
#>  [2,]   14    9    5    5    3    3    6    9    7    11     5     5     3
#>  [3,]   10    6    6    6    8    6    6   12    6    12     7     9     6
#>  [4,]   13    3    7    3    1    3    6    7   12     7     3     3     7
#>  [5,]   14    9    5    5    3    3    1    6    7    11     5     5     3
#>  [6,]   13    4    7    3    3    5    6    8    7     7     3     3     9
#>  [7,]    6   10    5    5    6   12    6    6    6    12     8    12     7
#>  [8,]   13    4    8    3    8    5    6    8    7     7     3     3     6
#>  [9,]   12    6    8    3    8    6    6    4    6     7     4    11     6
#> [10,]    4    4    7    4    8    5    6    8    7     4     3     3     6
#>       [,14] [,15] [,16] [,17] [,18] [,19] [,20]
#>  [1,]     3     3     6     3     4     6     3
#>  [2,]     5     3     3     4     5     7     3
#>  [3,]     7     6     6     9     2     6     5
#>  [4,]     9     5     9     5     2     9     1
#>  [5,]     5     3     3     4     5     7     3
#>  [6,]     7     3     5     5     4     6     3
#>  [7,]    10     2    10     7     2     9     5
#>  [8,]     9     2     5     5     4     6     3
#>  [9,]     9     2     6    10     2     6     5
#> [10,]     5     2     5     5     4     6     3