Prediction method for class mml
.
# S3 method for class 'mml'
predict(object, data = NULL, ...)
Matrix of predictions.
If object$learner == "l1"
, then model.matrix
is used to handle non-numeric covariates. If we also
have object$scaling == TRUE
, then data
is scaled to have zero mean and unit variance.
Di Francesco, R. (2023). Ordered Correlation Forest. arXiv preprint arXiv:2309.08755.
## Generate synthetic data.
set.seed(1986)
data <- generate_ordered_data(1000)
sample <- data$sample
Y <- sample$Y
X <- sample[, -1]
## Training-test split.
train_idx <- sample(seq_len(length(Y)), floor(length(Y) * 0.5))
Y_tr <- Y[train_idx]
X_tr <- X[train_idx, ]
Y_test <- Y[-train_idx]
X_test <- X[-train_idx, ]
## Fit multinomial machine learning on training sample using two different learners.
multinomial_forest <- multinomial_ml(Y_tr, X_tr, learner = "forest")
multinomial_l1 <- multinomial_ml(Y_tr, X_tr, learner = "l1")
## Predict out of sample.
predictions_forest <- predict(multinomial_forest, X_test)
predictions_l1 <- predict(multinomial_l1, X_test)
## Compare predictions.
cbind(head(predictions_forest), head(predictions_l1))
#> P(Y=1) P(Y=2) P(Y=3) P(Y=1) P(Y=2) P(Y=3)
#> [1,] 0.2750897 0.5810172 0.1438932 0.4743307 0.3381195 0.18754984
#> [2,] 0.3326762 0.4868678 0.1804560 0.4722287 0.3206559 0.20711540
#> [3,] 0.2855366 0.3379332 0.3765303 0.2912424 0.3013192 0.40743847
#> [4,] 0.4088999 0.4011857 0.1899144 0.4187626 0.3195804 0.26165697
#> [5,] 0.6617388 0.2228519 0.1154092 0.6312130 0.2924961 0.07629089
#> [6,] 0.4125682 0.3427655 0.2446663 0.4852413 0.3126977 0.20206096