提问人:Abhishek Bhatia 提问时间:3/5/2017 最后编辑:desertnautAbhishek Bhatia 更新时间:10/25/2022 访问量:51515
数值稳定的 softmax
Numerically stable softmax
问:
下面有没有一种数值稳定的方法来计算 softmax 函数? 我得到的值在神经网络代码中成为 Nans。
np.exp(x)/np.sum(np.exp(y))
答:
计算 softmax 函数没有错,因为它是你的情况。问题似乎来自梯度爆炸或训练方法的此类问题。通过“裁剪值”或“选择正确的初始权重分布”来关注这些问题。
评论
softmax(800)
softmax exp(x)/sum(exp(x)) 实际上在数值上表现良好。它只有正项,所以我们不必担心失去意义,而且分母至少和分子一样大,所以结果保证在 0 到 1 之间。
唯一可能发生的事故是指数中的过度或流入不足。x 的单个元素的溢出或所有元素的下溢将使输出或多或少无用。
但是,通过使用恒等式 softmax(x) = softmax(x + c) 可以很容易地防止这种情况,该恒等式适用于任何标量 c:从 x 中减去 max(x) 会留下一个只有非正条目的向量,从而排除溢出和至少一个为零的元素排除消失的分母(某些但不是所有条目中的下溢是无害的)。
脚注:从理论上讲,灾难性事故的总和是可能的,但你需要一个荒谬的术语。例如,即使使用只能解析 3 位小数的 16 位浮点数---与“正常”64 位浮点数的 15 位小数相比---我们也需要在 2^1431 (~6 x 10^431) 和 2^1432 之间得到相差 2 倍的总和。
Softmax 函数容易出现两个问题:上溢和下溢
溢出:当非常大的数字近似为infinity
下溢:当非常小的数字(数字行中接近零)近似(即四舍五入)为zero
为了在进行softmax计算时解决这些问题,一个常见的技巧是通过从所有元素中减去其中的最大元素来移动输入向量。对于输入向量,定义如下:x
z
z = x-max(x)
然后取新(稳定)向量的softmaxz
例:
def stable_softmax(x):
z = x - max(x)
numerator = np.exp(z)
denominator = np.sum(numerator)
softmax = numerator/denominator
return softmax
# input vector
In [267]: vec = np.array([1, 2, 3, 4, 5])
In [268]: stable_softmax(vec)
Out[268]: array([ 0.01165623, 0.03168492, 0.08612854, 0.23412166, 0.63640865])
# input vector with really large number, prone to overflow issue
In [269]: vec = np.array([12345, 67890, 99999999])
In [270]: stable_softmax(vec)
Out[270]: array([ 0., 0., 1.])
在上述案例中,我们通过使用stable_softmax()
评论
扩展 @kmario23 的答案以支持 1 维或 2 维 numpy 数组或列表。2D 张量(假设第一个维度是批量维度)在通过 softmax 传递一批结果时很常见:
import numpy as np
def stable_softmax(x):
z = x - np.max(x, axis=-1, keepdims=True)
numerator = np.exp(z)
denominator = np.sum(numerator, axis=-1, keepdims=True)
softmax = numerator / denominator
return softmax
test1 = np.array([12345, 67890, 99999999]) # 1D numpy
test2 = np.array([[12345, 67890, 99999999], # 2D numpy
[123, 678, 88888888]]) #
test3 = [12345, 67890, 999999999] # 1D list
test4 = [[12345, 67890, 999999999]] # 2D list
print(stable_softmax(test1))
print(stable_softmax(test2))
print(stable_softmax(test3))
print(stable_softmax(test4))
[0. 0. 1.]
[[0. 0. 1.]
[0. 0. 1.]]
[0. 0. 1.]
[[0. 0. 1.]]
评论
np.seterr(all='raise')
上一个:沿轴的火炬和张量
评论