提问人:singa1994 提问时间:12/6/2021 更新时间:6/19/2022 访问量:150
用于在火炬张量之间移动向量的 Pytorch 操作
Pytorch operation for moving vectors between torch tensors
问:
假设我们有火炬张量:
A: with shape BxHxW and values in {0,1}, where 0 and 1 are classes
B: with shape Bx2xD and real values, where D is the dimensionality of our vector
We want to create a new tensor of shape BxDxHxW that holds in each index specified in the spatial dimension (HxW), the vector that corresponds to its class (specified by A).
pytorch 中是否有实现该功能的函数?我尝试了火炬散射,但认为情况并非如此。
答:
0赞
Ivan
6/19/2022
#1
您实际上是在寻找相反的操作,即使用包含在另一个张量中的索引从一个张量收集值。
这是一个规范的答案,可以处理这种索引场景,并毫不费力地应用torch.gather
。
让我们用虚拟数据设置一个最小的示例:
>>> b = 2; d = 3; h = 2; w = 1
>>> A = torch.randint(0, 2, (b,h,w)) # bhw
>>> B = torch.rand(b,2,d) # b2d
根据您的问题定义要执行的索引规则,如下所示:
# out[b, d, h, w] = B[b, A[b, h, w]]
我们正在寻找 的 第二维的某种索引,使用 中的值。应用所有三个张量(输入、索引器和输出)时,必须具有相同数量的维度和相同的维度大小,但要索引的维度除外,即 here .观察我们的情况,我们必须坚持这种模式:
B
A
torch.gather
dim=1
# out[b, 1, d, h, w] = B[b, A[b, 1, d, h, w], d, h, w]
因此,为了解释这种变化,我们需要在输入张量和索引张量上释放/扩展其他维度。因此,为了坚持上述形状,我们可以做到:
首先,我们解开两个维度:
A
>>> A_ = A[:,None,None].expand(-1,1,d,-1,-1)
其次,我们解开了两个维度:
B
>>> B_ = B[..., None, None].expand(-1,-1,-1,h,w)
请注意,展开维度不会执行复制。它只是对张量基础数据的视图。在此步骤中,最终的形状为 ,而形状为 。
A_
(b, 1, d, h, w)
B_
(b, 2, d, h, w)
现在,我们可以简单地申请使用和:
torch.gather
dim=1
A_
B_
>>> out = B_.gather(dim=1, index=A_)
我们必须使用 的单例维数 ,这样我们就可以将其压缩到生成的张量上。这是您想要的结果形状:
dim=1
(b, d, h, w)
>>> out[:,0]
评论
B
c
[B,H,W]
v
[B,2,D]
0
1
result
[B,D,H,W]
result[b,d,h,w] = v[b, c[h,w], d]
torch.gather