Python 在部分更改的数组中查找最大值索引的最有效方法

Python most efficient way to find index of maximum in partially changed array

提问人:bproxauf 提问时间:7/19/2022 最后编辑:bproxauf 更新时间:7/22/2022 访问量:624

问:

我有一个包含大约 750000 个元素的复值数组,我反复(比如 10^6 次或更多次)更新 1000 个(或更少)不同的元素。在绝对平方数组中,我需要找到最大值的索引。这是较大代码的一部分,大约需要 ~700 秒才能运行。其中,通常 75%(~550 秒)用于查找最大值的索引。尽管根据 https://stackoverflow.com/a/26820109/5269892“速度极快”,但在 750000 个元素的数组上重复运行它(即使只更改了 1000 个元素)会花费太多时间。ndarray.argmax()

下面是一个最小的、完整的例子,我在其中使用了随机数和索引。您不能假设实值数组在更新后如何变化(即值可能更大、更小或相等),除非必须,更新后前一个最大值 () 索引处的数组可能会更小。'b''b[imax]'

我尝试使用排序数组,其中仅将更新的值(按排序顺序)插入到正确的位置以保持排序,因为这样我们知道最大值始终在索引处,我们不必重新计算它。下面的最小示例包括计时。不幸的是,选择未更新的值并插入更新的值需要太多时间(所有其他步骤加起来只需要 ~210 us,而不是 的 ~580 us)。-1ndarray.argmax()

上下文:这是在高效的 Clark (1980) 变体中实现反卷积算法 CLEAN (Hoegbom, 1974) 的一部分。由于我打算实现序列 CLEAN 算法 (Bose+, 2002),其中需要更多的迭代,或者可能想要使用更大的输入数组,我的问题是:

问题:在更新的数组中查找最大值的索引的最快方法是什么(不在每次迭代中应用于整个数组)?ndarray.argmax()

最小示例代码(在 python 3.7.6、numpy 1.21.2、scipy 1.6.0 上运行):

import numpy as np

# some array shapes ('nnu_use' and 'nm'), number of total values ('nvals'), number of selected values ('nsel'; here
# 'nsel' == 'nvals'; in general 'nsel' <= 'nvals') and number of values to be changed ('nchange')
nnu_use, nm = 10418//2 + 1, 144
nvals = nnu_use * nm
nsel = nvals
nchange = 1000

# fix random seed, generate random 2D 'Fourier transform' ('a', complex-valued), compute power ('b', real-valued), and
# two 2D arrays for indices of axes 0 and 1
np.random.seed(100)
a = np.random.rand(nsel) + 1j * np.random.rand(nsel)
b = a.real ** 2 + a.imag ** 2
inu_2d = np.tile(np.arange(nnu_use)[:,None], (1,nm))
im_2d = np.tile(np.arange(nm)[None,:], (nnu_use,1))

# select 'nsel' random indices and get 1D arrays of the selected 2D indices
isel = np.random.choice(nvals, nsel, replace=False)
inu_sel, im_sel = inu_2d.flatten()[isel], im_2d.flatten()[isel]

def do_update_iter(a, b):
    # find index of maximum, choose 'nchange' indices of which 'nchange - 1' are random and the remaining one is the
    # index of the maximum, generate random complex numbers, update 'a' and compute updated 'b'
    imax = b.argmax()
    ichange = np.concatenate(([imax],np.random.choice(nsel, nchange-1, replace=False)))
    a_change = np.random.rand(nchange) + 1j*np.random.rand(nchange)
    a[ichange] = a_change
    b[ichange] = a_change.real ** 2 + a_change.imag ** 2
    return a, b, ichange

# do an update iteration on 'a' and 'b'
a, b, ichange = do_update_iter(a, b)

# sort 'a', 'b', 'inu_sel' and 'im_sel'
i_sort = b.argsort()
a_sort, b_sort, inu_sort, im_sort = a[i_sort], b[i_sort], inu_sel[i_sort], im_sel[i_sort]

