为什么使用 pytorch 广播的 KNN 这么慢?

Why is the KNN using pytorch broadcasting is so slow?

提问人:Hajin Lee 提问时间:11/15/2023 最后编辑:Hajin Lee 更新时间:11/16/2023 访问量:31

问:

我正在尝试为网格点找到 knn。这是用于生成网格的代码

def grid_by(lims=[[0, 1], [0, 1]], size=[28, 28]):
    """
    Creates a tensor of 2D grid points.
    Grid points have one-to-one correspondence with input pixel values that are flattened in row-major order.

    Args:
        lims: [[domain of H], [domain of W]]
        size: [H, W]
    Returns:
        grid: Tensor of shape [(H*W), 2]
    """
    assert len(size) == 2 and len(lims) == len(size)
    expansions = [torch.linspace(start, end, steps) if i != 0 else torch.linspace(end, start, steps) for i, ((start, end), steps) in enumerate(zip(lims, size))]
    grid = torch.index_select(torch.cartesian_prod(*expansions),
                        dim=1,
                        index=torch.tensor([1,0]))
    return grid

我制作了一个自定义的 KNN 函数,以便它可以在 pytorch 中的 gpu 上运行。 pytorch 代码如下所示,假设 L2 距离。

def knn(grid, k):
    """
    Brute Force KNN.

    Args:
            grid: Tensor of shape [(H*W), D]
            k: Int representing number of neighbors 
    """
    d = grid.shape[-1]
    Xr = grid.unsqueeze(1)
    Yr = grid.view(1, -1, d)
    distances = torch.sqrt(torch.sum((Xr - Yr)**2, -1))
    dist, index = distances.topk(k, largest=False, dim=-1)
    return dist, index

grid = grid_by().to('cuda')
knn_dist, knn_index = knn(grid, k=10)

sklearn 代码如下所示

grid = grid_by()
nn = NearestNeighbors(n_jobs=-1)
nn.fit(grid)
knn_dist, knn_index = nn.kneighbors(self.grid, n_neighbors=10)

我使用了广播,所以我希望它在 GPU 上运行得很快。但是,当我使用 .速度这么慢的原因是什么?timeit

scikit-learn pytorch knn 数组广播

评论

0赞 Karl 11/15/2023
您能否提供 PyTorch 代码和 SKLEARN 代码的完整示例?
0赞 Hajin Lee 11/15/2023
是的。我编辑了代码

答:

0赞 Karl 11/16/2023 #1

我无法复制。我在一台具有 64 个内核和 3090 GPU 的机器上运行了以下测试。timeit

knn上:cpu786 µs ± 74 µs per loop

knn上:cuda:0197 µs ± 437 ns per loop

sklearn上:cpun_jobs=-121.6 ms ± 212 µs per loop

sklearn上:cpun_jobs=None1.16 ms ± 299 ns per loop

评论

0赞 Hajin Lee 11/16/2023
为简单起见,我从原始手动 knn 代码中删除了批次尺寸。我想在批处理维度上广播是当时速度变慢的原因。谢谢你的回答。
0赞 Hajin Lee 11/16/2023
此外,当我将函数的参数增加到并运行代码时,我收到如下错误。CUDA 错误:遇到非法内存访问 CUDA 内核错误可能会在其他 API 调用时异步报告,因此下面的堆栈跟踪可能不正确。对于调试,请考虑传递 CUDA_LAUNCH_BLOCKING=1。这种情况也会发生吗?sizegrid_by(224,224)