提问人:Christoph 提问时间:9/18/2019 最后编辑:MassifoxChristoph 更新时间:9/18/2019 访问量:1237
如何避免 Numpy 类型转换?
How to avoid Numpy type conversions?
问:
是否可以避免或发出从整数和到的自动 Numpy 类型转换的警告?32 bit float arrays
64 bit float arrays
我的用例是,我正在开发一个大型分析包(20k 行 Python 和 Numpy),目前混合了 float 32 和 64 以及一些 int dtypes,这很可能导致性能欠佳和内存浪费,基本上我想在任何地方一致地使用 float32。
我知道在 Tensorflow 中,组合两个不同 dtype 的数组会产生错误 - 正是因为对 float64 的隐式转换会导致性能不佳,并且在所有计算张量上都具有“传染性”,并且如果隐式进行,很难找到它在哪里引入。
在 Numpy 中寻找一个选项或一种对 Numpy 进行猴子修补的方法,以便它在这方面的行为像 Tensorflow 一样,即在诸如 等操作的隐式类型转换时发出错误,或者更好的是,发出带有打印回溯的警告,以便继续执行,但我看到它发生在哪里。可能?np.add
np.mul
答:
0赞
Paul Panzer
9/18/2019
#1
免责声明:我没有以任何认真的方式测试过这一点,但这似乎是一条很有前途的路线。
一种相对轻松的操纵 ufunc 行为的方法似乎是子类化 ndarray
并覆盖。例如,如果您满足于捕捉任何产生的东西__array_ufunc__
float64
class no64(np.ndarray):
def __array_ufunc__(self, ufunc, method, *inputs, **kwds):
ret = getattr(ufunc, method)(*map(np.asarray,inputs), **kwds)
# some ufuncs return multiple arrays:
if isinstance(ret,tuple):
if any(x.dtype == np.float64 for x in ret):
raise ValueError
return (*(x.view(no64) for x in ret),)
if ret.dtype == np.float64:
raise ValueError
return ret.view(no64)
x = np.arange(6,dtype=np.float32).view(no64)
现在让我们看看我们的班级可以做什么:
x*x
# no64([ 0., 1., 4., 9., 16., 25.], dtype=float32)
np.sin(x)
# no64([ 0. , 0.84147096, 0.9092974 , 0.14112 , -0.7568025 ,
# -0.9589243 ], dtype=float32)
np.frexp(x)
# (no64([0. , 0.5 , 0.5 , 0.75 , 0.5 , 0.625], dtype=float32), no64([0, 1, 2, 2, 3, 3], dtype=int32))
现在让我们将它与一个 64 位参数配对:
x + np.arange(6)
# Traceback (most recent call last):
# File "<stdin>", line 1, in <module>
# File "<stdin>", line 9, in __array_ufunc__
# ValueError
np.multiply.outer(x, np.arange(2.0))
# Traceback (most recent call last):
# File "<stdin>", line 1, in <module>
# File "<stdin>", line 9, in __array_ufunc__
# ValueError
什么不起作用(我敢肯定还有更多)
np.outer(x, np.arange(2.0)) # not a ufunc, so slips through
# array([[0., 0.],
# [0., 1.],
# [0., 2.],
# [0., 3.],
# [0., 4.],
# [0., 5.]])
__array_function__
似乎就是抓住了那些。
评论
ufunc
就像取一个参数一样。看起来默认值是 casting='no''。np.add
casting
same_kind' https://docs.scipy.org/doc/numpy/reference/ufuncs.html#casting-rules, https://docs.scipy.org/doc/numpy/reference/generated/numpy.can_cast.html#numpy.can_cast. I think you want
out
np.multiply(x,2., casting='no')
np.array(2.)
x
dtype