当我在 MNIST 数据库上运行我的神经网络时,它无法正确训练并猜测 2 个数字

When i run my neural network on the MNIST database, it won't train properly and guesses 2 numbers

提问人:BareVer 提问时间:11/17/2023 更新时间:11/17/2023 访问量:17

问:

我正在制作一个神经网络来识别学校项目的MNIST数据库中的数字。它尽可能直观,所以我不使用任何库,而且我的代码很长。

我的代码摘要首先,它使用来自训练集的 60000 张图像,并将它们放置在亮度值为 0 到 1 的数组中。然后,它使用了我从迈克尔·尼尔森(Michael Nielsen)的神经网络一书的第1章og 2中得到的反向传播方法。然后,它使用测试集中剩余的 10000 张图像测试神经网络。

然后,主要输出是网络从 10000 张图像中获得了多少张图像。问题是我当然希望它尽可能高。我使用不同的学习率、纪元值和小批量大小获得了 500 到 1700 之间的值,但不能再高了。

下面是在 Visual Studio 中以 c# 编写的代码

Network net = new Network(784, 30, 10);

net.train(1, 10, 3.00);

net.test();

class Network
{
    public int inputLength;
    public int hiddenLength;
    public int outputLength;

    public double[] a1;
    public double[] a2;
    public double[] a3;

    public double[] b2;
    public double[] b3;

    public double[,] w2;
    public double[,] w3;

    public double[] z2;
    public double[] z3;

    public double[] b2Change;
    public double[] b3Change;

    public double[,] w2Change;
    public double[,] w3Change;

    public double[] dcdz2;
    public double[] dcdz3;

    public int[] y;

    public Network(int input, int hidden, int output)
    {
        inputLength = input;
        hiddenLength = hidden;
        outputLength = output;

        a1 = new double[input];
        a2 = new double[hidden];
        a3 = new double[output];

        b2 = new double[hidden];
        b3 = new double[output];

        w2 = new double[hidden,input];
        w3 = new double[output,hidden];

        z2 = new double[hidden];
        z3 = new double[output];

        b2Change = new double[hidden];
        b3Change = new double[output];

        w2Change = new double[hidden, input];
        w3Change = new double[output, hidden];

        dcdz2 = new double[hidden];
        dcdz3 = new double[output];

        y = new int[output];
    }

    public void train(int epoch, int miniBatch, double learningRate)
    {
        randomWeights(w2);
        randomWeights(w3);
        randomBias(b2);
        randomBias(b3);
        IEnumerable<Image> TrainingData = MnistReader.ReadTestData();
        int count = 0;
        double multiplier = learningRate / miniBatch;
        for (int repeat = 0; repeat < epoch; repeat++)
        {
            foreach (Image image in TrainingData)
            {
                int counter = 0;
                for (int i = 0; i < image.Data.GetLength(1); i++)
                {
                    for (int j = 0; j < image.Data.GetLength(0); j++)
                    {
                        a1[counter] = image.Data[j, i] / 255.00;
                        counter++;
                    }
                }

                y[image.Label] = 1;
                backpropagation(multiplier);
                y[image.Label] = 0;
                count++;
                if (count == 10)
                {
                    applyChange();
                }
            }
            Console.WriteLine(repeat);
        }
    }

    public void test()
    {
        IEnumerable<Image> TestData = MnistReader.ReadTestData();
        int count = 0;
        foreach (Image image in TestData)
        {
            int counter = 0;
            for (int i = 0; i < image.Data.GetLength(1); i++)
            {
                for (int j = 0; j < image.Data.GetLength(0); j++)
                {
                    a1[counter] = image.Data[j, i] / 255.00;
                    counter++;
                }
            }

            calcValuesOfNetwork();

            if (maximumOfArray(a3) == image.Label)
            {
                count++;
            }
        }
        Console.WriteLine($"{count} / 10000");
    }

    public void backpropagation(double multiplier)
    {
        calcValuesOfNetwork();
        
        dcdz3 = multiplyArrays(dcda3(a3, y), dadz(z3));
        dcdz2 = multiplyArrays(dcda2(dcdz3, w3), dadz(z2));

        for (int i = 0; i < w3Change.GetLength(0); i++)
        {
            for (int j = 0; j < w3Change.GetLength(1); j++)
            {
                w3Change[i, j] += a2[j] * dcdz3[i] * multiplier;
            }
        }

        for (int i = 0; i < w2Change.GetLength(0); i++)
        {
            for (int j = 0; j < w2Change.GetLength(1); j++)
            {
                w2Change[i, j] += a1[j] * dcdz2[i] * multiplier;
            }
        }

        for (int i = 0; i < b3Change.GetLength(0); i++)
        {
            b3Change[i] += dcdz3[i] * multiplier;
        }
        
        for (int i = 0; i < b2Change.GetLength(0); i++)
        {
            b2Change[i] += dcdz2[i] * multiplier;
        }
    }

