通过 PyTorch 中非常复杂的函数反向传播时遇到问题 - 无法避免就地操作

trouble backpropagating through a very complicated function in pytorch - no way to avoid inplace operations

提问人:poisonDartFrog 提问时间:9/14/2023 最后编辑:poisonDartFrog 更新时间:9/14/2023 访问量:30

问:

我想基于神经网络输出的一系列复杂变换来定义一个损失函数。这些转换涉及一些复杂的逻辑,如果没有就地操作,这似乎是不可能的(请参阅注释):

def get_X_torch(C, c_table):
    """
    _zmat_transformation.py line 57
    C = torch tensor of floats where rows are all bonds, then all angles, then all dihedrals
    c_table = torch tensor of ints where rows are all bond_idx, then all angle_idx, then all dihedral_idx
    c_table blank indices for beginning of z-matrix are labeled as -9223372036854775807
    """
    X = torch.zeros_like(C, device="cuda:0")  # ([b a d], n_atoms)
    n_atoms = X.shape[1]

    # this is some complicated logic - not vectorizable because the variables
    # all influence each other throughout the loop (it's a nonlinear transformation)
    j: int = 0
    for j in range(n_atoms):
        B, ref_pos = get_B_torch(X, c_table, j)
        S = get_S_torch(C, j)
        X[:, j] = torch.mv(B, S) + get_ref_pos_torch(X, c_table[0, j])    # X[:, j] depends on X's current value as a whole!!! This is the tricky step
    return X.T

训练代码片段如下所示。我需要使用 my_function_script 迭代构建clash_loss,这是上述功能的包装器,但由于我没有执行 clash_loss +=,这应该没问题。我认为问题出在上面的复杂逻辑上。错误消息是,由于路径中某处的就地操作,它无法采用梯度。


           reconstructed_angles = torch.atan2(internal_data_batch_reconstructed[:, 0:304], internal_data_batch_reconstructed[:, 304:])
            if clash_mode is True:
                clash_loss = torch.tensor(0.0, requires_grad=True, device="cuda")
                bonds = torch.tensor(init_z_mat["bond"].values, device="cuda", requires_grad=True)
                angles = torch.tensor(init_z_mat["angle"].values * (torch.pi / 180), device="cuda", requires_grad=True)

                for i in range(batch_size):
                    print(i + 1, batch_size)
                    C = torch.stack((bonds, angles, reconstructed_angles[i]))
                    xyz = my_function_script(C, construction_table)  # very complicated function but written in pure PyTorch
                    # this function cannot not involve inplace operations (see other bit of code)
                    temp_loss = 1.0 if get_clash_loss(xyz) > 0.0 else 0.0
                    clash_loss = clash_loss + temp_loss
                total_loss = clash_loss
                total_loss.backward()  # <--- this fails

我能做些什么来使这一系列逻辑可微分,以便 clash_loss.backward() 工作?对于如此复杂的函数集,手动查找导数是完全不可能的......

我尝试使用副本重写,而没有明显的就地编辑(见下文),但这仍然不起作用。

Xs = [X]
for j in range(n_atoms):
    B, ref_pos = get_B_torch(Xs[-1], c_table, j)
    S = get_S_torch(C, j)
    first = torch.mv(B, S)
    second = get_ref_pos_torch(Xs[-1], c_table[0, j])
    Xcopy = torch.cat((Xs[-1][:, 0:j - 1], (first + second).reshape((-1, 1)), Xs[-1][:, j + 1:]), -1)
    Xs = Xs + [Xcopy] 

return Xs[-1].T
pytorch 就地 反向传播

评论

0赞 Yakov Dan 9/14/2023
您能否分享重现该问题的最小代码片段?
0赞 poisonDartFrog 9/14/2023
此代码片段足以重现该问题,前提是您知道reconstructed_angles是神经网络的输出。我认为共享所有代码是不切实际的,因为它有数千行......基本上,函数get_X_torch通过调用函数 get_ref_pos 来构建数组 X,该函数将 X 作为参数。基本上,X 是通过参考当时 X 的其他部分来构建的。这是麻烦的逻辑部分......
0赞 Yakov Dan 9/14/2023
在 get_X_torch 中为每次迭代创建一个 X 的副本是否可行?
1赞 poisonDartFrog 9/14/2023
见上文,还是没用......
1赞 poisonDartFrog 9/14/2023
更新:我通过制作许多副本来避免就地操作,但最终的梯度全部为零,这表明模型无法适应以最小化这种损失,或者由于所有操作而存在梯度消失问题。此外,backward() 非常慢。看来这个损失函数对我来说效果不佳。

答: 暂无答案