Skip to content

Commit

Permalink
Merge pull request #10 from ruivieira/RHOAIENG-11180
Browse files Browse the repository at this point in the history
fix: Add both LIME and SHAP as default explainers
  • Loading branch information
ruivieira authored Aug 12, 2024
2 parents 2249671 + 2dca42e commit 403bac9
Show file tree
Hide file tree
Showing 5 changed files with 30 additions and 26 deletions.
3 changes: 1 addition & 2 deletions src/main/docker/Dockerfile.jvm
Original file line number Diff line number Diff line change
@@ -1,4 +1,4 @@
FROM registry.access.redhat.com/ubi8/openjdk-17:1.18
FROM registry.access.redhat.com/ubi8/openjdk-17:latest

ENV LANGUAGE='en_US:en'

Expand All @@ -15,4 +15,3 @@ ENV JAVA_OPTS_APPEND="-Dquarkus.http.host=0.0.0.0 -Djava.util.logging.manager=or
ENV JAVA_APP_JAR="/deployments/quarkus-run.jar"

ENTRYPOINT [ "/opt/jboss/container/java/run/run-java.sh" ]

3 changes: 2 additions & 1 deletion src/main/java/org/kie/trustyai/ExplainerType.java
Original file line number Diff line number Diff line change
Expand Up @@ -2,5 +2,6 @@

public enum ExplainerType {
LIME,
SHAP
SHAP,
ALL
}
45 changes: 24 additions & 21 deletions src/main/java/org/kie/trustyai/ExplainerV1Endpoint.java
Original file line number Diff line number Diff line change
@@ -1,15 +1,12 @@
package org.kie.trustyai;

import java.util.Collection;
import java.util.List;
import java.util.Map;
import java.util.Objects;
import java.util.concurrent.CompletableFuture;
import java.util.concurrent.ExecutionException;

import com.fasterxml.jackson.databind.ObjectMapper;
import io.quarkus.logging.Log;
import io.vertx.core.json.JsonObject;
import jakarta.enterprise.inject.Default;
import jakarta.inject.Inject;
import jakarta.ws.rs.Consumes;
Expand All @@ -22,23 +19,17 @@
import org.apache.commons.math3.linear.RealVector;
import org.kie.trustyai.connectors.kserve.v1.KServeV1HTTPPredictionProvider;
import org.kie.trustyai.connectors.kserve.v1.KServeV1RequestPayload;
import org.kie.trustyai.explainability.local.LocalExplainer;
import org.kie.trustyai.explainability.model.Feature;
import org.kie.trustyai.explainability.model.Prediction;
import org.kie.trustyai.explainability.model.PredictionInput;
import org.kie.trustyai.explainability.model.PredictionOutput;
import org.kie.trustyai.explainability.model.PredictionProvider;
import org.kie.trustyai.explainability.model.SaliencyResults;
import org.kie.trustyai.explainability.model.SimplePrediction;
import org.kie.trustyai.payloads.SaliencyExplanationResponse;
import org.kie.trustyai.payloads.SaliencyExplanationResponse.FeatureSaliency;

@Path("/v1/models/{modelName}:explain")
public class ExplainerV1Endpoint {

@Inject
ObjectMapper objectMapper;

@Inject
@Default
ConfigService configService;
Expand Down Expand Up @@ -67,8 +58,13 @@ public Response explain(@PathParam("modelName") String modelName, KServeV1Reques
final PredictionOutput output = provider.predictAsync(input).get().get(0);
final Prediction prediction = new SimplePrediction(input.get(0), output);
final int dimensions = input.get(0).getFeatures().size();
final ExplainerType explainerType = configService.getExplainerType();

CompletableFuture<SaliencyResults> lime = null;
CompletableFuture<SaliencyResults> shap = null;

if (configService.getExplainerType() == ExplainerType.SHAP || configService.getExplainerType() == null) {

if (explainerType == ExplainerType.SHAP || explainerType == ExplainerType.ALL) {
if (Objects.isNull(streamingGeneratorManager.getStreamingGenerator())) {
Log.info("Initializing SHAP's Streaming Background Generator with dimension " + dimensions);
streamingGeneratorManager.initialize(dimensions);
Expand All @@ -79,22 +75,29 @@ public Response explain(@PathParam("modelName") String modelName, KServeV1Reques
}
final RealVector vectorData = new ArrayRealVector(numericData);
streamingGeneratorManager.getStreamingGenerator().update(vectorData);
shap = explainerFactory.getExplainer(ExplainerType.SHAP)
.explainAsync(prediction, provider);
}

if (explainerType == ExplainerType.LIME || explainerType == ExplainerType.ALL) {
Log.info("Sending explaining request to " + predictorURI);
lime = explainerFactory.getExplainer(ExplainerType.LIME)
.explainAsync(prediction, provider);
}


try {
Log.info("Sending explaining request to " + predictorURI);
CompletableFuture<SaliencyResults> lime = explainerFactory.getExplainer(ExplainerType.LIME)
.explainAsync(prediction, provider);
CompletableFuture<SaliencyResults> shap = explainerFactory.getExplainer(ExplainerType.SHAP)
.explainAsync(prediction, provider);
CompletableFuture<SaliencyExplanationResponse> response = lime.thenCombine(shap,
(limeExplanation, shapExplanation) -> SaliencyExplanationResponse
.fromSaliencyResults(limeExplanation, shapExplanation));
try {
if (explainerType == ExplainerType.ALL) {
CompletableFuture<SaliencyExplanationResponse> response = lime.thenCombine(shap,
SaliencyExplanationResponse::fromSaliencyResults);
return Response.ok(response.get(), MediaType.APPLICATION_JSON).build();
} catch (Exception e) {
return Response.serverError().entity("Error serializing SaliencyResults to JSON: " + e.getMessage())
.build();
} else if (explainerType == ExplainerType.SHAP) {
return Response.ok(shap.get(), MediaType.APPLICATION_JSON).build();
} else if (explainerType == ExplainerType.LIME) {
return Response.ok(lime.get(), MediaType.APPLICATION_JSON).build();
} else {
return Response.serverError().entity("Unsupported explainer type").build();
}
} catch (IllegalArgumentException e) {
return Response.serverError().entity("Error: " + e.getMessage()).build();
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -16,7 +16,8 @@ public class StreamingGeneratorManager {
private StreamingGenerator streamingGenerator = null;

public synchronized void initialize(int dimensions) {
if (streamingGenerator == null && configService.getExplainerType() == ExplainerType.SHAP) {
final ExplainerType explainerType = configService.getExplainerType();
if (streamingGenerator == null && (explainerType == ExplainerType.SHAP || explainerType == ExplainerType.ALL)) {
final MultivariateOnlineEstimator<MultivariateGaussianParameters> estimator = new WelfordOnlineEstimator(dimensions);
streamingGenerator = new StreamingGenerator(dimensions, configService.getQueueSize(), configService.getDiversitySize(), estimator);
}
Expand Down
2 changes: 1 addition & 1 deletion src/main/resources/application.properties
Original file line number Diff line number Diff line change
Expand Up @@ -2,7 +2,7 @@ quarkus.banner.path=banner.txt
quarkus.http.host=0.0.0.0
quarkus.http.http2=true
# Default explainer type
explainer.type=${EXPLAINER_TYPE:LIME}
explainer.type=${EXPLAINER_TYPE:ALL}
# Logging
quarkus.log.level=${QUARKUS_LOG_LEVEL:INFO}
quarkus.log.console.enable=true
Expand Down

0 comments on commit 403bac9

Please sign in to comment.