diff --git a/base-api/src/main/java/com/twosigma/beakerx/evaluator/BaseEvaluator.java b/base-api/src/main/java/com/twosigma/beakerx/evaluator/BaseEvaluator.java index 243bc02..e8a9b00 100644 --- a/base-api/src/main/java/com/twosigma/beakerx/evaluator/BaseEvaluator.java +++ b/base-api/src/main/java/com/twosigma/beakerx/evaluator/BaseEvaluator.java @@ -36,7 +36,6 @@ import org.apache.commons.io.FileUtils; import java.io.File; -import java.net.URISyntaxException; import java.nio.file.Path; import java.nio.file.Paths; import java.util.ArrayList; @@ -52,7 +51,9 @@ public abstract class BaseEvaluator implements Evaluator { - public static String INTERUPTED_MSG = "interrupted"; + public static String INTERUPTED_MSG = "KeyboardInterrupt"; + + private boolean interrupting = false; protected final String shellId; protected final String sessionId; private final ClasspathScanner classpathScanner; @@ -109,6 +110,29 @@ public void endEvaluation() { classLoaderSwitcher.end(); } + @Override + public void interruptKernel() { + interrupting = true; + killAllThreads(); + } + + @Override + public void interruptKernelDone() { + interrupting = false; + } + + public boolean isInterrupting() { + return interrupting; + } + + @Override + public TryResult processResult(TryResult result) { + if (interrupting) { + return TryResult.createError(INTERUPTED_MSG); + } + return result; + } + CompletableFuture background; protected TryResult evaluate(EvaluationObject seo, Callable callable) { diff --git a/base-api/src/main/java/com/twosigma/beakerx/evaluator/Evaluator.java b/base-api/src/main/java/com/twosigma/beakerx/evaluator/Evaluator.java index fdfbe7e..bfa9626 100644 --- a/base-api/src/main/java/com/twosigma/beakerx/evaluator/Evaluator.java +++ b/base-api/src/main/java/com/twosigma/beakerx/evaluator/Evaluator.java @@ -85,4 +85,12 @@ public interface Evaluator { void startEvaluation(); void endEvaluation(); + + void interruptKernel(); + + void interruptKernelDone(); + + boolean isInterrupting(); + + TryResult processResult(TryResult result); } diff --git a/base-api/src/main/java/com/twosigma/beakerx/kernel/KernelFunctionality.java b/base-api/src/main/java/com/twosigma/beakerx/kernel/KernelFunctionality.java index 384135f..a635d70 100644 --- a/base-api/src/main/java/com/twosigma/beakerx/kernel/KernelFunctionality.java +++ b/base-api/src/main/java/com/twosigma/beakerx/kernel/KernelFunctionality.java @@ -59,7 +59,11 @@ public interface KernelFunctionality { void cancelExecution(GroupName groupName); - void killAllThreads(); + void interruptKernel(); + + void interruptKernelDone(); + + boolean isInterrupting(); Handler getHandler(JupyterMessages type); @@ -120,4 +124,5 @@ public interface KernelFunctionality { void startEvaluation(); void endEvaluation(); + } diff --git a/base-api/src/main/java/com/twosigma/beakerx/kernel/msg/JupyterMessages.java b/base-api/src/main/java/com/twosigma/beakerx/kernel/msg/JupyterMessages.java index 301ccdd..55311b3 100644 --- a/base-api/src/main/java/com/twosigma/beakerx/kernel/msg/JupyterMessages.java +++ b/base-api/src/main/java/com/twosigma/beakerx/kernel/msg/JupyterMessages.java @@ -51,7 +51,9 @@ public enum JupyterMessages { ERROR, IS_COMPLETE_REQUEST, INPUT_REQUEST, - IS_COMPLETE_REPLY; + IS_COMPLETE_REPLY, + INTERRUPT_REQUEST, + INTERRUPT_REPLY; public String getName() { return this.name().toLowerCase(); @@ -70,4 +72,4 @@ public static JupyterMessages getType(final String input){ return ret; } -} \ No newline at end of file +} diff --git a/base-test/src/main/java/com/twosigma/beakerx/KernelTest.java b/base-test/src/main/java/com/twosigma/beakerx/KernelTest.java index 2151b30..369dc4f 100644 --- a/base-test/src/main/java/com/twosigma/beakerx/KernelTest.java +++ b/base-test/src/main/java/com/twosigma/beakerx/KernelTest.java @@ -296,10 +296,20 @@ public void cancelExecution(GroupName groupName) { } @Override - public void killAllThreads() { + public void interruptKernel() { } + @Override + public void interruptKernelDone() { + + } + + @Override + public boolean isInterrupting() { + return false; + } + @Override public Handler getHandler(JupyterMessages type) { return null; @@ -525,13 +535,5 @@ public List getObjectList() { } } -// public static EvaluationObject createSeo(String code) { -// return new SimpleEvaluationObject(code, new SeoConfigurationFactoryMock()); -// } -// -// public static EvaluationObject createSeo(String code, Message message) { -// return new SimpleEvaluationObject(code, new SeoConfigurationFactoryMock(message)); -// } - } diff --git a/base/src/main/java/com/twosigma/beakerx/KernelInfoHandler.java b/base/src/main/java/com/twosigma/beakerx/KernelInfoHandler.java index 7607280..69afc97 100644 --- a/base/src/main/java/com/twosigma/beakerx/KernelInfoHandler.java +++ b/base/src/main/java/com/twosigma/beakerx/KernelInfoHandler.java @@ -15,29 +15,26 @@ */ package com.twosigma.beakerx; -import static com.twosigma.beakerx.kernel.msg.JupyterMessages.KERNEL_INFO_REPLY; -import static com.twosigma.beakerx.handler.KernelHandlerWrapper.wrapBusyIdle; -import static java.util.Arrays.asList; - -import com.twosigma.beakerx.kernel.KernelFunctionality; import com.twosigma.beakerx.handler.KernelHandler; -import com.twosigma.beakerx.kernel.KernelManager; +import com.twosigma.beakerx.kernel.KernelFunctionality; import com.twosigma.beakerx.message.Header; import com.twosigma.beakerx.message.Message; +import org.slf4j.Logger; +import org.slf4j.LoggerFactory; import java.io.Serializable; import java.util.ArrayList; import java.util.HashMap; -import org.slf4j.Logger; -import org.slf4j.LoggerFactory; +import static com.twosigma.beakerx.handler.KernelHandlerWrapper.wrapBusyIdle; +import static com.twosigma.beakerx.kernel.msg.JupyterMessages.KERNEL_INFO_REPLY; +import static java.util.Arrays.asList; public abstract class KernelInfoHandler extends KernelHandler { private final static Logger logger = LoggerFactory.getLogger(KernelInfoHandler.class); public static final String PROTOCOL_VERSION = "protocol_version"; public static final String PROTOCOL_VERSION_NUMBER = "5.3"; - public static final String INTERRUPT_KERNEL = "interrupt_kernel"; public KernelInfoHandler(KernelFunctionality kernel) { super(kernel); @@ -49,12 +46,12 @@ public void handle(Message message) { } private void handleMsg(Message message) { - logger.debug("Processing kernel info request"); - Message reply = new Message(new Header(KERNEL_INFO_REPLY, message.getHeader().getSession())); - reply.setContent(content()); - reply.setParentHeader(message.getHeader()); - reply.setIdentities(message.getIdentities()); - send(reply); + logger.debug("Processing kernel info request"); + Message reply = new Message(new Header(KERNEL_INFO_REPLY, message.getHeader().getSession())); + reply.setContent(content()); + reply.setParentHeader(message.getHeader()); + reply.setIdentities(message.getIdentities()); + send(reply); } private HashMap languageInfo() { @@ -70,7 +67,6 @@ private HashMap content() { map.put("help_links", getHelpLinks()); map.put("beakerx", true); map.put("status", "ok"); - map.put("url_to_interrupt", KernelManager.get().getBeakerXServer().getURL() + INTERRUPT_KERNEL); return doContent(map); } diff --git a/base/src/main/java/com/twosigma/beakerx/kernel/Kernel.java b/base/src/main/java/com/twosigma/beakerx/kernel/Kernel.java index 6daf3d1..2cf241c 100644 --- a/base/src/main/java/com/twosigma/beakerx/kernel/Kernel.java +++ b/base/src/main/java/com/twosigma/beakerx/kernel/Kernel.java @@ -20,6 +20,7 @@ import com.twosigma.beakerx.DisplayerDataMapper; import com.twosigma.beakerx.TryResult; import com.twosigma.beakerx.autocomplete.AutocompleteResult; +import com.twosigma.beakerx.evaluator.BaseEvaluator; import com.twosigma.beakerx.evaluator.Evaluator; import com.twosigma.beakerx.evaluator.Hook; import com.twosigma.beakerx.handler.Handler; @@ -28,6 +29,7 @@ import com.twosigma.beakerx.jvm.object.EvaluationObject; import com.twosigma.beakerx.kernel.comm.Comm; import com.twosigma.beakerx.kernel.handler.CommOpenHandler; +import com.twosigma.beakerx.kernel.handler.ExecuteRequestHandler; import com.twosigma.beakerx.kernel.magic.command.MagicCommandConfiguration; import com.twosigma.beakerx.kernel.magic.command.MagicCommandType; import com.twosigma.beakerx.kernel.msg.JupyterMessages; @@ -159,8 +161,20 @@ public void updateEvaluatorParameters(final EvaluatorParameters kernelParameters } @Override - public void killAllThreads() { - evaluator.killAllThreads(); + public void interruptKernel() { + evaluator.interruptKernel(); + ExecuteRequestHandler executeRequestHandler = handlers.getExecuteRequestHandler(); + executeRequestHandler.interruptKernel(); + } + + @Override + public void interruptKernelDone() { + evaluator.interruptKernelDone(); + } + + @Override + public boolean isInterrupting() { + return evaluator.isInterrupting(); } @Override diff --git a/base/src/main/java/com/twosigma/beakerx/kernel/KernelHandlers.java b/base/src/main/java/com/twosigma/beakerx/kernel/KernelHandlers.java index f6bea64..342a19c 100644 --- a/base/src/main/java/com/twosigma/beakerx/kernel/KernelHandlers.java +++ b/base/src/main/java/com/twosigma/beakerx/kernel/KernelHandlers.java @@ -21,6 +21,7 @@ import com.twosigma.beakerx.kernel.handler.CommMsgHandler; import com.twosigma.beakerx.kernel.handler.CommOpenHandler; import com.twosigma.beakerx.kernel.handler.ExecuteRequestHandler; +import com.twosigma.beakerx.kernel.handler.InterruptMsgHandler; import com.twosigma.beakerx.kernel.msg.JupyterMessages; import com.twosigma.beakerx.handler.IsCompleteRequestHandler; import com.twosigma.beakerx.handler.KernelHandler; @@ -60,6 +61,7 @@ private Map> createHandlers(final CommOp handlers.put(JupyterMessages.COMM_CLOSE, new CommCloseHandler(kernel)); handlers.put(JupyterMessages.COMM_MSG, new CommMsgHandler(kernel)); handlers.put(JupyterMessages.IS_COMPLETE_REQUEST, new IsCompleteRequestHandler(kernel)); + handlers.put(JupyterMessages.INTERRUPT_REQUEST, new InterruptMsgHandler(kernel)); return handlers; } diff --git a/base/src/main/java/com/twosigma/beakerx/kernel/handler/ExecuteRequestHandler.java b/base/src/main/java/com/twosigma/beakerx/kernel/handler/ExecuteRequestHandler.java index 0ed3b0f..681bfdd 100644 --- a/base/src/main/java/com/twosigma/beakerx/kernel/handler/ExecuteRequestHandler.java +++ b/base/src/main/java/com/twosigma/beakerx/kernel/handler/ExecuteRequestHandler.java @@ -27,7 +27,11 @@ import java.io.Serializable; import java.util.HashMap; +import java.util.List; import java.util.Map; +import java.util.concurrent.ExecutorService; +import java.util.concurrent.Executors; +import java.util.concurrent.FutureTask; import static com.twosigma.beakerx.kernel.msg.JupyterMessages.EXECUTE_INPUT; import static java.util.Collections.singletonList; @@ -41,6 +45,7 @@ public class ExecuteRequestHandler extends KernelHandler { private final SimpleEvaluationObjectFactory seof; private int executionCount; + private ExecutorService executorService = Executors.newFixedThreadPool(1); public ExecuteRequestHandler(KernelFunctionality kernel) { super(kernel); @@ -48,27 +53,43 @@ public ExecuteRequestHandler(KernelFunctionality kernel) { this.seof = new SimpleEvaluationObjectFactory(); } + private FutureTask current; + @Override public void handle(Message message) { try { - handleMsg(message); + executorService.execute(() -> handleMsg(message)); } catch (Exception e) { handleException(message, e); } } + private void handleMsg(Message message) { - executionCount += 1; - kernel.sendBusyMessage(message); - String codeString = takeCodeFrom(message); - announceTheCode(message, codeString); - Code code = new CodeFactory(MessageCreator.get(), seof).create(codeString, message, kernel); - code.execute(kernel, executionCount); - finishExecution(message); + current = new FutureTask<>(() -> { + try { + runCode(message); + } catch (Exception e) { + handleException(message, e); + } + return "ok"; + }); + current.run(); } - private void finishExecution(Message message) { - kernel.sendIdleMessage(message); + private void runCode(Message message) { + if (kernel.isInterrupting()) { + Message abortedReply = MessageCreator.buildAbortedReply(message); + kernel.send(abortedReply); + } else { + kernel.sendBusyMessage(message); + executionCount += 1; + String codeString = takeCodeFrom(message); + announceTheCode(message, codeString); + Code code = new CodeFactory(MessageCreator.get(), seof).create(codeString, message, kernel); + code.execute(kernel, executionCount); + kernel.sendIdleMessage(message); + } } private String takeCodeFrom(Message message) { @@ -99,4 +120,20 @@ private void handleException(Message message, Exception e) { public void exit() { } + public void interruptKernel() { + waitForTheEndOfTheCurrentCell(); + List cells = executorService.shutdownNow(); + cells.forEach(Runnable::run); + executorService = Executors.newFixedThreadPool(1); + } + + private void waitForTheEndOfTheCurrentCell() { + if (current != null) { + try { + current.get(); + } catch (Exception e) { + throw new RuntimeException(e); + } + } + } } diff --git a/base/src/main/java/com/twosigma/beakerx/kernel/handler/InterruptMsgHandler.java b/base/src/main/java/com/twosigma/beakerx/kernel/handler/InterruptMsgHandler.java new file mode 100644 index 0000000..73b0441 --- /dev/null +++ b/base/src/main/java/com/twosigma/beakerx/kernel/handler/InterruptMsgHandler.java @@ -0,0 +1,48 @@ +/* + * Copyright 2020 TWO SIGMA OPEN SOURCE, LLC + * + * Licensed under the Apache License, Version 2.0 (the "License"); + * you may not use this file except in compliance with the License. + * You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ +package com.twosigma.beakerx.kernel.handler; + +import com.twosigma.beakerx.handler.KernelHandler; +import com.twosigma.beakerx.kernel.KernelFunctionality; +import com.twosigma.beakerx.kernel.msg.JupyterMessages; +import com.twosigma.beakerx.message.Header; +import com.twosigma.beakerx.message.Message; +import org.slf4j.Logger; +import org.slf4j.LoggerFactory; + +public class InterruptMsgHandler extends KernelHandler { + + private final static Logger logger = LoggerFactory.getLogger(InterruptMsgHandler.class); + + public InterruptMsgHandler(final KernelFunctionality kernel) { + super(kernel); + } + + public void handle(Message message) { + Message interruptReply = createInterruptReply(JupyterMessages.INTERRUPT_REPLY, message); + kernel.send(interruptReply); + kernel.interruptKernel(); + kernel.interruptKernelDone(); + logger.info("Interrupting done"); + } + + private static Message createInterruptReply(JupyterMessages type, Message message) { + Message reply = new Message(new Header(type, message.getHeader().getSession())); + reply.setParentHeader(message.getHeader()); + reply.setIdentities(message.getIdentities()); + return reply; + } +} diff --git a/base/src/main/java/com/twosigma/beakerx/kernel/msg/MessageCreator.java b/base/src/main/java/com/twosigma/beakerx/kernel/msg/MessageCreator.java index e9f184f..15bfc0a 100644 --- a/base/src/main/java/com/twosigma/beakerx/kernel/msg/MessageCreator.java +++ b/base/src/main/java/com/twosigma/beakerx/kernel/msg/MessageCreator.java @@ -133,7 +133,13 @@ private static Message buildReply(Message message, EvaluationObject seo) { } return reply; } - + public static Message buildAbortedReply(Message message) { + Message abortedReply = initMessage(EXECUTE_REPLY, message); + Hashtable content = new Hashtable(1); + content.put("status", "aborted"); + abortedReply.setContent(content); + return abortedReply; + } public static Message buildReplyWithoutStatus(Message message, int executionCount) { Message reply = initMessage(EXECUTE_REPLY, message); Hashtable map6 = new Hashtable(3); diff --git a/base/src/main/java/com/twosigma/beakerx/kernel/restserver/impl/BeakerXServerJavalin.java b/base/src/main/java/com/twosigma/beakerx/kernel/restserver/impl/BeakerXServerJavalin.java index 1831817..d97b7e5 100644 --- a/base/src/main/java/com/twosigma/beakerx/kernel/restserver/impl/BeakerXServerJavalin.java +++ b/base/src/main/java/com/twosigma/beakerx/kernel/restserver/impl/BeakerXServerJavalin.java @@ -26,7 +26,6 @@ import static com.twosigma.beakerx.BeakerXClient.CODE_CELL_PATH; import static com.twosigma.beakerx.BeakerXClient.URL_ARG; -import static com.twosigma.beakerx.KernelInfoHandler.INTERRUPT_KERNEL; import static com.twosigma.beakerx.kernel.comm.GetCodeCellsHandler.INSTANCE; import static com.twosigma.beakerx.kernel.magic.command.functionality.AsyncMagicCommand.CANCEL_EXECUTION; @@ -73,9 +72,6 @@ private void mappingsForAllKernels(Javalin server, KernelFunctionality kernel) { String body = ctx.body(); INSTANCE.handle(body); }); - server.post(INTERRUPT_KERNEL, ctx -> { - kernel.killAllThreads(); - }); server.post(CANCEL_EXECUTION+"/:groupname", ctx -> { kernel.cancelExecution(GroupName.of(ctx.param("groupname"))); }); diff --git a/base/src/main/java/com/twosigma/beakerx/socket/KernelSocketsZMQ.java b/base/src/main/java/com/twosigma/beakerx/socket/KernelSocketsZMQ.java index 7172f30..e4e5250 100644 --- a/base/src/main/java/com/twosigma/beakerx/socket/KernelSocketsZMQ.java +++ b/base/src/main/java/com/twosigma/beakerx/socket/KernelSocketsZMQ.java @@ -74,17 +74,16 @@ public KernelSocketsZMQ(KernelFunctionality kernel, Config configuration, Socket private void configureSockets(Config configuration) { final String connection = configuration.getTransport() + "://" + configuration.getHost(); - hearbeatSocket = getNewSocket(ZMQ.REP, configuration.getHeartbeat(), connection, context); iopubSocket = getNewSocket(ZMQ.PUB, configuration.getIopub(), connection, context); + hearbeatSocket = getNewSocket(ZMQ.ROUTER, configuration.getHeartbeat(), connection, context); controlSocket = getNewSocket(ZMQ.ROUTER, configuration.getControl(), connection, context); stdinSocket = getNewSocket(ZMQ.ROUTER, configuration.getStdin(), connection, context); shellSocket = getNewSocket(ZMQ.ROUTER, configuration.getShell(), connection, context); - sockets = new ZMQ.Poller(4); - sockets.register(controlSocket, ZMQ.Poller.POLLIN); + sockets = new ZMQ.Poller(3); sockets.register(hearbeatSocket, ZMQ.Poller.POLLIN); sockets.register(shellSocket, ZMQ.Poller.POLLIN); - sockets.register(stdinSocket, ZMQ.Poller.POLLIN); + sockets.register(controlSocket, ZMQ.Poller.POLLIN); } public void publish(List message) { @@ -168,10 +167,8 @@ public void run() { handleHeartbeat(); } else if (isShellMsg()) { handleShell(); - } else if (isStdinMsg()) { - handleStdIn(); - } else if (this.isShutdown()) { - break; + } else { + logger.error("not handled message from sockets"); } } } catch (Exception e) { @@ -211,6 +208,10 @@ private void handleControlMsg() { sendMsg(controlSocket, Collections.singletonList(reply)); shutdown(); } + Handler handler = kernel.getHandler(message.type()); + if (handler != null) { + handler.handle(message); + } } private ZMQ.Socket getNewSocket(int type, int port, String connection, ZMQ.Context context) { @@ -262,20 +263,16 @@ private String verifyDelim(ZFrame zframe) { return delim; } - private boolean isStdinMsg() { - return sockets.pollin(3); + private boolean isHeartbeatMsg() { + return sockets.pollin(0); } private boolean isShellMsg() { - return sockets.pollin(2); - } - - private boolean isHeartbeatMsg() { return sockets.pollin(1); } private boolean isControlMsg() { - return sockets.pollin(0); + return sockets.pollin(2); } private void shutdown() {