在 pytorch 中编制索引和切片

Indexing and slicing in pytorch

提问人:mt-clemente 提问时间:3/25/2023 更新时间:3/26/2023 访问量:124

问:

我有一批维数 [B,n,n] 的二维张量和维数 [B,2] 的坐标张量。

a = torch.arange(48).reshape((3,4,4))
coords = torch.tensor([[0,1],[1,2],[1,3]],dtype=int)

给我想要的结果是:

a[torch.arange(3),coords[:,0],coords[:,1]]

我只是不明白为什么我不能只使用以下内容,因为我认为“:”意味着获取所有索引:

a[:,coords[:,0],coords[:,1]]

我在这里错过了什么?

Python 索引 PyTorch 切片

评论


答:

0赞 Ivan 3/26/2023 #1

您可以将这种差异视为并行索引按组合选择

  • 在前者中,使用与 your 和 张量索引器长度相同的排列方式来表示要在每个批处理元素上选择的内容。如果您详细说明了 发生的索引,则您正在执行:dim=1dim=2a[torch.arange(3),coords[:,0],coords[:,1]]

    >>> a[[0, 1, 2], [0, 1, 1], [1, 2, 3])
    tensor([ 1, 22, 39])
    

    因此,您将我所说的“并行”,这意味着首先,然后,最后。这将对应于正在选择和堆叠的单个值:a[0,0,1]a[1,1,2]a[2,1,3]

    >>> torch.stack([a[0,0,1], a[1,1,2], a[2,1,3]])
    tensor([ 1, 22, 39])
    
  • 在后者中,您可以执行组合,因为指的是 上的“全选”。同样,如果我们详细说明索引,我们有::dim=0

    • a[:,0,1] 即。 屈服。a[[0,1,2],0,1]tensor([ 1, 17, 33])
    • a[:,1,2] 即。 屈服a[[0,1,2],1,2]tensor([ 6, 22, 38])
    • a[:,1,3] 即。 屈服。a[[0,1,2],1,3]tensor([ 7, 23, 39])

    这是按列执行的,而不是按批处理执行的。对于所有批处理元素,我们首先取所有 s,然后是所有 s,最后是所有 s。可以使用以下方式执行相应的操作:(0,1)(1,2)(1,3)

    torch.dstack([a[:,0,1], a[:,1,2], a[:,1,3]])
    tensor([[[ 1,  6,  7],
             [17, 22, 23],
             [33, 38, 39]]])