嵌套浮点的 numpy assert_equals

numpy assert_equals for nested floating point

提问人:Nikaido 提问时间:4/27/2023 最后编辑:desertnautNikaido 更新时间:4/28/2023 访问量:74

问:

我遇到了一个奇怪的行为,即对 vgg16 机器学习模型的权重进行相等检查

加载两倍的模型

import torch
from torch import nn
from torchvision.models import vgg16
import numpy as np
import torchvision.models as models

model = models.vgg16(weights='IMAGENET1K_V1')
torch.save(model.state_dict(), 'vgg16_model.pth')

vgg = vgg16(pretrained=True)
vgg.load_state_dict(torch.load("vgg16_model.pth", map_location='cpu'), strict=True)
params1 = np.array([param.detach().numpy() for param in vgg.parameters()], dtype=object)

vgg2 = vgg16(pretrained=True)
vgg2.load_state_dict(torch.load("vgg16_model.pth", map_location='cpu'), strict=True)
params2 = np.array([param.detach().numpy() for param in vgg2.parameters()], dtype=object)

请注意,如果我使用np.assert_equals

np.array_equal(params1, params2)

我得到了False

但是,如果我以迭代方式检查嵌套数组,则数组是相等的:

for val1, val2 in zip(params1, params2):
    print(np.array_equal(val1, val2))

我错过了什么?是由于我在开始时创建数组的方式,作为?dtype=object

python version 3.9.13
numpy version 1.21.5
python numpy 机器学习 pytorch 相等

评论

0赞 hpaulj 4/27/2023
生产什么?有什么警告吗?params1==params2

答:

2赞 simon 4/27/2023 #1

事实上,类型似乎是问题所在——请注意,您甚至可以按如下方式简化您的示例(在 Python 3.10.10 上使用 Numpy 1.24.3 进行测试):object

vgg = vgg16(pretrained=True)
vgg.load_state_dict(torch.load("vgg16_model.pth", map_location='cpu'), strict=True)
params_list = [param.detach().numpy() for param in vgg.parameters()]
params1 = np.array(params_list, dtype=object)
params2 = np.array(params_list, dtype=object)
print(np.array_equal(params1, params2))
# >>> False

也就是说,即使 ur 和 object 数组中的元素是相同的浮点数组(不仅是相等/等效的数组),比较也会返回 。params1params2False

对我来说,Numpy 数组类型的比较似乎相当不直观,并且没有很好的文档记录。我能找到的唯一暗示这种行为的文档是在一些旧的 Numpy 发行说明中:object

对象数组相等性比较

在将来的对象数组比较中,==np.equal 将不再使用身份检查。例如:

>>> a = np.array([np.array([1, 2, 3]), 1])
>>> b = np.array([np.array([1, 2, 3]), 1])
>>> a == b

将始终返回 False(并且将来会出错),即使 ab 中的数组是同一个对象。

虽然你观察到的行为在那里被描述,但关于为什么这种比较返回的决定并不是真正的动机。False

底线:如果你想对 Numpy 数组进行元素比较,也许不要使用该类型。如需进一步阅读,也许还可以看看这个相关问题object

0赞 hpaulj 4/28/2023 #2

在 1.24 中

In [83]: >>> a = np.array([np.array([1, 2, 3]), 1],object)
    ...: >>> b = np.array([np.array([1, 2, 3]), 1],object)

In [84]: a==b
<ipython-input-84-a6f7ccb4d5ba>:1: DeprecationWarning: elementwise comparison failed; this will raise an error in the future.
  a==b
Out[84]: False