Machine learning methods can often learn high-dimensional functions which generalize well but are not human interpretable. The mmpf package marginalizes prediction functions using Monte-Carlo methods, allowing users to investigate the behavior of these learned functions, as on a lower dimensional subset of input features: partial dependence and variations thereof. This makes machine learning methods more useful in situations where accurate prediction is not the only goal, such as in the social sciences where linear models are commonly used because of their interpretability.
Many methods for estimating prediction functions produce estimated functions which are not directly human-interpretable because of their complexity: for example, they may include high-dimensional interactions and/or complex nonlinearities. While a learning method’s capacity to automatically learn interactions and nonlinearities is attractive when the goal is prediction, there are many cases where users want good predictions and the ability to understand how predictions depend on the features. mmpf implements general methods for interpreting prediction functions using Monte-Carlo methods. These methods allow any function which generates predictions to be be interpreted. mmpf is currently used in other packages for machine learning like edarf and mlr (Bischl et al. 2016; Jones and Linder 2016).
The core function of mmpf, marginalPrediction
, allows
marginalization of a prediction function so that it depends on a subset
of the features. Say the matrix of features
This however, can distort the relationship between
To illustrate this point, suppose data are generated from an additive
model,
Integrating against the marginal distribution of
In practical settings we do not know
This the behavior of the prediction function at a vector or matrix of
values for
The function marginalPrediction
allows users to compute easily partial
dependence and many variations thereof. The key arguments of
marginalPrediction
are the prediction function (predict.fun
), the
training data (data
), the names of the columns of the training data
which are of interest (vars
), the number of points to use in the grid
for n
, an integer vector of length 2). Additional
arguments control how the grid is constructed (e.g., uniform sampling,
user chosen values, non-uniform sampling), indicate the use of weights,
and instruct how aggregation is done (e.g., deviations from partial
dependence). Below is an example using the Iris data
(Anderson 1936):
library(mmpf)
library(randomForest)
data(iris)
iris.features = iris[, -ncol(iris)] # exclude the species column
fit = randomForest(iris.features, iris$Species)
mp = marginalPrediction(data = iris.features,
vars = "Petal.Width",
n = c(10, nrow(iris)), model = fit, uniform = TRUE,
predict.fun = function(object, newdata) predict(object, newdata, type = "prob"))
print(mp)
## Petal.Width setosa versicolor virginica
## 1: 0.1000000 0.6374133 0.2337733 0.1288133
## 2: 0.3666667 0.6374133 0.2337733 0.1288133
## 3: 0.6333333 0.6356267 0.2350533 0.1293200
## 4: 0.9000000 0.1707200 0.5997333 0.2295467
## 5: 1.1666667 0.1688267 0.6016267 0.2295467
## 6: 1.4333333 0.1688133 0.5880800 0.2431067
## 7: 1.7000000 0.1640400 0.4242800 0.4116800
## 8: 1.9666667 0.1619867 0.2066667 0.6313467
## 9: 2.2333333 0.1619867 0.2047867 0.6332267
## 10: 2.5000000 0.1619867 0.2047867 0.6332267
In this case, Petal.Width
.” This is computed
based on the average prediction for each value of “Petal.Width
” shown
and all the observed values of the other variables in the training data.
As can be readily observed, partial dependence can be easily visualized,
as in Figure 2.
Petal.Width
.” In fact, any function of the marginalized function
mp.int = marginalPrediction(data = iris.features,
vars = c("Petal.Width", "Petal.Length"),
n = c(10, nrow(iris)), model = fit, uniform = TRUE,
predict.fun = function(object, newdata) predict(object, newdata, type = "prob"),
aggregate.fun = function(x) list("mean" = mean(x), "variance" = var(x)))
head(mp.int)
## Petal.Width Petal.Length setosa.mean setosa.variance versicolor.mean
## 1: 0.1 1.000000 0.9549867 0.0011619193 0.04448000
## 2: 0.1 1.655556 0.9549867 0.0011619193 0.04448000
## 3: 0.1 2.311111 0.9530933 0.0011317899 0.04637333
## 4: 0.1 2.966667 0.4574667 0.0003524653 0.52818667
## 5: 0.1 3.622222 0.4550400 0.0002619447 0.53061333
## 6: 0.1 4.277778 0.4550400 0.0002619447 0.52472000
## versicolor.variance virginica.mean virginica.variance
## 1: 0.001141889 0.0005333333 0.00000239821
## 2: 0.001141889 0.0005333333 0.00000239821
## 3: 0.001112236 0.0005333333 0.00000239821
## 4: 0.001154918 0.0143466667 0.00054076492
## 5: 0.001016158 0.0143466667 0.00054076492
## 6: 0.001556364 0.0202400000 0.00093196886
Petal.Width
” and “Petal.Length
.” Petal.Width
” and “Petal.Length
.” Non-constant variance indicates
interaction between these variables and those marginalized out of
This article is converted from a Legacy LaTeX article using the texor package. The pdf version is the official version. To report a problem with the html, refer to CONTRIBUTE on the R Journal homepage.
Text and figures are licensed under Creative Commons Attribution CC BY 4.0. The figures that have been reused from other sources don't fall under this license and can be recognized by a note in their caption: "Figure from ...".
For attribution, please cite this work as
Jones, "mmpf: Monte-Carlo Methods for Prediction Functions", The R Journal, 2018
BibTeX citation
@article{RJ-2018-038, author = {Jones, Zachary M.}, title = {mmpf: Monte-Carlo Methods for Prediction Functions}, journal = {The R Journal}, year = {2018}, note = {https://rjournal.github.io/}, volume = {10}, issue = {1}, issn = {2073-4859}, pages = {56-60} }