提问人:Bern 提问时间:11/9/2023 最后编辑:Bern 更新时间:11/9/2023 访问量:65
Zygote 不支持突变数组
Mutating arrays is not supported in Zygote
问:
我正在尝试使用科学机器学习框架从 Julia 中的示例中编写螺旋示例
但是,当代码通过函数进行区分时,我收到错误:get_batch
ERROR: Mutating arrays is not supported -- called setindex!(Vector{Int64}, ...) This error occurs when you ask Zygote to differentiate operations that change the elements of arrays in place (e.g. setting values with x .= ...)
下面是一个最小的可重现示例:
using Lux, DiffEqFlux, DifferentialEquations, ComponentArrays, Random, StatsBase, MLUtils
using Zygote
true_y0 = [2., 0.];
true_A = [-0.1 2.; -2. -0.1];
data_size = 1000;
times = LinRange(0, 25, data_size);
function ground_truth!(du, u, p, t)
du .= true_A * (u.^3)
end
ground_truth_odeProb = ODEProblem(ground_truth!, true_y0, (0, times[end]));
sol_ode = Array(
solve(ground_truth_odeProb,
Tsit5(),
abstol = 1e-10, reltol = 1e-10,
saveat = times)
);
batch_time = 10
batch_size = 20
function get_batch()
s = sample(range(1, data_size - batch_time), batch_size, replace=false)
batch_y0 = sol_ode[ :, s];
batch_t = times[1:batch_time];
batch_y = stack([sol_ode[:, s .+ i] for i in range(1, batch_time)], dims=3);
return [batch_y0, batch_t, batch_y];
end
const neural_net = Lux.Chain(Lux.Dense(2, 50, tanh),
Lux.Dense(50, 2))
rng = Random.default_rng();
p, st = Lux.setup(rng, neural_net)
const _st = st;
p_init = ComponentArray(p);
function neural_net_func!(du, u, p, t)
du .= neural_net(u.^3, p, st)[1];
end
prob_nn = ODEProblem(neural_net_func!, [0., 0.], [0., 0.], p);
function predict(θ, y0s, ts)
_prob = remake( prob_nn,
u0 = y0s,
tspan = (ts[1], ts[end]),
p = θ
);
Array(solve(_prob, Tsit5(), saveat = ts,
abstol = 1e-5, reltol = 1e-5));
end
function test(θ)
y0s, ts, targets = get_batch();
pred = predict(θ, y0s, ts);
loss = sum(abs2, targets .- pred)
end
x, lambda = pullback((θ) -> test(θ), p_init,);
lambda(x) # give the error
答:
2赞
max xilian
11/9/2023
#1
标题已经几乎告诉了你问题所在:Zygote,自动微分 (AD) 库,不支持突变数组,但您确实在 RHS 中突变数组。
事实上,许多 AD 库不支持突变。由于 AD 依赖于跟踪方法执行的每个操作,因此从技术上讲,支持突变操作非常具有挑战性。
你的例子并不是那么小,所以我没有运行它。但在我看来,似乎有一个简单的解决方法:如果你看一下 RHS 的定义 ,很容易避免突变:只需以不合适的样式定义它,所以编写一个函数neural_net_func!
neural_net_func(u,p,t) = neural_net(u.^3, p, st)[1];
我怀疑这已经可以解决您的问题。如果您查看错误消息,这也已经是错误消息建议您做的事情!
如果这没有帮助,通常还有其他方法可以解决这个问题:
- 使用 Enzyme 作为 AD 文库:Enzyme 是为数不多的真正支持突变的 AD 文库之一。为了使用它,您需要在命令中告诉。有关详细信息,请参阅文档:https://docs.sciml.ai/SciMLSensitivity/stable/
SciMLSensitivity
solve
- 当无法避免突变时定义自定义伴随,但您想继续使用 Zygote:ChainRules.jl 是该 https://github.com/JuliaDiff/ChainRules.jl 的库
评论
0赞
Bern
11/9/2023
嗨,感谢您的回复。这条线对我不起作用。当我运行代码时。它告诉我错误在里面,我不知道为什么。neural_net_func(u,p,t) = neural_net(u.^3, p, st)[1];
s = sample(range(1, data_size - batch_time), batch_size, replace=false)
0赞
max xilian
11/9/2023
然后进一步减少您的问题,并将功能置于您想要区分的功能之外。如果这可行(使用一些任意输入),那么您可以修复该功能。请仔细阅读错误消息。它在这条线上抱怨什么?get_batch
get_batch
评论