提问人:Nikaido 提问时间:4/27/2023 最后编辑:desertnautNikaido 更新时间:4/28/2023 访问量:74
嵌套浮点的 numpy assert_equals
numpy assert_equals for nested floating point
问:
我遇到了一个奇怪的行为,即对 vgg16 机器学习模型的权重进行相等检查
加载两倍的模型
import torch
from torch import nn
from torchvision.models import vgg16
import numpy as np
import torchvision.models as models
model = models.vgg16(weights='IMAGENET1K_V1')
torch.save(model.state_dict(), 'vgg16_model.pth')
vgg = vgg16(pretrained=True)
vgg.load_state_dict(torch.load("vgg16_model.pth", map_location='cpu'), strict=True)
params1 = np.array([param.detach().numpy() for param in vgg.parameters()], dtype=object)
vgg2 = vgg16(pretrained=True)
vgg2.load_state_dict(torch.load("vgg16_model.pth", map_location='cpu'), strict=True)
params2 = np.array([param.detach().numpy() for param in vgg2.parameters()], dtype=object)
请注意,如果我使用np.assert_equals
np.array_equal(params1, params2)
我得到了False
但是,如果我以迭代方式检查嵌套数组,则数组是相等的:
for val1, val2 in zip(params1, params2):
print(np.array_equal(val1, val2))
我错过了什么?是由于我在开始时创建数组的方式,作为?dtype=object
python version 3.9.13
numpy version 1.21.5
答:
2赞
simon
4/27/2023
#1
事实上,类型似乎是问题所在——请注意,您甚至可以按如下方式简化您的示例(在 Python 3.10.10 上使用 Numpy 1.24.3 进行测试):object
vgg = vgg16(pretrained=True)
vgg.load_state_dict(torch.load("vgg16_model.pth", map_location='cpu'), strict=True)
params_list = [param.detach().numpy() for param in vgg.parameters()]
params1 = np.array(params_list, dtype=object)
params2 = np.array(params_list, dtype=object)
print(np.array_equal(params1, params2))
# >>> False
也就是说,即使 ur 和 object 数组中的元素是相同的浮点数组(不仅是相等/等效的数组),比较也会返回 。params1
params2
False
对我来说,Numpy 数组类型的比较似乎相当不直观,并且没有很好的文档记录。我能找到的唯一暗示这种行为的文档是在一些旧的 Numpy 发行说明中:object
对象数组相等性比较
在将来的对象数组比较中,== 和 np.equal 将不再使用身份检查。例如:
>>> a = np.array([np.array([1, 2, 3]), 1]) >>> b = np.array([np.array([1, 2, 3]), 1]) >>> a == b
将始终返回 False(并且将来会出错),即使 a 和 b 中的数组是同一个对象。
虽然你观察到的行为在那里被描述,但关于为什么这种比较返回的决定并不是真正的动机。False
底线:如果你想对 Numpy 数组进行元素比较,也许不要使用该类型。如需进一步阅读,也许还可以看看这个相关问题。object
0赞
hpaulj
4/28/2023
#2
在 1.24 中
In [83]: >>> a = np.array([np.array([1, 2, 3]), 1],object)
...: >>> b = np.array([np.array([1, 2, 3]), 1],object)
In [84]: a==b
<ipython-input-84-a6f7ccb4d5ba>:1: DeprecationWarning: elementwise comparison failed; this will raise an error in the future.
a==b
Out[84]: False
评论
params1==params2