Prediction method for class ocf.

# S3 method for class 'ocf'
predict(object, data = NULL, type = "response", ...)

Arguments

object

An ocf object.

data

Data set of class data.frame. It must contain at least the same covariates used to train the forests. If data is NULL, then object$full_data is used.

type

Type of prediction. Either "response" or "terminalNodes".

...

Further arguments passed to or from other methods.

Value

Desired predictions.

Details

If type == "response", the routine returns the predicted conditional class probabilities and the predicted class labels. If forests are honest, the predicted probabilities are honest.

If type == "terminalNodes", the IDs of the terminal node in each tree for each observation in data are returned.

References

See also

Author

Riccardo Di Francesco

Examples

## Generate synthetic data.
set.seed(1986)

data <- generate_ordered_data(100)
sample <- data$sample
Y <- sample$Y
X <- sample[, -1]

## Training-test split.
train_idx <- sample(seq_len(length(Y)), floor(length(Y) * 0.5))

Y_tr <- Y[train_idx]
X_tr <- X[train_idx, ]

Y_test <- Y[-train_idx]
X_test <- X[-train_idx, ]

## Fit ocf on training sample.
forests <- ocf(Y_tr, X_tr)

## Predict on test sample.
predictions <- predict(forests, X_test)
head(predictions$probabilities)
#>         P(Y=1)    P(Y=2)     P(Y=3)
#> [1,] 0.3543743 0.5375517 0.10807401
#> [2,] 0.4571570 0.4057421 0.13710091
#> [3,] 0.1173914 0.4803692 0.40223943
#> [4,] 0.6342519 0.3132168 0.05253126
#> [5,] 0.3482501 0.3873043 0.26444560
#> [6,] 0.6284292 0.2999011 0.07166965
predictions$classification
#>  [1] 2 1 2 1 2 1 3 1 2 1 3 1 1 3 1 1 2 3 3 1 3 3 3 3 2 1 1 3 1 2 1 2 1 3 3 3 1 2
#> [39] 1 2 1 3 3 3 2 3 3 1 1 1

## Get terminal nodes.
predictions <- predict(forests, X_test, type = "terminalNodes")
predictions$forest.1[1:10, 1:20] # Rows are observations, columns are forests.
#>       [,1] [,2] [,3] [,4] [,5] [,6] [,7] [,8] [,9] [,10] [,11] [,12] [,13]
#>  [1,]    8    6    6    5    7    5    3   10    3    11     5     5     9
#>  [2,]    9    9    5    5    6    3    4    9    3    11     5     5     3
#>  [3,]   10    6    6    6    6    6    7   10    6    12     7     9     6
#>  [4,]   10    3    7    3    3    3    5    7    9     7     3     3     7
#>  [5,]    9    9    5    5    6    3    4    6    4    10     5     5     3
#>  [6,]    9    4    7    3    4    5    3    8    3     7     3     3     9
#>  [7,]    6   10    5    5    6   10    7    6    6    10     8    10     7
#>  [8,]    6    4    8    4    4    5    3    8    4     7     3     3     6
#>  [9,]    6    6    8    4    4    6    7    4    6     8     4    10     6
#> [10,]    4    4    7    4    4    5    3    8    3     4     3     3     6
#>       [,14] [,15] [,16] [,17] [,18] [,19] [,20]
#>  [1,]     3     3     6     3     4     5     3
#>  [2,]     5     3     7     3     6    10     3
#>  [3,]     7     6     6     9     2     6     5
#>  [4,]     9     5     7     5     2    12     1
#>  [5,]     5     3     7     3     6    10     3
#>  [6,]     7     3     5     5     4     5     3
#>  [7,]    10     2     4     7     2    10     5
#>  [8,]     9     2     5     5     4     5     3
#>  [9,]     9     2     6    10     2     5     5
#> [10,]     5     2     5     5     4     5     3