提问人:Coder Boy 提问时间:5/20/2023 最后编辑:Coder Boy 更新时间:5/22/2023 访问量:116
保存和加载 PyTorch NN 模型(.nnet 或 .onnx 格式)
Saving and Loading a PyTorch NN model (.nnet or .onnx format)
问:
我正在尝试在我的计算机中本地训练和保存 PyTorch 模型(最好是 .nnet 或 .onnet 格式)。
# Defining the neural network class
class Net(nn.Module):
def __init__(self, input_size, hidden_size1, hidden_size2, output_size):
super(Net, self).__init__()
self.hidden1 = nn.Linear(input_size, hidden_size1)
self.hidden2 = nn.Linear(hidden_size1, hidden_size2)
self.output = nn.Linear(hidden_size2, output_size)
self.relu = nn.ReLU()
def forward(self, x):
x = self.relu(self.hidden1(x))
x = self.relu(self.hidden2(x))
x = self.output(x)
return x
# Defining the input size, hidden layer sizes, and output size
input_size =5
hidden_size1 = 2
hidden_size2 = 3
output_size = 5
# Creating an instance of the neural network
model = Net(input_size, hidden_size1, hidden_size2, output_size)
# Printing the model architecture
print(model)
我使用以下代码以 .nnet 格式保存了模型
torch.save(model,'theModel.nnet')
我想稍后将模型加载到 PyTorch 对象中,并在以后独立使用该模型,而无需编写相同的代码。 我该怎么做?
我尝试使用
saved_model=torch.load('theModel.nnet')
它抛出错误
AttributeError Traceback (most recent call last)
Cell In[7], line 1
----> 1 saved_model=torch.load('theModel.nnet')
File ~\anaconda3\lib\site-packages\torch\serialization.py:712, in load(f, map_location, pickle_module, **pickle_load_args)
710 opened_file.seek(orig_position)
711 return torch.jit.load(opened_file)
--> 712 return _load(opened_zipfile, map_location, pickle_module, **pickle_load_args)
713 return _legacy_load(opened_file, map_location, pickle_module, **pickle_load_args)
File ~\anaconda3\lib\site-packages\torch\serialization.py:1049, in _load(zip_file, map_location, pickle_module, pickle_file, **pickle_load_args)
1047 unpickler = UnpicklerWrapper(data_file, **pickle_load_args)
1048 unpickler.persistent_load = persistent_load
-> 1049 result = unpickler.load()
1051 torch._utils._validate_loaded_sparse_tensors()
1053 return result
File ~\anaconda3\lib\site-packages\torch\serialization.py:1042, in _load.<locals>.UnpicklerWrapper.find_class(self, mod_name, name)
1040 pass
1041 mod_name = load_module_mapping.get(mod_name, mod_name)
-> 1042 return super().find_class(mod_name, name)
AttributeError: Can't get attribute 'Net' on <module '__main__'>
有没有其他方法?
答:
0赞
blue_lama
5/22/2023
#1
尝试
torch.save(model.state_dict(),'theModel.nnet')
和
state_dict = torch.load('theModel.nnet')
model.load_state_dict(state_dict)
其中按上述方式实例化model
model = Net(...)
评论
torch.load('theModel.nnet')