Skip to content

Commit

Permalink
Merge pull request #19 from twosigma/jarek/12_interrupt
Browse files Browse the repository at this point in the history
jarek/12: switch interrupting mode to message
  • Loading branch information
ildipo authored Aug 3, 2020
2 parents a52b74a + afdb5bc commit aa69650
Show file tree
Hide file tree
Showing 13 changed files with 199 additions and 62 deletions.
Original file line number Diff line number Diff line change
Expand Up @@ -36,7 +36,6 @@
import org.apache.commons.io.FileUtils;

import java.io.File;
import java.net.URISyntaxException;
import java.nio.file.Path;
import java.nio.file.Paths;
import java.util.ArrayList;
Expand All @@ -52,7 +51,9 @@

public abstract class BaseEvaluator implements Evaluator {

public static String INTERUPTED_MSG = "interrupted";
public static String INTERUPTED_MSG = "KeyboardInterrupt";

private boolean interrupting = false;
protected final String shellId;
protected final String sessionId;
private final ClasspathScanner classpathScanner;
Expand Down Expand Up @@ -109,6 +110,29 @@ public void endEvaluation() {
classLoaderSwitcher.end();
}

@Override
public void interruptKernel() {
interrupting = true;
killAllThreads();
}

@Override
public void interruptKernelDone() {
interrupting = false;
}

public boolean isInterrupting() {
return interrupting;
}

@Override
public TryResult processResult(TryResult result) {
if (interrupting) {
return TryResult.createError(INTERUPTED_MSG);
}
return result;
}

CompletableFuture<TryResult> background;

protected TryResult evaluate(EvaluationObject seo, Callable<TryResult> callable) {
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -85,4 +85,12 @@ public interface Evaluator {
void startEvaluation();

void endEvaluation();

void interruptKernel();

void interruptKernelDone();

boolean isInterrupting();

TryResult processResult(TryResult result);
}
Original file line number Diff line number Diff line change
Expand Up @@ -59,7 +59,11 @@ public interface KernelFunctionality {

void cancelExecution(GroupName groupName);

void killAllThreads();
void interruptKernel();

void interruptKernelDone();

boolean isInterrupting();

Handler<Message> getHandler(JupyterMessages type);

Expand Down Expand Up @@ -120,4 +124,5 @@ public interface KernelFunctionality {
void startEvaluation();

void endEvaluation();

}
Original file line number Diff line number Diff line change
Expand Up @@ -51,7 +51,9 @@ public enum JupyterMessages {
ERROR,
IS_COMPLETE_REQUEST,
INPUT_REQUEST,
IS_COMPLETE_REPLY;
IS_COMPLETE_REPLY,
INTERRUPT_REQUEST,
INTERRUPT_REPLY;

public String getName() {
return this.name().toLowerCase();
Expand All @@ -70,4 +72,4 @@ public static JupyterMessages getType(final String input){
return ret;
}

}
}
20 changes: 11 additions & 9 deletions base-test/src/main/java/com/twosigma/beakerx/KernelTest.java
Original file line number Diff line number Diff line change
Expand Up @@ -296,10 +296,20 @@ public void cancelExecution(GroupName groupName) {
}

@Override
public void killAllThreads() {
public void interruptKernel() {

}

@Override
public void interruptKernelDone() {

}

@Override
public boolean isInterrupting() {
return false;
}

@Override
public Handler<Message> getHandler(JupyterMessages type) {
return null;
Expand Down Expand Up @@ -525,13 +535,5 @@ public List<EvaluationObject> getObjectList() {
}
}

// public static EvaluationObject createSeo(String code) {
// return new SimpleEvaluationObject(code, new SeoConfigurationFactoryMock());
// }
//
// public static EvaluationObject createSeo(String code, Message message) {
// return new SimpleEvaluationObject(code, new SeoConfigurationFactoryMock(message));
// }

}

28 changes: 12 additions & 16 deletions base/src/main/java/com/twosigma/beakerx/KernelInfoHandler.java
Original file line number Diff line number Diff line change
Expand Up @@ -15,29 +15,26 @@
*/
package com.twosigma.beakerx;

import static com.twosigma.beakerx.kernel.msg.JupyterMessages.KERNEL_INFO_REPLY;
import static com.twosigma.beakerx.handler.KernelHandlerWrapper.wrapBusyIdle;
import static java.util.Arrays.asList;

import com.twosigma.beakerx.kernel.KernelFunctionality;
import com.twosigma.beakerx.handler.KernelHandler;
import com.twosigma.beakerx.kernel.KernelManager;
import com.twosigma.beakerx.kernel.KernelFunctionality;
import com.twosigma.beakerx.message.Header;
import com.twosigma.beakerx.message.Message;
import org.slf4j.Logger;
import org.slf4j.LoggerFactory;

import java.io.Serializable;
import java.util.ArrayList;
import java.util.HashMap;

import org.slf4j.Logger;
import org.slf4j.LoggerFactory;
import static com.twosigma.beakerx.handler.KernelHandlerWrapper.wrapBusyIdle;
import static com.twosigma.beakerx.kernel.msg.JupyterMessages.KERNEL_INFO_REPLY;
import static java.util.Arrays.asList;

public abstract class KernelInfoHandler extends KernelHandler<Message> {

private final static Logger logger = LoggerFactory.getLogger(KernelInfoHandler.class);
public static final String PROTOCOL_VERSION = "protocol_version";
public static final String PROTOCOL_VERSION_NUMBER = "5.3";
public static final String INTERRUPT_KERNEL = "interrupt_kernel";

public KernelInfoHandler(KernelFunctionality kernel) {
super(kernel);
Expand All @@ -49,12 +46,12 @@ public void handle(Message message) {
}

private void handleMsg(Message message) {
logger.debug("Processing kernel info request");
Message reply = new Message(new Header(KERNEL_INFO_REPLY, message.getHeader().getSession()));
reply.setContent(content());
reply.setParentHeader(message.getHeader());
reply.setIdentities(message.getIdentities());
send(reply);
logger.debug("Processing kernel info request");
Message reply = new Message(new Header(KERNEL_INFO_REPLY, message.getHeader().getSession()));
reply.setContent(content());
reply.setParentHeader(message.getHeader());
reply.setIdentities(message.getIdentities());
send(reply);
}

private HashMap<String, Serializable> languageInfo() {
Expand All @@ -70,7 +67,6 @@ private HashMap<String, Serializable> content() {
map.put("help_links", getHelpLinks());
map.put("beakerx", true);
map.put("status", "ok");
map.put("url_to_interrupt", KernelManager.get().getBeakerXServer().getURL() + INTERRUPT_KERNEL);
return doContent(map);
}

Expand Down
18 changes: 16 additions & 2 deletions base/src/main/java/com/twosigma/beakerx/kernel/Kernel.java
Original file line number Diff line number Diff line change
Expand Up @@ -20,6 +20,7 @@
import com.twosigma.beakerx.DisplayerDataMapper;
import com.twosigma.beakerx.TryResult;
import com.twosigma.beakerx.autocomplete.AutocompleteResult;
import com.twosigma.beakerx.evaluator.BaseEvaluator;
import com.twosigma.beakerx.evaluator.Evaluator;
import com.twosigma.beakerx.evaluator.Hook;
import com.twosigma.beakerx.handler.Handler;
Expand All @@ -28,6 +29,7 @@
import com.twosigma.beakerx.jvm.object.EvaluationObject;
import com.twosigma.beakerx.kernel.comm.Comm;
import com.twosigma.beakerx.kernel.handler.CommOpenHandler;
import com.twosigma.beakerx.kernel.handler.ExecuteRequestHandler;
import com.twosigma.beakerx.kernel.magic.command.MagicCommandConfiguration;
import com.twosigma.beakerx.kernel.magic.command.MagicCommandType;
import com.twosigma.beakerx.kernel.msg.JupyterMessages;
Expand Down Expand Up @@ -159,8 +161,20 @@ public void updateEvaluatorParameters(final EvaluatorParameters kernelParameters
}

@Override
public void killAllThreads() {
evaluator.killAllThreads();
public void interruptKernel() {
evaluator.interruptKernel();
ExecuteRequestHandler executeRequestHandler = handlers.getExecuteRequestHandler();
executeRequestHandler.interruptKernel();
}

@Override
public void interruptKernelDone() {
evaluator.interruptKernelDone();
}

@Override
public boolean isInterrupting() {
return evaluator.isInterrupting();
}

@Override
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -21,6 +21,7 @@
import com.twosigma.beakerx.kernel.handler.CommMsgHandler;
import com.twosigma.beakerx.kernel.handler.CommOpenHandler;
import com.twosigma.beakerx.kernel.handler.ExecuteRequestHandler;
import com.twosigma.beakerx.kernel.handler.InterruptMsgHandler;
import com.twosigma.beakerx.kernel.msg.JupyterMessages;
import com.twosigma.beakerx.handler.IsCompleteRequestHandler;
import com.twosigma.beakerx.handler.KernelHandler;
Expand Down Expand Up @@ -60,6 +61,7 @@ private Map<JupyterMessages, KernelHandler<Message>> createHandlers(final CommOp
handlers.put(JupyterMessages.COMM_CLOSE, new CommCloseHandler(kernel));
handlers.put(JupyterMessages.COMM_MSG, new CommMsgHandler(kernel));
handlers.put(JupyterMessages.IS_COMPLETE_REQUEST, new IsCompleteRequestHandler(kernel));
handlers.put(JupyterMessages.INTERRUPT_REQUEST, new InterruptMsgHandler(kernel));
return handlers;
}

Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -27,7 +27,11 @@

import java.io.Serializable;
import java.util.HashMap;
import java.util.List;
import java.util.Map;
import java.util.concurrent.ExecutorService;
import java.util.concurrent.Executors;
import java.util.concurrent.FutureTask;

import static com.twosigma.beakerx.kernel.msg.JupyterMessages.EXECUTE_INPUT;
import static java.util.Collections.singletonList;
Expand All @@ -41,34 +45,51 @@ public class ExecuteRequestHandler extends KernelHandler<Message> {

private final SimpleEvaluationObjectFactory seof;
private int executionCount;
private ExecutorService executorService = Executors.newFixedThreadPool(1);

public ExecuteRequestHandler(KernelFunctionality kernel) {
super(kernel);
this.executionCount = 0;
this.seof = new SimpleEvaluationObjectFactory();
}

private FutureTask<String> current;

@Override
public void handle(Message message) {
try {
handleMsg(message);
executorService.execute(() -> handleMsg(message));
} catch (Exception e) {
handleException(message, e);
}
}


private void handleMsg(Message message) {
executionCount += 1;
kernel.sendBusyMessage(message);
String codeString = takeCodeFrom(message);
announceTheCode(message, codeString);
Code code = new CodeFactory(MessageCreator.get(), seof).create(codeString, message, kernel);
code.execute(kernel, executionCount);
finishExecution(message);
current = new FutureTask<>(() -> {
try {
runCode(message);
} catch (Exception e) {
handleException(message, e);
}
return "ok";
});
current.run();
}

private void finishExecution(Message message) {
kernel.sendIdleMessage(message);
private void runCode(Message message) {
if (kernel.isInterrupting()) {
Message abortedReply = MessageCreator.buildAbortedReply(message);
kernel.send(abortedReply);
} else {
kernel.sendBusyMessage(message);
executionCount += 1;
String codeString = takeCodeFrom(message);
announceTheCode(message, codeString);
Code code = new CodeFactory(MessageCreator.get(), seof).create(codeString, message, kernel);
code.execute(kernel, executionCount);
kernel.sendIdleMessage(message);
}
}

private String takeCodeFrom(Message message) {
Expand Down Expand Up @@ -99,4 +120,20 @@ private void handleException(Message message, Exception e) {
public void exit() {
}

public void interruptKernel() {
waitForTheEndOfTheCurrentCell();
List<Runnable> cells = executorService.shutdownNow();
cells.forEach(Runnable::run);
executorService = Executors.newFixedThreadPool(1);
}

private void waitForTheEndOfTheCurrentCell() {
if (current != null) {
try {
current.get();
} catch (Exception e) {
throw new RuntimeException(e);
}
}
}
}
Original file line number Diff line number Diff line change
@@ -0,0 +1,48 @@
/*
* Copyright 2020 TWO SIGMA OPEN SOURCE, LLC
*
* Licensed under the Apache License, Version 2.0 (the "License");
* you may not use this file except in compliance with the License.
* You may obtain a copy of the License at
*
* http://www.apache.org/licenses/LICENSE-2.0
*
* Unless required by applicable law or agreed to in writing, software
* distributed under the License is distributed on an "AS IS" BASIS,
* WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
* See the License for the specific language governing permissions and
* limitations under the License.
*/
package com.twosigma.beakerx.kernel.handler;

import com.twosigma.beakerx.handler.KernelHandler;
import com.twosigma.beakerx.kernel.KernelFunctionality;
import com.twosigma.beakerx.kernel.msg.JupyterMessages;
import com.twosigma.beakerx.message.Header;
import com.twosigma.beakerx.message.Message;
import org.slf4j.Logger;
import org.slf4j.LoggerFactory;

public class InterruptMsgHandler extends KernelHandler<Message> {

private final static Logger logger = LoggerFactory.getLogger(InterruptMsgHandler.class);

public InterruptMsgHandler(final KernelFunctionality kernel) {
super(kernel);
}

public void handle(Message message) {
Message interruptReply = createInterruptReply(JupyterMessages.INTERRUPT_REPLY, message);
kernel.send(interruptReply);
kernel.interruptKernel();
kernel.interruptKernelDone();
logger.info("Interrupting done");
}

private static Message createInterruptReply(JupyterMessages type, Message message) {
Message reply = new Message(new Header(type, message.getHeader().getSession()));
reply.setParentHeader(message.getHeader());
reply.setIdentities(message.getIdentities());
return reply;
}
}
Loading

0 comments on commit aa69650

Please sign in to comment.