提问人:Ælex 提问时间:11/16/2023 更新时间:11/16/2023 访问量:20
PyTorch 用于循环优化和加速技术
PyTorch For Loop Optimisations and Speedup techniques
问:
这是我在过去一年中遇到过三次的问题。
我很欣赏在某些情况下,矢量化解决方案会更好,而且速度更快。
然而,恕我直言,在发现矢量化解决方案和使用本质上是 for 循环(或双 for 循环)之间存在权衡。发现矢量化解决方案(如果确实存在的话)可能需要更多的努力、反复试验、研究等等。
最简单的代码形式(在本例中为双 for 循环)几乎总是最终成为我的瓶颈,但实现和测试所需的时间很少。
下面是一个示例:
@torch.jit.script
def seq_prob(t_samples: torch.Tensor):
i = 0
probs = [0] * len(t_samples)
for t_i in torch.unbind(t_samples):
for t_k in torch.unbind(t_samples):
is_same = torch.all(torch.isclose(t_i, t_k, rtol=1e-05, atol=1e-08, equal_nan=False))
if is_same is True:
probs[i] += 1
i += 1
return probs
简单地将外部维度视为可迭代的维度。在某些情况下,我花了相当多的时间来推导循环的矢量化形式,这通常会导致屏蔽、cumsum、index select 和各种内置的 pytorch 方法,与 for 循环相比,使逻辑复杂化,但使其更快。torch.unbind
同样,使用 CUDA 有时会有所帮助(但并非总是如此)。@torch.jit
因此,我的问题是:
- 在 pytorch 中使用某种形式的 for 循环(例如,或类似的东西)时,其目的是遍历维度并执行操作
torch.unbind
torch.chunk
- 有没有一种方法,一个黄金标准,一些选项,可以加快速度(不包括矢量化)?
- 如果矢量化是唯一的选择,那么什么是好的第一攻击计划?以上面所示的代码为例,该代码计算给定一定容差的样本集中值的观测值。
答:
1赞
Klops
12/5/2023
#1
我无法想出一个矢量化版本,但我修复了一个错误并将循环运行减少了 50% 以上。此外,这里不需要,你可以开箱即用地在外部维度上迭代数组(或张量)。torch.unbind
我不能为你提供一个黄金标准,除了:尽可能矢量化,不要在两个方向上进行成对比较,当它们可以避免时。
import torch
t = torch.Tensor([[1, 2, 3], [1, 2, 3], [1, 1, 1]])
torch.all(torch.isclose(t[0], t[0]))
t_samples = t
i = 0
probs = torch.zeros(len(t_samples))
for id_i, t_i in enumerate(t_samples):
# dont do the same calculation twice, start at id_i + 1
for id_j, t_k in enumerate(t_samples[id_i+1:], start=id_i+1):
is_same = torch.all(
torch.isclose(t_i, t_k, rtol=1e-05, atol=1e-08, equal_nan=False)
)
if is_same: # you had an error here, torch booleans dont work your way
probs[id_i] += 1 # compare A to B
probs[id_j] += 1 # compare B to A
评论
0赞
Ælex
12/6/2023
感谢您的错误修复。通常比在张量上切片和循环更快(根据 pytorch 论坛)。torch.unbind
1赞
Klops
12/5/2023
#2
除了我另一个答案中的优化版本外,这是针对您的问题的矢量化版本:
import torch
# sample data with a duplicate at index 0 and 1
t = torch.Tensor([[1, 2, 3], [1, 2, 3], [1, 1, 1]])
# indices of all unique pairwise comparisons (triu=upper triangle). Offset 1 since we don't need to compare A to A
indices1, indices2 = torch.triu_indices(row=len(t), col=len(t), offset=1)
# check for each unique pair (e.g. treat A<->B same as B<->A) if values are the same (e.g. difference is zero)
matches = torch.isclose(
(t[indices1] - t[indices2]).abs().sum(axis=1),
torch.zeros(len(t)),
rtol=1e-05, atol=1e-08
)
# get the indices for cases where differences were zero (matches!)
matched_indices = torch.cat([indices1[:, None], indices2[:, None]], 1)[matches]
# container for probs
probs = torch.zeros(len(t))
# increase the probs for those indices that have duplicates
probs[matched_indices] += 1
注意:考虑到矩阵将大致包含 (n**2) / 2 个元素,并且可能会爆炸为巨大的 n。
评论
0赞
Ælex
12/6/2023
这很好,是的,我认为这要好得多。我最终使用并生成了相同的索引,但我怀疑这要快得多,因为它会导致对底层 cuda 代码的一次调用,而必须使用多次调用。是的,内存是一个问题,在某些情况下,我不得不在非常大的嵌入上运行非常大的余弦相似性,最终总是出现 OOM。但这效果非常好,我会牢记这一点以备将来参考。torch.roll
torch.triu_indices
triu_indices
roll
评论
if same is True
if same
1 if torch.ones(1, dtype=torch.bool)[0] else 0
1 if torch.ones(1, dtype=torch.bool)[0] is True else 0