Zygote 不支持突变数组

Mutating arrays is not supported in Zygote

提问人:Bern 提问时间:11/9/2023 最后编辑:Bern 更新时间:11/9/2023 访问量:65

问:

我正在尝试使用科学机器学习框架从 Julia 中的示例中编写螺旋示例

但是,当代码通过函数进行区分时,我收到错误:get_batch

ERROR: Mutating arrays is not supported -- called setindex!(Vector{Int64}, ...) This error occurs when you ask Zygote to differentiate operations that change the elements of arrays in place (e.g. setting values with x .= ...)

下面是一个最小的可重现示例:

  using Lux, DiffEqFlux, DifferentialEquations, ComponentArrays, Random, StatsBase, MLUtils
  using Zygote

  true_y0 = [2., 0.];
  true_A = [-0.1 2.; -2. -0.1];

  data_size = 1000;
  times = LinRange(0, 25, data_size);

  function ground_truth!(du, u, p, t)
      du .= true_A * (u.^3) 
  end

  ground_truth_odeProb = ODEProblem(ground_truth!, true_y0, (0, times[end]));

  sol_ode = Array(
                    solve(ground_truth_odeProb,
                          Tsit5(),
                          abstol = 1e-10, reltol = 1e-10,
                          saveat = times)
                  );

  batch_time = 10
  batch_size = 20

  function get_batch()

    s = sample(range(1, data_size - batch_time), batch_size, replace=false)

    batch_y0 = sol_ode[ :, s];
    
    batch_t = times[1:batch_time];

    batch_y = stack([sol_ode[:, s .+ i] for i in range(1, batch_time)], dims=3);

    return [batch_y0, batch_t, batch_y];

  end

  const neural_net = Lux.Chain(Lux.Dense(2, 50, tanh),
                        Lux.Dense(50, 2))     

  rng = Random.default_rng();
  p, st = Lux.setup(rng, neural_net)

  const _st = st;

  p_init = ComponentArray(p);

  function neural_net_func!(du, u, p, t)
    du .= neural_net(u.^3, p, st)[1];
  end

  prob_nn = ODEProblem(neural_net_func!, [0., 0.], [0., 0.], p);

  function predict(θ, y0s, ts)
    
    _prob = remake( prob_nn, 
                    u0 = y0s, 
                    tspan = (ts[1], ts[end]), 
                    p = θ
                  );
    
    Array(solve(_prob, Tsit5(), saveat = ts,
              abstol = 1e-5, reltol = 1e-5)); 
  end

  function test(θ)
    y0s, ts, targets = get_batch();
    pred = predict(θ, y0s, ts);
    loss = sum(abs2, targets .- pred)
  end

  x, lambda = pullback((θ) -> test(θ), p_init,);

  lambda(x) # give the error
朱莉娅

评论


答:

2赞 max xilian 11/9/2023 #1

标题已经几乎告诉了你问题所在:Zygote,自动微分 (AD) 库,不支持突变数组,但您确实在 RHS 中突变数组。

事实上,许多 AD 库不支持突变。由于 AD 依赖于跟踪方法执行的每个操作,因此从技术上讲,支持突变操作非常具有挑战性。

你的例子并不是那么小,所以我没有运行它。但在我看来,似乎有一个简单的解决方法:如果你看一下 RHS 的定义 ,很容易避免突变:只需以不合适的样式定义它,所以编写一个函数neural_net_func!

neural_net_func(u,p,t) = neural_net(u.^3, p, st)[1];

我怀疑这已经可以解决您的问题。如果您查看错误消息,这也已经是错误消息建议您做的事情!

如果这没有帮助,通常还有其他方法可以解决这个问题:

评论

0赞 Bern 11/9/2023
嗨,感谢您的回复。这条线对我不起作用。当我运行代码时。它告诉我错误在里面,我不知道为什么。neural_net_func(u,p,t) = neural_net(u.^3, p, st)[1];s = sample(range(1, data_size - batch_time), batch_size, replace=false)
0赞 max xilian 11/9/2023
然后进一步减少您的问题,并将功能置于您想要区分的功能之外。如果这可行(使用一些任意输入),那么您可以修复该功能。请仔细阅读错误消息。它在这条线上抱怨什么?get_batchget_batch