计算 2 个 numpy 数组之间的最接近的对位 - KDTree

Compute closest neghtbor between 2 numpy arrays - KDTree

提问人:Arun 提问时间:8/31/2023 更新时间:8/31/2023 访问量:38

问:

我有 2 个 numpy 数组:一个由 int 值组成的(较小的)数组,b(较大的)数组由浮点值组成。这个想法是 b 包含接近 a 中某些 int 值的浮点值。例如,作为玩具示例,我有下面的代码。数组不是这样排序的,我在 a 和 b 上使用 np.sort() 来获取:

a = np.array([35, 11, 48, 20, 13, 31, 49])
b = np.array([34.78, 34.8, 35.1, 34.99, 11.3, 10.7, 11.289, 18.78, 19.1, 20.05, 12.32, 12.87, 13.5, 31.03, 31.15, 29.87, 48.1, 48.5, 49.2])

对于 a 中的每个元素,b 中有多个浮点值,目标是为 a 中的每个元素获取 b 中最接近的值

为了天真地实现这一点,我使用了 for 循环:

for e in a:
    idx = np.abs(e - b).argsort()
    print(f"{e} has nearest match = {b[idx[0]]:.4f}")
'''
11 has nearest match = 11.2890
13 has nearest match = 12.8700
20 has nearest match = 20.0500
31 has nearest match = 31.0300
35 has nearest match = 34.9900
48 has nearest match = 48.1000
49 has nearest match = 49.2000
'''

a 中可以有 b 中不存在的值,反之亦然。

a.size = 2040 和 b.size = 1041901

要构造 KD-Tree:

# Construct KD-Tree using and query nearest neighnor-
kd_tree = KDTree(data = np.expand_dims(a, 1))
dist_nn, idx_nn = kd_tree.query(x = np.expand_dims(b, 1), k = [1])


dist.shape, idx.shape
# ((19, 1), (19, 1))

为了获得“b”相对于“a”的最近邻,我这样做:

b[idx]
'''
array([[10.7  ],
       [10.7  ],
       [10.7  ],
       [11.289],
       [11.289],
       [11.289],
       [11.3  ],
       [11.3  ],
       [11.3  ],
       [12.32 ],
       [12.32 ],
       [12.32 ],
       [12.87 ],
       [12.87 ],
       [12.87 ],
       [12.87 ],
       [13.5  ],
       [13.5  ],
       [18.78 ]])
'''

问题:

  • KD-Tree 似乎没有超过“a”中的值 20。[31, 25, 48, 49] 在 A 中完全错过
  • 与for循环的输出相比,它找到的大多数最近邻都是错误的!

怎么了?

python 数组 numpy 最近邻

评论


答:

1赞 Homer512 8/31/2023 #1

如果要获取 中每个条目的最接近的元素,请为 构建 KD-Tree,然后查询 。aba

from scipy import spatial

kd = spatial.KDTree(b[:,np.newaxis])
distances, indices = kd.query(a[:, np.newaxis])
values = b[indices]

for ai, bi in zip(a, values):
    print(f"{ai} has nearest match = {bi:.4f}")
35 has nearest match = 34.9900
11 has nearest match = 11.2890
48 has nearest match = 48.1000
20 has nearest match = 20.0500
13 has nearest match = 12.8700
31 has nearest match = 31.0300
49 has nearest match = 49.2000