鉴于 PyTorch 不允许多个可变引用,类似 PyTorch 的自动微分如何在 Rust 中工作?

How can PyTorch-like automatic differentiation work in Rust, given that it does not allow multiple mutable references?

提问人:MWB 提问时间:10/3/2023 最后编辑:MWB 更新时间:10/7/2023 访问量:139

问:

我主要是一个局外人,试图了解 Rust 是否适合我的项目。

在 Rust 中有一些框架可以进行自动区分。具体来说,根据他们的描述,我认为 candle 和其他一些项目以某种方式以类似于 PyTorch 的方式做到这一点。

但是,我知道 Rust 不允许多个可变引用。这似乎是类似 PyTorch 的自动区分所需要的:

x = torch.rand(10) # an array of 10 elements
x.requires_grad = True

y = x.sin()
z = x**2

都必须保留对 的可变引用,因为您可能希望反向传播它们,这将修改 .例如:yzxx.grad

(y.dot(z)).backwards()
print(x.grad) # .backwards() adds a new field (an array) to x, without modifying it otherwise

那么,鉴于 Rust 不允许多个可变引用,如何在 Rust 中实现类似的行为呢?

源反向传播 自动微分

评论

3赞 Caesar 10/3/2023
显而易见的方法是一个内部的 RefCell。或者,您可能一开始就不让变量成为正确的堆栈变量,而总是必须通过或类似方式引用它们。some_manager_object.set_variable("a", 42f32)
0赞 MWB 10/3/2023
@Caesar谢谢。我想知道是否适用于在任意有向无环图中向后行走。似乎如果没有一些额外的机制,你最终可能会对同一个对象进行多次可变借用。RefCell
0赞 Caesar 10/4/2023
RefCell实际上不是这样,它应该被应用于使你的变量可变,但它们之间的关系必须单独表达。这种分为 rust-unofficial.github.io/too-many-lists 类。你可以用 s 来做到这一点,但有很多陷阱。在 Rust 中处理图形的大多数理智方法都归结为拥有它,无论是作为 petgraph 还是 bump 分配器 - 这使得循环引用成为可能。Rcsome_manager_object
1赞 MWB 10/4/2023
@Caesar “使循环引用成为可能” 这里的计算图本质上是非循环的。新变量需要引用计算它们的变量,但反之则不然。
1赞 Caesar 10/5/2023
@MateenUlhaq 你的(ii)本质上是petgraph所做的。(只是不带字符串。哦,开销。;)) 但它也比使用 dumpalo 更不符合人体工程学:您始终需要完成它(或),并且您需要手动管理哪些条目是实时的或管理 GC 根。DictHashMap

答:

2赞 kmdreko 10/7/2023 #1

在 Rust 中提供看似多个可变引用的方法是通过内部可变性,它允许通过共享引用进行突变。Rust 仍然有一些要求不允许同时发生突变,但有几种方法可以确保这一点,因此有几种类型通常提供内部可变性:Cell、RefCellMutexRwLock它们以 UnsafeCell 为基础构建,作为核心原语,告诉编译器这并不一定意味着所包含值不可变。&

如果我们看一下蜡烛的源,基本面包含一个,它允许多个张量“句柄”引用相同的 - 提供共享所有权():TensorArc

pub struct Tensor(Arc<Tensor_>);

隐藏的内部类型如下所示(源代码):Tensor_

pub struct Tensor_ {
    id: TensorId,
    // As we provide inner mutability on the tensor content, the alternatives are:
    // - Using a mutex, this would have the highest cost when retrieving the storage but would
    //   prevent errors when concurrent access takes place. Mutex would also be subject to
    //   deadlocks for example using the current code if the same tensor is used twice by a single
    //   binary op.
    // - Using a refcell unsafe cell would have some intermediary cost, borrow checking would be
    //   verified dynamically, but the resulting tensors would not be send or sync.
    // - Using an unsafe cell would have the lowest cost but undefined behavior on concurrent
    //   accesses.
    // Ideally, we would use Arc<Storage> for tensors on which we don't plan on modifying the data
    // and Arc<Mutex<Storage>> for tensors where the data could be modified, e.g. variables but
    // that's tricky to encode in the current setup.
    storage: Arc<RwLock<Storage>>,
    layout: Layout,
    op: BackpropOp,
    is_variable: bool,
    dtype: DType,
    device: Device,
}

