Skip to content

Commit

Permalink
Don't reset the objective estimate on the last iteration (#417)
Browse files Browse the repository at this point in the history
  • Loading branch information
rcurtin authored Dec 31, 2024
1 parent b0e3348 commit 7791091
Show file tree
Hide file tree
Showing 6 changed files with 347 additions and 97 deletions.
3 changes: 3 additions & 0 deletions HISTORY.md
Original file line number Diff line number Diff line change
@@ -1,5 +1,8 @@
### ensmallen ?.??.?: "???"
###### ????-??-??
* Fix `exactObjective` output for SGD-like optimizers when the number of
iterations is an even number of epochs
([#417](https://github.com/mlpack/ensmallen/pull/417)).

### ensmallen 2.22.1: "E-Bike Excitement"
###### 2024-12-02
Expand Down
397 changes: 316 additions & 81 deletions doc/optimizers.md

Large diffs are not rendered by default.

11 changes: 7 additions & 4 deletions include/ensmallen_bits/bigbatch_sgd/bigbatch_sgd_impl.hpp
Original file line number Diff line number Diff line change
Expand Up @@ -244,10 +244,13 @@ BigBatchSGD<UpdatePolicyType>::Optimize(
terminate |= Callback::BeginEpoch(*this, f, iterate, epoch,
overallObjective, callbacks...);

// Reset the counter variables.
lastObjective = overallObjective;
overallObjective = 0;
currentFunction = 0;
// Reset the counter variables if we will continue.
if (i != actualMaxIterations)
{
lastObjective = overallObjective;
overallObjective = 0;
currentFunction = 0;
}

if (shuffle) // Determine order of visitation.
f.Shuffle();
Expand Down
11 changes: 7 additions & 4 deletions include/ensmallen_bits/eve/eve_impl.hpp
Original file line number Diff line number Diff line change
Expand Up @@ -186,10 +186,13 @@ Eve::Optimize(SeparableFunctionType& function,
terminate |= Callback::BeginEpoch(*this, f, iterate, epoch,
overallObjective, callbacks...);

// Reset the counter variables.
lastOverallObjective = overallObjective;
overallObjective = 0;
currentFunction = 0;
// Reset the counter variables if we will continue.
if (i != actualMaxIterations)
{
lastOverallObjective = overallObjective;
overallObjective = 0;
currentFunction = 0;
}

if (shuffle) // Determine order of visitation.
f.Shuffle();
Expand Down
11 changes: 7 additions & 4 deletions include/ensmallen_bits/sgd/sgd_impl.hpp
Original file line number Diff line number Diff line change
Expand Up @@ -193,10 +193,13 @@ SGD<UpdatePolicyType, DecayPolicyType>::Optimize(
terminate |= Callback::BeginEpoch(*this, f, iterate, epoch,
overallObjective, callbacks...);

// Reset the counter variables.
lastObjective = overallObjective;
overallObjective = 0;
currentFunction = 0;
// Reset the counter variables if we will continue.
if (i != actualMaxIterations)
{
lastObjective = overallObjective;
overallObjective = 0;
currentFunction = 0;
}

if (shuffle) // Determine order of visitation.
f.Shuffle();
Expand Down
11 changes: 7 additions & 4 deletions include/ensmallen_bits/spalera_sgd/spalera_sgd_impl.hpp
Original file line number Diff line number Diff line change
Expand Up @@ -212,10 +212,13 @@ SPALeRASGD<DecayPolicyType>::Optimize(
return overallObjective;
}

// Reset the counter variables.
lastObjective = overallObjective;
overallObjective = 0;
currentFunction = 0;
// Reset the counter variables if we will continue.
if (i != actualMaxIterations)
{
lastObjective = overallObjective;
overallObjective = 0;
currentFunction = 0;
}

terminate |= Callback::BeginEpoch(*this, f, iterate, epoch,
overallObjective, callbacks...);
Expand Down

0 comments on commit 7791091

Please sign in to comment.