    public void applyChange()
    {
        w2 = sub2dArrays(w2, w2Change);
        w3 = sub2dArrays(w3, w3Change);
        b2 = subArrays(b2, b2Change);
        b3 = subArrays(b3, b3Change);

        reset2dArray(w2Change);
        reset2dArray(w3Change);
        resetArray(b2Change);
        resetArray(b3Change);
    }

    public void calcValuesOfNetwork()
    {
        z2 = z(a1, w2, b2);
        a2 = sigmoid(z2);
        z3 = z(a2, w3, b3);
        a3 = sigmoid(z3);
    }

    public double[] sigmoid(double[] z)
    {
        double[] output = new double[z.GetLength(0)];
        for (int i = 0; i < z.GetLength(0); i++)
        {
            output[i] = 1 / (1 + Math.Exp(-z[i]));
        }
        return output;
    }

    public double[] z(double[] a, double[,] w, double[] b)
    {
        double[] output = new double[b.GetLength(0)];
        for (int i = 0; i < w.GetLength(0); i++)
        {
            for (int j = 0; j < w.GetLength(1); j++)
            {
                output[i] += a[j] * w[i,j];
            }
        }
        for (int i = 0; i < b.GetLength(0); i++)
        {
            output[i] += b[i];
        }
        return output;

    }

    public double[] dcda3(double[] a, int[] y)
    {
        double[] output = new double[a.GetLength(0)];
        for (int i = 0; i < a.GetLength(0); i++)
        {
            output[i] = a[i] - y[i];
        }
        return output;
    }

    public double[] dadz(double[] z)
    {
        double[] output = new double[z.GetLength(0)];
        double[] sig = sigmoid(z);
        for (int i = 0; i < z.GetLength(0); i++)
        {
            output[i] = sig[i] * (1 - sig[i]);
        }
        return output;
    }

    public double[] dcda2(double[] dcdz, double[,] w)
    {
        double[] output = new double[w.GetLength(1)];
        for (int i = 0; i < w.GetLength(1); i++)
        {
            for (int j = 0; j < w.GetLength(0); j++)
            {
                output[i] += dcdz[j] * w[j, i];
            }
        }
        return output;
    }

    public double[] multiplyArrays(double[] ar1, double[] ar2)
    {
        double[] output = new double[ar1.GetLength(0)];
        for (int i = 0; i < ar1.GetLength(0); i++)
        {
            output[i] = ar1[i] * ar2[i];
        }
        return output;
    }

    public double[] subArrays(double[] ar1, double[] ar2)
    {
        double[] output = new double[ar1.GetLength(0)];
        for (int i = 0; i < ar1.GetLength(0); i++)
        {
            output[i] = ar1[i] - ar2[i];
        }
        return output;
    }

    public double[,] sub2dArrays(double[,] ar1, double[,] ar2)
    {
        double[,] output = new double[ar1.GetLength(0), ar1.GetLength(1)];
        for (int i = 0; i < ar1.GetLength(0); i++)
        {
            for (int j = 0; j < ar1.GetLength(1); j++)
            {
                output[i,j] = ar1[i,j] - ar2[i,j];
            }
        }
        return output;
    }

    public double[] resetArray(double[] ar)
    {
        double[] output = new double[ar.GetLength(0)];
        return output;
    }

    public double[,] reset2dArray(double[,] ar)
    {
        double[,] output = new double[ar.GetLength(0), ar.GetLength(1)];
        return output;
    }

    public int maximumOfArray(double[] a3)
    {
        int output = 0;
        double max = 0;
        for (int i = 0; i < a3.GetLength(0); i++)
        {
            if (a3[i] > max)
            {
                max = a3[i];
                output = i;
            }
        }
        return output;
    }

    public void randomWeights(double[,] w)
    {
        Random rnd = new Random();
        for (int i = 0; i < w.GetLength(0); i++)
        {
            for (int j = 0; j < w.GetLength(1); j++)
            {
                w[i, j] = (rnd.NextDouble() - 0.5) * 2;
            }
        }
    }

    public void randomBias(double[] b)
    {
        Random rnd = new Random();
        for (int i = 0; i < b.GetLength(0); i++)
        {
            b[i] = (rnd.NextDouble() - 0.5) * 2;
        }
    }
}

public static class MnistReader
{
    private const string TrainImages = "mnist/train-images.idx3-ubyte";
    private const string TrainLabels = "mnist/train-labels.idx1-ubyte";
    private const string TestImages = "mnist/t10k-images.idx3-ubyte";
    private const string TestLabels = "mnist/t10k-labels.idx1-ubyte";

