提问人:cluelessness 提问时间:10/25/2023 更新时间:10/25/2023 访问量:27
追加以列出 JAX 的替代工作流
Appending to list alternative workflow for JAX
问:
我正在研究一个用 JAX 编写的微分方程求解器。我遇到的一个常见工作流程是这样的:
import jax.numpy as jnp
from jax import jit
# Function to integrate.
@jit
def dxdt(t, x):
return -x**2
# Euler method for simplicity.
@jit
def integrator(f, t, x, dt):
return x + f(t, x) * dt
t_arr = jnp.linspace(0, 10, 100)
dt = t_arr[1] - t_arr[0]
x_list = []
# initialize x.
x = 0.
for t in t_arr:
x_list.append(x)
x = integrator(f, t, x, dt)
x_arr = jnp.array(x_list)
我的问题是,是否有办法使用 JAX “矢量化”该 for 循环?
我认识到这在这里不合适,因为变量 x 在每次 for 循环迭代中都会发生变化。如果此工作流有更适合 JAX 的方法?jax.vmap()
答:
1赞
jakevdp
10/25/2023
#1
这种顺序操作(其中每个步骤都依赖于最后一个步骤)在 JAX 中通过 jax.lax.scan
得到支持。以下是如何使用以下方法进行等效的计算:scan
import jax
def scan_body(carry, t):
x, dt = carry
new_x = integrator(dxdt, t, x, dt)
return (new_x, dt), x
_, x_arr = jax.lax.scan(scan_body, (0., dt), t_arr)
上一个:Python 中的一维热感应方程
评论
integrator
stepper
integrator
jit