numba 中的 numpy nanmean

numpy nanmean in numba

提问人:user9875321__ 提问时间:11/9/2023 更新时间:11/10/2023 访问量:47

问:

我正在尝试为 numba 编写一个更简单的 numpy.nanmean 版本。 这是我的代码:

from numba import jit, prange
import numpy as np

@jit(nopython=True)
def nanmeanMY(a, axis=None):
    if a.ndim>1:
        ncols = a.shape[1]
        nrows = a.shape[0]
        a = a.T.flatten()
        res = np.zeros(ncols)
        for i in prange(ncols):
            col_no_nan = a[i*nrows:(i+1)*nrows]
            res[i] = np.mean(col_no_nan[~np.isnan(col_no_nan)])
        return res
    else:
        return np.mean(a[~np.isnan(a)])

该代码应该检查您是在处理向量还是矩阵,并给出 if 矩阵的列均值。 使用测试矩阵

X = np.array([[1,2], [3,4]])
nanmeanMY(X)

我收到以下错误:

Traceback (most recent call last):

  Cell In[157], line 1
    nanmeanMY(a)

  File ~\anaconda3\Lib\site-packages\numba\core\dispatcher.py:468 in _compile_for_args
    error_rewrite(e, 'typing')

  File ~\anaconda3\Lib\site-packages\numba\core\dispatcher.py:409 in error_rewrite
    raise e.with_traceback(None)

TypingError: No implementation of function Function(<built-in function getitem>) found for signature:
 
getitem(array(int32, 2d, C), array(bool, 2d, C))
 
There are 22 candidate implementations:
      - Of which 20 did not match due to:
      Overload of function 'getitem': File: <numerous>: Line N/A.
        With argument(s): '(array(int32, 2d, C), array(bool, 2d, C))':
       No match.
      - Of which 2 did not match due to:
      Overload in function 'GetItemBuffer.generic': File: numba\core\typing\arraydecl.py: Line 209.
        With argument(s): '(array(int32, 2d, C), array(bool, 2d, C))':
       Rejected as the implementation raised a specific error:
         NumbaTypeError: Multi-dimensional indices are not supported.
  raised from C:\Users\*****\anaconda3\Lib\site-packages\numba\core\typing\arraydecl.py:89

During: typing of intrinsic-call at C:\Users\*****\AppData\Local\Temp\ipykernel_10432\1652358289.py (22)

这里有什么问题?

numpy 表示 缺少数据 numba

评论


答:

1赞 ken 11/10/2023 #1

显然,因为您重用了变量,numba 无法正确推断变量的类型。aa

不要重用变量,而是创建一个新变量。

@jit(nopython=True)
def nanmeanMY(a):
    if a.ndim > 1:
        ncols = a.shape[1]
        nrows = a.shape[0]
        a_flatten = a.T.flatten()  # Renamed a to a_flatten.
        res = np.zeros(ncols)
        for i in prange(ncols):
            col_no_nan = a_flatten[i * nrows : (i + 1) * nrows]  # Use a_flatten.
            res[i] = np.mean(col_no_nan[~np.isnan(col_no_nan)])
        return res
    else:
        return np.mean(a[~np.isnan(a)])

评论

0赞 user9875321__ 11/14/2023
效果很好!