Pytorch 中的批处理矩阵乘法 - 与输出维度的处理混淆

Batch-Matrix multiplication in Pytorch - Confused with the handling of the output's dimension

提问人:singa1994 提问时间:6/11/2019 最后编辑:singa1994 更新时间:7/31/2020 访问量:13439

问:

我有两个数组:

A
B

数组包含一批RGB图像,形状为:A

[batch, Width, Height, 3]

而 Array 包含对图像进行“类似变换”操作所需的系数,其形状为:B

[batch, 4, 4, 3]

简单地说,单个图像的运算是输出环境映射()的乘法。normalMap * Coefficients

我想要的输出应该保持形状:

[batch, Width, Height, 3]

我尝试使用但失败了。这可能吗?torch.bmm

python 矢量化 pytorch 批处理 矩阵乘法

评论

1赞 Matan Danos 6/11/2019
我不明白矩阵乘法的维度?乘法需要在通道轴上工作吗?也许退房?torch.nn.functional.conv2d
0赞 singa1994 6/12/2019
@Danos我希望将来自张量 A 的批次中的每个图像分别与来自张量 B 的 4*4 矩阵相乘,在通道轴上是的。
0赞 Matan Danos 6/12/2019
根据 torch.bmm 的文档,矩阵尺寸必须一致(即,如果 A*B,则高度等于 4)。如果不是这种情况,则操作失败是有道理的。如果你想要元素乘法,请查看 torch.mul,在这种情况下,我认为你需要确保 B 是可广播的。

答:

5赞 prosti 6/13/2019 #1

我认为您需要计算 PyTorch 可以与

BxCxHxW : number of mini-batches, channels, height, width

格式,并使用 matmul,因为 bmm 使用张量或 ndim/dim/rank =3。

我知道您可能会在网上找到这个,但无论如何:

batch1 = torch.randn(10, 3, 20, 10)
batch2 = torch.randn(10, 3, 10, 30)
res = torch.matmul(batch1, batch2)
res.size() # torch.Size([10, 3, 20, 30])