Rust 并行获取对 ndarray 的每个元素的可变引用

Rust get mutable reference to each element of an ndarray in parallel

提问人:blomp 提问时间:7/18/2023 最后编辑:tadmanblomp 更新时间:7/19/2023 访问量:44

问:

我正在 Rust 中处理并行矩阵乘法代码,我想在其中并行计算乘积的每个元素。我使用 s 来存储我的数据。因此,我的代码将单独是行ndarray

fn mul(lhs: &Array2<f32>, rhs: &Array2<f32>) -> Array2<f32> {
   let N = lhs.raw_size()[0];
   let M = rhs.raw_size()[1];
   let mut result = Array2::zeros((N,M));
   
   range_2d(0..N,0..M).par_iter().map(|(i, j)| {
      // load the result for the (i,j) element into 'result'
   }).count();

   result
}

有什么办法可以做到这一点吗?

不安全 人造丝 锈蚀-ND阵列

评论

0赞 tadman 7/18/2023
A) 这个矩阵有多大?B) 你的线程策略是什么?C) 为什么不使用现有的数学库?D) 您使用的是 SIMD 吗?E) 使用 GPU 会不会更好?这是一种选择吗?
0赞 blomp 7/18/2023
A) rhs 将是一个具有稀疏行的矩阵,可能有大约 1000 个条目。LHS将是一个密集的随机高斯矩阵,可能非常大。B) 我不确定你的意思 C) 部分是为了练习,也因为用例非常专业,我找不到合适的库 E) 是的,使用 GPU 肯定会有所帮助
0赞 tadman 7/18/2023
矩阵乘法是一个非常完善的领域,甚至还有适合您特定用例的板条箱。如果你出于学术原因想自己做,这没有错,但值得研究别人的工作。
1赞 Chayim Friedman 7/18/2023
ndarray确实支持,但我没有看到对提取索引的支持。rayon
1赞 blomp 7/18/2023
“par_iter”是 Rayon 的迭代器,是的。不过,我主要关心的是获取所有可变的引用

答:

0赞 Chayim Friedman 7/18/2023 #1

您可以通过以下方式创建并行迭代器:

use rayon::prelude::*;

pub fn mul(lhs: &Array2<f32>, rhs: &Array2<f32>) -> Array2<f32> {
    let n = lhs.raw_dim()[0];
    let m = rhs.raw_dim()[1];
    let mut result = Array2::zeros((n, m));

    result
        .axis_iter_mut(Axis(0))
        .into_par_iter()
        .enumerate()
        .flat_map(|(n, axis)| {
            axis.into_slice()
                .unwrap()
                .par_iter_mut()
                .enumerate()
                .map(move |(m, item)| (n, m, item))
        })
        .for_each(|(n, m, item)| {
            // Do the multiplication.
            *item = n as f32 * m as f32;
        });

    result
}

评论

0赞 blomp 7/19/2023
据我所知par_bridge效率真的很低。从本质上讲,我认为瓶颈是 Array2,因为将其切成可变块是可能的,但只是一种痛苦。在提出这个问题时,我虽然可能有一种方法可以生成指向每个项目的原始指针,就像split_at_mut如何做到这一点一样,但似乎并非如此。
0赞 Chayim Friedman 7/19/2023
@stomfaig 经过进一步的思考,我得出了没有 .编辑了上面的答案。par_bridge()