diff --git a/src/main/docker/Dockerfile.jvm b/src/main/docker/Dockerfile.jvm index 3a59a97..0c064d7 100644 --- a/src/main/docker/Dockerfile.jvm +++ b/src/main/docker/Dockerfile.jvm @@ -1,4 +1,4 @@ -FROM registry.access.redhat.com/ubi8/openjdk-17:1.18 +FROM registry.access.redhat.com/ubi8/openjdk-17:latest ENV LANGUAGE='en_US:en' @@ -15,4 +15,3 @@ ENV JAVA_OPTS_APPEND="-Dquarkus.http.host=0.0.0.0 -Djava.util.logging.manager=or ENV JAVA_APP_JAR="/deployments/quarkus-run.jar" ENTRYPOINT [ "/opt/jboss/container/java/run/run-java.sh" ] - diff --git a/src/main/java/org/kie/trustyai/ExplainerType.java b/src/main/java/org/kie/trustyai/ExplainerType.java index 00a555d..4bf59b7 100644 --- a/src/main/java/org/kie/trustyai/ExplainerType.java +++ b/src/main/java/org/kie/trustyai/ExplainerType.java @@ -2,5 +2,6 @@ public enum ExplainerType { LIME, - SHAP + SHAP, + ALL } diff --git a/src/main/java/org/kie/trustyai/ExplainerV1Endpoint.java b/src/main/java/org/kie/trustyai/ExplainerV1Endpoint.java index d2a73e3..6a7713f 100644 --- a/src/main/java/org/kie/trustyai/ExplainerV1Endpoint.java +++ b/src/main/java/org/kie/trustyai/ExplainerV1Endpoint.java @@ -1,15 +1,12 @@ package org.kie.trustyai; -import java.util.Collection; import java.util.List; -import java.util.Map; import java.util.Objects; import java.util.concurrent.CompletableFuture; import java.util.concurrent.ExecutionException; import com.fasterxml.jackson.databind.ObjectMapper; import io.quarkus.logging.Log; -import io.vertx.core.json.JsonObject; import jakarta.enterprise.inject.Default; import jakarta.inject.Inject; import jakarta.ws.rs.Consumes; @@ -22,8 +19,6 @@ import org.apache.commons.math3.linear.RealVector; import org.kie.trustyai.connectors.kserve.v1.KServeV1HTTPPredictionProvider; import org.kie.trustyai.connectors.kserve.v1.KServeV1RequestPayload; -import org.kie.trustyai.explainability.local.LocalExplainer; -import org.kie.trustyai.explainability.model.Feature; import org.kie.trustyai.explainability.model.Prediction; import org.kie.trustyai.explainability.model.PredictionInput; import org.kie.trustyai.explainability.model.PredictionOutput; @@ -31,14 +26,10 @@ import org.kie.trustyai.explainability.model.SaliencyResults; import org.kie.trustyai.explainability.model.SimplePrediction; import org.kie.trustyai.payloads.SaliencyExplanationResponse; -import org.kie.trustyai.payloads.SaliencyExplanationResponse.FeatureSaliency; @Path("/v1/models/{modelName}:explain") public class ExplainerV1Endpoint { - @Inject - ObjectMapper objectMapper; - @Inject @Default ConfigService configService; @@ -67,8 +58,13 @@ public Response explain(@PathParam("modelName") String modelName, KServeV1Reques final PredictionOutput output = provider.predictAsync(input).get().get(0); final Prediction prediction = new SimplePrediction(input.get(0), output); final int dimensions = input.get(0).getFeatures().size(); + final ExplainerType explainerType = configService.getExplainerType(); + + CompletableFuture lime = null; + CompletableFuture shap = null; - if (configService.getExplainerType() == ExplainerType.SHAP || configService.getExplainerType() == null) { + + if (explainerType == ExplainerType.SHAP || explainerType == ExplainerType.ALL) { if (Objects.isNull(streamingGeneratorManager.getStreamingGenerator())) { Log.info("Initializing SHAP's Streaming Background Generator with dimension " + dimensions); streamingGeneratorManager.initialize(dimensions); @@ -79,22 +75,29 @@ public Response explain(@PathParam("modelName") String modelName, KServeV1Reques } final RealVector vectorData = new ArrayRealVector(numericData); streamingGeneratorManager.getStreamingGenerator().update(vectorData); + shap = explainerFactory.getExplainer(ExplainerType.SHAP) + .explainAsync(prediction, provider); } + if (explainerType == ExplainerType.LIME || explainerType == ExplainerType.ALL) { + Log.info("Sending explaining request to " + predictorURI); + lime = explainerFactory.getExplainer(ExplainerType.LIME) + .explainAsync(prediction, provider); + } + + try { Log.info("Sending explaining request to " + predictorURI); - CompletableFuture lime = explainerFactory.getExplainer(ExplainerType.LIME) - .explainAsync(prediction, provider); - CompletableFuture shap = explainerFactory.getExplainer(ExplainerType.SHAP) - .explainAsync(prediction, provider); - CompletableFuture response = lime.thenCombine(shap, - (limeExplanation, shapExplanation) -> SaliencyExplanationResponse - .fromSaliencyResults(limeExplanation, shapExplanation)); - try { + if (explainerType == ExplainerType.ALL) { + CompletableFuture response = lime.thenCombine(shap, + SaliencyExplanationResponse::fromSaliencyResults); return Response.ok(response.get(), MediaType.APPLICATION_JSON).build(); - } catch (Exception e) { - return Response.serverError().entity("Error serializing SaliencyResults to JSON: " + e.getMessage()) - .build(); + } else if (explainerType == ExplainerType.SHAP) { + return Response.ok(shap.get(), MediaType.APPLICATION_JSON).build(); + } else if (explainerType == ExplainerType.LIME) { + return Response.ok(lime.get(), MediaType.APPLICATION_JSON).build(); + } else { + return Response.serverError().entity("Unsupported explainer type").build(); } } catch (IllegalArgumentException e) { return Response.serverError().entity("Error: " + e.getMessage()).build(); diff --git a/src/main/java/org/kie/trustyai/StreamingGeneratorManager.java b/src/main/java/org/kie/trustyai/StreamingGeneratorManager.java index 02f429f..be94f9a 100644 --- a/src/main/java/org/kie/trustyai/StreamingGeneratorManager.java +++ b/src/main/java/org/kie/trustyai/StreamingGeneratorManager.java @@ -16,7 +16,8 @@ public class StreamingGeneratorManager { private StreamingGenerator streamingGenerator = null; public synchronized void initialize(int dimensions) { - if (streamingGenerator == null && configService.getExplainerType() == ExplainerType.SHAP) { + final ExplainerType explainerType = configService.getExplainerType(); + if (streamingGenerator == null && (explainerType == ExplainerType.SHAP || explainerType == ExplainerType.ALL)) { final MultivariateOnlineEstimator estimator = new WelfordOnlineEstimator(dimensions); streamingGenerator = new StreamingGenerator(dimensions, configService.getQueueSize(), configService.getDiversitySize(), estimator); } diff --git a/src/main/resources/application.properties b/src/main/resources/application.properties index a9fad7d..6244f6e 100644 --- a/src/main/resources/application.properties +++ b/src/main/resources/application.properties @@ -2,7 +2,7 @@ quarkus.banner.path=banner.txt quarkus.http.host=0.0.0.0 quarkus.http.http2=true # Default explainer type -explainer.type=${EXPLAINER_TYPE:LIME} +explainer.type=${EXPLAINER_TYPE:ALL} # Logging quarkus.log.level=${QUARKUS_LOG_LEVEL:INFO} quarkus.log.console.enable=true