此代码将循环矩阵乘法转换为 einsum 是否正确?

Is this code correct to convert for loop matrix multiplication into einsum?

提问人:Sai Kishore 提问时间:11/17/2023 更新时间:11/18/2023 访问量:27

问:

我已经为我想执行的矩阵乘法编写了一个基于 for 循环和基于 einsum 的代码。你能帮我检查它的正确性吗?

`

w = torch.randn((10,32,32))
x = torch.randn((3,32,32))
x_c = x.clone()
z = torch.zeros(x.shape)
for i in range(x.shape[0]):
  dummy_x = torch.zeros((x.shape[1],w.shape[2]))
  for j in range(w.shape[0]):
    dummy_x += torch.matmul(x[i],w[j])
  z[i]=dummy_x

result = torch.einsum("ijk,lkm->ijm",x_c,w)
# result = torch.einsum("iljm->ijm",result)
torch.eq(result,z)

我尝试了上面的代码并使用torch.eq检查了相等性,但答案是错误的

PyTorch 火炬 Einsum

评论


答:

0赞 Adesoji Alu 11/18/2023 #1
import torch

# Initialize the tensors
w = torch.randn((10, 32, 32))
x = torch.randn((3, 32, 32))

# For loop based matrix multiplication
z = torch.zeros(x.shape)
for i in range(x.shape[0]):
    dummy_x = torch.zeros((x.shape[1], w.shape[2]))
    for j in range(w.shape[0]):
        dummy_x += torch.matmul(x[i], w[j])
    z[i] = dummy_x

# Expand dimensions to make the tensors broadcastable
x_expanded = x.unsqueeze(1)  # Shape: (3, 1, 32, 32)
w_expanded = w.unsqueeze(0)  # Shape: (1, 10, 32, 32)

# Perform batch matrix multiplication and sum over the second dimension
result = torch.matmul(x_expanded, w_expanded).sum(dim=1)  # Shape: (3, 32, 32)

# Check for equality
are_equal = torch.all(torch.eq(result, z))

print("Are the results equal: ", are_equal.item())

评论

0赞 Sai Kishore 11/18/2023
非常感谢您的回答,einsum 代码是否正确??