提问人:puddles 提问时间:9/5/2023 最后编辑:puddles 更新时间:9/5/2023 访问量:12
pytorch 3D 张量切片与列表仅在列表中有一个值或一个 True 值时才有效
pytorch 3d tensor slicing with list works only if there is one value or one True value in list
问:
我有一个 3d 张量,我想使用 .是否要屏蔽某行取决于布尔列表,其 True 值对应于要屏蔽的行。当列表中有一个 True 时,下面的代码按预期工作,但当有两个 True 时会给出一个 IndexError,当有三个 True 时会给出意外的输出。有人可以指出一个解决方案或解释两个和三个 True 的行为吗?row_mask
x = torch.arange(24).reshape(4,1,6).float()
row_mask = [1,2,5]
app_mask = [False,True,False,False] # apply row_mask in 1st row only
x[app_mask,-1,row_mask] = -float('inf')
print(x[:,-1,:])
tensor([[ 0., 1., 2., 3., 4., 5.],
[ 6., -inf, -inf, 9., 10., -inf],
[12., 13., 14., 15., 16., 17.],
[18., 19., 20., 21., 22., 23.]])
对于两个 True,
x = torch.arange(24).reshape(4,1,6).float()
row_mask = [1,2,5]
app_mask2 = [False,True,True,False] # apply row_mask in 1st and 2nd rows
x[app_mask2,-1,row_mask] = -float('inf')
我明白了:
IndexError: shape mismatch: indexing tensors could not be broadcast together with shapes [2], [3]
对于三个 True,
x = torch.arange(24).reshape(4,1,6).float()
row_mask = [1,2,5]
app_mask3 = [False,True,True,True] # apply row_mask to 1st, 2nd, and 3rd rows
x[app_mask3,-1,row_mask] = -float('inf')
我得到了这个意想不到的结果,它改变了第 1、2 和 3 行,但不是我预期的方式。
tensor([[ 0., 1., 2., 3., 4., 5.],
[ 6., -inf, 8., 9., 10., 11.],
[12., 13., -inf, 15., 16., 17.],
[18., 19., 20., 21., 22., -inf]])
除了布尔值之外,还可以重新分配第 1 行,但不能重新分配第 0 行和第 1 行。那么,对 3d 张量的多行进行索引是什么问题,我该如何解决它呢?x[[1],-1,row_mask] = -float('inf')
x[[0,1],-1,row_mask] = -float('inf')
答: 暂无答案
评论