在函数中为 data.table 应用匹配项

Applying a match within a function for a data.table

提问人:TvCasteren 提问时间:11/9/2023 最后编辑:r2evansTvCasteren 更新时间:11/10/2023 访问量:50

问:

我有一个data.table,其中包含每个ID,X和Y坐标以及许多包含相邻ID的列。相邻 ID 引用此 DT 中的其他观测值/行。

这里的目标是将某些 ID 设置为 NA,如果它们的距离太大。为此,我必须将距离函数应用于 DT 中匹配 ID 的 X 和 Y 坐标。

#the update function
update_columns <- function(dt, columns_to_update) {
   for (col in columns_to_update) {
      dt[, (col) := ifelse(chebyshev_distance(x,y, dt.nearest_neighbours[match.SD, c("x", "y"), on="id"]) > 10, NA, dt[[col]])]
   }
   return(dt.nearest_neighbours)
}

#the chebyshev distance function
chebyshev_distance <- function(x1, y1, data) {
   pmax(abs(x1-data$x), abs(y1-data$y))
}


#I created a mock data.table for reproducing the problem:

dt.nearest_neighbours <- data.table(
  id = c(1,2,3,4), #the ID
  x = c(10, 20, 30, 40), #the X coordinate of the ID
  y = c(5, 10,25,5), #the Y coordinate of the ID
  V1 = c(2,3,2,1), #a neighbour of the ID -> the numbers in the V columns refer to other ID's in this dt
  V2 = c(4,1,4,2), #a second neighbour of the ID
  V3 = c(3,1,1,3) #third neighbour of the ID
)

此当前代码给出以下错误:

'匹配。在调用作用域中找不到 SD,它也不是列名。当 DT 中的第一个参数是单个符号(例如 DT[var])时,data.table 会在调用范围中查找 var。

看来比赛没有正确完成。我该如何解决这个问题?

r 数据表

评论


答:

0赞 TvCasteren 11/9/2023 #1

这是通过以下代码解决的:

update_columns <- function(dt, columns_to_update) {
  for (col in columns_to_update) {
    dt[, (col) := ifelse(chebyshev_distance(x, y, dt[.(id = get(col)), .(id, x, y), on = "id"]) > 10, NA, dt[[col]])]
  }
  return(dt)
}

2赞 r2evans 11/10/2023 #2

另一种解决方案,更喜欢使用连接和完全矢量化(无“帧”)距离计算。

chebyshev_distance2 <- function(x1, y1, x2, y2) pmax(abs(x1-x2), abs(y1-y2))
update_columns2 <- function(DT, columns, limit = 10) {
  if (is.numeric(columns)) columns <- names(DT)[columns]
  stopifnot(
    "'x2' and 'y2' are used internally, they must not be in 'DT'" =
      !any(c("x2", "y2") %in% names(DT)),
    "'columns' must be length 1 or more, non-NA, and character/integer" =
      length(columns) > 0 && !anyNA(columns) && (is.character(columns) || is.numeric(columns)),
    "'limit' must be length-1, non-NA, and numeric/integer" =
      !anyNA(limit) && length(limit) == 1 && is.numeric(limit)
  )
  on.exit(suppressWarnings(DT[, c("x2", "y2") := NULL]), add = TRUE)
  for (col in columns) {
    na <- DT[[col]][1][NA]
    DT[DT, c("x2", "y2") := .(i.x, i.y), on = paste(col, "== id")]
    DT[, (col) := replace(get(col), chebyshev_distance2(x, y, x2, y2) > limit, NA)]
  }
  DT
}

演示:

DT
#       id     x     y    V1    V2    V3
#    <num> <num> <num> <num> <num> <num>
# 1:     1    10     5     2     4     3
# 2:     2    20    10     3     1     1
# 3:     3    30    25     2     4     1
# 4:     4    40     5     1     2     3
update_columns2(DT, c("V1","V2","V3"))[]
#       id     x     y    V1    V2    V3
#    <num> <num> <num> <num> <num> <num>
# 1:     1    10     5     2    NA    NA
# 2:     2    20    10    NA     1     1
# 3:     3    30    25    NA    NA    NA
# 4:     4    40     5    NA    NA    NA

这将产生与您相同的结果,并且使用此(小?)数据集的执行速度稍快

update_columns(DT, c("V1","V2","V3"))[]
#       id     x     y    V1    V2    V3
#    <num> <num> <num> <num> <num> <num>
# 1:     1    10     5     2    NA    NA
# 2:     2    20    10    NA     1     1
# 3:     3    30    25    NA    NA    NA
# 4:     4    40     5    NA    NA    NA

bench::mark(
  TvCasteren = update_columns(DT, c("V1","V2","V3"))[],
  r2evans = update_columns2(DT, c("V1","V2","V3"))[]
)
# # A tibble: 2 × 13
#   expression      min   median `itr/sec` mem_alloc `gc/sec` n_itr  n_gc total_time result       memory time             gc                
#   <bch:expr> <bch:tm> <bch:tm>     <dbl> <bch:byt>    <dbl> <int> <dbl>   <bch:tm> <list>       <list> <list>           <list>            
# 1 TvCasteren   2.53ms   2.81ms      356.        NA     8.58   166     4      466ms <dt [4 × 6]> <NULL> <bench_tm [170]> <tibble [170 × 3]>
# 2 r2evans      2.15ms   2.33ms      430.        NA     6.29   205     3      477ms <dt [4 × 6]> <NULL> <bench_tm [208]> <tibble [208 × 3]>

数据

DT <- data.table::as.data.table(structure(list(id = c(1, 2, 3, 4), x = c(10, 20, 30, 40), y = c(5, 10, 25, 5), V1 = c(2, 3, 2, 1), V2 = c(4, 1, 4, 2), V3 = c(3, 1, 1, 3)), row.names = c(NA, -4L), class = c("data.table", "data.frame")))