# do an update iteration on 'a_sort' and 'b_sort'
a_sort, b_sort, ichange = do_update_iter(a_sort, b_sort)
b_sort_copy = b_sort.copy()

ind = np.arange(nsel)

def binary_insert(a_sort, b_sort, inu_sort, im_sort, ichange):
    # binary insertion as an idea to save computation time relative to repeated argmax over entire (large) arrays
    # find updated values for 'a_sort', compute updated values for 'b_sort'
    a_change = a_sort[ichange]
    b_change = a_change.real ** 2 + a_change.imag ** 2
    # sort the updated values for 'a_sort' and 'b_sort' as well as the corresponding indices
    i_sort = b_change.argsort()
    a_change_sort = a_change[i_sort]
    b_change_sort = b_change[i_sort]
    inu_change_sort = inu_sort[ichange][i_sort]
    im_change_sort = im_sort[ichange][i_sort]
    # find indices of the non-updated values, cut out those indices from 'a_sort', 'b_sort', 'inu_sort' and 'im_sort'
    ind_complement = np.delete(ind, ichange)
    a_complement = a_sort[ind_complement]
    b_complement = b_sort[ind_complement]
    inu_complement = inu_sort[ind_complement]
    im_complement = im_sort[ind_complement]
    # find indices where sorted updated elements would have to be inserted into the sorted non-updated arrays to keep
    # the merged arrays sorted and insert the elements at those indices
    i_insert = b_complement.searchsorted(b_change_sort)
    a_updated = np.insert(a_complement, i_insert, a_change_sort)
    b_updated = np.insert(b_complement, i_insert, b_change_sort)
    inu_updated = np.insert(inu_complement, i_insert, inu_change_sort)
    im_updated = np.insert(im_complement, i_insert, im_change_sort)

    return a_updated, b_updated, inu_updated, im_updated

# do the binary insertion
a_updated, b_updated, inu_updated, im_updated = binary_insert(a_sort, b_sort, inu_sort, im_sort, ichange)

# do all the steps of the binary insertion, just to have the variable names defined
a_change = a_sort[ichange]
b_change = a_change.real ** 2 + a_change.imag ** 2
i_sort = b_change.argsort()
a_change_sort = a_change[i_sort]
b_change_sort = b_change[i_sort]
inu_change_sort = inu_sort[ichange][i_sort]
im_change_sort = im_sort[ichange][i_sort]
ind_complement = np.delete(ind, i_sort)
a_complement = a_sort[ind_complement]
b_complement = b_sort[ind_complement]
inu_complement = inu_sort[ind_complement]
im_complement = im_sort[ind_complement]
i_insert = b_complement.searchsorted(b_change_sort)
a_updated = np.insert(a_complement, i_insert, a_change_sort)
b_updated = np.insert(b_complement, i_insert, b_change_sort)
inu_updated = np.insert(inu_complement, i_insert, inu_change_sort)
im_updated = np.insert(im_complement, i_insert, im_change_sort)

# timings for argmax and for sorting
%timeit b.argmax()             # 579 µs ± 1.26 µs per loop (mean ± std. dev. of 7 runs, 1000 loops each)
%timeit b_sort.argmax()        # 580 µs ± 810 ns per loop (mean ± std. dev. of 7 runs, 1000 loops each)
%timeit np.sort(b)             # 70.2 ms ± 120 µs per loop (mean ± std. dev. of 7 runs, 10 loops each)
%timeit np.sort(b_sort)        # 25.2 ms ± 109 µs per loop (mean ± std. dev. of 7 runs, 10 loops each)
%timeit b_sort_copy.sort()     # 14 ms ± 78.2 µs per loop (mean ± std. dev. of 7 runs, 100 loops each)

