提问人:MWB 提问时间:10/3/2023 最后编辑:MWB 更新时间:10/7/2023 访问量:139
鉴于 PyTorch 不允许多个可变引用,类似 PyTorch 的自动微分如何在 Rust 中工作?
How can PyTorch-like automatic differentiation work in Rust, given that it does not allow multiple mutable references?
问:
我主要是一个局外人,试图了解 Rust 是否适合我的项目。
在 Rust 中有一些框架可以进行自动区分。具体来说,根据他们的描述,我认为 candle 和其他一些项目以某种方式以类似于 PyTorch 的方式做到这一点。
但是,我知道 Rust 不允许多个可变引用。这似乎是类似 PyTorch 的自动区分所需要的:
x = torch.rand(10) # an array of 10 elements
x.requires_grad = True
y = x.sin()
z = x**2
和 都必须保留对 的可变引用,因为您可能希望反向传播它们,这将修改 .例如:y
z
x
x.grad
(y.dot(z)).backwards()
print(x.grad) # .backwards() adds a new field (an array) to x, without modifying it otherwise
那么,鉴于 Rust 不允许多个可变引用,如何在 Rust 中实现类似的行为呢?
答:
在 Rust 中提供看似多个可变引用的方法是通过内部可变性,它允许通过共享引用进行突变。Rust 仍然有一些要求不允许同时发生突变,但有几种方法可以确保这一点,因此有几种类型通常提供内部可变性:Cell、RefCell
、Mutex
、RwLock
。 它们以
UnsafeCell
为基础构建,作为核心原语,告诉编译器这并不一定意味着所包含值不可变。&
如果我们看一下蜡烛的源,基本面包含一个弧
,它允许多个张量“句柄”引用相同的值 - 提供共享所有权(源):Tensor
Arc
pub struct Tensor(Arc<Tensor_>);
隐藏的内部类型如下所示(源代码):Tensor_
pub struct Tensor_ {
id: TensorId,
// As we provide inner mutability on the tensor content, the alternatives are:
// - Using a mutex, this would have the highest cost when retrieving the storage but would
// prevent errors when concurrent access takes place. Mutex would also be subject to
// deadlocks for example using the current code if the same tensor is used twice by a single
// binary op.
// - Using a refcell unsafe cell would have some intermediary cost, borrow checking would be
// verified dynamically, but the resulting tensors would not be send or sync.
// - Using an unsafe cell would have the lowest cost but undefined behavior on concurrent
// accesses.
// Ideally, we would use Arc<Storage> for tensors on which we don't plan on modifying the data
// and Arc<Mutex<Storage>> for tensors where the data could be modified, e.g. variables but
// that's tricky to encode in the current setup.
storage: Arc<RwLock<Storage>>,
layout: Layout,
op: BackpropOp,
is_variable: bool,
dtype: DType,
device: Device,
}
这方便地有一个评论权衡了内部可变性的选项。是允许多个句柄访问和/或改变张量内容的部分。storage
RwLock
因此,当反向传播发生时,它通过获取 a 来访问相关张量,以便访问数据以执行操作,然后在对生成的张量执行任何操作之前释放这些守卫,以避免死锁(因为如果在持有现有守卫时尝试突变,则会阻塞)。storage
RwLockReadGuard
RwLock
该库似乎没有利用这种内部可变性,因为它更喜欢创建新的张量而不是改变现有的张量,除非它是一个应该更新以反映新数据的变量。在这种情况下,它获取 a 以用新值交换数据,并再次快速释放保护。RwLockWriteGuard
很难在蜡烛的来源中给出确切的线条,因为有许多层用于反向传播、存储操作和跟踪结果。我也无法用公式进行具体演示,因为我不太精通这个主题。但我希望这一点是清楚的,可以帮助你进行自己的冒险。
评论
你说得对,rust 编译器强制要求一次只能有一个对一个值的可变引用,但有一个转义舱口:内部可变性模式。
此模式允许程序员构造数据结构,在运行时而不是在编译时检查规则。
标准库提供了许多实现内部可变性的容器,具有不同的使用模式,适用于不同的场景。主要示例包括:
RefCell<T>
,允许对单线程使用情况进行运行时借用检查RwLock<T>
,允许对多线程使用情况进行运行时借用检查Mutex<T>
,一次只允许对其内容进行一次引用
这如何适用于蜡烛?让我们来看看引擎盖下:
pub struct Tensor_ {
...
storage: Arc<RwLock<Storage>>,
...
支持张量的存储内容受 .事实上,紧挨着这上面的代码中有一些注释,描述了选择这个特定解决方案的原因 - 值得一读。RwLock
不仅如此,这反过来又包含在 Arc<T>
中 - 这意味着它实际上是一个堆分配的引用计数值。此值可以有多个“所有者”,并且只有在最后一个所有者超出范围时才会解除分配。
在反向传播的情况下如何使用它?好吧,backward()
方法不会直接修改张量,而是返回一个包含计算梯度的 GradStore
。A 可能反过来被优化器
使用。 是一个特征,具有几种不同的实现,因此让我们看一下 SGD
优化器:Tensor
GradStore
Optimizer
fn step(&mut self, grads: &candle::backprop::GradStore) -> Result<()> {
for var in self.vars.iter() {
if let Some(grad) = grads.get(var) {
var.set(&var.sub(&(grad * self.learning_rate)?)?)?;
}
}
Ok(())
}
好的,所以这里的梯度被应用于某些实例 - 这些(在这里定义)是什么?Var
pub struct Var(Tensor);
好的,一个包装器。set
方法如何完成它的工作?这一行是关键:Tensor
let (mut dst, layout) = self.storage_mut_and_layout();
这给了我们一个可变变量,它似乎代表了集合操作的目的地。这个 storage_mut_and_layout()
方法有什么作用?
let storage = self.storage.write().unwrap();
啊哈!它调用我们上面看到的 write()
方法,存储位于其中。此方法的文档说:RwLock
使用独占写入访问权限锁定此 RwLock,阻塞电流 线程,直到可以获取它。
当其他编写器或其他读取器时,此函数不会返回 当前可以访问该锁。
所以总而言之:
- 该方法本身似乎没有修改输入,但它返回了一个包含梯度的数据结构
backward()
Tensor
- 渐变将使用 .
Tensor
Optimizer
- 使用该方法来更改 ,在后台,它使用保护它的方法对 的数据存储进行可变访问。
Optimizer
set
Tensor
Tensor
write()
RwLock
评论
的 RefCell
。或者,您可能一开始就不让变量成为正确的堆栈变量,而总是必须通过或类似方式引用它们。some_manager_object.set_variable("a", 42f32)
RefCell
RefCell
实际上不是这样,它应该被应用于使你的变量可变,但它们之间的关系必须单独表达。这种分为 rust-unofficial.github.io/too-many-lists 类。你可以用 s 来做到这一点,但有很多陷阱。在 Rust 中处理图形的大多数理智方法都归结为拥有它,无论是作为 petgraph 还是 bump 分配器 - 这使得循环引用成为可能。Rc
some_manager_object
Dict
HashMap