Skip to content

Commit

Permalink
tweak cifar
Browse files Browse the repository at this point in the history
  • Loading branch information
JulioJerez committed Nov 19, 2023
1 parent 71ceddb commit 229e8c0
Showing 1 changed file with 19 additions and 17 deletions.
Original file line number Diff line number Diff line change
Expand Up @@ -14,17 +14,18 @@

static void LoadTrainingData(ndSharedPtr<ndBrainMatrix>& trainingImages, ndSharedPtr<ndBrainMatrix>& trainingLabels)
{
ndInt32 batches = 5;
const ndInt32 batches = 5;
const ndInt32 pixelSize = 32 * 32;
char filename[1024];
char outPathName[1024];
ndUnsigned8 data[pixelSize * 3];

trainingLabels = new ndBrainMatrix(ndInt32(batches * 10000), ndInt32(10));
trainingImages = new ndBrainMatrix(ndInt32(batches * 10000), ndInt32(32 * 32 * 3));
trainingImages = new ndBrainMatrix(ndInt32(batches * 10000), ndInt32(pixelSize * 3));

ndBrainMatrix& labelMatrix = *(*trainingLabels);
ndBrainMatrix& imageMatrix = *(*trainingImages);

char filename[1024];
char outPathName[1024];
ndUnsigned8 data[32 * 32 * 3];

ndInt32 base = 0;
labelMatrix.Set(ndBrainFloat(0.0f));
for (ndInt32 i = 0; i < batches; ++i)
Expand All @@ -42,15 +43,15 @@ static void LoadTrainingData(ndSharedPtr<ndBrainMatrix>& trainingImages, ndShare

labelMatrix[j + base][label] = ndBrainFloat(1.0f);
ndBrainVector& image = imageMatrix[j + base];
ret = fread(data, 1, 32 * 32 * 3, fp);
for (ndInt32 k = 0; k < 32 * 32 * 3; ++k)
ret = fread(data, 1, pixelSize * 3, fp);
for (ndInt32 k = 0; k < pixelSize * 3; ++k)
{
image[k] = ndBrainFloat(data[k]) / ndBrainFloat(255.0f);
}

for (ndInt32 k = 0; k < 3; ++k)
{
ndBrainMemVector imageChannel (&image[k * 32 * 32], 32 * 32);
ndBrainMemVector imageChannel (&image[k * pixelSize], pixelSize);
imageChannel.GaussianNormalize();
}
}
Expand All @@ -62,15 +63,16 @@ static void LoadTrainingData(ndSharedPtr<ndBrainMatrix>& trainingImages, ndShare

static void LoadTestData(ndSharedPtr<ndBrainMatrix>& images, ndSharedPtr<ndBrainMatrix>& labels)
{
const ndInt32 pixelSize = 32 * 32;
char outPathName[1024];
ndUnsigned8 data[pixelSize * 3];

labels = new ndBrainMatrix(ndInt32(10000), ndInt32(10));
images = new ndBrainMatrix(ndInt32(10000), ndInt32(32 * 32 * 3));
images = new ndBrainMatrix(ndInt32(10000), ndInt32(pixelSize * 3));

ndBrainMatrix& labelMatrix = *(*labels);
ndBrainMatrix& imageMatrix = *(*images);

char outPathName[1024];
ndUnsigned8 data[32 * 32 * 3];

labelMatrix.Set(ndBrainFloat(0.0f));
ndGetWorkingFileName("cifar-10-batches-bin/test_batch.bin", outPathName);
FILE* const fp = fopen(outPathName, "rb");
Expand All @@ -84,15 +86,15 @@ static void LoadTestData(ndSharedPtr<ndBrainMatrix>& images, ndSharedPtr<ndBrain

labelMatrix[j][label] = ndBrainFloat(1.0f);
ndBrainVector& image = imageMatrix[ndInt32(j)];
ret = fread(data, 1, 32 * 32 * 3, fp);
for (ndInt32 k = 0; k < 32 * 32 * 3; ++k)
ret = fread(data, 1, pixelSize * 3, fp);
for (ndInt32 k = 0; k < pixelSize * 3; ++k)
{
image[k] = ndBrainFloat(data[k]) / ndBrainFloat(255.0f);
}

for (ndInt32 k = 0; k < 3; ++k)
{
ndBrainMemVector imageChannel(&image[k * 32 * 32], 32 * 32);
ndBrainMemVector imageChannel(&image[k * pixelSize], pixelSize);
imageChannel.GaussianNormalize();
}
}
Expand Down Expand Up @@ -267,7 +269,7 @@ static void Cifar10TrainingSet()
priorityList.PushBack(ndRandInt() % trainingLabels->GetCount());
}

for (ndInt32 epoch = 0; epoch < 500; ++epoch)
for (ndInt32 epoch = 0; epoch < 1000; ++epoch)
{
ndInt32 start = 0;
ndMemSet(failCount, ndUnsigned32(0), D_MAX_THREADS_COUNT);
Expand Down

0 comments on commit 229e8c0

Please sign in to comment.