Skip to content

Commit

Permalink
Always raise, don't swallow exception, if exception is `KeyboardInter…
Browse files Browse the repository at this point in the history
…rupt` in `ProtocolUnit.execute` (#215)

* Always raise, don't swallow exception, if exception is `KeyboardInterrupt` in `ProtocolUnit.execute`

* Added explicit test for KeyboardInterrupt raised in ProtocolUnit._execute

---------

Co-authored-by: Mike Henry <[email protected]>
  • Loading branch information
dotsdl and mikemhenry authored Sep 1, 2023
1 parent 2028a63 commit 0b611a5
Show file tree
Hide file tree
Showing 2 changed files with 34 additions and 1 deletion.
3 changes: 3 additions & 0 deletions gufe/protocols/protocolunit.py
Original file line number Diff line number Diff line change
Expand Up @@ -323,6 +323,9 @@ def execute(self, *,
start_time=start, end_time=datetime.datetime.now(),
)

except KeyboardInterrupt:
# if we "fail" due to a KeyboardInterrupt, we always want to raise
raise
except Exception as e:
if raise_error:
raise
Expand Down
32 changes: 31 additions & 1 deletion gufe/tests/test_protocolunit.py
Original file line number Diff line number Diff line change
Expand Up @@ -2,7 +2,7 @@
import pytest
from pathlib import Path

from gufe.protocols.protocolunit import ProtocolUnit, Context, ProtocolUnitFailure
from gufe.protocols.protocolunit import ProtocolUnit, Context, ProtocolUnitFailure, ProtocolUnitResult
from gufe.tests.test_tokenization import GufeTokenizableTestsMixin


Expand All @@ -16,6 +16,16 @@ def _execute(ctx: Context, an_input=2, **inputs):
return {"foo": "bar"}


class DummyKeyboardInterruptUnit(ProtocolUnit):
@staticmethod
def _execute(ctx: Context, an_input=2, **inputs):

if an_input != 2:
raise KeyboardInterrupt

return {"foo": "bar"}


@pytest.fixture
def dummy_unit():
return DummyUnit(name="qux")
Expand Down Expand Up @@ -66,6 +76,26 @@ def test_execute(self, tmpdir):
with pytest.raises(ValueError, match="should always be 2"):
unit.execute(context=ctx, raise_error=True, an_input=3)

def test_execute_KeyboardInterrupt(self, tmpdir):
with tmpdir.as_cwd():

unit = DummyKeyboardInterruptUnit()

shared = Path('shared') / str(unit.key)
shared.mkdir(parents=True)

scratch = Path('scratch') / str(unit.key)
scratch.mkdir(parents=True)

ctx = Context(shared=shared, scratch=scratch)

with pytest.raises(KeyboardInterrupt):
unit.execute(context=ctx, an_input=3)

u: ProtocolUnitResult = unit.execute(context=ctx, an_input=2)

assert u.outputs == {'foo': 'bar'}

def test_normalize(self, dummy_unit):
thingy = dummy_unit.key

Expand Down

0 comments on commit 0b611a5

Please sign in to comment.