这个 Jax 函数的纯函数版本是什么?

What is a pure functional version of this Jax function?

提问人:Leigh Gable 提问时间:3/18/2023 最后编辑:Leigh Gable 更新时间:3/18/2023 访问量:51

问:

我正在编写一个简单的分类器,试图掌握 Jax。我想对数据进行归一化(它是来自 sklearn 的 Iris 数据集)并且我的小函数可以工作,但从我在 Jax the Sharp Bits 文档中读到的内容来看,我应该避免使用 lambda 函数和迭代向量。我不精通函数式编程,我很好奇是否有更好、更像 Jax 的方法来做到这一点。这是我到目前为止的代码:

import jax.numpy as jnp 
from jax import jit, vmap 
# lots of imports ... 

iris = load_it('data', 'iris.pkl')

def normalize(data): 
    return jnp.apply_along_axis(lambda x: x/jnp.linalg.norm(x), 1, data) 

# TODO: use a functional style, maybe use partial
# and get rid of the lambda ...

tic = time.perf_counter() 
iris_data_normal = normalize(iris.data) 
toc = time.perf_counter() 
print(f"It took jax {toc - tic:0.4f} seconds.")

当我运行它时,我得到:任何指导都是最值得赞赏的!It took jax 0.0677 seconds.

lambda 迭代器 闭包 jax

评论


答:

0赞 jakevdp 3/18/2023 #1

您目前的方法在我看来很好:它是纯粹的(该功能没有副作用),并且在 JAX 中是根据 vmap 实现的,因此在计算效率方面没有问题。apply_along_axis

如果你想使用直接数组操作编写一个类似的函数,你可以等效地做这样的事情:

def normalize(data):
  return data / jnp.linalg.norm(data, axis=1, keepdims=True)

评论

0赞 Leigh Gable 3/18/2023
谢谢!太棒了。而且速度更快。