获取具有最大组平均值的组行的最快解决方案

Fastest solution to get rows of group with largest group average

提问人:LMc 提问时间:11/16/2023 更新时间:11/16/2023 访问量:59

问:

Reprex

假设我有一个数值矩阵:m

m <- as.matrix(iris[-5])
#      Sepal.Length Sepal.Width Petal.Length Petal.Width
# [1,]          5.1         3.5          1.4         0.2
# [2,]          4.9         3.0          1.4         0.2
# [3,]          4.7         3.2          1.3         0.2
# ...

以及一个将以下行分组的向量:groupsm

groups <- as.character(iris$Species)
# [1] "setosa" "setosa" "setosa" ...

问题

返回所有列中平均值最大的组的行的最方法是什么?如果可能的话,我还想跟踪原始行号。m

预期输出

在此示例中,该组具有最大的平均值:"virginica"

sapply(split(m, groups), mean)
#     setosa versicolor  virginica 
#     2.5355     3.5730     4.2850 

因此,预期输出为:

m[groups == "virginica", ]
#       Sepal.Length Sepal.Width Petal.Length Petal.Width
#  [1,]          6.3         3.3          6.0         2.5
#  [2,]          5.8         2.7          5.1         1.9
#  [3,]          7.1         3.0          5.9         2.1
# ...

如前所述,在子设置后重置行号,我会对一个同时跟踪原始行号的解决方案感兴趣(在这种情况下,这些行号将是)。101-150


我读到的大多数 SO 问题都倾向于以每组具有最大值的行为中心,而不是返回具有最大平均值的组的行。

因此,一个潜在的解决方案(不跟踪行号)是:

m[groups == names(which.max(sapply(split(m, groups), mean))),]

但我很好奇是否有更快的选择(我猜最快的选择可能是)。请附上基准。在我的实际问题中,我有一个包含数千个矩阵的列表,每个矩阵都有几千行。data.table

R 性能 分组

评论


答:

2赞 Rui Barradas 11/16/2023 #1

这个基本的 R 解决方案很简单,而且速度不应该很慢。 用 C 语言编码,是一个快速函数。
仅打印前 6 行。
rowMeans

vec <- as.matrix(iris[-5]) |> rowMeans()
s <- tapply(vec, iris$Species, mean) |> which.max() |> names()

iris[iris$Species == s, ] |> head()
#>     Sepal.Length Sepal.Width Petal.Length Petal.Width   Species
#> 101          6.3         3.3          6.0         2.5 virginica
#> 102          5.8         2.7          5.1         1.9 virginica
#> 103          7.1         3.0          5.9         2.1 virginica
#> 104          6.3         2.9          5.6         1.8 virginica
#> 105          6.5         3.0          5.8         2.2 virginica
#> 106          7.6         3.0          6.6         2.1 virginica

创建于 2023-11-15 使用 reprex v2.0.2


上面的代码是一个管道,后跟最后一个子集。

s <- iris[-5] |>
  as.matrix() |> 
  rowMeans() |>
  tapply(iris$Species, mean) |> 
  which.max() |> 
  names()

iris[iris$Species == s, ] |> head()
#>     Sepal.Length Sepal.Width Petal.Length Petal.Width   Species
#> 101          6.3         3.3          6.0         2.5 virginica
#> 102          5.8         2.7          5.1         1.9 virginica
#> 103          7.1         3.0          5.9         2.1 virginica
#> 104          6.3         2.9          5.6         1.8 virginica
#> 105          6.5         3.0          5.8         2.2 virginica
#> 106          7.6         3.0          6.6         2.1 virginica

创建于 2023-11-15 使用 reprex v2.0.2

2赞 Maël 11/16/2023 #2

以下是使用 .使用 和 代替 和 的执行速度是基准测试的 3 倍。collapse::fmeanfmeanrowMeansmeansplit

library(collapse)
library(microbenchmark)

microbenchmark(
  base = m[groups == names(which.max(sapply(split(m, groups), mean))), ],
  fmean = m[groups == names(which.max(fmean(split(m, groups)))), ],
  "fmean + rowMeans" = m[groups == names(which.max(rowMeans(fmean(m, groups)))),]
)

# Unit: microseconds
#              expr     min       lq      mean   median       uq     max neval
#              base 217.801 226.9510 260.34502 237.6000 259.8510 698.400   100
#             fmean 155.301 165.0010 192.98291 173.3515 197.0010 444.701   100
#  fmean + rowMeans  61.500  70.2505  83.35604  78.8510  87.4515 187.101   100
2赞 jblood94 11/16/2023 #3

调整@Maël的答案以获得额外的加速:

library(Rfast)

microbenchmark::microbenchmark(
  "fmean + rowMeans" = m[groups == names(which.max(rowMeans(fmean(m, groups)))),],
  "fmean + rowSums" = m[groups == names(which.max(fmean(rowSums(m), groups))),],
  "fmean + rowsums" = m[groups == names(which.max(fmean(rowsums(m), groups))),],
  check = "identical",
  unit = "relative",
  times = 1e3
)
#> Unit: relative
#>              expr      min       lq     mean   median       uq       max neval
#>  fmean + rowMeans 1.203125 1.205882 1.101343 1.200000 1.234899 0.9703661  1000
#>   fmean + rowSums 1.171875 1.176471 1.112008 1.171429 1.194631 0.9947705  1000
#>   fmean + rowsums 1.000000 1.000000 1.000000 1.000000 1.000000 1.0000000  1000

随着矩阵的增大,加速变得更加明显:

m <- matrix(runif(1e6), 1e3, 1e3)
groups <- sample(LETTERS, 1e3, 1)

microbenchmark::microbenchmark(
  "fmean + rowMeans" = m[groups == names(which.max(rowMeans(fmean(m, groups)))),],
  "fmean + rowSums" = m[groups == names(which.max(fmean(rowSums(m), groups))),],
  "fmean + rowsums" = m[groups == names(which.max(fmean(rowsums(m), groups))),],
  check = "identical",
  unit = "relative"
)
#> Unit: relative
#>              expr      min       lq     mean   median       uq       max neval
#>  fmean + rowMeans 7.206646 6.533759 5.260140 6.170580 5.935219 0.8081945   100
#>   fmean + rowSums 6.070651 5.497177 4.481413 5.150618 4.902984 1.5396386   100
#>   fmean + rowsums 1.000000 1.000000 1.000000 1.000000 1.000000 1.0000000   100