提问人:Klaus3 提问时间:7/17/2023 最后编辑:Klaus3 更新时间:7/18/2023 访问量:102
为什么 numba 切片比 numpy 切片快得多?
Why is numba slicing so much faster than numpy slicing?
问:
def test(x):
k = x[1:2]
l = x[0:3]
m = x[0:1]
@njit
def test2(x):
k = x[1:2]
l = x[0:3]
m = x[0:1]
x = np.arange(5)
test2(x)
%timeit test(x)
%timeit test2(x)
776 ns ± 1.83 ns per loop (mean ± std. dev. of 7 runs, 1000000 loops each)
280 ns ± 2.53 ns per loop (mean ± std. dev. of 7 runs, 1000000 loops each)
随着切片的增加,它们之间的差距会扩大
def test(x):
k = x[1:2]
l = x[0:3]
m = x[0:1]
n = x[1:3]
o = x[2:3]
@njit
def test2(x):
k = x[1:2]
l = x[0:3]
m = x[0:1]
n = x[1:3]
o = x[2:3]
test2(x)
%timeit test(x)
%timeit test2(x)
1.18 µs ± 1.82 ns per loop (mean ± std. dev. of 7 runs, 1000000 loops each)
279 ns ± 0.562 ns per loop (mean ± std. dev. of 7 runs, 1000000 loops each)
numpy 函数似乎变得线性变慢,并且 numba 函数并不关心你切片了多少次(这是我期望在这两种情况下都会发生的情况)
编辑:
在 chrslg 回答之后,我决定放置两个函数的返回语句。只需输入两者
return k,l,m,n,o
时间是:
1.23 µs ± 2 ns per loop (mean ± std. dev. of 7 runs, 1000000 loops each)
1.61 µs ± 9.6 ns per loop (mean ± std. dev. of 7 runs, 1000000 loops each)
所以 numba 函数现在变得更慢了,这似乎使它确实只是一个死代码。但是,在看到用户 jared 评论后,我决定用他尝试过的切片测试相同的产品操作:
def test5(x):
k = x[1:2]
l = x[0:3]
m = x[0:1]
n = x[1:4]
o = x[2:3]
return (k*l*m*n*o)
@njit
def test6(x):
k = x[1:2]
l = x[0:3]
m = x[0:1]
n = x[1:4]
o = x[2:3]
return (k*l*m*n*o)
test6(x)
%timeit test5(x)
%timeit test6(x)
5.79 µs ± 202 ns per loop (mean ± std. dev. of 7 runs, 100000 loops each)
787 ns ± 1.52 ns per loop (mean ± std. dev. of 7 runs, 1000000 loops each)
现在,numba 函数变得比简单的返回函数 (!) 更快,速度差距再次扩大。老实说,我现在更困惑了。
答:
因为它可能什么都不做。
对于此功能,我得到了相同的时间,对于 numba
def test3(x):
pass
请注意,这几乎什么也做不了。这些只是切片,没有任何与之相关的操作。因此,没有数据传输或任何东西。
只是创建了 3 个变量,并进行了一些边界调整。test
如果代码大约是 5000000 个元素的数组,以及其中 1000000 个元素的切片,它不会更慢。因此,我想,当你想在“更大”的情况下缩放一些东西时,你决定不增加数据大小(因为你可能知道数据大小在这里无关紧要),而是增加行数。
但是,好吧,即使几乎什么都不做,仍然在做这 3 个未使用的切片。test
其中 numba 编译一些生成的 C 代码。而编译器,使用优化器,没有理由保留那些以后从未使用过的切片变量。
我在这里完全推测(我从未见过 numba 生成的代码)。但我想代码可以看成线条
void test2(double *x, int xstride, int xn){
double *k = x+1*xstride;
int kstride=xstride;
int kn=1;
double *l=x;
int lstride=xstride;
int ln=3;
double *m=x;
int mstride=xstride;
int mn=1;
// And then it would be possible, from there to iterates those slices
// for example k[:]=m[:] could generate this code
// for(int i=0; i<kn; i++) k[i*stride] = m[i*stride]
}
(这里我用算术中的大小,而实际上步幅是以字节为单位,但没关系,那只是伪代码)stride
double *
我的观点是,如果之后有什么东西(就像我在评论中所说的那样),那么,这段代码,即使只是一些算术运算,仍然会“几乎没有,但不是没有”。
但之后什么都没有。所以,它只是一些局部变量初始化,代码显然没有副作用。编译器优化器很容易删除所有代码。并编译一个空函数,其效果和结果完全相同。
所以,再说一遍,只是代表我的猜测。但是任何像样的代码生成器+编译器都应该只为 编译一个空函数。所以和是一样的东西。test2
test2
test3
虽然解释器没有,但通常这种优化(首先,很难提前知道即将发生的事情,其次,优化所花费的时间是在运行时,因此需要权衡,对于编译器来说,即使需要 1 小时的编译时间来节省 1ns 的运行时间,它仍然值得)
编辑:更多实验
Jared 和我都有的想法,那就是做一些事情,不管它是什么,迫使切片存在,并比较当它必须做某事时会发生什么,因此真正做切片,是自然的。问题是,一旦你开始做某事,任何事情,切片本身的时间就会变得微不足道。因为切片不算什么。
但是,从统计学上讲,您可以删除它,并且仍然以某种方式测量“切片”部分。
以下是一些时间数据
空函数
在我的计算机上,一个空函数在纯 python 中的成本为 130 ns。和 540 ns 的 numba。
这并不奇怪。什么都不做,但在跨越“python/C 边界”的同时这样做可能会花费一点,只是为了那个“python/C”。 不多
时间与切片数
下一个实验是你所做的确切实验(因为,顺便说一句,你的帖子包含我自己对我的答案的证明:你已经看到在纯 python 时间是 O(n),n 是切片数,而在 numba 中它是 O(1)。仅此一项就证明根本没有发生切片。如果切片完成,在numba中,就像在任何其他非量子计算机:D中一样,成本必须为O(n)。当然,如果它 t=100+0.000001*n,可能很难区分 O(n) 和 O(1)。因此,我从评估“空”案例开始的原因
在纯 python 中,仅切片,明显 O(n) 中的切片数量不断增加,确实:
线性回归表明,这大约是 138+n×274,单位为 ns。
这与“空”时间一致
另一方面,对于 numba,我们得到
因此,无需线性回归来证明这一点
- 它确实是 O(1)
- 时序与“空”情况下的 540 ns 一致
请注意,这意味着,对于 n=2 个或更多切片,在我的计算机上,numba 变得具有竞争力。以前,它不是。 但是,好吧,“无所事事”的竞争......
使用切片
当然,如果我们之后添加代码来强制使用切片,事情就会发生变化。编译器不能只删除切片。
但我们必须小心
- 避免在操作本身中添加 O(n)
- 为了区分操作的时间和切片的时间,切片的时间可以忽略不计
我当时所做的是计算一些加法slice1[1]+slice2[2]+slice3[3]+...
但是无论切片的数量是多少,我在这个加法中都有 1000 个术语。因此,对于 n=2(2 个切片),该加法是 1000 项。slice1[1]+slice2[2]+slice1[3]+slice2[4]+...
由于添加,这应该有助于删除 O(n) 部分。然后,有了足够大的数据,我们可以从围绕它的变化中提取一些值,即使在加法时间本身之前(因此甚至在加法时间的噪音之前),这些变化也可以忽略不计。但是通过足够的测量,噪音会变得足够低,可以开始看到东西)
在纯 python 中
线性回归给出 199000 + 279×n ns
我们从中学到的是,我的实验设置是可以的。279 与之前的 274 非常接近,可以说,事实上,加法部分,尽管它很大 (200000 ns) 是 O(1),因为与仅切片相比,O(n) 部分保持不变。因此,我们只是具有与之前相同的时间+加法部分的巨大常数。
用麻木
所有这些都只是证明实验设置的合理性的序言。现在是有趣的部分,实验本身
线性回归告诉 1147 + 1.3×n
所以,在这里,它确实是 O(n)。
结论
在 numba 中切片确实需要一些成本。它是 O(n)。 但是如果不使用它,编译器只需删除它,我们就会得到一个 O(1) 操作。
证明原因确实是,在您的版本中,numba 代码什么都不做
2. 无论它是什么操作,你对切片所做的强制使用它,并防止编译器删除它,成本要大得多,如果没有统计预防措施,它会掩盖 O(n) 部分。因此,感觉“当我们使用变量时是一样的”。
3. 无论如何,大多数时候 numba 比 numpy 快。
我的意思是,numpy 是获得“编译语言速度”的好方法,而无需使用编译语言。但它并没有击败真正的编译。 因此,在 numba 中使用朴素的算法击败 numpy 中非常智能的矢量化是相当经典的。(古典的,对于像我这样的人来说非常令人失望,他以成为知道谁在 numpy 中矢量化事物的人为生。有时,我觉得使用 numba,最幼稚的嵌套 for 循环更好)。
它不再是这样,当
- Numpy 使用多个内核(您也可以使用 numba 来做到这一点。但不仅仅是使用朴素的算法)
- 您正在执行存在非常智能算法的操作。Numpy算法经过数十年的优化。不能用 3 个嵌套循环来击败它。除了有些任务太简单了,以至于无法真正优化。
所以,我仍然更喜欢 numpy 而不是 numba。更喜欢在 numpy 背后使用数十年的优化,在 numba 中重新发明轮子。另外,有时最好不要依赖编译器。
但是,好吧,让 numba 击败 numpy 是经典的。
只是,不是你案件的比例。因为在你的情况下,你正在比较(我想我现在已经证明了这一点,正如你自己所证明的那样,当 numba 情况是 O(1) 时,看到 numpy 情况是 O(n) ),“用 numpy 做切片与用 numba 什么都不做”
评论
n = x[1:4]
评论
print(test6.inspect_llvm(test6.signatures[0]))