Skip to content

Commit

Permalink
forward port of patch in 0.28.0 that terminates a python worker when … (
Browse files Browse the repository at this point in the history
  • Loading branch information
siddvenk committed Jan 21, 2025
1 parent ab53670 commit ae93814
Show file tree
Hide file tree
Showing 2 changed files with 12 additions and 0 deletions.
Original file line number Diff line number Diff line change
Expand Up @@ -14,6 +14,7 @@

import ai.djl.Device;
import ai.djl.Model;
import ai.djl.engine.EngineException;
import ai.djl.inference.Predictor;
import ai.djl.modality.Input;
import ai.djl.modality.Output;
Expand Down Expand Up @@ -63,6 +64,10 @@ public PyPredictor(
@Override
@SuppressWarnings("unchecked")
public List<O> batchPredict(List<I> inputs) throws TranslateException {
if (process.isModelUnrecoverable()) {
throw new EngineException(
"Backend Python process is unrecoverable. Initiating worker termination");
}
if (!process.isReady()) {
// TODO: wait for restart
throw new TranslateException("Backend Python process is stopped.");
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -51,6 +51,7 @@ class PyProcess {
private CountDownLatch latch;
private volatile boolean started; // NOPMD
private volatile boolean modelLoaded; // NOPMD
private volatile boolean modelUnrecoverable; // NOPMD
private AtomicInteger restartCount;
private CompletableFuture<Void> restartFuture;
private boolean passiveWorkersMode;
Expand Down Expand Up @@ -144,6 +145,8 @@ Output predict(Input inputs, int timeout, boolean initialLoad) {
if (!initialLoad) {
logger.info("Restart python process ...");
restartFuture = CompletableFuture.runAsync(this::startPythonProcess);
} else {
modelUnrecoverable = true;
}
if (e instanceof EngineException) {
throw (EngineException) e;
Expand Down Expand Up @@ -260,6 +263,10 @@ boolean isReady() {
return started && modelLoaded;
}

boolean isModelUnrecoverable() {
return modelUnrecoverable;
}

private static String[] getHosts(int clusterSize) {
String leaderAddr = Utils.getenv("DJL_LEADER_ADDR");
String workerAddrFormat = Utils.getenv("DJL_WORKER_ADDR_FORMAT");
Expand Down

0 comments on commit ae93814

Please sign in to comment.