提问人:slow_learner 提问时间:12/21/2022 最后编辑:Philslow_learner 更新时间:12/21/2022 访问量:547
如何计算回归树的 R 中的均方误差?
How can I calculate the mean square error in R of a regression tree?
问:
我正在研究葡萄酒质量数据库。
我正在研究取决于不同变量的回归树,例如:
library(rpart)
library(rpart.plot)
library(rattle)
library(naniar)
library(dplyr)
library(ggplot2)
vinos <- read.csv(file = 'Wine.csv', header = T)
arbol0<-rpart(formula=quality~chlorides, data=vinos, method="anova")
fancyRpartPlot(arbol0)
arbol1<-rpart(formula=quality~chlorides+density, data=vinos, method="anova")
fancyRpartPlot(arbol1)
我想计算均方误差,看看 arbol1 是否比 arbol0 好。我将使用我自己的数据集,因为没有更多可用数据。我试着把它做成
aaa<-predict(object=arbol0, newdata=data.frame(chlorides=vinos$chlorides), type="anova")
bbb<-predict(object=arbol1, newdata=data.frame(chlorides=vinos$chlorides, density=vinos$density), type="anova")
然后从 和 中手动减去 DataFrame 的最后一列。但是,我收到错误。有人可以帮我吗?aaa
bbb
答:
1赞
Esther
12/21/2022
#1
这个网站可能对你有用。在训练模型之前,将数据集拆分为训练子集和测试子集非常重要。在下面的代码中,我用函数完成了它,但从 caTools 包中调用了另一个函数来执行相同的过程。我附上这个网站,您可以在其中看到在 R 中拆分数据的所有方法。base
sample.split
请记住,均方误差 (MSE) 的函数如下:
因此,使用 R 应用它非常简单。您只需计算观察到的(即来自测试子集的响应变量)和预测值(即您从带有函数的模型中预测的值)之间的平方差的平均值。predict
基于以前的网站,您的葡萄酒数据集的解决方案可能是这个。
library(rpart)
library(dplyr)
library(data.table)
vinos <- fread(file = 'Winequality-red.csv', header = TRUE)
# Split data into train and test subsets
sample_index <- sample(nrow(vinos), size = nrow(vinos)*0.75)
train <- vinos[sample_index, ]
test <- vinos[-sample_index, ]
# Train regression trees models
arbol0 <- rpart(formula = quality ~ chlorides, data = train, method = "anova")
arbol1 <- rpart(formula = quality ~ chlorides + density, data = train, method = "anova")
# Make predictions for each model
pred0 <- predict(arbol0, newdata = test)
pred1 <- predict(arbol1, newdata = test)
# Calculate MSE for each model
mean((pred0 - test$quality)^2)
mean((pred1 - test$quality)^2)
评论