Python 将数组与零比 np.any(array) 快

Python comparing array to zero faster than np.any(array)

提问人:bproxauf 提问时间:7/15/2022 更新时间:7/16/2022 访问量:302

问:

我想测试数组的所有元素是否都为零。根据 StackOverflow 帖子 测试 numpy 数组是否只包含零https://stackoverflow.com/a/72976775/5269892,与 相比,应该是内存效率最高且速度最快的方法。(array == 0).all()not array.any()

我用随机数浮动数组测试了性能,见下文。不知何故,至少对于给定的数组大小,甚至将数组转换为布尔类型似乎也比 慢。怎么会这样?not array.any()(array == 0).all()

np.random.seed(100)
a = np.random.rand(10418*144)

%timeit (a == 0)
%timeit (a == 0).all()
%timeit a.astype(bool)
%timeit a.any()
%timeit not a.any()

# 711 µs ± 192 ns per loop (mean ± std. dev. of 7 runs, 1000 loops each)
# 740 µs ± 1.38 µs per loop (mean ± std. dev. of 7 runs, 1000 loops each)
# 1.69 ms ± 587 ns per loop (mean ± std. dev. of 7 runs, 1000 loops each)
# 1.71 ms ± 1.31 µs per loop (mean ± std. dev. of 7 runs, 1000 loops each)
# 1.71 ms ± 2.05 µs per loop (mean ± std. dev. of 7 runs, 1000 loops each)
python 数组 numpy boolean

评论

1赞 7/15/2022
我和你得到的不一样。(Python 3.9.13、1.23.0) 617 μs ± 每个循环 270 ns(平均 7 次运行±标准开发,每次 1,000 个循环) 每个循环 624 μs ± 1.16 μs(平均±标准开发,7 次运行,每次 1,000 个循环) 每个循环 254 μs ± 702 ns(平均 7 次运行±标准开发,每次 1,000 个循环) 262 μs ± 655 ns(平均±标准开发 7 次运行, 每个循环 1,000 个环路) 每个环路 262 μs ± 714 ns(平均±标准开发 7 次,每次 1,000 个循环)
1赞 Jérôme Richard 7/15/2022
请注意,在 1.23 版本中,我们改进了基本的缩减函数,包括 np.all 和 np.any(参见 github.com/numpy/numpy/pull/21001),尽管对 np.any 和 np.all 的影响应该很小。更新 Numpy 可能会有所帮助。
1赞 Jérôme Richard 7/15/2022
@Murali 这是令人惊讶的。你在 Windows 上运行吗?AFAIK Windows 版本的行为通常不同(令人惊讶)。您的处理器架构是什么?
2赞 7/15/2022
@JérômeRichard 我正在使用 Mac Os(v 12.4)、arm 架构(M1 芯片)。
1赞 Salvatore Daniele Bianco 7/15/2022
一个题外话提示:如果你知道你的数组是正数(可能不是你的情况)会更快。a.sum()==0

答:

2赞 Jérôme Richard 7/16/2022 #1

该问题是由于前两个操作使用 SIMD 指令进行矢量化,而后三个操作则不然。更具体地说,最后三个调用对尚未矢量化的 bool () 进行隐式转换。这是一个已知问题,我已经为此提出了一个拉取请求(由于未定义的行为,它揭示了一些意外问题,现在已经修复)。如果一切正常,它应该在 Numpy 的下一个主要版本中可用。_aligned_contig_cast_double_to_bool

请注意,并隐式地对布尔数组执行强制转换,以便更快地执行操作。这不是很有效,但这是这样做的,这样可以减少生成的函数变体的数量(Numpy 是用 C 语言编写的,因此必须为每种类型生成不同的实现,并且很难优化许多变体,因此我们更喜欢在这里执行隐式转换,更不用说这也减小了生成的二进制文件的大小)。如果这还不够,则不能使用 Cython 来生成更快的特定优化代码。a.any()not a.any()any