AxisError:轴 1 超出了维度 1 数组的边界 (OneHotEncoded python)

AxisError: axis 1 is out of bounds for array of dimension 1 (OneHotEncoded python)

提问人:Jeremy Jin 提问时间:11/15/2023 更新时间:11/15/2023 访问量:46

问:

我正在研究犬种的分类模型,我试图显示标签的示例及其各自的一个热编码标签,但我收到一个错误,说 AxisError:轴 1 超出了维度 1 数组的边界。

错误信息

这是我当前的代码:

    def decode_one_hot(one_hot_encoded, labels):
        # Use numpy's argmax to get the index of the '1' in each encoded list
        indices = np.argmax(one_hot_encoded, axis=1)
    
        # Convert indices back to original labels
        decoded_labels = [labels[index] for index in indices]
        return decoded_labels

    print('Example of label and one_hot_encoded label')
    train_labels = decode_one_hot(y_train, breeds)
    show_images(image_array= X_train, labels=train_labels, encoded_labels=y_train)

这是我期望实现的目标: 预期结果

python 图像处理 计算机视觉 一热编码

评论

0赞 mmonti 11/15/2023
是什么形状,是什么?y_trainbreeds
0赞 Jeremy Jin 11/15/2023
y_train的形状是 526,品种是我目录中的犬种类别 ''' path = “C:\\Users\\E\\Downloads\\Dog_Breed_Dataset\\data” breeds = os.listdir(path) os.listdir(path) breeds.sort()# 由于标签更正,需要排序 print() print('after sorted') print(breeds) '''
0赞 mmonti 11/15/2023
等等,如果是一维的,你为什么要调用它的第二维度?y_trainnp.argmax()
0赞 Jeremy Jin 11/15/2023
我已经将列表转换为 numpy 数组的方式从 y = np.array(data['label']) 更改为 y = np.array(y),现在y_train是 2D 的
0赞 mmonti 11/15/2023
不过,这仍然是 1D,如果是 1D,它将保持 1D,您可以检查data['label']print(np.shape(y))

答:

0赞 mmonti 11/15/2023 #1

我的猜测是你的数组是一维的,在这种情况下,语法是:

indices = np.argmax(one_hot_encoded, axis=0)