如何从 pytorch 中的图像中提取补丁?

How to extract patches from an image in pytorch?

提问人:John Sall 提问时间:5/7/2020 更新时间:5/7/2020 访问量:3595

问:

我想从补丁大小为 128 且步幅为 32 的图像中提取图像补丁,所以我有此代码,但它给了我一个错误:

from PIL import Image 
img = Image.open("cat.jpg")
x = transforms.ToTensor()(img)

x = x.unsqueeze(0)

size = 128 # patch size
stride = 32 # patch stride
patches = x.unfold(1, size, stride).unfold(2, size, stride).unfold(3, size, stride)
print(patches.shape)

我得到的错误是:

RuntimeError: maximum size for tensor at dimension 1 is 3 but size is 128

这是我迄今为止找到的唯一方法。但它给了我这个错误

Python 图像处理 pytorch

评论


答:

11赞 Michael Jungo 5/7/2020 #1

your 的大小是 .调用尝试从大小为 3 的维度 1 创建大小为 128 的切片,因此它太小而无法创建任何切片。x[1, 3, height, width]x.unfold(1, size, stride)

您不希望跨维度 1 创建切片,因为这些是图像的通道(在本例中为 RGB),并且需要保持所有色块的通道。仅在图像的高度和宽度上创建色块。

patches = x.unfold(2, size, stride).unfold(3, size, stride)

生成的张量将具有大小。您可以对其进行调整以组合切片以获得补丁列表,即大小:[1, 3, num_vertical_slices, num_horizontal_slices, 128, 128][1, 3, num_patches, 128, 128]

patches = patches.reshape(1, 3, -1, size, size)