当我在 MNIST 数据库上运行我的神经网络时,它无法正确训练并猜测 2 个数字

When i run my neural network on the MNIST database, it won't train properly and guesses 2 numbers

提问人:BareVer 提问时间:11/17/2023 更新时间:11/17/2023 访问量:17



我的代码摘要首先,它使用来自训练集的 60000 张图像,并将它们放置在亮度值为 0 到 1 的数组中。然后,它使用了我从迈克尔·尼尔森(Michael Nielsen)的神经网络一书的第1章og 2中得到的反向传播方法。然后,它使用测试集中剩余的 10000 张图像测试神经网络。

然后,主要输出是网络从 10000 张图像中获得了多少张图像。问题是我当然希望它尽可能高。我使用不同的学习率、纪元值和小批量大小获得了 500 到 1700 之间的值,但不能再高了。

下面是在 Visual Studio 中以 c# 编写的代码

Network net = new Network(784, 30, 10);

net.train(1, 10, 3.00);


class Network
    public int inputLength;
    public int hiddenLength;
    public int outputLength;

    public double[] a1;
    public double[] a2;
    public double[] a3;

    public double[] b2;
    public double[] b3;

    public double[,] w2;
    public double[,] w3;

    public double[] z2;
    public double[] z3;

    public double[] b2Change;
    public double[] b3Change;

    public double[,] w2Change;
    public double[,] w3Change;

    public double[] dcdz2;
    public double[] dcdz3;

    public int[] y;

    public Network(int input, int hidden, int output)
        inputLength = input;
        hiddenLength = hidden;
        outputLength = output;

        a1 = new double[input];
        a2 = new double[hidden];
        a3 = new double[output];

        b2 = new double[hidden];
        b3 = new double[output];

        w2 = new double[hidden,input];
        w3 = new double[output,hidden];

        z2 = new double[hidden];
        z3 = new double[output];

        b2Change = new double[hidden];
        b3Change = new double[output];

        w2Change = new double[hidden, input];
        w3Change = new double[output, hidden];

        dcdz2 = new double[hidden];
        dcdz3 = new double[output];

        y = new int[output];

    public void train(int epoch, int miniBatch, double learningRate)
        IEnumerable<Image> TrainingData = MnistReader.ReadTestData();
        int count = 0;
        double multiplier = learningRate / miniBatch;
        for (int repeat = 0; repeat < epoch; repeat++)
            foreach (Image image in TrainingData)
                int counter = 0;
                for (int i = 0; i < image.Data.GetLength(1); i++)
                    for (int j = 0; j < image.Data.GetLength(0); j++)
                        a1[counter] = image.Data[j, i] / 255.00;

                y[image.Label] = 1;
                y[image.Label] = 0;
                if (count == 10)

    public void test()
        IEnumerable<Image> TestData = MnistReader.ReadTestData();
        int count = 0;
        foreach (Image image in TestData)
            int counter = 0;
            for (int i = 0; i < image.Data.GetLength(1); i++)
                for (int j = 0; j < image.Data.GetLength(0); j++)
                    a1[counter] = image.Data[j, i] / 255.00;


            if (maximumOfArray(a3) == image.Label)
        Console.WriteLine($"{count} / 10000");

    public void backpropagation(double multiplier)
        dcdz3 = multiplyArrays(dcda3(a3, y), dadz(z3));
        dcdz2 = multiplyArrays(dcda2(dcdz3, w3), dadz(z2));

        for (int i = 0; i < w3Change.GetLength(0); i++)
            for (int j = 0; j < w3Change.GetLength(1); j++)
                w3Change[i, j] += a2[j] * dcdz3[i] * multiplier;

        for (int i = 0; i < w2Change.GetLength(0); i++)
            for (int j = 0; j < w2Change.GetLength(1); j++)
                w2Change[i, j] += a1[j] * dcdz2[i] * multiplier;

        for (int i = 0; i < b3Change.GetLength(0); i++)
            b3Change[i] += dcdz3[i] * multiplier;
        for (int i = 0; i < b2Change.GetLength(0); i++)
            b2Change[i] += dcdz2[i] * multiplier;

    public void applyChange()
        w2 = sub2dArrays(w2, w2Change);
        w3 = sub2dArrays(w3, w3Change);
        b2 = subArrays(b2, b2Change);
        b3 = subArrays(b3, b3Change);


    public void calcValuesOfNetwork()
        z2 = z(a1, w2, b2);
        a2 = sigmoid(z2);
        z3 = z(a2, w3, b3);
        a3 = sigmoid(z3);

    public double[] sigmoid(double[] z)
        double[] output = new double[z.GetLength(0)];
        for (int i = 0; i < z.GetLength(0); i++)
            output[i] = 1 / (1 + Math.Exp(-z[i]));
        return output;

    public double[] z(double[] a, double[,] w, double[] b)
        double[] output = new double[b.GetLength(0)];
        for (int i = 0; i < w.GetLength(0); i++)
            for (int j = 0; j < w.GetLength(1); j++)
                output[i] += a[j] * w[i,j];
        for (int i = 0; i < b.GetLength(0); i++)
            output[i] += b[i];
        return output;


    public double[] dcda3(double[] a, int[] y)
        double[] output = new double[a.GetLength(0)];
        for (int i = 0; i < a.GetLength(0); i++)
            output[i] = a[i] - y[i];
        return output;

    public double[] dadz(double[] z)
        double[] output = new double[z.GetLength(0)];
        double[] sig = sigmoid(z);
        for (int i = 0; i < z.GetLength(0); i++)
            output[i] = sig[i] * (1 - sig[i]);
        return output;

    public double[] dcda2(double[] dcdz, double[,] w)
        double[] output = new double[w.GetLength(1)];
        for (int i = 0; i < w.GetLength(1); i++)
            for (int j = 0; j < w.GetLength(0); j++)
                output[i] += dcdz[j] * w[j, i];
        return output;

    public double[] multiplyArrays(double[] ar1, double[] ar2)
        double[] output = new double[ar1.GetLength(0)];
        for (int i = 0; i < ar1.GetLength(0); i++)
            output[i] = ar1[i] * ar2[i];
        return output;

    public double[] subArrays(double[] ar1, double[] ar2)
        double[] output = new double[ar1.GetLength(0)];
        for (int i = 0; i < ar1.GetLength(0); i++)
            output[i] = ar1[i] - ar2[i];
        return output;

    public double[,] sub2dArrays(double[,] ar1, double[,] ar2)
        double[,] output = new double[ar1.GetLength(0), ar1.GetLength(1)];
        for (int i = 0; i < ar1.GetLength(0); i++)
            for (int j = 0; j < ar1.GetLength(1); j++)
                output[i,j] = ar1[i,j] - ar2[i,j];
        return output;

    public double[] resetArray(double[] ar)
        double[] output = new double[ar.GetLength(0)];
        return output;

    public double[,] reset2dArray(double[,] ar)
        double[,] output = new double[ar.GetLength(0), ar.GetLength(1)];
        return output;

    public int maximumOfArray(double[] a3)
        int output = 0;
        double max = 0;
        for (int i = 0; i < a3.GetLength(0); i++)
            if (a3[i] > max)
                max = a3[i];
                output = i;
        return output;

    public void randomWeights(double[,] w)
        Random rnd = new Random();
        for (int i = 0; i < w.GetLength(0); i++)
            for (int j = 0; j < w.GetLength(1); j++)
                w[i, j] = (rnd.NextDouble() - 0.5) * 2;

    public void randomBias(double[] b)
        Random rnd = new Random();
        for (int i = 0; i < b.GetLength(0); i++)
            b[i] = (rnd.NextDouble() - 0.5) * 2;

