提问人:John Sall 提问时间:5/7/2020 更新时间:5/7/2020 访问量:3595
如何从 pytorch 中的图像中提取补丁?
How to extract patches from an image in pytorch?
问:
我想从补丁大小为 128 且步幅为 32 的图像中提取图像补丁,所以我有此代码,但它给了我一个错误:
from PIL import Image
img = Image.open("cat.jpg")
x = transforms.ToTensor()(img)
x = x.unsqueeze(0)
size = 128 # patch size
stride = 32 # patch stride
patches = x.unfold(1, size, stride).unfold(2, size, stride).unfold(3, size, stride)
print(patches.shape)
我得到的错误是:
RuntimeError: maximum size for tensor at dimension 1 is 3 but size is 128
这是我迄今为止找到的唯一方法。但它给了我这个错误
答:
11赞
Michael Jungo
5/7/2020
#1
your 的大小是 .调用尝试从大小为 3 的维度 1 创建大小为 128 的切片,因此它太小而无法创建任何切片。x
[1, 3, height, width]
x.unfold(1, size, stride)
您不希望跨维度 1 创建切片,因为这些是图像的通道(在本例中为 RGB),并且需要保持所有色块的通道。仅在图像的高度和宽度上创建色块。
patches = x.unfold(2, size, stride).unfold(3, size, stride)
生成的张量将具有大小。您可以对其进行调整以组合切片以获得补丁列表,即大小:[1, 3, num_vertical_slices, num_horizontal_slices, 128, 128]
[1, 3, num_patches, 128, 128]
patches = patches.reshape(1, 3, -1, size, size)
评论