用于在火炬张量之间移动向量的 Pytorch 操作

Pytorch operation for moving vectors between torch tensors

提问人:singa1994 提问时间:12/6/2021 更新时间:6/19/2022 访问量:150

问:

假设我们有火炬张量:

A: with shape BxHxW and values in {0,1}, where 0 and 1 are classes
B: with shape Bx2xD and real values, where D is the dimensionality of our vector

We want to create a new tensor of shape BxDxHxW that holds in each index specified in the spatial dimension (HxW), the vector that corresponds to its class (specified by A).

pytorch 中是否有实现该功能的函数?我尝试了火炬散射,但认为情况并非如此。

Python 矩阵 PyTorch 张量

评论

2赞 hdkrgr 12/7/2021
您似乎在答案中使用了两次,一次作为张量名称,一次作为维度大小。因此,让我解释一下您的问题:给定形状的类张量和形状的向量张量,其中第二维对应于可能的类或,您正在寻找一种有效的方法来计算形状的张量,使得 .正确?如果是的话,这实际上有些不同,而且非常有趣!Bc[B,H,W]v[B,2,D]01result[B,D,H,W]result[b,d,h,w] = v[b, c[h,w], d]torch.gather

答:

0赞 Ivan 6/19/2022 #1

您实际上是在寻找相反的操作,即使用包含在另一个张量中的索引从一个张量收集值。 这是一个规范的答案,可以处理这种索引场景,并毫不费力地应用torch.gather

让我们用虚拟数据设置一个最小的示例:

>>> b = 2; d = 3; h = 2; w = 1
>>> A = torch.randint(0, 2, (b,h,w)) # bhw
>>> B = torch.rand(b,2,d) # b2d
  1. 根据您的问题定义要执行的索引规则,如下所示:

    # out[b, d, h, w] = B[b, A[b, h, w]]
    
  2. 我们正在寻找 的 第二维的某种索引,使用 中的值。应用所有三个张量(输入、索引器和输出)时,必须具有相同数量的维度和相同的维度大小,但要索引的维度除外, here .观察我们的情况,我们必须坚持这种模式:BAtorch.gatherdim=1

    # out[b, 1, d, h, w] = B[b, A[b, 1, d, h, w], d, h, w]
    
  3. 因此,为了解释这种变化,我们需要在输入张量和索引张量上释放/扩展其他维度。因此,为了坚持上述形状,我们可以做到:

    首先,我们解开两个维度:A

    >>> A_ = A[:,None,None].expand(-1,1,d,-1,-1)
    

    其次,我们解开了两个维度:B

    >>> B_ = B[..., None, None].expand(-1,-1,-1,h,w)
    

    请注意,展开维度不会执行复制。它只是对张量基础数据的视图。在此步骤中,最终的形状为 ,而形状为 。A_(b, 1, d, h, w)B_(b, 2, d, h, w)

  4. 现在,我们可以简单地申请使用和:torch.gatherdim=1A_B_

    >>> out = B_.gather(dim=1, index=A_)
    

    我们必须使用 的单例维数 ,这样我们就可以将其压缩到生成的张量上。这是您想要的结果形状:dim=1(b, d, h, w)

    >>> out[:,0]