Skip to content

Commit

Permalink
Make failure to close a stream an error, as it would be by default.
Browse files Browse the repository at this point in the history
  • Loading branch information
Paul Prescod committed Nov 26, 2021
1 parent ca96441 commit f60cd1c
Show file tree
Hide file tree
Showing 4 changed files with 50 additions and 26 deletions.
13 changes: 3 additions & 10 deletions snowfakery/api.py
Original file line number Diff line number Diff line change
Expand Up @@ -224,16 +224,9 @@ def configure_output_stream(
try:
yield output_stream
finally:
try:
messages = output_stream.close()
except Exception as e:
messages = None
parent_application.echo(
f"Could not close {output_stream}: {str(e)}", err=True
)
if messages:
for message in messages:
parent_application.echo(message)
messages = output_stream.close() or []
for message in messages:
parent_application.echo(message)


@contextmanager
Expand Down
18 changes: 15 additions & 3 deletions snowfakery/output_streams.py
Original file line number Diff line number Diff line change
Expand Up @@ -125,7 +125,7 @@ def close(self) -> Optional[Sequence[str]]:
Return a list of messages to print out.
"""
return super().close()
raise NotImplementedError()

def __enter__(self, *args):
return self
Expand Down Expand Up @@ -550,7 +550,7 @@ def _render(self, dotfile, outfile):
assert dotfile.exists()
try:
out = subprocess.Popen(
["dot", "-T" + self.format, dotfile, "-o" + str(outfile)],
["dot", "-T" + self.format, str(dotfile), "-o" + str(outfile)],
stdout=subprocess.PIPE,
stderr=subprocess.STDOUT,
)
Expand Down Expand Up @@ -578,8 +578,20 @@ def write_row(self, tablename: str, row_with_references: Dict) -> None:
stream.write_row(tablename, row_with_references)

def close(self) -> Optional[Sequence[str]]:
all_messages = []
closing_errors = []
for stream in self.outputstreams:
stream.close()
try:
messages = stream.close() or []
all_messages.extend(messages)
except Exception as e:
closing_errors.append(e)

if len(closing_errors) == 1:
raise closing_errors[0]
elif closing_errors:
raise IOError(f"Could not close streams: {closing_errors}")
return all_messages

def write_single_row(self, tablename: str, row: Dict) -> None:
return super().write_single_row(tablename, row)
21 changes: 8 additions & 13 deletions tests/test_embedding.py
Original file line number Diff line number Diff line change
Expand Up @@ -63,7 +63,7 @@ def test_continuation_as_open_file(self):
with mapping_file.open() as f:
assert yaml.safe_load(f)

def test_parent_application__echo(self):
def test_parent_application__exception_raised(self):
called = False

class MyEmbedder(SnowfakeryApplication):
Expand All @@ -74,10 +74,10 @@ def echo(self, *args, **kwargs):
meth = "snowfakery.output_streams.DebugOutputStream.close"
with mock.patch(meth) as close:
close.side_effect = AssertionError
generate_data(
yaml_file="examples/company.yml", parent_application=MyEmbedder()
)
assert called
with pytest.raises(AssertionError):
generate_data(
yaml_file="examples/company.yml", parent_application=MyEmbedder()
)

def test_parent_application__early_finish(self, generated_rows):
class MyEmbedder(SnowfakeryApplication):
Expand All @@ -89,14 +89,9 @@ def check_if_finished(self, idmanager):
assert self.__class__.count < 100, "Runaway recipe!"
return idmanager["Employee"] >= 10

meth = "snowfakery.output_streams.DebugOutputStream.close"
with mock.patch(meth) as close:
close.side_effect = AssertionError
generate_data(
yaml_file="examples/company.yml", parent_application=MyEmbedder()
)
# called 5 times, after generating 2 employees each
assert MyEmbedder.count == 5
generate_data(yaml_file="examples/company.yml", parent_application=MyEmbedder())
# called 5 times, after generating 2 employees each
assert MyEmbedder.count == 5

def test_embedding__cannot_infer_output_format(self):
with pytest.raises(exc.DataGenError, match="No format"):
Expand Down
24 changes: 24 additions & 0 deletions tests/test_output_streams.py
Original file line number Diff line number Diff line change
Expand Up @@ -6,6 +6,7 @@
from pathlib import Path
from tempfile import TemporaryDirectory
from contextlib import redirect_stdout
from unittest import mock


import pytest
Expand Down Expand Up @@ -375,3 +376,26 @@ def test_external_output_stream__failure(self):
generate_cli.callback(
yaml_file=sample_yaml, output_format="no.such.output.Stream"
)


class TestMultiplexOutputStream:
@mock.patch("snowfakery.output_streams.DebugOutputStream.close")
def test_cannot_close_multiple_streams(self, close):
close.side_effect = AssertionError
with TemporaryDirectory() as t:
files = [Path(t) / "1.txt", Path(t) / "2.txt"]
with pytest.raises(IOError) as e:
generate_cli.callback(
yaml_file="examples/company.yml", output_files=files
)
assert "Could not close streams:" in str(e.value)

@mock.patch("snowfakery.output_streams.DebugOutputStream.close")
def test_cannot_close_one_stream(self, close):
close.side_effect = AssertionError
with TemporaryDirectory() as t:
files = [Path(t) / "1.txt", Path(t) / "2.jpg"]
with pytest.raises(AssertionError):
generate_cli.callback(
yaml_file="examples/company.yml", output_files=files
)

0 comments on commit f60cd1c

Please sign in to comment.