public static class MnistReader
    private const string TrainImages = "mnist/train-images.idx3-ubyte";
    private const string TrainLabels = "mnist/train-labels.idx1-ubyte";
    private const string TestImages = "mnist/t10k-images.idx3-ubyte";
    private const string TestLabels = "mnist/t10k-labels.idx1-ubyte";

    public static IEnumerable<Image> ReadTrainingData()
        foreach (var item in Read(TrainImages, TrainLabels))
            yield return item;

    public static IEnumerable<Image> ReadTestData()
        foreach (var item in Read(TestImages, TestLabels))
            yield return item;

    private static IEnumerable<Image> Read(string imagesPath, string labelsPath)
        BinaryReader labels = new BinaryReader(new FileStream(labelsPath, FileMode.Open));
        BinaryReader images = new BinaryReader(new FileStream(imagesPath, FileMode.Open));

        int magicNumber = images.ReadBigInt32();
        int numberOfImages = images.ReadBigInt32();
        int width = images.ReadBigInt32();
        int height = images.ReadBigInt32();

        int magicLabel = labels.ReadBigInt32();
        int numberOfLabels = labels.ReadBigInt32();

        for (int i = 0; i < numberOfImages; i++)
            var bytes = images.ReadBytes(width * height);
            var arr = new byte[height, width];

            arr.ForEach((j, k) => arr[j, k] = bytes[j * height + k]);

            yield return new Image()
                Data = arr,
                Label = labels.ReadByte()

