JAX 中的矢量化最小化和寻根

vectorized minimization and root finding in jax

提问人:Dan Leonte 提问时间:10/25/2023 最后编辑:Dan Leonte 更新时间:10/25/2023 访问量:53

问:

我有一系列参数化的函数args

f(x, args)

并希望确定 的值的最小值。我可以访问该函数及其衍生物。我的第一次尝试是循环遍历 的不同值并在每次迭代中使用 scipy.optimizer,但这需要很长时间。我相信可以通过矢量化来加快操作速度。我的下一个尝试是在 or 中使用,但我似乎无法为 传递多个值。fxN = 1000argsargsjax.vmapjax.scipy.optimize.minimizejaxopt.ScipyMinimizeargs

或者,我可以编写自己的矢量化优化方法,例如平分,其中矢量化是指对数组进行固定次数的迭代操作,如果其中一个优化问题提前达到一定的容错水平,则不会提前停止。我希望使用一些优化的现成算法。

如果 jax 中有一个实现可用,我希望使用一些已经优化的现成算法。这个线程是相关的,但没有改变。args

优化 西皮 矢 量化 贾克斯

评论


答:

1赞 jakevdp 10/25/2023 #1

你可以定义一个函数来找到给定的最小值,然后将其包装起来以自动矢量化它。例如:argsjax.vmap

import jax
import jax.numpy as jnp
from jax.scipy import optimize

def f(x, args):
  a, b = args
  return jnp.sum(a + (x - b) ** 2)

def find_min(a, b):
  x0 = jnp.array([1.0])
  args = (a, b)
  return optimize.minimize(f, x0, (args,), method="BFGS")

a_grid, b_grid = jnp.meshgrid(jnp.arange(5.0), jnp.arange(5.0))

results = jax.vmap(find_min)(a_grid.ravel(), b_grid.ravel())

print(results.success)
# [ True  True  True  True  True  True  True  True  True  True  True  True
#   True  True  True  True  True  True  True  True  True  True  True  True
#   True]

print(results.x.T)
# [[0. 0. 0. 0. 0. 1. 1. 1. 1. 1. 2. 2. 2. 2. 2.
#   3. 3. 3. 3. 3. 4. 4. 4. 4. 4.]]