# timings for binary insertion
%timeit binary_insert(a_sort, b_sort, inu_sort, im_sort, ichange)          # 33.7 ms ± 208 µs per loop (mean ± std. dev. of 7 runs, 10 loops each)
%timeit a_change = a_sort[ichange]                                         # 4.28 µs ± 40.5 ns per loop (mean ± std. dev. of 7 runs, 100000 loops each)
%timeit b_change = a_change.real ** 2 + a_change.imag ** 2                 # 8.25 µs ± 127 ns per loop (mean ± std. dev. of 7 runs, 100000 loops each)
%timeit i_sort = b_change.argsort()                                        # 35.6 µs ± 529 ns per loop (mean ± std. dev. of 7 runs, 10000 loops each)
%timeit a_change_sort = a_change[i_sort]                                   # 4.2 µs ± 62.3 ns per loop (mean ± std. dev. of 7 runs, 100000 loops each)
%timeit b_change_sort = b_change[i_sort]                                   # 2.05 µs ± 47 ns per loop (mean ± std. dev. of 7 runs, 100000 loops each)
%timeit inu_change_sort = inu_sort[ichange][i_sort]                        # 4.47 µs ± 38 ns per loop (mean ± std. dev. of 7 runs, 100000 loops each)
%timeit im_change_sort = im_sort[ichange][i_sort]                          # 4.51 µs ± 48.5 ns per loop (mean ± std. dev. of 7 runs, 100000 loops each)
%timeit ind_complement = np.delete(ind, ichange)                           # 1.38 ms ± 25.8 µs per loop (mean ± std. dev. of 7 runs, 1000 loops each)
%timeit a_complement = a_sort[ind_complement]                              # 3.52 ms ± 31.6 µs per loop (mean ± std. dev. of 7 runs, 100 loops each)
%timeit b_complement = b_sort[ind_complement]                              # 1.44 ms ± 256 µs per loop (mean ± std. dev. of 7 runs, 1000 loops each)
%timeit inu_complement = inu_sort[ind_complement]                          # 1.36 ms ± 6.61 µs per loop (mean ± std. dev. of 7 runs, 1000 loops each)
%timeit im_complement = im_sort[ind_complement]                            # 1.31 ms ± 17.2 µs per loop (mean ± std. dev. of 7 runs, 1000 loops each)
%timeit i_insert = b_complement.searchsorted(b_change_sort)                # 148 µs ± 464 ns per loop (mean ± std. dev. of 7 runs, 10000 loops each)
%timeit a_updated = np.insert(a_complement, i_insert, a_change_sort)       # 3.08 ms ± 111 µs per loop (mean ± std. dev. of 7 runs, 100 loops each)
%timeit b_updated = np.insert(b_complement, i_insert, b_change_sort)       # 1.37 ms ± 16.6 µs per loop (mean ± std. dev. of 7 runs, 1000 loops each)
%timeit inu_updated = np.insert(inu_complement, i_insert, inu_change_sort) # 1.41 ms ± 28 µs per loop (mean ± std. dev. of 7 runs, 1000 loops each)
%timeit im_updated = np.insert(im_complement, i_insert, im_change_sort)    # 1.52 ms ± 173 µs per loop (mean ± std. dev. of 7 runs, 1000 loops each)

更新:正如下面 @Jérôme Richard 所建议的,在部分更新的数组中重复查找最大值索引的一种快速方法是将数组拆分为多个块,预先计算块的最大值,然后在每次迭代中仅重新计算(或更少)更新的块的最大值,然后计算块最大值上的 argmax, 返回块索引,并在该块索引的块内查找 argmax。nchange

我从理查德@Jérôme的回答中复制了代码。在实践中,当在我的系统上运行时,他的解决方案会导致大约 7.3 的速度提升,需要 46.6 + 33 = 79.6 musec 而不是 580 musecb.argmax()

import numba as nb

