提问人:Sasha 提问时间:9/4/2023 更新时间:9/4/2023 访问量:86
PyTorch 矩阵乘法不考虑切片
PyTorch matrix multiplication does not respect slicing
问:
我必须对转印器模型进行批量长输入,并注意到批处理和非批处理结果之间的差异,从而达到这一点。我最终隔离了我注意到的第一个差异,结果如下:
import torch
n = 20
vec = torch.rand(n, 20)
a = torch.rand(30, 20)
for i in range(1, n+1):
print(i, torch.equal(
torch.nn.functional.linear(vec, a)[:i],
torch.nn.functional.linear(vec[:i], a)))
产生输出:
1 False
2 False
3 False
4 True
5 True
6 True
7 False
8 False
9 False
10 True
11 True
12 True
13 False
14 False
15 False
16 True
17 True
18 True
19 True
20 True
这只是一个操作,当多次组合时(如在转换器中),它可能会导致较大的发散,从而扩大 torch.allclose 输出 True 的 atol。这是为什么呢,能做些什么吗?
答:
0赞
Alexey Birukov
9/4/2023
#1
欢迎来到浮点运算的勇敢世界! 运算引入了舍入误差,矩阵乘法将它们累积到有效值。https://pytorch.org/docs/stable/notes/numerical_accuracy.html如果避免使用以下方法进行不精确的舍入float
vec = torch.floor (torch.rand(n, 20)*10)
a = torch.floor( torch.rand(30, 20)*10 )
你会得到所有的 -s。True
可能的解决方案是使用 .torch.isclose
戴尔
评论
0赞
Sasha
9/4/2023
您好,谢谢。我不确定你的答案是否与手头的确切问题超级相关,我相信这个问题更微妙。我说的是“批处理”计算(我认为 vec 是 n 个较小张量 vec[i] 的批次),原则上,理想情况下,无论我一次执行还是每批元素执行,计算都应该完全相同。我们在这里看到一个非常奇怪的 False/True 模式,具体取决于批处理元素的数量,这可能与引擎盖下的一些奇怪的优化谈话场所有关......
0赞
Alexey Birukov
9/4/2023
计算是相同的,但它们的顺序可能不是,这很重要。 制动关联定律。floats
0赞
Sasha
9/4/2023
对于不同的 I,VEC[i] 的计算之间没有交互,因此顺序应该无关紧要。
0赞
Alexey Birukov
9/4/2023
首先,for 仍然是矩阵。其次,即使将向量乘以矩阵,里面也有标量乘积,它的计算顺序由机器选择。n>0
vec[:i]
0赞
Sasha
9/4/2023
如果你看到涉及的计算是什么,那么说,它是,简化,就像计算三个浮点数 a、b 和 c 的 ab 和 ac 一样。顺序应该无关紧要,理想情况下,如果 a、b 和 c 是确定性的,则结果应该是确定性的。然而,在这里我们看到(粗略地),要求计算机立即计算 (ab, ac) 并提取结果 ab(第一个坐标)给出的答案与要求它单独计算 ab 不同。
评论
2.0.1+cpu
20
30
vec = torch.rand(30, 30); a = torch.rand(20, 30)
True
i >= 20
a = torch.rand(25, 30)
True
i >= 25
2.0.1
input.shape[0] < weight.shape[0]