public class Image
    public byte Label { get; set; }
    public byte[,] Data { get; set; }

public static class Extensions
    public static int ReadBigInt32(this BinaryReader br)
        var bytes = br.ReadBytes(sizeof(Int32));
        if (BitConverter.IsLittleEndian) Array.Reverse(bytes);
        return BitConverter.ToInt32(bytes, 0);

    public static void ForEach<T>(this T[,] source, Action<int, int> action)
        for (int w = 0; w < source.GetLength(0); w++)
            for (int h = 0; h < source.GetLength(1); h++)
                action(w, h);



我尝试打印输出神经元值并与它应该是什么进行比较,看起来它就像在猜测每次启动程序时“选择”的 2 个特定数字。


输出神经元值 0: 0,8439129508768974 1: 0,9122277563390606 2: 0,017403770398116702 3: 0,011134118269524533 4: 0,13462595373016392 5: 0,036897782500381525 6: 0,04750856319601651 7: 0,15035649081365893 8: 0,01481989084000439 9: 0,15321481202658915 猜:1 答案:6

输出神经元值 0: 0,838089716270575 1: 0,6086701853058746 2: 0,02905089940189534 3: 0,16903075237731124 4: 0,1404228501842471 5: 0,01814462956762896 6: 0,13817735383695717 7: 0,043509857883374116 8: 0,03455730693234839 9: 0,02082949989901883 猜:0 答案:5

输出神经元值 0: 0,9366729671922033 1: 0,7781000823699303 2: 0,022226345496076768 3: 0,007671436030838936 4: 0,0632049260642895 5: 0,004989765778696671 6: 0,03230139825555404 7: 0,15464187502773366 8: 0,018234953784533374 9: 0,007956539392870072 猜:0 答案:4

这是来自不同尝试的一些猜测,它只猜到了 3 个 猜:3 答案:7 猜:3 答案:8 猜:3 答案:9 猜:3 答案:0 猜:3 答案:1 猜:3 答案:2 猜:3 答案:3 猜:3 答案:4 猜:3 答案:5 猜:3 答案:6

C# 神经网络 MNIST


1赞 AKX 11/17/2023
对不起,我不认为人们会挖掘 400 行左右的代码来找出可能存在错误的地方。
0赞 BareVer 11/17/2023
我已经使用测试数据解决了训练问题。150 行用于创建可变变量并使用我已确认它正在工作的图像创建数组。因此,类 MnistReader 下的所有内容都不需要通读
0赞 Community 11/17/2023
0赞 AKX 11/17/2023
@BareVer 好吧,让我换个说法:我不认为人们会挖掘 320 行左右的代码来找出可能存在错误的地方。
0赞 BareVer 11/20/2023

答: 暂无答案