diff --git a/R/metrics-binary.R b/R/metrics-binary.R index 4a0abed49..104a8bcbe 100644 --- a/R/metrics-binary.R +++ b/R/metrics-binary.R @@ -83,6 +83,6 @@ brier_score <- function(observed, predicted) { logs_binary <- function(observed, predicted) { assert_input_binary(observed, predicted) observed <- as.numeric(observed) - 1 - logs <- -log(ifelse(observed == 1, predicted, 1 - predicted)) + logs <- -log(1 - abs(observed - predicted)) return(logs) } diff --git a/tests/testthat/test-metrics-binary.R b/tests/testthat/test-metrics-binary.R index 79f56c535..fae7476c5 100644 --- a/tests/testthat/test-metrics-binary.R +++ b/tests/testthat/test-metrics-binary.R @@ -164,3 +164,24 @@ test_that("Binary metrics work within and outside of `score()`", { result$log_score ) }) + +test_that("`logs_binary()` works as expected", { + # check against the function Metrics::ll + obs2 <- as.numeric(as.character(observed)) + expect_equal( + logs_binary(observed, predicted), + Metrics::ll(obs2, predicted) + ) + + # check this works for a single observed value + expect_equal( + logs_binary(observed[1], predicted), + Metrics::ll(obs2[1], predicted) + ) + + # check this works for a single predicted value + expect_equal( + logs_binary(observed, predicted[1]), + Metrics::ll(obs2, predicted[1]) + ) +})