From 78bc964e0ea2cd4ca7ed39c645e07fe2213434d8 Mon Sep 17 00:00:00 2001 From: Sathvik Bhagavan Date: Wed, 31 Jan 2024 14:20:37 +0000 Subject: [PATCH] refactor: fix min max adaptive loss --- src/adaptive_losses.jl | 6 ++++-- 1 file changed, 4 insertions(+), 2 deletions(-) diff --git a/src/adaptive_losses.jl b/src/adaptive_losses.jl index 6bfb192194..d6f84c7bbe 100644 --- a/src/adaptive_losses.jl +++ b/src/adaptive_losses.jl @@ -237,13 +237,15 @@ function generate_adaptive_loss_function(pinnrep::PINNRepresentation, adaloss::MiniMaxAdaptiveLoss, pde_loss_functions, bc_loss_functions) pde_max_optimiser = adaloss.pde_max_optimiser + pde_max_optimiser_setup = OptimizationOptimisers.Optimisers.setup(pde_max_optimiser, adaloss.pde_loss_weights) bc_max_optimiser = adaloss.bc_max_optimiser + bc_max_optimiser_setup = OptimizationOptimisers.Optimisers.setup(bc_max_optimiser, adaloss.bc_loss_weights) iteration = pinnrep.iteration function run_minimax_adaptive_loss(θ, pde_losses, bc_losses) if iteration[1] % adaloss.reweight_every == 0 - OptimizationOptimisers.Optimisers.update(pde_max_optimiser, adaloss.pde_loss_weights, -pde_losses) - OptimizationOptimisers.Optimisers.update(bc_max_optimiser, adaloss.bc_loss_weights, -bc_losses) + OptimizationOptimisers.Optimisers.update!(pde_max_optimiser_setup, adaloss.pde_loss_weights, -pde_losses) + OptimizationOptimisers.Optimisers.update!(bc_max_optimiser_setup, adaloss.bc_loss_weights, -bc_losses) logvector(pinnrep.logger, adaloss.pde_loss_weights, "adaptive_loss/pde_loss_weights", iteration[1]) logvector(pinnrep.logger, adaloss.bc_loss_weights,