Skip to content

Commit

Permalink
save last training hyper param.
Browse files Browse the repository at this point in the history
  • Loading branch information
JulioJerez committed Sep 20, 2024
1 parent dcfe189 commit b097892
Show file tree
Hide file tree
Showing 2 changed files with 33 additions and 21 deletions.
Original file line number Diff line number Diff line change
Expand Up @@ -560,23 +560,35 @@ namespace ndAdvancedRobot
ndFloat32 rewardWeigh = 1.0f / 6.0f;
ndFloat32 azimuthReward = ScalarReward(positError2.m_z);

ndFloat32 reward = rewardWeigh * azimuthReward;
if (azimuthReward > 0.5f)
{
const ndVector rotationError(CalculateDeltaTargetRotation(currentEffectorMatrix));
const ndVector rotationError2 = rotationError * rotationError;

ndFloat32 omega_xReward = rewardWeigh * GaussianReward(rotationError2.m_x);
ndFloat32 omega_yReward = rewardWeigh * GaussianReward(rotationError2.m_y);
ndFloat32 omega_zReward = rewardWeigh * GaussianReward(rotationError2.m_z);
reward += (omega_xReward + omega_yReward + omega_zReward);
if ((omega_xReward > 1.0e-3f) || (omega_yReward > 1.0e-3f) || (omega_zReward > 1.0e-3f))
{
ndFloat32 posit_xReward = rewardWeigh * GaussianReward(positError2.m_x);
ndFloat32 posit_yReward = rewardWeigh * GaussianReward(positError2.m_y);
reward += (posit_xReward + posit_yReward);
}
}
//ndFloat32 reward = rewardWeigh * azimuthReward;
//f (azimuthReward > 0.5f)
//
// const ndVector rotationError(CalculateDeltaTargetRotation(currentEffectorMatrix));
// const ndVector rotationError2 = rotationError * rotationError;
//
// ndFloat32 omega_xReward = rewardWeigh * GaussianReward(rotationError2.m_x);
// ndFloat32 omega_yReward = rewardWeigh * GaussianReward(rotationError2.m_y);
// ndFloat32 omega_zReward = rewardWeigh * GaussianReward(rotationError2.m_z);
// reward += (omega_xReward + omega_yReward + omega_zReward);
// if ((omega_xReward > 1.0e-3f) || (omega_yReward > 1.0e-3f) || (omega_zReward > 1.0e-3f))
// {
// ndFloat32 posit_xReward = rewardWeigh * GaussianReward(positError2.m_x);
// ndFloat32 posit_yReward = rewardWeigh * GaussianReward(positError2.m_y);
// reward += (posit_xReward + posit_yReward);
// }

ndFloat32 posit_xReward = GaussianReward(positError2.m_x);
ndFloat32 posit_yReward = GaussianReward(positError2.m_y);

const ndVector rotationError(CalculateDeltaTargetRotation(currentEffectorMatrix));
const ndVector rotationError2 = rotationError * rotationError;

ndFloat32 omega_xReward = GaussianReward(rotationError2.m_x);
ndFloat32 omega_yReward = GaussianReward(rotationError2.m_y);
ndFloat32 omega_zReward = GaussianReward(rotationError2.m_z);

ndFloat32 reward = (azimuthReward + posit_xReward + posit_yReward +
omega_xReward + omega_yReward + omega_zReward) * rewardWeigh;
return reward;
}

Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -50,10 +50,10 @@ ndBrainAgentContinuePolicyGradient_TrainerMaster::HyperParameters::HyperParamete

//m_criticLearnRate = ndBrainFloat(0.0004f);
//m_policyLearnRate = ndBrainFloat(0.0002f);
//m_criticLearnRate = ndBrainFloat(0.0002f);
//m_policyLearnRate = ndBrainFloat(0.0001f);
m_criticLearnRate = ndBrainFloat(0.0005f);
m_policyLearnRate = ndBrainFloat(0.001f);
m_criticLearnRate = ndBrainFloat(0.0001f);
m_policyLearnRate = ndBrainFloat(0.0002f);
//m_criticLearnRate = ndBrainFloat(0.0005f);
//m_policyLearnRate = ndBrainFloat(0.001f);

m_regularizer = ndBrainFloat(1.0e-6f);
m_discountFactor = ndBrainFloat(0.99f);
Expand Down

0 comments on commit b097892

Please sign in to comment.