筛选 Keras image_dataset_from_directory类

Filter Keras image_dataset_from_directory classes

提问人:basit khan 提问时间:10/24/2023 最后编辑:Nicolas Gervaisbasit khan 更新时间:10/25/2023 访问量:27

问:

我正在从 kaggle 导入数据集,它有 15 个类,我只需要其中的 10 个类,如何将这些类过滤到我的数据集中?

我正在尝试这段代码


image_size= 256
batch_size=8
channels=3
epochs=50

dataset = tf.keras.preprocessing.image_dataset_from_directory('/kaggle/input/plant-village/PlantVillage',
                                                              seed=123,
                                                              shuffle=True,
                                                              image_size=(image_size,image_size),
                                                              batch_size=batch_size)
dataset.class_names

结果是

找到 20638 个文件,属于 15 个类。 ['Pepper__bell___Bacterial_spot', 'Pepper__bell___healthy', 'Potato___Early_blight', 'Potato___Late_blight', 'Potato___healthy', 'Tomato_Bacterial_spot', 'Tomato_Early_blight', 'Tomato_Late_blight', 'Tomato_Leaf_Mold', 'Tomato_Septoria_leaf_spot', 'Tomato_Spider_mites_Two_spotted_spider_mite', 'Tomato__Target_Spot', 'Tomato__Tomato_YellowLeaf__Curl_Virus', 'Tomato__Tomato_mosaic_virus', 'Tomato_healthy']

我只期待这些 claases

desired_classes = [
    'Tomato_Bacterial_spot',
    'Tomato_Early_blight',
    'Tomato_Late_blight',
    'Tomato_Leaf_Mold',
    'Tomato_Septoria_leaf_spot',
    'Tomato_Spider_mites_Two_spotted_spider_mite',
    'Tomato__Target_Spot',
    'Tomato__Tomato_YellowLeaf__Curl_Virus',
    'Tomato__Tomato_mosaic_virus',
    'Tomato_healthy'
]

python tensorflow keras tensorflow-datasets kaggle

评论

0赞 basit khan 10/26/2023
@Nicholas它确实有效,即使我在它上搜索了 GPT,您也为我节省了很多时间,但我找不到正确的解决方案。谢谢。

答:

0赞 Nicolas Gervais 10/25/2023 #1

它可能更易于使用,因为您可以直接筛选类。tensorflow.keras.preprocessing.image.ImageDataGenerator

from tensorflow.keras.preprocessing.image import (
    DirectoryIterator, ImageDataGenerator
)

directory = r'path/to/image/directory'
batch_size = 10
image_size = 256

img_iterator = ImageDataGenerator(
    rescale=1./255.
)

iterator = DirectoryIterator(
    directory=directory,
    image_data_generator=img_iterator,
    classes=desired_classes,
    target_size=(image_size, image_size),
    batch_size=batch_size
)