PyTorch 用于循环优化和加速技术

PyTorch For Loop Optimisations and Speedup techniques

提问人:Ælex 提问时间:11/16/2023 更新时间:11/16/2023 访问量:20

问:

这是我在过去一年中遇到过三次的问题。

我很欣赏在某些情况下,矢量化解决方案会更好,而且速度更快。

然而,恕我直言,在发现矢量化解决方案和使用本质上是 for 循环(或双 for 循环)之间存在权衡。发现矢量化解决方案(如果确实存在的话)可能需要更多的努力、反复试验、研究等等。

最简单的代码形式(在本例中为双 for 循环)几乎总是最终成为我的瓶颈,但实现和测试所需的时间很少。

下面是一个示例:

@torch.jit.script    
def seq_prob(t_samples: torch.Tensor):
    i = 0
    probs = [0] * len(t_samples)    
    for t_i in torch.unbind(t_samples):
        for t_k in torch.unbind(t_samples):
            is_same = torch.all(torch.isclose(t_i, t_k, rtol=1e-05, atol=1e-08, equal_nan=False))
            if is_same is True:
                probs[i] += 1
        i += 1
    return probs

简单地将外部维度视为可迭代的维度。在某些情况下,我花了相当多的时间来推导循环的矢量化形式,这通常会导致屏蔽、cumsum、index select 和各种内置的 pytorch 方法,与 for 循环相比,使逻辑复杂化,但使其更快。torch.unbind

同样,使用 CUDA 有时会有所帮助(但并非总是如此)。@torch.jit

因此,我的问题是:

  • 在 pytorch 中使用某种形式的 for 循环(例如,或类似的东西)时,其目的是遍历维度并执行操作torch.unbindtorch.chunk
  • 有没有一种方法,一个黄金标准,一些选项,可以加快速度(不包括矢量化)?
  • 如果矢量化是唯一的选择,那么什么是好的第一攻击计划?以上面所示的代码为例,该代码计算给定一定容差的样本集中值的观测值。
Python for 循环 PyTorch 火炬

评论

0赞 Klops 12/5/2023
您的代码中存在错误,火炬布尔值无法以这种方式工作:必须更改为 .您可以使用以下代码行来理解火炬行为:versusif same is Trueif same1 if torch.ones(1, dtype=torch.bool)[0] else 01 if torch.ones(1, dtype=torch.bool)[0] is True else 0

答:

1赞 Klops 12/5/2023 #1

我无法想出一个矢量化版本,但我修复了一个错误并将循环运行减少了 50% 以上。此外,这里不需要,你可以开箱即用地在外部维度上迭代数组(或张量)。torch.unbind

我不能为你提供一个黄金标准,除了:尽可能矢量化,不要在两个方向上进行成对比较,当它们可以避免时。

import torch

t = torch.Tensor([[1, 2, 3], [1, 2, 3], [1, 1, 1]])
torch.all(torch.isclose(t[0], t[0]))


t_samples = t
i = 0
probs = torch.zeros(len(t_samples))
for id_i, t_i in enumerate(t_samples):
    # dont do the same calculation twice, start at id_i + 1
    for id_j, t_k in enumerate(t_samples[id_i+1:], start=id_i+1):
        is_same = torch.all(
            torch.isclose(t_i, t_k, rtol=1e-05, atol=1e-08, equal_nan=False)
        )
        if is_same:  # you had an error here, torch booleans dont work your way
            probs[id_i] += 1  # compare A to B
            probs[id_j] += 1  # compare B to A

评论

0赞 Ælex 12/6/2023
感谢您的错误修复。通常比在张量上切片和循环更快(根据 pytorch 论坛)。torch.unbind
1赞 Klops 12/5/2023 #2

除了我另一个答案中的优化版本外,这是针对您的问题的矢量化版本:

import torch

# sample data with a duplicate at index 0 and 1
t = torch.Tensor([[1, 2, 3], [1, 2, 3], [1, 1, 1]])

# indices of all unique pairwise comparisons (triu=upper triangle). Offset 1 since we don't need to compare A to A
indices1, indices2 = torch.triu_indices(row=len(t), col=len(t), offset=1)

# check for each unique pair (e.g. treat A<->B same as B<->A) if values are the same (e.g. difference is zero)
matches = torch.isclose(
    (t[indices1] - t[indices2]).abs().sum(axis=1),
    torch.zeros(len(t)),
    rtol=1e-05, atol=1e-08
)

# get the indices for cases where differences were zero (matches!)
matched_indices = torch.cat([indices1[:, None], indices2[:, None]], 1)[matches]

# container for probs
probs = torch.zeros(len(t))

# increase the probs for those indices that have duplicates
probs[matched_indices] += 1

注意:考虑到矩阵将大致包含 (n**2) / 2 个元素,并且可能会爆炸为巨大的 n。

评论

0赞 Ælex 12/6/2023
这很好,是的,我认为这要好得多。我最终使用并生成了相同的索引,但我怀疑这要快得多,因为它会导致对底层 cuda 代码的一次调用,而必须使用多次调用。是的,内存是一个问题,在某些情况下,我不得不在非常大的嵌入上运行非常大的余弦相似性,最终总是出现 OOM。但这效果非常好,我会牢记这一点以备将来参考。torch.rolltorch.triu_indicestriu_indicesroll