提问人:BareVer 提问时间:11/17/2023 更新时间:11/17/2023 访问量:17
当我在 MNIST 数据库上运行我的神经网络时,它无法正确训练并猜测 2 个数字
When i run my neural network on the MNIST database, it won't train properly and guesses 2 numbers
问:
我正在制作一个神经网络来识别学校项目的MNIST数据库中的数字。它尽可能直观,所以我不使用任何库,而且我的代码很长。
我的代码摘要首先,它使用来自训练集的 60000 张图像,并将它们放置在亮度值为 0 到 1 的数组中。然后,它使用了我从迈克尔·尼尔森(Michael Nielsen)的神经网络一书的第1章og 2中得到的反向传播方法。然后,它使用测试集中剩余的 10000 张图像测试神经网络。
然后,主要输出是网络从 10000 张图像中获得了多少张图像。问题是我当然希望它尽可能高。我使用不同的学习率、纪元值和小批量大小获得了 500 到 1700 之间的值,但不能再高了。
下面是在 Visual Studio 中以 c# 编写的代码
Network net = new Network(784, 30, 10);
net.train(1, 10, 3.00);
net.test();
class Network
{
public int inputLength;
public int hiddenLength;
public int outputLength;
public double[] a1;
public double[] a2;
public double[] a3;
public double[] b2;
public double[] b3;
public double[,] w2;
public double[,] w3;
public double[] z2;
public double[] z3;
public double[] b2Change;
public double[] b3Change;
public double[,] w2Change;
public double[,] w3Change;
public double[] dcdz2;
public double[] dcdz3;
public int[] y;
public Network(int input, int hidden, int output)
{
inputLength = input;
hiddenLength = hidden;
outputLength = output;
a1 = new double[input];
a2 = new double[hidden];
a3 = new double[output];
b2 = new double[hidden];
b3 = new double[output];
w2 = new double[hidden,input];
w3 = new double[output,hidden];
z2 = new double[hidden];
z3 = new double[output];
b2Change = new double[hidden];
b3Change = new double[output];
w2Change = new double[hidden, input];
w3Change = new double[output, hidden];
dcdz2 = new double[hidden];
dcdz3 = new double[output];
y = new int[output];
}
public void train(int epoch, int miniBatch, double learningRate)
{
randomWeights(w2);
randomWeights(w3);
randomBias(b2);
randomBias(b3);
IEnumerable<Image> TrainingData = MnistReader.ReadTestData();
int count = 0;
double multiplier = learningRate / miniBatch;
for (int repeat = 0; repeat < epoch; repeat++)
{
foreach (Image image in TrainingData)
{
int counter = 0;
for (int i = 0; i < image.Data.GetLength(1); i++)
{
for (int j = 0; j < image.Data.GetLength(0); j++)
{
a1[counter] = image.Data[j, i] / 255.00;
counter++;
}
}
y[image.Label] = 1;
backpropagation(multiplier);
y[image.Label] = 0;
count++;
if (count == 10)
{
applyChange();
}
}
Console.WriteLine(repeat);
}
}
public void test()
{
IEnumerable<Image> TestData = MnistReader.ReadTestData();
int count = 0;
foreach (Image image in TestData)
{
int counter = 0;
for (int i = 0; i < image.Data.GetLength(1); i++)
{
for (int j = 0; j < image.Data.GetLength(0); j++)
{
a1[counter] = image.Data[j, i] / 255.00;
counter++;
}
}
calcValuesOfNetwork();
if (maximumOfArray(a3) == image.Label)
{
count++;
}
}
Console.WriteLine($"{count} / 10000");
}
public void backpropagation(double multiplier)
{
calcValuesOfNetwork();
dcdz3 = multiplyArrays(dcda3(a3, y), dadz(z3));
dcdz2 = multiplyArrays(dcda2(dcdz3, w3), dadz(z2));
for (int i = 0; i < w3Change.GetLength(0); i++)
{
for (int j = 0; j < w3Change.GetLength(1); j++)
{
w3Change[i, j] += a2[j] * dcdz3[i] * multiplier;
}
}
for (int i = 0; i < w2Change.GetLength(0); i++)
{
for (int j = 0; j < w2Change.GetLength(1); j++)
{
w2Change[i, j] += a1[j] * dcdz2[i] * multiplier;
}
}
for (int i = 0; i < b3Change.GetLength(0); i++)
{
b3Change[i] += dcdz3[i] * multiplier;
}
for (int i = 0; i < b2Change.GetLength(0); i++)
{
b2Change[i] += dcdz2[i] * multiplier;
}
}
public void applyChange()
{
w2 = sub2dArrays(w2, w2Change);
w3 = sub2dArrays(w3, w3Change);
b2 = subArrays(b2, b2Change);
b3 = subArrays(b3, b3Change);
reset2dArray(w2Change);
reset2dArray(w3Change);
resetArray(b2Change);
resetArray(b3Change);
}
public void calcValuesOfNetwork()
{
z2 = z(a1, w2, b2);
a2 = sigmoid(z2);
z3 = z(a2, w3, b3);
a3 = sigmoid(z3);
}
public double[] sigmoid(double[] z)
{
double[] output = new double[z.GetLength(0)];
for (int i = 0; i < z.GetLength(0); i++)
{
output[i] = 1 / (1 + Math.Exp(-z[i]));
}
return output;
}
public double[] z(double[] a, double[,] w, double[] b)
{
double[] output = new double[b.GetLength(0)];
for (int i = 0; i < w.GetLength(0); i++)
{
for (int j = 0; j < w.GetLength(1); j++)
{
output[i] += a[j] * w[i,j];
}
}
for (int i = 0; i < b.GetLength(0); i++)
{
output[i] += b[i];
}
return output;
}
public double[] dcda3(double[] a, int[] y)
{
double[] output = new double[a.GetLength(0)];
for (int i = 0; i < a.GetLength(0); i++)
{
output[i] = a[i] - y[i];
}
return output;
}
public double[] dadz(double[] z)
{
double[] output = new double[z.GetLength(0)];
double[] sig = sigmoid(z);
for (int i = 0; i < z.GetLength(0); i++)
{
output[i] = sig[i] * (1 - sig[i]);
}
return output;
}
public double[] dcda2(double[] dcdz, double[,] w)
{
double[] output = new double[w.GetLength(1)];
for (int i = 0; i < w.GetLength(1); i++)
{
for (int j = 0; j < w.GetLength(0); j++)
{
output[i] += dcdz[j] * w[j, i];
}
}
return output;
}
public double[] multiplyArrays(double[] ar1, double[] ar2)
{
double[] output = new double[ar1.GetLength(0)];
for (int i = 0; i < ar1.GetLength(0); i++)
{
output[i] = ar1[i] * ar2[i];
}
return output;
}
public double[] subArrays(double[] ar1, double[] ar2)
{
double[] output = new double[ar1.GetLength(0)];
for (int i = 0; i < ar1.GetLength(0); i++)
{
output[i] = ar1[i] - ar2[i];
}
return output;
}
public double[,] sub2dArrays(double[,] ar1, double[,] ar2)
{
double[,] output = new double[ar1.GetLength(0), ar1.GetLength(1)];
for (int i = 0; i < ar1.GetLength(0); i++)
{
for (int j = 0; j < ar1.GetLength(1); j++)
{
output[i,j] = ar1[i,j] - ar2[i,j];
}
}
return output;
}
public double[] resetArray(double[] ar)
{
double[] output = new double[ar.GetLength(0)];
return output;
}
public double[,] reset2dArray(double[,] ar)
{
double[,] output = new double[ar.GetLength(0), ar.GetLength(1)];
return output;
}
public int maximumOfArray(double[] a3)
{
int output = 0;
double max = 0;
for (int i = 0; i < a3.GetLength(0); i++)
{
if (a3[i] > max)
{
max = a3[i];
output = i;
}
}
return output;
}
public void randomWeights(double[,] w)
{
Random rnd = new Random();
for (int i = 0; i < w.GetLength(0); i++)
{
for (int j = 0; j < w.GetLength(1); j++)
{
w[i, j] = (rnd.NextDouble() - 0.5) * 2;
}
}
}
public void randomBias(double[] b)
{
Random rnd = new Random();
for (int i = 0; i < b.GetLength(0); i++)
{
b[i] = (rnd.NextDouble() - 0.5) * 2;
}
}
}
public static class MnistReader
{
private const string TrainImages = "mnist/train-images.idx3-ubyte";
private const string TrainLabels = "mnist/train-labels.idx1-ubyte";
private const string TestImages = "mnist/t10k-images.idx3-ubyte";
private const string TestLabels = "mnist/t10k-labels.idx1-ubyte";
public static IEnumerable<Image> ReadTrainingData()
{
foreach (var item in Read(TrainImages, TrainLabels))
{
yield return item;
}
}
public static IEnumerable<Image> ReadTestData()
{
foreach (var item in Read(TestImages, TestLabels))
{
yield return item;
}
}
private static IEnumerable<Image> Read(string imagesPath, string labelsPath)
{
BinaryReader labels = new BinaryReader(new FileStream(labelsPath, FileMode.Open));
BinaryReader images = new BinaryReader(new FileStream(imagesPath, FileMode.Open));
int magicNumber = images.ReadBigInt32();
int numberOfImages = images.ReadBigInt32();
int width = images.ReadBigInt32();
int height = images.ReadBigInt32();
int magicLabel = labels.ReadBigInt32();
int numberOfLabels = labels.ReadBigInt32();
for (int i = 0; i < numberOfImages; i++)
{
var bytes = images.ReadBytes(width * height);
var arr = new byte[height, width];
arr.ForEach((j, k) => arr[j, k] = bytes[j * height + k]);
yield return new Image()
{
Data = arr,
Label = labels.ReadByte()
};
}
labels.Close();
images.Close();
}
}
public class Image
{
public byte Label { get; set; }
public byte[,] Data { get; set; }
}
public static class Extensions
{
public static int ReadBigInt32(this BinaryReader br)
{
var bytes = br.ReadBytes(sizeof(Int32));
if (BitConverter.IsLittleEndian) Array.Reverse(bytes);
return BitConverter.ToInt32(bytes, 0);
}
public static void ForEach<T>(this T[,] source, Action<int, int> action)
{
for (int w = 0; w < source.GetLength(0); w++)
{
for (int h = 0; h < source.GetLength(1); h++)
{
action(w, h);
}
}
}
}
我添加了四个文件,即目录“程序名称”“程序名称”\bin\Debug\net6.0\mnist中的MNIST数据库文件
我尝试过的主要事情是打印不同的值并寻找这些模式不应该存在的模式。
我尝试打印输出神经元值并与它应该是什么进行比较,看起来它就像在猜测每次启动程序时“选择”的 2 个特定数字。
为了说明这一点,我提供了一些通过网络运行的图像示例,其中输出神经元以及它的猜测和实际数量
输出神经元值 0: 0,8439129508768974 1: 0,9122277563390606 2: 0,017403770398116702 3: 0,011134118269524533 4: 0,13462595373016392 5: 0,036897782500381525 6: 0,04750856319601651 7: 0,15035649081365893 8: 0,01481989084000439 9: 0,15321481202658915 猜:1 答案:6
输出神经元值 0: 0,838089716270575 1: 0,6086701853058746 2: 0,02905089940189534 3: 0,16903075237731124 4: 0,1404228501842471 5: 0,01814462956762896 6: 0,13817735383695717 7: 0,043509857883374116 8: 0,03455730693234839 9: 0,02082949989901883 猜:0 答案:5
输出神经元值 0: 0,9366729671922033 1: 0,7781000823699303 2: 0,022226345496076768 3: 0,007671436030838936 4: 0,0632049260642895 5: 0,004989765778696671 6: 0,03230139825555404 7: 0,15464187502773366 8: 0,018234953784533374 9: 0,007956539392870072 猜:0 答案:4
这是来自不同尝试的一些猜测,它只猜到了 3 个 猜:3 答案:7 猜:3 答案:8 猜:3 答案:9 猜:3 答案:0 猜:3 答案:1 猜:3 答案:2 猜:3 答案:3 猜:3 答案:4 猜:3 答案:5 猜:3 答案:6
答: 暂无答案
评论