保存和加载 PyTorch NN 模型(.nnet 或 .onnx 格式)

Saving and Loading a PyTorch NN model (.nnet or .onnx format)

提问人:Coder Boy 提问时间:5/20/2023 最后编辑:Coder Boy 更新时间:5/22/2023 访问量:116

问:

我正在尝试在我的计算机中本地训练和保存 PyTorch 模型(最好是 .nnet 或 .onnet 格式)。

# Defining the neural network class
class Net(nn.Module):
    def __init__(self, input_size, hidden_size1, hidden_size2, output_size):
        super(Net, self).__init__()
        self.hidden1 = nn.Linear(input_size, hidden_size1)
        self.hidden2 = nn.Linear(hidden_size1, hidden_size2)
        self.output = nn.Linear(hidden_size2, output_size)
        self.relu = nn.ReLU()

    def forward(self, x):
        x = self.relu(self.hidden1(x))
        x = self.relu(self.hidden2(x))
        x = self.output(x)
        return x

# Defining the input size, hidden layer sizes, and output size
input_size =5
hidden_size1 = 2
hidden_size2 = 3
output_size = 5

# Creating an instance of the neural network
model = Net(input_size, hidden_size1, hidden_size2, output_size)

# Printing the model architecture
print(model)

我使用以下代码以 .nnet 格式保存了模型

torch.save(model,'theModel.nnet')

我想稍后将模型加载到 PyTorch 对象中,并在以后独立使用该模型,而无需编写相同的代码。 我该怎么做?

我尝试使用

saved_model=torch.load('theModel.nnet')

它抛出错误

AttributeError                            Traceback (most recent call last)
Cell In[7], line 1
----> 1 saved_model=torch.load('theModel.nnet')

File ~\anaconda3\lib\site-packages\torch\serialization.py:712, in load(f, map_location, pickle_module, **pickle_load_args)
    710             opened_file.seek(orig_position)
    711             return torch.jit.load(opened_file)
--> 712         return _load(opened_zipfile, map_location, pickle_module, **pickle_load_args)
    713 return _legacy_load(opened_file, map_location, pickle_module, **pickle_load_args)

File ~\anaconda3\lib\site-packages\torch\serialization.py:1049, in _load(zip_file, map_location, pickle_module, pickle_file, **pickle_load_args)
   1047 unpickler = UnpicklerWrapper(data_file, **pickle_load_args)
   1048 unpickler.persistent_load = persistent_load
-> 1049 result = unpickler.load()
   1051 torch._utils._validate_loaded_sparse_tensors()
   1053 return result

File ~\anaconda3\lib\site-packages\torch\serialization.py:1042, in _load.<locals>.UnpicklerWrapper.find_class(self, mod_name, name)
   1040         pass
   1041 mod_name = load_module_mapping.get(mod_name, mod_name)
-> 1042 return super().find_class(mod_name, name)

AttributeError: Can't get attribute 'Net' on <module '__main__'>

有没有其他方法?

Python 深度学习 PyTorch Onnx NNET

评论

0赞 Prayson W. Daniel 5/20/2023
不应该吗?torch.load('theModel.nnet')

答:

0赞 blue_lama 5/22/2023 #1

尝试

torch.save(model.state_dict(),'theModel.nnet')

state_dict = torch.load('theModel.nnet')
model.load_state_dict(state_dict)

其中按上述方式实例化modelmodel = Net(...)