在分组列表列中捕获最大值

capture max values in grouped list column

提问人:Joe 提问时间:11/5/2023 最后编辑:Joe 更新时间:11/5/2023 访问量:38

问:

我有几个相同长度的列表,可以在所有列表中捕获每个模拟的最大值:

library(purrr)

POST_SIMS <- 3

set.seed(42)

x <- list(rnorm(POST_SIMS, 0, 1))
y <- list(rnorm(POST_SIMS, 0, 1))
z <- list(rnorm(POST_SIMS, 0, 1))

x;y;z

pmap(list(x,y,z), pmax)

我正在努力将此过程应用于列表列中的组。我生成一个名为列表列的列表列,该每天为每个条件存储 4 个模拟值。post

我想以 post_max 为单位保存每天每次模拟的最大值,即每个单元格 4 个值的列表列。

library(tidyverse)

POST_SIMS <- 4
CONDITIONS <- 3
DURATION <- 2

df <- 
    tibble(
        day = rep(1:DURATION, each = CONDITIONS),
        condition = rep(LETTERS[1:CONDITIONS], times = DURATION),
    ) |> 
    rowwise() |>
    mutate(post = list(rnorm(POST_SIMS, 0, 1))) 

df |>
    group_by(day) |>
    mutate(post_max = pmap(list(post), pmax)) |> 
    unnest(cols = c(post, post_max))

pmax()应用于 3 个帖子列表的分组并没有产生我预期的,即每天 4 个最大模拟值的一组一致。相反,我只是复制所有值:

如果我和使用,我会得到我想要的结果:每天 4 个最大值的一致集合unnest()max()

df |> 
    unnest() |> 
    filter(day == 1) |> 
    mutate(post_max = max(post))  # updated to max()

但是,由于各种预期原因,使用列表列单元格不起作用:max 返回单个值,而我想要列出值。max()

df |> 
    group_by(day) |> 
    mutate(post_max = max(post)) # updated to max()

在计算最大值之前取消嵌套在预期的应用程序中不是一个可行的解决方案,我还试图更好地理解如何在数据保留在列表列中时处理数据。

如果我将数据传播得更广,我可以在...pmax()pmap()

df |> 
    pivot_wider(
            id_cols = c(day), 
            names_from = "condition",
            values_from = 'post'
        ) |> 
    mutate(post_max = pmap(list(A,B,C), pmax)) |> 
    unnest(cols = c(A, B, C, post_max))

...但这也是不可取的,因为它需要在“列表(A,B,C)”中手动列出条件,并且预期的应用程序需要灵活地适应任意数量的条件> 1。因此,在仅对数据进行分组时执行 pmax 操作是可取的,因为这会自动适应任意数量的条件。

R 列表

评论


答:

1赞 jay.sf 11/5/2023 #1

用。Reduce

> by(df, ~day, \(x) cbind(day=x[1, 1], setNames(as.data.frame(x$post), x$condition), 
+                         p_max=Reduce(pmax, x$post))) |> do.call(what='rbind')
    day          A          B          C      p_max
1.1   1 -0.0627141 -0.2787888 -2.6564554 -0.0627141
1.2   1  1.3048697 -0.1333213 -2.4404669  1.3048697
1.3   1  2.2866454  0.6359504  1.3201133  2.2866454
1.4   1 -1.3888607 -0.2842529 -0.3066386 -0.2842529
2.1   2 -1.7813084 -0.4304691 -0.6399949 -0.4304691
2.2   2 -0.1719174 -0.2572694  0.4554501  0.4554501
2.3   2  1.2146747 -1.7631631  0.7048373  1.2146747
2.4   2  1.8951935  0.4600974  1.0351035  1.8951935

数据:

> dput(df)
structure(list(day = c(1L, 1L, 1L, 2L, 2L, 2L), condition = c("A", 
"B", "C", "A", "B", "C"), post = list(c(-0.608926375407211, 0.50495512329797, 
-1.71700867907334, -0.784459008379496), c(-0.850907594176518, 
-2.41420764994663, 0.0361226068922556, 0.205998600200254), c(-0.361057298548666, 
0.758163235699517, -0.726704827076575, -1.36828104441929), c(0.432818025888717, 
-0.811393176186672, 1.44410126172125, -0.431446202613345), c(0.655647883402207, 
0.321925265203947, -0.783838940880375, 1.57572751979198), c(0.642899305717316, 
0.0897606465996057, 0.276550747291463, 0.679288816055271))), class = c("rowwise_df", 
"tbl_df", "tbl", "data.frame"), row.names = c(NA, -6L), groups = structure(list(
    .rows = structure(list(1L, 2L, 3L, 4L, 5L, 6L), ptype = integer(0), class = c("vctrs_list_of", 
    "vctrs_vctr", "list"))), row.names = c(NA, -6L), class = c("tbl_df", 
"tbl", "data.frame")))