提问人:poisonDartFrog 提问时间:9/14/2023 最后编辑:poisonDartFrog 更新时间:9/14/2023 访问量:30
通过 PyTorch 中非常复杂的函数反向传播时遇到问题 - 无法避免就地操作
trouble backpropagating through a very complicated function in pytorch - no way to avoid inplace operations
问:
我想基于神经网络输出的一系列复杂变换来定义一个损失函数。这些转换涉及一些复杂的逻辑,如果没有就地操作,这似乎是不可能的(请参阅注释):
def get_X_torch(C, c_table):
"""
_zmat_transformation.py line 57
C = torch tensor of floats where rows are all bonds, then all angles, then all dihedrals
c_table = torch tensor of ints where rows are all bond_idx, then all angle_idx, then all dihedral_idx
c_table blank indices for beginning of z-matrix are labeled as -9223372036854775807
"""
X = torch.zeros_like(C, device="cuda:0") # ([b a d], n_atoms)
n_atoms = X.shape[1]
# this is some complicated logic - not vectorizable because the variables
# all influence each other throughout the loop (it's a nonlinear transformation)
j: int = 0
for j in range(n_atoms):
B, ref_pos = get_B_torch(X, c_table, j)
S = get_S_torch(C, j)
X[:, j] = torch.mv(B, S) + get_ref_pos_torch(X, c_table[0, j]) # X[:, j] depends on X's current value as a whole!!! This is the tricky step
return X.T
训练代码片段如下所示。我需要使用 my_function_script 迭代构建clash_loss,这是上述功能的包装器,但由于我没有执行 clash_loss +=,这应该没问题。我认为问题出在上面的复杂逻辑上。错误消息是,由于路径中某处的就地操作,它无法采用梯度。
reconstructed_angles = torch.atan2(internal_data_batch_reconstructed[:, 0:304], internal_data_batch_reconstructed[:, 304:])
if clash_mode is True:
clash_loss = torch.tensor(0.0, requires_grad=True, device="cuda")
bonds = torch.tensor(init_z_mat["bond"].values, device="cuda", requires_grad=True)
angles = torch.tensor(init_z_mat["angle"].values * (torch.pi / 180), device="cuda", requires_grad=True)
for i in range(batch_size):
print(i + 1, batch_size)
C = torch.stack((bonds, angles, reconstructed_angles[i]))
xyz = my_function_script(C, construction_table) # very complicated function but written in pure PyTorch
# this function cannot not involve inplace operations (see other bit of code)
temp_loss = 1.0 if get_clash_loss(xyz) > 0.0 else 0.0
clash_loss = clash_loss + temp_loss
total_loss = clash_loss
total_loss.backward() # <--- this fails
我能做些什么来使这一系列逻辑可微分,以便 clash_loss.backward() 工作?对于如此复杂的函数集,手动查找导数是完全不可能的......
我尝试使用副本重写,而没有明显的就地编辑(见下文),但这仍然不起作用。
Xs = [X]
for j in range(n_atoms):
B, ref_pos = get_B_torch(Xs[-1], c_table, j)
S = get_S_torch(C, j)
first = torch.mv(B, S)
second = get_ref_pos_torch(Xs[-1], c_table[0, j])
Xcopy = torch.cat((Xs[-1][:, 0:j - 1], (first + second).reshape((-1, 1)), Xs[-1][:, j + 1:]), -1)
Xs = Xs + [Xcopy]
return Xs[-1].T
答: 暂无答案
评论