提问人:Syuuuu 提问时间:8/8/2023 最后编辑:Syuuuu 更新时间:8/8/2023 访问量:27
如何在 Customdatagenerator 中获取confusion_matrix y_true
How to get confusion_matrix y_true in Customdatagenerator
问:
我想构建confusion_matrix但我总是收到错误消息
ValueError
Found input variables with inconsistent numbers of samples: [0, 62]
File "C:\Labbb\inceptionResnetV2\InceptionResnetV2_1.py", line 216, in <module>
sns.heatmap(confusion_matrix(y_true, y_pred),
ValueError: Found input variables with inconsistent numbers of samples: [0, 62]
如何在Customdatagenerator中获得y_true?
我尝试在get_data中附加y_true,并使用 def get_y_true返回y_true,但不起作用
下面是 CustomDataGenerator 代码。
class CustomDataGenerator(Sequence):
def __init__(self, image_folders, label_folders, dir, dim=(512,512), batch_size=1,n_classes=7,n_channels=8,shuffle=True):
self.image_folders = image_folders
...
self.image_paths = []
self.label_paths = []
self.y_true = []
self.on_epoch_end()
def __len__(self):
return int(np.ceil(len(self.image_paths) / self.batch_size))
def __getitem__(self, index):
batch_image_paths = self.image_paths[index * self.batch_size: (index + 1) * self.batch_size]
batch_label_paths = self.label_paths[index * self.batch_size: (index + 1) * self.batch_size]
batch = zip(batch_image_paths, batch_label_paths)
return self.get_data(batch)
def on_epoch_end(self):
self.image_paths = []
self.label_paths = []
for folder in self.image_folders:
image_folder_path = os.path.join(self.dir, folder)
image_files = os.listdir(image_folder_path)
for file_name in image_files:
self.image_paths.append(os.path.join(image_folder_path, file_name))
for folder in self.label_folders:
label_folder_path = os.path.join(self.dir, folder)
label_files = os.listdir(label_folder_path)
for file_name in label_files:
self.label_paths.append(os.path.join(label_folder_path, file_name))
if self.shuffle:
np.random.shuffle(self.image_paths)
np.random.shuffle(self.label_paths)
def get_data(self, batch):
X = np.empty((self.batch_size, *self.dim, self.n_channels))
y = np.empty((self.batch_size, self.n_classes))
y_true = []
for i, (image_path, label_path) in enumerate(batch):
image = np.load(image_path)
with open(label_path, 'r') as f:
line = f.readline().strip()
filepath, label = line.rsplit(' ', 1)
label = int(label)
y_true.append(label)
label_one_hot = to_categorical(label, num_classes=self.n_classes)
X[i,] = image
y[i,] = label_one_hot
return X, y
def get_y_true(self):
return self.y_true
这是获取y_true和y_pred,并构建confusion_matrix
train_datagen = CustomDataGenerator(image_folders, label_folders, train_dir, **params, shuffle = True)
val_datagen = CustomDataGenerator(image_folders, label_folders, valid_dir, **params, shuffle = True)
y_true = CustomDataGenerator.get_y_true(val_datagen)
Y_pred = model.predict(val_datagen)
y_pred = np.argmax(Y_pred, axis=1)
sns.heatmap(confusion_matrix(y_true, y_pred),annot=True, fmt="d", cmap='Greens',ax = ax)
答:
我想对几点发表评论。
至于你最初的问题,y_true是空的:在课堂上。它永远不会被填满。In 是一个 ,但它不是 ,所以它不会被存储,并且在方法结束时丢失。形状的错误也表明了这一点,这里的形状为 0,所以它是空的。self.y_true=[]
__init__()
get_data(..)
y_true
self.y_true
[0, 62]
self.y_true
这里有一些关于代码质量的提示。 做得太多了。您不需要每个时期都重写映像路径。在另一个方法中进行初始化,并且只在 中进行洗牌。
您还应该小心 中的参数。 是 Python 的内置函数,除非您知道自己在做什么,否则您不应该覆盖它们。这就是为什么它在此处的代码中以橙色突出显示的原因。在这个特定的代码中,它不会造成任何伤害,但请注意这一点。
与其打电话,不如做.它的工作原理相同,并且(在我看来)更清晰。坦率地说,我以前从未见过你的符号。on_epoch_end(..)
on_epoch_end()
dir
__init__()
dir
y_true = CustomDataGenerator.get_y_true(val_datagen)
y_true = val_datagen.get_y_true()
最后一点,你的例子是不可重现的。我试图运行你的代码,但你似乎省略了代码的某些部分,我遇到了错误,不得不猜测才能修复它们。当您提交整个(相关)代码并对其进行注释时,它确实很有帮助。
评论