pytorch 3D 张量切片与列表仅在列表中有一个值或一个 True 值时才有效

pytorch 3d tensor slicing with list works only if there is one value or one True value in list

提问人:puddles 提问时间:9/5/2023 最后编辑:puddles 更新时间:9/5/2023 访问量:12

问:

我有一个 3d 张量,我想使用 .是否要屏蔽某行取决于布尔列表,其 True 值对应于要屏蔽的行。当列表中有一个 True 时,下面的代码按预期工作,但当有两个 True 时会给出一个 IndexError,当有三个 True 时会给出意外的输出。有人可以指出一个解决方案或解释两个和三个 True 的行为吗?row_mask

x = torch.arange(24).reshape(4,1,6).float()
row_mask = [1,2,5]
app_mask = [False,True,False,False]    # apply row_mask in 1st row only
x[app_mask,-1,row_mask] = -float('inf')
print(x[:,-1,:])
tensor([[ 0.,  1.,  2.,  3.,  4.,  5.],
        [ 6., -inf, -inf,  9., 10., -inf],
        [12., 13., 14., 15., 16., 17.],
        [18., 19., 20., 21., 22., 23.]])

对于两个 True,

x = torch.arange(24).reshape(4,1,6).float()
row_mask = [1,2,5]
app_mask2 = [False,True,True,False]    # apply row_mask in 1st and 2nd rows
x[app_mask2,-1,row_mask] = -float('inf')

我明白了:

IndexError: shape mismatch: indexing tensors could not be broadcast together with shapes [2], [3]

对于三个 True,

x = torch.arange(24).reshape(4,1,6).float()
row_mask = [1,2,5]
app_mask3 = [False,True,True,True]    # apply row_mask to 1st, 2nd, and 3rd rows
x[app_mask3,-1,row_mask] = -float('inf')

我得到了这个意想不到的结果,它改变了第 1、2 和 3 行,但不是我预期的方式。

tensor([[ 0.,  1.,  2.,  3.,  4.,  5.],
        [ 6., -inf,  8.,  9., 10., 11.],
        [12., 13., -inf, 15., 16., 17.],
        [18., 19., 20., 21., 22., -inf]])

除了布尔值之外,还可以重新分配第 1 行,但不能重新分配第 0 行和第 1 行。那么,对 3d 张量的多行进行索引是什么问题,我该如何解决它呢?x[[1],-1,row_mask] = -float('inf')x[[0,1],-1,row_mask] = -float('inf')

PyTorch 切片 张量 掩码

评论


答: 暂无答案