如何避免numpy.random.choice中的舍入错误?

How to avoid roundoff errors in numpy.random.choice?

提问人:Fırat Kıyak 提问时间:2/25/2022 最后编辑:Fırat Kıyak 更新时间:6/11/2022 访问量:1229

问:

比如说x_1,x_2,...,x_n是n个对象,人们想选择其中一个,这样选择x_i的概率就成正比,u_i某个数字。Numpy为此提供了一个函数:

x, u = np.array([x_1, x_2, ..., x_n]), np.array([u_1, ..., u_n])
np.random.choice(x, p = u/np.sum(u))

但是,我观察到此代码有时会抛出一个 ValueError,说“概率总和不等于 1”。这可能是由于有限精度算术的舍入误差所致。应该怎么做才能使此功能正常工作?

python numpy 随机 浮点 精度

评论

0赞 Mortz 2/25/2022
您担心什么类型的错误?
1赞 Pychopath 2/25/2022
类似问题
0赞 Fırat Kıyak 3/4/2022
@Mortz:“ValueError:概率总和不等于 1”
1赞 Fırat Kıyak 3/9/2022
@Mortz stackoverflow.com/a/60386427/6087087 提供了一个解决方案。numpy.random.multinomial (docs.scipy.org/doc/numpy-1.15.0/reference/generated/...) 会自动调整最后的概率来解决问题,但需要注意的是,不应依赖此概率。其他答案,不要给出满意的答案。例如,该问题 stackoverflow.com/a/46539921/6087087 公认的解决方案建议对概率进行归一化,这可能由于舍入误差而无法解决问题。请参阅 pd shah 对该答案的评论。
1赞 Leopd 9/15/2022
这一切都引出了一个问题,为什么numpy不只在内部做这些事情。我的意思是numpy的一个关键点是使进行复杂的数值计算变得容易,而不必成为IEEE-754舍入bs的专家。

答:

0赞 vovakirdan 2/25/2022 #1

根据 NumPy 文档,我们必须使用 . 所以我认为如果 u-array 是概率数组,那么你可以尝试一下:p1-D array-like

x, u = np.array([x_1, x_2, ..., x_n]), np.array([u_1, ..., u_n])
np.random.choice(x, p = u)

x, u = np.array([x_1, x_2, ..., x_n]), np.array([u_1, ..., u_n])
s = sum(u)
u1 = [i/s for i in u]
np.random.choice(x, p = u1)

评论

0赞 Fırat Kıyak 3/4/2022
这不能回答我的问题。第二个代码与我发布的代码几乎相同。我担心由于除法过程中的有限精度算术而发生的累积误差。这可能导致概率总和不等于(确切地)1。
5赞 Fırat Kıyak 3/9/2022 #2

在阅读了@Pychopath指出的问题的答案 https://stackoverflow.com/a/60386427/6087087 后,我找到了以下解决方案,其灵感来自numpy.random.multinomial https://docs.scipy.org/doc/numpy-1.15.0/reference/generated/numpy.random.multinomial.html 的文档

Say 是概率数组,即使我们用 对其进行归一化,也可能不完全是由于舍入误差造成的。这并不罕见,请参阅@pd Shah在答案 https://stackoverflow.com/a/46539921/6087087 的评论。p1p = p/np.sum(p)

只是做

p[-1] = 1 - np.sum(p[0:-1])
np.random.choice(x, p = p)

问题就解决了!由于减法导致的舍入误差将比归一化导致的舍入误差小得多。此外,人们不必担心 p 的变化,它们属于舍入误差的顺序。

评论

1赞 Leopd 9/15/2022
最好使用,因为有时舍入错误会导致最终数字为负数(如 -1e-16),这也将失败,但p[-1] = max(0, 1 - np.sum(p[0:-1]))np.random.choiceValueError: probabilities are not non-negative
0赞 Reza Roboubi 1/15/2023
那好吧......这是生成该错误的源代码,但我不确定解决问题的最佳方法是什么,或者为什么 numpy 还没有修复它。github.com/numpy/numpy/blob/main/numpy/random/mtrand.pyx
1赞 Reza Roboubi 1/15/2023
好的,我的问题似乎已经解决了:p = np.array(p, dtype=numpy.float64),即类型转换。我使用的是 Jax 数组。我的错。