追加以列出 JAX 的替代工作流

Appending to list alternative workflow for JAX

提问人:cluelessness 提问时间:10/25/2023 更新时间:10/25/2023 访问量:27

问:

我正在研究一个用 JAX 编写的微分方程求解器。我遇到的一个常见工作流程是这样的:

import jax.numpy as jnp
from jax import jit

# Function to integrate.
@jit
def dxdt(t, x):
   return -x**2

# Euler method for simplicity.
@jit
def integrator(f, t, x, dt):
    return x + f(t, x) * dt

t_arr = jnp.linspace(0, 10, 100)
dt = t_arr[1] - t_arr[0]

x_list = []

# initialize x.
x = 0.

for t in t_arr:
    x_list.append(x)
    x = integrator(f, t, x, dt)

x_arr = jnp.array(x_list)

我的问题是,是否有办法使用 JAX “矢量化”该 for 循环?

我认识到这在这里不合适,因为变量 x 在每次 for 循环迭代中都会发生变化。如果此工作流有更适合 JAX 的方法?jax.vmap()

python 数值方法 微分方程 jax

评论

0赞 Lutz Lehmann 10/25/2023
重命名时间循环并将其放入返回计算值列表的新函数中。然后,您可以使用此功能进行装饰。计算(在一般情况下)本质上是顺序的,没有什么可以并行化的。但是,如果您可以用一些编译的变体替换 python 循环,这应该会加快速度。integratorstepperintegratorjit

答:

1赞 jakevdp 10/25/2023 #1

这种顺序操作(其中每个步骤都依赖于最后一个步骤)在 JAX 中通过 jax.lax.scan 得到支持。以下是如何使用以下方法进行等效的计算:scan

import jax

def scan_body(carry, t):
  x, dt = carry
  new_x = integrator(dxdt, t, x, dt)
  return (new_x, dt), x

_, x_arr = jax.lax.scan(scan_body, (0., dt), t_arr)