PyTorch load_state_dict() 不加载精确值

PyTorch load_state_dict() does not load precise value

提问人:Random Seed 提问时间:11/7/2023 最后编辑:Random Seed 更新时间:11/7/2023 访问量:42

问:

为简单起见,我想使用此代码将火炬模型的所有参数设置为常量72114982

model = Net()
params = model.state_dict()

for k, v in params.items():
    params[k] = torch.full(v.shape, 72114982, dtype=torch.long) 

model.load_state_dict(params)
print(model.state_dict().values())

然后 print 语句显示所有值实际上都设置为与我最初预期的值相差 2 个值。72114984

为简单起见,定义如下Net

class Net(nn.Module):
    def __init__(self) -> None:
        super().__init__()
        self.conv1 = nn.Conv2d(2, 2, 2)
        self.pool = nn.MaxPool2d(2, 2)
        self.conv2 = nn.Conv2d(6, 16, 5)
        self.fc1 = nn.Linear(16 * 5 * 5, 120)
        self.fc2 = nn.Linear(120, 84)
        self.fc3 = nn.Linear(2, 2)
PyTorch 精度

评论


答:

0赞 Karl 11/7/2023 #1

这是数据类型的问题。

模型参数被转换为浮点张量。 足够大,以至于其浮点表示形式四舍五入为 。7211498472114984

您可以通过以下方法进行验证:

x = torch.tensor(72114982, dtype=torch.long)
y = x.float() # y will actually be `72114984.0`

# this returns `True` because x is cast to a float before evaluating
x == y
> tensor(True)

# for the same reason, this returns 0.
y - x
> tensor(0.)

# this returns `False` because the tensors have different values and we don't cast to float
x == y.long()
> tensor(False)

# as longs, the difference correctly evaluates to 2
y.long() - x
> tensor(2)