Prediction method for class mml.

# S3 method for class 'mml'
predict(object, data = NULL, ...)

Arguments

object

An mml object.

data

Data set of class data.frame. It must contain the same covariates used to train the base learners. If data is NULL, then object$X is used.

...

Further arguments passed to or from other methods.

Value

Matrix of predictions.

Details

If object$learner == "l1", then model.matrix is used to handle non-numeric covariates. If we also have object$scaling == TRUE, then data is scaled to have zero mean and unit variance.

References

  • Di Francesco, R. (2023). Ordered Correlation Forest. arXiv preprint arXiv:2309.08755.

Author

Riccardo Di Francesco

Examples

## Generate synthetic data.
set.seed(1986)

data <- generate_ordered_data(100)
sample <- data$sample
Y <- sample$Y
X <- sample[, -1]

## Training-test split.
train_idx <- sample(seq_len(length(Y)), floor(length(Y) * 0.5))

Y_tr <- Y[train_idx]
X_tr <- X[train_idx, ]

Y_test <- Y[-train_idx]
X_test <- X[-train_idx, ]

## Fit multinomial machine learning on training sample using two different learners.
multinomial_forest <- multinomial_ml(Y_tr, X_tr, learner = "forest")
multinomial_l1 <- multinomial_ml(Y_tr, X_tr, learner = "l1")

## Predict out of sample.
predictions_forest <- predict(multinomial_forest, X_test)
predictions_l1 <- predict(multinomial_l1, X_test)

## Compare predictions.
cbind(head(predictions_forest), head(predictions_l1))
#>         P(Y=1)    P(Y=2)     P(Y=3)     P(Y=1)    P(Y=2)     P(Y=3)
#> [1,] 0.3537709 0.4865778 0.15965128 0.37483081 0.4934319 0.13173734
#> [2,] 0.6324324 0.2491552 0.11841243 0.39553675 0.4512993 0.15316400
#> [3,] 0.1416059 0.4282620 0.43013203 0.06737991 0.3712126 0.56140745
#> [4,] 0.6299436 0.2814432 0.08861318 0.59561559 0.3571572 0.04722723
#> [5,] 0.4841232 0.2875252 0.22835164 0.32814211 0.3535764 0.31828152
#> [6,] 0.6875407 0.2334853 0.07897406 0.43330267 0.4505802 0.11611709