如何在给定索引下将一个 pytorch 张量的元素复制到另一个张量中,而无需中间分配或循环

How to copy elements of one pytorch tensor at given indices into another tensor without intermediate allocation or looping

提问人:yuri kilochek 提问时间:8/5/2023 最后编辑:yuri kilochek 更新时间:8/5/2023 访问量:21

问:

鉴于

import torch

a: torch.Tensor
b: torch.Tensor
assert a.shape[1:] == b.shape[1:]
idx = torch.randint(b.shape[0], [a.shape[0]])

我想做

b[...] = a[idx]

但是没有中间缓冲区产生或循环。我该怎么做?a[idx]idx

python pytorch 切片 张量

评论


答:

0赞 yuri kilochek 8/5/2023 #1

您可以使用torch.index_select

torch.index_select(a, 0, idx, out = b)