Skip to content

Commit

Permalink
more cnn test (wip)
Browse files Browse the repository at this point in the history
  • Loading branch information
JulioJerez committed Nov 18, 2023
1 parent dd61bf0 commit 6395a6b
Show file tree
Hide file tree
Showing 2 changed files with 13 additions and 11 deletions.
Original file line number Diff line number Diff line change
Expand Up @@ -251,10 +251,10 @@ static void Cifar10TrainingSet()
ndInt32 batches = trainingLabels->GetCount() / m_bashBufferSize;

// so far best training result on the cifar-10 data set
optimizer.SetRegularizer(ndBrainFloat(0.0e-5f)); // test data score
//optimizer.SetRegularizer(ndBrainFloat(1.0e-5f)); // test data score
//optimizer.SetRegularizer(ndBrainFloat(2.0e-5f)); // test data score
//optimizer.SetRegularizer(ndBrainFloat(3.0e-5f)); // test data score
optimizer.SetRegularizer(ndBrainFloat(0.0e-5f)); // test data score (%)
//optimizer.SetRegularizer(ndBrainFloat(1.0e-5f)); // test data score (%)
//optimizer.SetRegularizer(ndBrainFloat(2.0e-5f)); // test data score (%)
//optimizer.SetRegularizer(ndBrainFloat(3.0e-5f)); // test data score (%)

//batches = 1;
ndArray<ndUnsigned32> shuffleBuffer;
Expand All @@ -269,7 +269,7 @@ static void Cifar10TrainingSet()
priorityList.PushBack(ndRandInt() % trainingLabels->GetCount());
}

for (ndInt32 epoch = 0; epoch < 100; ++epoch)
for (ndInt32 epoch = 0; epoch < 500; ++epoch)
{
ndInt32 start = 0;
ndMemSet(failCount, ndUnsigned32(0), D_MAX_THREADS_COUNT);
Expand All @@ -293,9 +293,10 @@ static void Cifar10TrainingSet()

if (fails <= minTrainingFail)
{
const ndInt32 minTestCheck = 500;
ndInt32 actualTraining = fails;
bool traningTest = fails < minTrainingFail;
minTrainingFail = ndMax(fails, ndInt32(200));
minTrainingFail = ndMax(fails, minTestCheck);

auto CrossValidateTest = ndMakeObject::ndFunction([this, testDigits, testLabels, &failCount](ndInt32 threadIndex, ndInt32 threadCount)
{
Expand Down Expand Up @@ -337,7 +338,7 @@ static void Cifar10TrainingSet()
fails += failCount[j];
}

if (traningTest && (minTrainingFail > 200))
if (traningTest && (minTrainingFail > minTestCheck))
{
minTestFail = fails;
bestBrain.CopyFrom(m_brain);
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -255,7 +255,7 @@ static void MnistTrainingSet()
//batches = 1;

// so far best training result on the mnist data set
//optimizer.SetRegularizer(ndBrainFloat(0.0e-5f)); // test data score fully(98.070%) conv(99.170%)
//optimizer.SetRegularizer(ndBrainFloat(0.0e-5f)); // test data score fully(98.070%) conv(99.380%)
optimizer.SetRegularizer(ndBrainFloat(1.0e-5f)); // test data score fully(98.200%) conv(99.240%)
//optimizer.SetRegularizer(ndBrainFloat(2.0e-5f)); // test data score fully(97.980%) conv(99.210%)
//optimizer.SetRegularizer(ndBrainFloat(3.0e-5f)); // test data score fully(%) conv(%)
Expand All @@ -274,7 +274,7 @@ static void MnistTrainingSet()
}

//for (ndInt32 epoch = 0; epoch < 1000; ++epoch)
for (ndInt32 epoch = 0; epoch < 100; ++epoch)
for (ndInt32 epoch = 0; epoch < 500; ++epoch)
{
ndInt32 start = 0;
ndMemSet(failCount, ndUnsigned32(0), D_MAX_THREADS_COUNT);
Expand Down Expand Up @@ -304,9 +304,10 @@ static void MnistTrainingSet()

if (fails <= minTrainingFail)
{
const ndInt32 minTestCheck = 500;
ndInt32 actualTraining = fails;
bool traningTest = fails < minTrainingFail;
minTrainingFail = ndMax(fails, ndInt32(500));
minTrainingFail = ndMax(fails, minTestCheck);

auto CrossValidateTest = ndMakeObject::ndFunction([this, testDigits, testLabels, &failCount](ndInt32 threadIndex, ndInt32 threadCount)
{
Expand Down Expand Up @@ -348,7 +349,7 @@ static void MnistTrainingSet()
fails += failCount[j];
}

if (traningTest && (minTrainingFail > 500))
if (traningTest && (minTrainingFail > minTestCheck))
{
minTestFail = fails;
bestBrain.CopyFrom(m_brain);
Expand Down

0 comments on commit 6395a6b

Please sign in to comment.