提问人:FedeMatte 提问时间:9/25/2023 更新时间:9/25/2023 访问量:33
列表而不是元组作为返回:Numba jit-ted 函数中的失败类型推断
List instead of a Tuple as a return: failed type inference in Numba jit-ted function
问:
我正在尝试使用 numba.jit 编译一个函数,该函数接受多个输入并返回一个元组。我在 jit 装饰中指定了输入和输出的类型,认为这会有所帮助。但是,我收到一条警告,说包含元组的列表无法转换为元组,并且 numba.jit 会回退到对象模式。
我正在构建一个脚本来模拟物理 2D 现象。下面是我尝试使用 numba.jit 编译的函数的代码(有点简化)。输入中未列出一些参数。这些是之前定义的常量参数,类型为 float,大小为 1。 这里的关键是我声明我将收到一个类型 Tuple(array(float64, 3d, C), array(float64, 2d, A))。
@nb.jit(nb.types.containers.Tuple((nb.types.float64[:,:,::1],nb.types.float64[:,:]))(nb.types.float64,nb.types.containers.Tuple((nb.types.float64[:,:,:],nb.types.float64[:,:])),nb.types.float64[:,:,:],nb.types.float64[:,:,:],nb.types.float64[:,:],nb.types.float64[:,:,:],nb.types.float64[:,:]))
def model_f(t,dvar,dNdy,N_y,N_x,r,coeff2):
X = dvar[0]
X_ch = dvar[1]
# trivial operations on matrix element-wise:
i = 0
for row in X[0,:,0]:
j = 0
for col in X[0,0,:]:
N_y[:,i,j] = DGM_N(X,i,j)
[r[:,i,j]] = RRates(X,T,i,j)
j += 1
i += 1
# gradient of N_y:
comp = 0
for c_plane in X[:,0,0]:
i = 0
for row in X[0,:,0]:
if i == 0:
dNdy[comp,i,:] = (N_y[comp,i+1,:]-N_y[comp,i,:])/dy
elif i == NY-1:
dNdy[comp,i,:] = (N_y[comp,i,:]-N_y[comp,i-1,:])/dy
else:
dNdy[comp,i,:] = (N_y[comp,i+1,:]-N_y[comp,i-1,:])/(2*dy)
i += 1
comp += 1
coeff1 = 1/(-dNdy+r)
N_x = speed*X_ch
i = 0
for cell in x_axis:
if i == 0:
coeff2[:,i] = (-N_x[:,i]+No)/dx-N_y[:,0,i]/channel_height
else:
coeff2[:,i] = (-N_x[:,i]+N_x[:,i-1])/dx-N_y[:,0,i]/channel_height
i += 1
return [(coeff1,coeff2)]
但是,当我尝试编译函数时,出现以下错误:
NumbaWarning:
Compilation is falling back to object mode WITH looplifting enabled because Function "model_f" failed type inference due to: No conversion from list(Tuple(array(float64, 3d, C), array(float64, 2d, A)))<iv=None> to Tuple(array(float64, 3d, C), array(float64, 2d, A)) for '$1656return_value.4', defined at None
File "soe_model_v2.1.py", line 338:
def model_f(t,dvar,T,mu_fuel,mu_air,dXdy,dNdy,N_y,N_x,r,Dkn,Deffbin,coeff2,result):
<source elided>
return [(coeff1,coeff2)]
^
During: typing of assignment at c:\users\matfe\...\soe_model_v2.1.py (338)
File "soe_model_v2.1.py", line 338:
def model_f(t,dvar,T,mu_fuel,mu_air,dXdy,dNdy,N_y,N_x,r,Dkn,Deffbin,coeff2,result):
<source elided>
return [(coeff1,coeff2)]
^
@nb.jit(nb.types.containers.Tuple((nb.types.float64[:,:,::1],nb.types.float64[:,:]))(nb.types.float64,nb.types.containers.Tuple((nb.types.float64[:,:,:],nb.types.float64[:,:])),nb.types.float64,nb.types.float64,nb.types.float64,nb.types.float64[:,:,:],nb.types.float64[:,:,:],nb.types.float64[:,:,:],nb.types.float64[:,:],nb.types.float64[:,:,:],nb.types.float64[:],nb.types.float64[:,:],nb.types.float64[:,:],nb.types.containers.Tuple((nb.types.float64[:,:,:],nb.types.float64[:,:]))))
您可以看到,在原始函数代码中,我的函数有更多的输入,这些输入也是在 jit 修饰期间推断出来的(这使得错误消息很长),但这在这里应该无关紧要。
因此,我的“return”对象似乎被视为包含元组的列表,而不仅仅是元组。这是它不编译的原因吗?有人有什么建议吗?
答:
0赞
Rutger Kassies
9/25/2023
#1
在 Python 中返回多个值已经默认为元组,因此您应该能够将其简化为类似于下面的玩具示例:
import numba as nb
from numba.types import Tuple, float64, int64
import numpy as np
@nb.njit(Tuple((float64[:,:,::1], float64[:,:]))(int64))
def model_f(x):
a = np.random.randn(x, x, x)
b = np.random.randn(x, x)
return a, b
model_f(2)
这将正确编译,并返回一个包含两个数组的元组。
评论
(coeff1,coeff2)