提问人:Leigh Gable 提问时间:3/18/2023 最后编辑:Leigh Gable 更新时间:3/18/2023 访问量:51
这个 Jax 函数的纯函数版本是什么?
What is a pure functional version of this Jax function?
问:
我正在编写一个简单的分类器,试图掌握 Jax。我想对数据进行归一化(它是来自 sklearn 的 Iris 数据集)并且我的小函数可以工作,但从我在 Jax the Sharp Bits 文档中读到的内容来看,我应该避免使用 lambda 函数和迭代向量。我不精通函数式编程,我很好奇是否有更好、更像 Jax 的方法来做到这一点。这是我到目前为止的代码:
import jax.numpy as jnp
from jax import jit, vmap
# lots of imports ...
iris = load_it('data', 'iris.pkl')
def normalize(data):
return jnp.apply_along_axis(lambda x: x/jnp.linalg.norm(x), 1, data)
# TODO: use a functional style, maybe use partial
# and get rid of the lambda ...
tic = time.perf_counter()
iris_data_normal = normalize(iris.data)
toc = time.perf_counter()
print(f"It took jax {toc - tic:0.4f} seconds.")
当我运行它时,我得到:任何指导都是最值得赞赏的!It took jax 0.0677 seconds.
答:
0赞
jakevdp
3/18/2023
#1
您目前的方法在我看来很好:它是纯粹的(该功能没有副作用),并且在 JAX 中是根据 vmap
实现的,因此在计算效率方面没有问题。apply_along_axis
如果你想使用直接数组操作编写一个类似的函数,你可以等效地做这样的事情:
def normalize(data):
return data / jnp.linalg.norm(data, axis=1, keepdims=True)
评论
0赞
Leigh Gable
3/18/2023
谢谢!太棒了。而且速度更快。
评论