提问人:stats_noob 提问时间:1/12/2021 最后编辑:stats_noob 更新时间:1/15/2021 访问量:637
R:从函数获取规则
R: Obtaining Rules from a Function
问:
我正在使用 R 编程语言。我使用了“rpart”库,并使用一些数据拟合了决策树:
#from a previous question : https://stackoverflow.com/questions/65678552/r-changing-plot-sizes
library(rpart)
car.test.frame$Reliability = as.factor(car.test.frame$Reliability)
z.auto <- rpart(Reliability ~ ., car.test.frame)
plot(z.auto)
text(z.auto, use.n=TRUE, xpd=TRUE, cex=.8)
这很好,但我正在寻找一种更简单的方法来总结这棵树的结果,以防树变得太大、太复杂和杂乱(并且无法可视化)。我在这里找到了另一个stackoverflow帖子,展示了如何获取规则列表: 从 rpart 包中的决策规则中提取信息
library(party)
library(partykit)
party_obj <- as.party.rpart(z.auto, data = TRUE)
decisions <- partykit:::.list.rules.party(party_obj)
cat(paste(decisions, collapse = "\n"))
这将返回以下规则列表(每行是与“z.auto”图相对应的规则):
Country %in% c("NA", "Germany", "Korea", "Mexico", "Sweden", "USA") & Weight >= 3167.5
Country %in% c("NA", "Germany", "Korea", "Mexico", "Sweden", "USA") & Weight < 3167.5
Country %in% c("NA", "Japan", "Japan/USA")>
但是,从此列表中,无法知道哪个规则导致哪个值为“可靠性”。目前,我正在手动解释树并手动跟踪每个规则到结果,但是有没有办法在每行中添加“相应的可靠性值”?
例如,是否有可能生产这样的东西?
Country %in% c("NA", "Germany", "Korea", "Mexico", "Sweden", "USA") & Weight >= 3167.5 then reliability = 3,7,4,0
(注1:我也不确定为什么这些国家显示为“befgh”而不是他们的真实名称。
注2:我知道有一个库“rpart.plot”,它有一种更简单的方法来获取这些规则。但是,我使用的是一台没有互联网接入或 USB 端口的计算机,因此我无法下载 rpart.plot 库。我有带有一些预加载包的 R。我正在尝试使用诸如 rpart、dplyr、purr、party、partykit、base R 中的函数等库来获取决策规则)
谢谢
答:
2赞
jared_mamrot
1/15/2021
#1
这不是我的专业领域,但也许这个函数(来自 https://www.togaware.com/datamining/survivor/Convert_Tree.html)可以做你想做的事情:
library(rpart)
car.test.frame$Reliability = as.factor(car.test.frame$Reliability)
z.auto <- rpart(Reliability ~ ., car.test.frame)
plot(z.auto, margin = 0.25)
text(z.auto, pretty = TRUE, cex = 0.8,
splits = TRUE, use.n = TRUE, all = FALSE)
list.rules.rpart <- function(model)
{
if (!inherits(model, "rpart")) stop("Not a legitimate rpart tree")
#
# Get some information.
#
frm <- model$frame
names <- row.names(frm)
ylevels <- attr(model, "ylevels")
ds.size <- model$frame[1,]$n
#
# Print each leaf node as a rule.
#
for (i in 1:nrow(frm))
{
if (frm[i,1] == "<leaf>")
{
# The following [,5] is hardwired - needs work!
cat("\n")
cat(sprintf(" Rule number: %s ", names[i]))
cat(sprintf("[yval=%s cover=%d (%.0f%%) prob=%0.2f]\n",
ylevels[frm[i,]$yval], frm[i,]$n,
round(100*frm[i,]$n/ds.size), frm[i,]$yval2[,5]))
pth <- path.rpart(model, nodes=as.numeric(names[i]), print.it=FALSE)
cat(sprintf(" %s\n", unlist(pth)[-1]), sep="")
}
}
}
list.rules.rpart(z.auto)
>Rule number: 4 [yval=3 cover=10 (20%) prob=0.00]
> Country=Germany,Korea,Mexico,Sweden,USA
> Weight>=3168
>
> Rule number: 5 [yval=2 cover=18 (37%) prob=4.00]
> Country=Germany,Korea,Mexico,Sweden,USA
> Weight< 3168
>
> Rule number: 3 [yval=5 cover=21 (43%) prob=2.00]
> Country=Japan,Japan/USA
评论
0赞
stats_noob
1/15/2021
谢谢!这太完美了,我一直在寻找这样的东西!只是为了澄清:“规则编号:4”,“4”真的没有任何意义?yval=3 ...“3”指的是类变量?覆盖=10 (20%) ...表示根据此规则分类的数据的百分比?我不确定“prob=0.00”是什么意思,你知道吗?非常感谢您的帮助!
上一个:根据列中的条件为组赋值
下一个:按 row.names 子集矩阵
评论