Skip to content

Commit

Permalink
Create a single shareable ClusterRequestHandler for all channels to p…
Browse files Browse the repository at this point in the history
…revent ssh-keygen race conditions. Add unit tests for the ClusterRequestHandler (#2547)
  • Loading branch information
si2d authored Nov 13, 2024
1 parent 62a63be commit 567ae86
Show file tree
Hide file tree
Showing 5 changed files with 335 additions and 9 deletions.
1 change: 1 addition & 0 deletions serving/build.gradle.kts
Original file line number Diff line number Diff line change
Expand Up @@ -38,6 +38,7 @@ dependencies {
testImplementation(libs.testng) {
exclude(group = "junit", module = "junit")
}
testImplementation(libs.mockito.core)
}

tasks {
Expand Down
4 changes: 3 additions & 1 deletion serving/src/main/java/ai/djl/serving/ServerInitializer.java
Original file line number Diff line number Diff line change
Expand Up @@ -39,6 +39,7 @@ public class ServerInitializer extends ChannelInitializer<Channel> {
private Connector.ConnectorType connectorType;
private SslContext sslCtx;
private FolderScanPluginManager pluginManager;
private final ClusterRequestHandler clusterRequestHandler;

/**
* Creates a new {@code HttpRequestHandler} instance.
Expand All @@ -54,6 +55,7 @@ public ServerInitializer(
this.sslCtx = sslCtx;
this.connectorType = connectorType;
this.pluginManager = pluginManager;
this.clusterRequestHandler = ClusterRequestHandler.getInstance();
}

/** {@inheritDoc} */
Expand All @@ -76,7 +78,7 @@ public void initChannel(Channel ch) {
pipeline.addLast("inference", new InferenceRequestHandler());
break;
case CLUSTER:
pipeline.addLast("cluster", new ClusterRequestHandler());
pipeline.addLast("cluster", clusterRequestHandler);
break;
case BOTH:
default:
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -23,6 +23,7 @@
import ai.djl.serving.workflow.Workflow;
import ai.djl.util.Utils;

import io.netty.channel.ChannelHandler.Sharable;
import io.netty.channel.ChannelHandlerContext;
import io.netty.handler.codec.http.FullHttpRequest;
import io.netty.handler.codec.http.QueryStringDecoder;
Expand All @@ -34,17 +35,50 @@
import java.io.InputStream;
import java.nio.file.Files;
import java.nio.file.Path;
import java.nio.file.Paths;
import java.util.List;
import java.util.Map;
import java.util.concurrent.ConcurrentHashMap;
import java.util.concurrent.TimeUnit;
import java.util.function.Function;

/** A class handling inbound HTTP requests for the cluster management API. */
public class ClusterRequestHandler extends HttpRequestHandler {
@Sharable
public final class ClusterRequestHandler extends HttpRequestHandler {

private static final ClusterRequestHandler INSTANCE = new ClusterRequestHandler();
private static final Logger logger = LoggerFactory.getLogger(ClusterRequestHandler.class);

private ClusterConfig config = ClusterConfig.getInstance();
private Function<String[], ProcessBuilder> processBuilderFunction;
private Path sshDirPath;

/** Create a real process builder for the default instance. */
private ClusterRequestHandler() {
processBuilderFunction = (cmds -> new ProcessBuilder(cmds).redirectErrorStream(true));
}

/** A method to allow mocking the process builder. */
protected void setProcessBuilderFunction(
Function<String[], ProcessBuilder> processBuilderFunction) {
this.processBuilderFunction = processBuilderFunction;
}

/**
* Override the path used for the id_rsa keys. Defaults to `user.home` property. This path can
* only be set once and cannot be set if the keygen has been run once.
*/
public synchronized void setSshGenDir(Path sshDirPath) {
if (this.sshDirPath != null && !this.sshDirPath.equals(sshDirPath)) {
logger.error("Attempt to set ssh path after running keygen");
throw new IllegalStateException("Attempt to set ssh path after running keygen");
}
this.sshDirPath = sshDirPath;
}

/** Get the singleton instance that can be reused across channels. */
public static ClusterRequestHandler getInstance() {
return INSTANCE;
}

/** {@inheritDoc} */
@Override
Expand All @@ -64,12 +98,16 @@ protected void handleRequest(
QueryStringDecoder decoder,
String[] segments)
throws ModelException {
Path sshDir = Paths.get(System.getProperty("user.home")).resolve(".ssh");
switch (segments[2]) {
case "sshpublickey":
Path publicKeyFile = sshDir.resolve("id_rsa.pub");
if (Files.notExists(publicKeyFile)) {
sshkeygen(sshDir.resolve("id_rsa").toString());
if (this.sshDirPath == null) {
setSshGenDir(Path.of(System.getProperty("user.home")).resolve(".ssh"));
}
Path publicKeyFile = sshDirPath.resolve("id_rsa.pub");
synchronized (this) {
if (Files.notExists(publicKeyFile)) {
sshkeygen(sshDirPath.resolve("id_rsa").toString());
}
}
NettyUtils.sendFile(ctx, publicKeyFile, false);
return;
Expand Down Expand Up @@ -104,12 +142,16 @@ protected void handleRequest(
private void sshkeygen(String rsaFile) {
try {
String[] commands = {"ssh-keygen", "-q", "-t", "rsa", "-N", "", "-f", rsaFile};
Process exec = new ProcessBuilder(commands).redirectErrorStream(true).start();
Process exec = processBuilderFunction.apply(commands).start();
String logOutput;
try (InputStream is = exec.getInputStream()) {
logOutput = Utils.toString(is);
}
int exitCode = exec.waitFor();
if (!exec.waitFor(60, TimeUnit.SECONDS)) {
exec.destroy();
throw new IllegalStateException("Generate ssh key timeout");
}
int exitCode = exec.exitValue();
if (0 != exitCode) {
logger.error("Generate ssh key failed: {}", logOutput);
config.setError(logOutput);
Expand Down
Loading

0 comments on commit 567ae86

Please sign in to comment.