提问人:Penguin 提问时间:11/18/2023 最后编辑:Penguin 更新时间:11/20/2023 访问量:169
元学习中需要梯度的参数如何修改?
How to modify parameters that require gradients in meta learning?
问:
我有一个神经网络,它被训练成输出学习率:
import torch
import torch.nn as nn
import torch.optim as optim
criterion = nn.MSELoss()
device = torch.device("cuda" if torch.cuda.is_available() else "cpu")
class Meta_Model(nn.Module):
def __init__(self):
super(Meta_Model, self).__init__()
self.fc1 = nn.Linear(1,32)
self.fc2 = nn.Linear(32,32)
self.fc3 = nn.Linear(32,32)
self.fc4 = nn.Linear(32,1)
self.lky = nn.LeakyReLU(0.1)
def forward(self, x):
x = self.lky(self.fc1(x))
x = self.lky(self.fc2(x))
x = self.lky(self.fc3(x))
x = self.fc4(x)
return x # x should be some learning rate
meta_model = Meta_Model().to(device)
meta_model_opt = optim.Adam(meta_model.parameters(), lr=1e-1)
我有一些输入和一个我正在尝试学习的函数:
input_tensor = torch.rand(1000,1) # some inputs
label_tensor = 2 * input_tensor # function to learn
我正在尝试更新一个可训练参数来解决这个函数:
meta_model_epochs = 10
w_epochs = 5
for _ in range(meta_model_epochs):
torch.manual_seed(42) # reset seed for reproducibility
w1 = torch.rand(1, requires_grad=True) # reset **trainable weight**
weight_opt = optim.SGD([w1], lr=1e-1) # reset weight optimizer
meta_loss = 0 # reset meta loss
for _ in range(w_epochs):
predicted_tensor = w1 * input_tensor
loss = criterion(predicted_tensor, label_tensor)
meta_loss += loss # add to meta loss
meta_model_output = meta_model(loss.detach().unsqueeze(0)) # input to the meta model is the loss
weight_opt.zero_grad()
loss.backward(retain_graph=True) # get grads
w1 = w1 - meta_model_output * w1.grad # step --> this is the issue
meta_model_opt.zero_grad()
meta_loss.backward()
meta_model_opt.step()
print('meta_loss', meta_loss.item())
因此,设置是元模型应该学习输出最优学习率,以根据当前损失更新可训练参数。w1
问题是我得到了RuntimeError: one of the variables needed for gradient computation has been modified by an inplace operation: [torch.FloatTensor [1]] is at version 2; expected version 0 instead. Hint: enable anomaly detection to find the operation that failed to compute its gradient, with torch.autograd.set_detect_anomaly(True).
我还尝试将更新步骤替换为 ,这解决了问题,但随后元模型没有更新(即损失保持不变)w1.data = w1.data - meta_model_output * w1.grad # step
更新1:
尝试@VonC计算 w1 的更新值(使用 w1:w1_updated_value 的克隆)并将其设置为 w1 的数据的想法:
w1_clone = w1.clone()
w1_clone = w1_clone - meta_model_output * w1.grad # step
w1.data = w1_clone
虽然这消除了错误,但它会导致元模型未更新的相同问题(即损失保持不变)
答: 暂无答案
评论