diff --git a/test/reactant/loss_tests.jl b/test/reactant/loss_tests.jl index ad5d3da74..cf2bce1fb 100644 --- a/test/reactant/loss_tests.jl +++ b/test/reactant/loss_tests.jl @@ -72,10 +72,12 @@ @test celoss_ls(ŷ, y) ≈ @jit(celoss_ls(ŷ_ra, y_ra)) celoss_lp = CrossEntropyLoss(; logits=Val(true)) - @test celoss_lp(log.(ŷ), y) ≈ @jit(celoss_lp(log.(ŷ_ra), y_ra)) + logit_celoss_lp = (ŷ, y) -> celoss_lp(log.(ŷ), y) + @test logit_celoss_lp(ŷ, y) ≈ @jit(logit_celoss_lp(ŷ_ra, y_ra)) celoss_lp_ls = CrossEntropyLoss(; logits=Val(true), label_smoothing=0.1) - @test celoss_lp_ls(log.(ŷ), y) ≈ @jit(celoss_lp_ls(log.(ŷ_ra), y_ra)) + logit_celoss_lp_ls = (ŷ, y) -> celoss_lp_ls(log.(ŷ), y) + @test logit_celoss_lp_ls(ŷ, y) ≈ @jit(logit_celoss_lp_ls(ŷ_ra, y_ra)) end @testset "Binary CrossEntropyLoss" begin @@ -86,11 +88,13 @@ @test bceloss_ls(ŷ, y) ≈ @jit(bceloss_ls(ŷ_ra, y_ra)) bceloss_lp = BinaryCrossEntropyLoss(; logits=Val(true)) - @test bceloss_lp(log.(ŷ), y) ≈ @jit(bceloss_lp(log.(ŷ_ra), y_ra)) + logit_bceloss_lp = (ŷ, y) -> bceloss_lp(log.(ŷ), y) + @test logit_bceloss_lp(ŷ, y) ≈ @jit(logit_bceloss_lp(ŷ_ra, y_ra)) bceloss_lp_ls = BinaryCrossEntropyLoss(; logits=Val(true), label_smoothing=0.1) - @test bceloss_lp_ls(log.(ŷ), y) ≈ @jit(bceloss_lp_ls(log.(ŷ_ra), y_ra)) + logit_bceloss_lp_ls = (ŷ, y) -> bceloss_lp_ls(log.(ŷ), y) + @test logit_bceloss_lp_ls(ŷ, y) ≈ @jit(logit_bceloss_lp_ls(ŷ_ra, y_ra)) end @testset "BinaryFocalLoss" begin diff --git a/test/reactant/training_tests.jl b/test/reactant/training_tests.jl index 73926a7ee..e37500151 100644 --- a/test/reactant/training_tests.jl +++ b/test/reactant/training_tests.jl @@ -24,18 +24,22 @@ ps, st = Lux.setup(StableRNG(1234), model) |> xdev x_ra = randn(Float32, 2, 32) |> xdev + y_ra = rand(Float32, 2, 32) |> xdev - inference_fn = @compile model(x_ra, ps, Lux.testmode(st)) + inference_loss_fn = (xᵢ, yᵢ, mode, ps, st) -> begin + ŷᵢ, _ = model(xᵢ, ps, Lux.testmode(st)) + return MSELoss()(ŷᵢ, yᵢ) + end + inference_loss_fn_compiled = @compile inference_loss_fn( + x_ra, y_ra, model, ps, st + ) x = [rand(Float32, 2, 32) for _ in 1:32] y = [xᵢ .^ 2 for xᵢ in x] dataloader = DeviceIterator(xdev, zip(x, y)) - total_initial_loss = mapreduce(+, dataloader) do (xᵢ, yᵢ) - ŷᵢ, _ = inference_fn(xᵢ, ps, Lux.testmode(st)) - return MSELoss()(ŷᵢ, yᵢ) - end + total_initial_loss = mapreduce(inference_loss_fn_compiled, +, dataloader) train_state = Training.TrainState(model, ps, st, Adam(0.01f0)) @@ -51,10 +55,7 @@ end end - total_final_loss = mapreduce(+, dataloader) do (xᵢ, yᵢ) - ŷᵢ, _ = inference_fn(xᵢ, train_state.parameters, Lux.testmode(st)) - return MSELoss()(ŷᵢ, yᵢ) - end + total_final_loss = mapreduce(inference_loss_fn_compiled, +, dataloader) @test total_final_loss < 100 * total_initial_loss end