    public static IEnumerable<Image> ReadTrainingData()
    {
        foreach (var item in Read(TrainImages, TrainLabels))
        {
            yield return item;
        }
    }

    public static IEnumerable<Image> ReadTestData()
    {
        foreach (var item in Read(TestImages, TestLabels))
        {
            yield return item;
        }
    }

    private static IEnumerable<Image> Read(string imagesPath, string labelsPath)
    {
        BinaryReader labels = new BinaryReader(new FileStream(labelsPath, FileMode.Open));
        BinaryReader images = new BinaryReader(new FileStream(imagesPath, FileMode.Open));

        int magicNumber = images.ReadBigInt32();
        int numberOfImages = images.ReadBigInt32();
        int width = images.ReadBigInt32();
        int height = images.ReadBigInt32();

        int magicLabel = labels.ReadBigInt32();
        int numberOfLabels = labels.ReadBigInt32();

        for (int i = 0; i < numberOfImages; i++)
        {
            var bytes = images.ReadBytes(width * height);
            var arr = new byte[height, width];

            arr.ForEach((j, k) => arr[j, k] = bytes[j * height + k]);

            yield return new Image()
            {
                Data = arr,
                Label = labels.ReadByte()
            };
        }
        labels.Close();
        images.Close();
    }
}

public class Image
{
    public byte Label { get; set; }
    public byte[,] Data { get; set; }
}

public static class Extensions
{
    public static int ReadBigInt32(this BinaryReader br)
    {
        var bytes = br.ReadBytes(sizeof(Int32));
        if (BitConverter.IsLittleEndian) Array.Reverse(bytes);
        return BitConverter.ToInt32(bytes, 0);
    }

    public static void ForEach<T>(this T[,] source, Action<int, int> action)
    {
        for (int w = 0; w < source.GetLength(0); w++)
        {
            for (int h = 0; h < source.GetLength(1); h++)
            {
                action(w, h);
            }
        }
    }
}

我添加了四个文件,即目录“程序名称”“程序名称”\bin\Debug\net6.0\mnist中的MNIST数据库文件

我尝试过的主要事情是打印不同的值并寻找这些模式不应该存在的模式。

我尝试打印输出神经元值并与它应该是什么进行比较,看起来它就像在猜测每次启动程序时“选择”的 2 个特定数字。

为了说明这一点,我提供了一些通过网络运行的图像示例,其中输出神经元以及它的猜测和实际数量

输出神经元值 0: 0,8439129508768974 1: 0,9122277563390606 2: 0,017403770398116702 3: 0,011134118269524533 4: 0,13462595373016392 5: 0,036897782500381525 6: 0,04750856319601651 7: 0,15035649081365893 8: 0,01481989084000439 9: 0,15321481202658915 猜:1 答案:6

输出神经元值 0: 0,838089716270575 1: 0,6086701853058746 2: 0,02905089940189534 3: 0,16903075237731124 4: 0,1404228501842471 5: 0,01814462956762896 6: 0,13817735383695717 7: 0,043509857883374116 8: 0,03455730693234839 9: 0,02082949989901883 猜:0 答案:5

输出神经元值 0: 0,9366729671922033 1: 0,7781000823699303 2: 0,022226345496076768 3: 0,007671436030838936 4: 0,0632049260642895 5: 0,004989765778696671 6: 0,03230139825555404 7: 0,15464187502773366 8: 0,018234953784533374 9: 0,007956539392870072 猜:0 答案:4

这是来自不同尝试的一些猜测,它只猜到了 3 个 猜:3 答案:7 猜:3 答案:8 猜:3 答案:9 猜:3 答案:0 猜:3 答案:1 猜:3 答案:2 猜:3 答案:3 猜:3 答案:4 猜:3 答案:5 猜:3 答案:6

C# 神经网络 MNIST

评论

1赞 AKX 11/17/2023
对不起,我不认为人们会挖掘 400 行左右的代码来找出可能存在错误的地方。
0赞 BareVer 11/17/2023
我已经使用测试数据解决了训练问题。150 行用于创建可变变量并使用我已确认它正在工作的图像创建数组。因此,类 MnistReader 下的所有内容都不需要通读
0赞 Community 11/17/2023
请修剪您的代码,以便更轻松地找到您的问题。请遵循这些准则,以创建最小的可重现示例
0赞 AKX 11/17/2023
@BareVer 好吧,让我换个说法:我不认为人们会挖掘 320 行左右的代码来找出可能存在错误的地方。
0赞 BareVer 11/20/2023
好吧,我明白了。对不起,我似乎很不顾一切地送了这么多东西。我的朋友在看到我为我的问题发疯后说服我寄出它。我将尝试表示同情并遵循指南,因为我刚刚匆忙完成了程序。谢谢建议。

答: 暂无答案