在 R 中,在同一图上绘制 LDA 和 QDA 分区线

in R, draw LDA and QDA partition lines on the same plot

提问人:Alex 提问时间:8/9/2023 最后编辑:Sandipan DeyAlex 更新时间:8/10/2023 访问量:78

问:

我很难理解如何在 R 中“导出”由线性 (LDA) 或二次判别分析 (QDA) 产生的分界线方程。

理想情况下,我想在同一张图上比较两种描述(线性 + 二次曲线)。下面是一个输出示例(取自 http://adjchen.com/wiki/classification),有 2 个组。但它应该与多个组相同:

enter image description here

我已经研究了 stackoverflow 上的这个部分解决方案,我认为它有点太复杂了,无法绘制简单的线性划界?

如何在 R 中的线性判别分析图上绘制分类边界

以下是他们在该示例中用于使用的数据集的改编:

library(MASS)

# generate data
set.seed(123)
Ng <- 100 # number of cases per group
group.a.x <- rnorm(n = Ng, mean = 2, sd = 3)
group.a.y <- rnorm(n = Ng, mean = 2, sd = 3)

group.b.x <- rnorm(n = Ng, mean = 11, sd = 3)
group.b.y <- rnorm(n = Ng, mean = 11, sd = 3)

group.a <- data.frame(x = group.a.x, y = group.a.y, group = "A")
group.b <- data.frame(x = group.b.x, y = group.b.y, group = "B")

my.xy <- rbind(group.a, group.b)

# construct models
mdlLDA <- lda(group ~ x + y, data = my.xy)

mdlQDA <- qda(group ~ x + y, data = my.xy)

我最接近的替代方案是包中的函数,但是很难修改或自定义。有没有其他使用基本 R 图和分界线/曲线方程的选项?或者也许是 ggplot 函数?提前感谢您对本:)的任何想法最好。partimat()klaR

R 机器学习 ggplot2 分类 统计

评论


答:

0赞 Sandipan Dey 8/10/2023 #1

下面介绍如何使用 绘制 LDA 和 QDA 决策边界。请注意,颜色较深的点是训练数据点,紫色线是 LDA 的线性决策边界,绿色曲线是 QDA 的非线性决策边界。ggplot2

library(MASS)

# generate data with nonlinear decision boundary
set.seed(123)
Ng <- 500 # number of cases per group

my.xy <- data.frame(x = rnorm(Ng, 2, 3), y = rnorm(Ng, 2, 3))

my.xy$group = 'A'
my.xy[my.xy$x^2/36 + my.xy$y^2/64 < 1,]$group = 'B'

# construct models
mdlLDA <- lda(group ~ x + y, data = my.xy)
mdlQDA <- qda(group ~ x + y, data = my.xy)


#create test data
np = 50
x = seq(from = min(my.xy$x), to = max(my.xy$x), length.out = np)
y = seq(from = min(my.xy$y), to = max(my.xy$y), length.out = np)
df <- expand.grid(x = x, y = y)

df$classL <- as.numeric(predict(mdlLDA, newdata = df)$class)
df$classQ <- as.numeric(predict(mdlQDA, newdata = df)$class)

df$classLf <- ifelse(df$classL==1, 'A', 'B') 
df$classQf <- ifelse(df$classQ==1, 'A', 'B') 

library(tidyverse)
my.xy %>% ggplot() + geom_point(aes(x,y,col=group)) + 
       geom_point(data=df, aes(x,y,col=classQf), alpha=0.1) + 
       geom_contour(data=df, aes(x,y,z=classQ), col='green') + 
       geom_contour(data=df, aes(x,y,z=classL), col='purple', lty=4) + 
       scale_colour_manual(values=c("blue", "red")) + 
  theme_bw() 

enter image description here