diff --git a/engines/python/src/main/java/ai/djl/python/engine/PyPredictor.java b/engines/python/src/main/java/ai/djl/python/engine/PyPredictor.java index dd72112ef..431e74c78 100644 --- a/engines/python/src/main/java/ai/djl/python/engine/PyPredictor.java +++ b/engines/python/src/main/java/ai/djl/python/engine/PyPredictor.java @@ -14,6 +14,7 @@ import ai.djl.Device; import ai.djl.Model; +import ai.djl.engine.EngineException; import ai.djl.inference.Predictor; import ai.djl.modality.Input; import ai.djl.modality.Output; @@ -63,6 +64,10 @@ public PyPredictor( @Override @SuppressWarnings("unchecked") public List batchPredict(List inputs) throws TranslateException { + if (process.isModelUnrecoverable()) { + throw new EngineException( + "Backend Python process is unrecoverable. Initiating worker termination"); + } if (!process.isReady()) { // TODO: wait for restart throw new TranslateException("Backend Python process is stopped."); diff --git a/engines/python/src/main/java/ai/djl/python/engine/PyProcess.java b/engines/python/src/main/java/ai/djl/python/engine/PyProcess.java index 562ffdbed..d97efb4f5 100644 --- a/engines/python/src/main/java/ai/djl/python/engine/PyProcess.java +++ b/engines/python/src/main/java/ai/djl/python/engine/PyProcess.java @@ -51,6 +51,7 @@ class PyProcess { private CountDownLatch latch; private volatile boolean started; // NOPMD private volatile boolean modelLoaded; // NOPMD + private volatile boolean modelUnrecoverable; // NOPMD private AtomicInteger restartCount; private CompletableFuture restartFuture; private boolean passiveWorkersMode; @@ -144,6 +145,8 @@ Output predict(Input inputs, int timeout, boolean initialLoad) { if (!initialLoad) { logger.info("Restart python process ..."); restartFuture = CompletableFuture.runAsync(this::startPythonProcess); + } else { + modelUnrecoverable = true; } if (e instanceof EngineException) { throw (EngineException) e; @@ -260,6 +263,10 @@ boolean isReady() { return started && modelLoaded; } + boolean isModelUnrecoverable() { + return modelUnrecoverable; + } + private static String[] getHosts(int clusterSize) { String leaderAddr = Utils.getenv("DJL_LEADER_ADDR"); String workerAddrFormat = Utils.getenv("DJL_WORKER_ADDR_FORMAT");