这方便地有一个评论权衡了内部可变性的选项。是允许多个句柄访问和/或改变张量内容的部分。storageRwLock

因此,当反向传播发生时,它通过获取 a 来访问相关张量,以便访问数据以执行操作,然后在对生成的张量执行任何操作之前释放这些守卫,以避免死锁(因为如果在持有现有守卫时尝试突变,则会阻塞)。storageRwLockReadGuardRwLock

该库似乎没有利用这种内部可变性,因为它更喜欢创建新的张量而不是改变现有的张量,除非它是一个应该更新以反映新数据的变量。在这种情况下,它获取 a 以用新值交换数据,并再次快速释放保护。RwLockWriteGuard

很难在蜡烛的来源中给出确切的线条,因为有许多层用于反向传播、存储操作和跟踪结果。我也无法用公式进行具体演示,因为我不太精通这个主题。但我希望这一点是清楚的,可以帮助你进行自己的冒险。

评论

0赞 harmic 10/7/2023
打败我 3 小时!那会教我输入答案,然后去吃午饭!
2赞 harmic 10/7/2023 #2

你说得对,rust 编译器强制要求一次只能有一个对一个值的可变引用,但有一个转义舱口:内部可变性模式。

此模式允许程序员构造数据结构,在运行时而不是在编译时检查规则。

标准库提供了许多实现内部可变性的容器,具有不同的使用模式,适用于不同的场景。主要示例包括:

  • RefCell<T>,允许对单线程使用情况进行运行时借用检查

  • RwLock<T>,允许对多线程使用情况进行运行时借用检查

  • Mutex<T>,一次只允许对其内容进行一次引用

还有其他的 - 请参阅单元同步的模块级文档。

这如何适用于蜡烛?让我们来看看引擎盖下

pub struct Tensor_ {
    ...
    storage: Arc<RwLock<Storage>>,
    ...

支持张量的存储内容受 .事实上,紧挨着这上面的代码中有一些注释,描述了选择这个特定解决方案的原因 - 值得一读。RwLock

不仅如此,这反过来又包含在 Arc<T> 中 - 这意味着它实际上是一个堆分配的引用计数值。此值可以有多个“所有者”,并且只有在最后一个所有者超出范围时才会解除分配。

在反向传播的情况下如何使用它?好吧,backward() 方法不会直接修改张量,而是返回一个包含计算梯度的 GradStore。A 可能反过来被优化器使用。 是一个特征,具有几种不同的实现,因此让我们看一下 SGD 优化器:TensorGradStoreOptimizer

    fn step(&mut self, grads: &candle::backprop::GradStore) -> Result<()> {
        for var in self.vars.iter() {
            if let Some(grad) = grads.get(var) {
                var.set(&var.sub(&(grad * self.learning_rate)?)?)?;
            }
        }
        Ok(())
    }

好的,所以这里的梯度被应用于某些实例 - 这些(在这里定义)是什么?Var

pub struct Var(Tensor);

好的,一个包装器。set 方法如何完成它的工作?这一行是关键:Tensor

let (mut dst, layout) = self.storage_mut_and_layout();

这给了我们一个可变变量,它似乎代表了集合操作的目的地。这个 storage_mut_and_layout() 方法有什么作用?

let storage = self.storage.write().unwrap();

啊哈!它调用我们上面看到的 write() 方法,存储位于其中。此方法的文档说:RwLock

使用独占写入访问权限锁定此 RwLock,阻塞电流 线程,直到可以获取它。

当其他编写器或其他读取器时,此函数不会返回 当前可以访问该锁。

所以总而言之:

  • 该方法本身似乎没有修改输入,但它返回了一个包含梯度的数据结构backward()Tensor
  • 渐变将使用 .TensorOptimizer
  • 使用该方法来更改 ,在后台,它使用保护它的方法对 的数据存储进行可变访问。OptimizersetTensorTensorwrite()RwLock