@nb.njit('(f8[::1],)', parallel=True)
def precompute_max_per_chunk(b):
    # Required for this simplified version to work and be simple
    assert b.size % 32 == 0
    max_per_chunk = np.empty(b.size // 32)

    for chunk_idx in nb.prange(b.size//32):
        offset = chunk_idx * 32
        maxi = b[offset]
        for j in range(1, 32):
            maxi = max(b[offset + j], maxi)
        max_per_chunk[chunk_idx] = maxi

    return max_per_chunk
# OMP: Info #276: omp_set_nested routine deprecated, please use omp_set_max_active_levels instead.

@nb.njit('(f8[::1], f8[::1])')
def argmax_from_chunks(b, max_per_chunk):
    # Required for this simplified version to work and be simple
    assert b.size % 32 == 0
    assert max_per_chunk.size == b.size // 32

    chunk_idx = np.argmax(max_per_chunk)
    offset = chunk_idx * 32
    return offset + np.argmax(b[offset:offset+32])

@nb.njit('(f8[::1], f8[::1], i8[::1])')
def update_max_per_chunk(b, max_per_chunk, ichange):
    # Required for this simplified version to work and be simple
    assert b.size % 32 == 0
    assert max_per_chunk.size == b.size // 32

    for idx in ichange:
        chunk_idx = idx // 32
        offset = chunk_idx * 32
        maxi = b[offset]
        for j in range(1, 32):
            maxi = max(b[offset + j], maxi)
        max_per_chunk[chunk_idx] = maxi

b = np.random.rand(nsel)
max_per_chunk = precompute_max_per_chunk(b)
a, b, ichange = do_update_iter(a, b)
argmax_from_chunks(b, max_per_chunk)
update_max_per_chunk(b, max_per_chunk, ichange)

%timeit max_per_chunk = precompute_max_per_chunk(b)     # 77.3 µs ± 17.2 µs per loop (mean ± std. dev. of 7 runs, 10000 loops each)
%timeit argmax_from_chunks(b, max_per_chunk)            # 46.6 µs ± 11.3 ns per loop (mean ± std. dev. of 7 runs, 10000 loops each)
%timeit update_max_per_chunk(b, max_per_chunk, ichange) # 33 µs ± 40.6 ns per loop (mean ± std. dev. of 7 runs, 10000 loops each)

更新2:我现在修改了 @Jérôme Richard 的解决方案,以处理大小不等于块大小的整数倍的数组。此外,如果更新的值小于上一个块的最大值,则代码仅访问所有块值,否则直接将更新的值设置为新的块最大值。与更新的值大于上一个最大值时节省的时间相比,if 查询需要的时间应该很短。在我的代码中,随着迭代次数的增加,这种情况的可能性会越来越大(更新的值越来越接近噪声,即随机)。在实践中,对于随机数,执行时间会进一步减少,从 ~33 us 减少到 ~27 us。代码和新计时是:bupdate_max_per_chunk()

import math

@nb.njit('(f8[::1],)', parallel=True)
def precompute_max_per_chunk_bp(b):
    nchunks = math.ceil(b.size/32)
    imod = b.size % 32
    max_per_chunk = np.empty(nchunks)
    
    for chunk_idx in nb.prange(nchunks):
        offset = chunk_idx * 32
        maxi = b[offset]
        if (chunk_idx != (nchunks - 1)) or (not imod):
            iend = 32
        else:
            iend = imod
        for j in range(1, iend):
            maxi = max(b[offset + j], maxi)
        max_per_chunk[chunk_idx] = maxi

    return max_per_chunk

@nb.njit('(f8[::1], f8[::1])')
def argmax_from_chunks_bp(b, max_per_chunk):
    nchunks = max_per_chunk.size
    imod = b.size % 32
    chunk_idx = np.argmax(max_per_chunk)
    offset = chunk_idx * 32
    if (chunk_idx != (nchunks - 1)) or (not imod):
        return offset + np.argmax(b[offset:offset+32])
    else:
        return offset + np.argmax(b[offset:offset+imod])

@nb.njit('(f8[::1], f8[::1], i8[::1])')
def update_max_per_chunk_bp(b, max_per_chunk, ichange):
    nchunks = max_per_chunk.size
    imod = b.size % 32
    for idx in ichange:
        chunk_idx = idx // 32
        if b[idx] < max_per_chunk[chunk_idx]:
            offset = chunk_idx * 32
            if (chunk_idx != (nchunks - 1)) or (not imod):
                iend = 32
            else:
                iend = imod
            maxi = b[offset]
            for j in range(1, iend):
                maxi = max(b[offset + j], maxi)
            max_per_chunk[chunk_idx] = maxi
        else:
            max_per_chunk[chunk_idx] = b[idx]

%timeit max_per_chunk = precompute_max_per_chunk_bp(b)     # 74.6 µs ± 29.8 µs per loop (mean ± std. dev. of 7 runs, 10000 loops each)
%timeit argmax_from_chunks_bp(b, max_per_chunk)            # 46.6 µs ± 9.92 ns per loop (mean ± std. dev. of 7 runs, 10000 loops each)
%timeit update_max_per_chunk_bp(b, max_per_chunk, ichange) # 26.5 µs ± 19.1 ns per loop (mean ± std. dev. of 7 runs, 10000 loops each)
python 数组 numpy 性能 argmax

评论

0赞 Bill 7/21/2022
有人提到稀疏矩阵吗?您可以将更新保存在单独的稀疏矩阵中,并且以这种方式只找到更新元素的最大值。(我假设 max(sparse_array) 只检查非零元素)。
0赞 bproxauf 7/21/2022
仅查找已更新元素的最大值是行不通的,因为整体最大值可能属于未更新的元素。而且我想在每次迭代中从 numpy 数组到稀疏矩阵的转换,反之亦然会花费太长时间(但我不确定)......
0赞 Bill 7/21/2022
事实上。您需要保留先前迭代的最大值记录,并将最新稀疏数组的最大值与如下所示的值进行比较:.我不知道创建稀疏矩阵是否开销很大。也许不是。current_max = max(previous_max, sparse_array.max())
0赞 bproxauf 7/21/2022
到目前为止还没有使用稀疏矩阵,但我可以使用它。创建单个稀疏矩阵可能不会花费太多时间,但是随着迭代的增加,由于更新的索引不断变化,矩阵的稀疏度越来越低,当许多元素不为零时,它应该会变得很慢。当然,除非我们在找到最大值后将矩阵重置为零。
1赞 Bill 7/21/2022
对不起,我错过了每次迭代都会替换值的事实,因此会变得无效。您必须保留“前 N”个值的记录,然后在每次前 N 个值耗尽时定期对整个数组进行重新排序。previous_max

答:

2赞 Jérôme Richard 7/20/2022 #1

ndarray.argmax()根据 https://stackoverflow.com/a/26820109/5269892 的说法,速度“极快

Argmax 不是最佳的,因为它无法成功地使我的机器上的 RAM 带宽饱和(这是可能的),但它非常好,因为它在您的案例中饱和了总 RAM 吞吐量的 ~40%,在我的机器上顺序饱和了大约 65%-70%(一个内核无法使大多数机器上的 RAM 饱和)。大多数机器的吞吐量较低,因此在这些机器上应该更接近最佳。np.argmax

使用多个线程找到最大值有助于达到最佳值,但就功能的当前性能而言,不应期望在大多数 PC 上加速大于 2(在计算服务器上更多)。

在更新的数组中查找最大值索引的最快方法是什么

无论进行何种计算,读取内存中的整个数组至少需要几秒钟。使用非常好的 2 通道 DDR4 RAM,最佳时间约为 ~125 us,而最好的 1 通道 DDR4 RAM 达到 ~225 us。如果数组是就地写入的,则最佳时间是原来的两倍,如果创建了一个新数组(异地计算),则在 x86-64 平台上的最佳时间要大 3 倍。事实上,对于后者来说,情况更糟,因为操作系统虚拟内存的开销很大。b.size * 8 / RAM_throughput

这意味着,在主流 PC 上,任何读取整个数组的异地计算都无法击败 np.argmax。这也解释了为什么排序解决方案如此缓慢:它会创建许多临时数组。即使是一个完美的排序阵列策略也不会比这里快多少(因为在最坏的情况下,所有项目都需要在RAM中移动,平均情况下远远超过一半)。事实上,任何就地方法写入整个数组的好处都很低(仍然在主流 PC 上):它只会比 稍微快一点。获得显著加速的唯一解决方案是不对整个阵列进行操作。np.argmaxnp.argmax

解决此问题的一种有效解决方案是使用平衡的二叉搜索树。实际上,您可以及时从包含节点的树中删除节点。然后,您可以同时插入更新的值。这比您情况下的解决方案要好得多,因为 n ~= 750_000 和 k ~= 1_000。不过,请注意,复杂性背后有一个隐藏因素,二叉搜索树在实践中可能不会那么快,尤其是在它们不是很优化的情况下。另请注意,更新树值比删除节点并插入新节点更好。在这种情况下,纯 Python 实现几乎不够快(并且占用大量内存)。只有 **Cython 或天然溶液可以快速(例如。C / C++,或任何本地实现的Python模块,但我找不到任何快速的模块)。knO(k log n)O(n)

另一种选择是基于静态 n 元树的偏最大值数据结构。它包括将数组拆分为块,并首先预先计算每个块的最大值。更新值时(假设项目数是恒定的),您需要 (1) 重新计算每个块的最大值。要计算全局最大值,您需要 (2) 计算每个块最大值的最大值。该解决方案还需要(半)本机实现,因此要快速,因为 Numpy 在更新每个块的最大值期间引入了大量开销(因为它对这种情况不是很优化),但人们肯定会看到速度加快。例如,Numba 和 Cython 可用于执行此操作。块的大小需要仔细选择。在你的情况下,16到32之间的东西应该会给你带来巨大的加速。

对于大小为 32 的块,最多只需要读取 32*k=32_000 个值来重新计算总最大值(最多写入 1000 个值)。这远远小于 750_000。部分最大值的更新需要计算 n/32 ~= 23_400 值的最大值,该值仍然相对较小。我预计通过优化的实现,这将快 5 倍,但在实践中甚至可能快 >10 倍,尤其是在使用并行实现时。这当然是最好的解决方案(没有额外的假设)。


在努巴的实施

下面是一个(几乎没有测试过的)Numba 实现:

import numba as nb

@nb.njit('(f8[::1],)', parallel=True)
def precompute_max_per_chunk(arr):
    # Required for this simplied version to work and be simple
    assert b.size % 32 == 0
    max_per_chunk = np.empty(b.size // 32)

    for chunk_idx in nb.prange(b.size//32):
        offset = chunk_idx * 32
        maxi = b[offset]
        for j in range(1, 32):
            maxi = max(b[offset + j], maxi)
        max_per_chunk[chunk_idx] = maxi

    return max_per_chunk

@nb.njit('(f8[::1], f8[::1])')
def argmax_from_chunks(arr, max_per_chunk):
    # Required for this simplied version to work and be simple
    assert b.size % 32 == 0
    assert max_per_chunk.size == b.size // 32

    chunk_idx = np.argmax(max_per_chunk)
    offset = chunk_idx * 32
    return offset + np.argmax(b[offset:offset+32])

@nb.njit('(f8[::1], f8[::1], i8[::1])')
def update_max_per_chunk(arr, max_per_chunk, ichange):
    # Required for this simplied version to work and be simple
    assert b.size % 32 == 0
    assert max_per_chunk.size == b.size // 32

    for idx in ichange:
        chunk_idx = idx // 32
        offset = chunk_idx * 32
        maxi = b[offset]
        for j in range(1, 32):
            maxi = max(b[offset + j], maxi)
        max_per_chunk[chunk_idx] = maxi

以下是如何在我的(6 核)机器上使用它和计时的示例:

# Precomputation (306 µs)
max_per_chunk = precompute_max_per_chunk(b)

# Computation of the global max from the chunks (22.3 µs)
argmax_from_chunks(b, max_per_chunk)

# Update of the chunks (25.2 µs)
update_max_per_chunk(b, max_per_chunk, ichange)

# Initial best implementation: 357 µs
np.argmax(b)

如您所见,它非常快。更新应该需要 22.3+25.2 = 47.5 μs,而 Numpy 朴素实现需要 357 μs。因此,Numba 的实现速度提高了 7.5 倍!我认为它可以进一步优化,但这并不简单。请注意,更新是顺序的,预计算是并行的。有趣的事实:预计算后调用比使用多个线程更快!argmax_from_chunksnp.argmax


进一步改进

由于 SIMD 指令,可以改进。事实上,当前的实现在 x86-64 机器上生成标量/指令,这是次优的。该操作可以通过使用基于瓦片的 argmin 来矢量化,该 argmin 使用 x4 展开循环(在最近的 512 位宽 SIMD 机器上甚至可能是 x8)计算最大值。在支持 AVX 指令集的处理器上,实验表明 Numba 可以生成以 6-7 us 运行的代码(大约快 4 倍)。话虽如此,这很难实现,而且生成的功能有点丑陋。argmax_from_chunksmaxsdvmaxsd

同样的方法也可用于加速,不幸的是,默认情况下也没有矢量化。我还希望在最近的 x4-86 机器上将速度提高 ~64 倍。然而,Numba 在许多情况下生成了一种非常低效的矢量化方法(它试图矢量化外部循环而不是内部循环)。结果,我对 Numba 的最佳尝试达到了 16.5 我们。update_max_per_chunk

从理论上讲,在主流的 x86-64 机器上,整个更新的速度可以提高 4 倍左右,尽管在实践中,代码至少可以快 2 倍!

评论

1赞 Jérôme Richard 7/20/2022
事实上,在这种情况下,这甚至更好:) .我建议您使用最后一个解决方案,在 Python 环境中的实践中,它更简单,当然也更快。如果实现没有经过仔细优化,并且大多数实现没有(尤其是 Python 模块),那么二叉树转换确实会很昂贵。我尝试了一个基本的 Numpy 解决方案,但 Numpy 的当前实现显然没有针对这种情况进行优化,结果证明效率非常低。如果您不熟悉 Cython,Numba 无疑是最佳选择,因为您可以用 Python 编写所有内容。
1赞 Jérôme Richard 7/20/2022
我编写了一个 Numba 实现,并得到了很大的加速。调用 + 给了我正确的值,所以它应该没问题,但我没有彻底测试代码,所以你需要检查。请注意,您当然需要调整实现,以支持不是 32 倍数的数组大小。precompute_max_per_chunkargmax_from_chunks
1赞 Jérôme Richard 7/21/2022
np.max也在这里工作。请注意,我使用了循环,因为即使在 Numba 中,Numpy 函数通常也会引入少量开销。虽然它在 Numba 中通常是可以接受的,但这是一个非常关键的循环,所以我更喜欢在运行它时不要感到惊讶。请注意,即使在 Numba 中,列表理解也往往很慢,因为它们会增加一些开销(也因为创建可变大小的列表本质上比固定大小的预分配数组慢)。
1赞 Jérôme Richard 7/21/2022
arr在您的示例中(IDK 表示真实世界的代码)。这可以在最后一个示例代码中看到。请注意,这是一个函数参数,而参数是参数,它们不应混淆(但请随意更改名称)。OpenMP 警告很奇怪。这似乎表明 Numba 函数正在并行上下文中运行。如果是这样,我建议您测试以删除 and,因为并行嵌套循环通常效率低下。它也可能只是一个内部无用的警告......barrbparallel=Trueprange
1赞 bproxauf 7/21/2022
我想我的意思是函数中没有使用。这不是一个大问题,因为它将使用 全局变量 ,但无论如何。你的其他解释也很有帮助。我将尝试删除明天重新上班,看看是否进一步提高了速度,以防它使解决方案比必要的速度慢。谢谢!arrbparallelprange