Skip to content

Commit

Permalink
Add async engine specific tests
Browse files Browse the repository at this point in the history
  • Loading branch information
desertaxle committed Nov 22, 2024
1 parent 5252970 commit 7d3ead6
Showing 1 changed file with 65 additions and 17 deletions.
82 changes: 65 additions & 17 deletions tests/test_flow_engine.py
Original file line number Diff line number Diff line change
Expand Up @@ -32,6 +32,7 @@
Pause,
)
from prefect.flow_engine import (
AsyncFlowRunEngine,
FlowRunEngine,
load_flow_and_flow_run,
run_flow,
Expand All @@ -57,24 +58,24 @@ async def foo():


class TestFlowRunEngine:
async def test_basic_init(self):
def test_basic_init(self):
engine = FlowRunEngine(flow=foo)
assert isinstance(engine.flow, Flow)
assert engine.flow.name == "foo"
assert engine.parameters == {}

async def test_empty_init(self):
def test_empty_init(self):
with pytest.raises(
TypeError, match="missing 1 required positional argument: 'flow'"
):
FlowRunEngine()

async def test_client_attr_raises_informative_error(self):
def test_client_attr_raises_informative_error(self):
engine = FlowRunEngine(flow=foo)
with pytest.raises(RuntimeError, match="not started"):
engine.client

async def test_client_attr_returns_client_after_starting(self):
def test_client_attr_returns_client_after_starting(self):
engine = FlowRunEngine(flow=foo)
with engine.initialize_run():
client = engine.client
Expand All @@ -83,21 +84,33 @@ async def test_client_attr_returns_client_after_starting(self):
with pytest.raises(RuntimeError, match="not started"):
engine.client

async def test_load_flow_from_entrypoint(self, monkeypatch, tmp_path, flow_run):
flow_code = """
from prefect import flow

@flow
def dog():
return "woof!"
"""
fpath = tmp_path / "f.py"
fpath.write_text(dedent(flow_code))
class TestAsyncFlowRunEngine:
def test_basic_init(self):
engine = AsyncFlowRunEngine(flow=foo)
assert isinstance(engine.flow, Flow)
assert engine.flow.name == "foo"
assert engine.parameters == {}

monkeypatch.setenv("PREFECT__FLOW_ENTRYPOINT", f"{fpath}:dog")
loaded_flow_run, flow = load_flow_and_flow_run(flow_run.id)
assert loaded_flow_run.id == flow_run.id
assert flow.fn() == "woof!"
def test_empty_init(self):
with pytest.raises(
TypeError, match="missing 1 required positional argument: 'flow'"
):
AsyncFlowRunEngine()

def test_client_attr_raises_informative_error(self):
engine = AsyncFlowRunEngine(flow=foo)
with pytest.raises(RuntimeError, match="not started"):
engine.client

async def test_client_attr_returns_client_after_starting(self):
engine = AsyncFlowRunEngine(flow=foo)
async with engine.initialize_run():
client = engine.client
assert isinstance(client, SyncPrefectClient)

with pytest.raises(RuntimeError, match="not started"):
engine.client


class TestStartFlowRunEngine:
Expand All @@ -119,6 +132,25 @@ def flow_with_retries():
engine.begin_run()


class TestStartAsyncFlowRunEngine:
async def test_start_updates_empirical_policy_on_provided_flow_run(
self, prefect_client: PrefectClient
):
@flow(retries=3, retry_delay_seconds=10)
def flow_with_retries():
pass

flow_run = await prefect_client.create_flow_run(flow_with_retries)

engine = AsyncFlowRunEngine(flow=flow_with_retries, flow_run=flow_run)
async with engine.start():
assert engine.flow_run.empirical_policy.retries == 3
assert engine.flow_run.empirical_policy.retry_delay == 10

# avoid error on teardown
await engine.begin_run()


class TestFlowRunsAsync:
async def test_basic(self):
@flow
Expand Down Expand Up @@ -1744,6 +1776,22 @@ def g(required: str, model: TheModel = {"x": [1, 2, 3]}): # type: ignore


class TestLoadFlowAndFlowRun:
def test_load_flow_from_entrypoint(self, monkeypatch, tmp_path, flow_run):
flow_code = """
from prefect import flow
@flow
def dog():
return "woof!"
"""
fpath = tmp_path / "f.py"
fpath.write_text(dedent(flow_code))

monkeypatch.setenv("PREFECT__FLOW_ENTRYPOINT", f"{fpath}:dog")
loaded_flow_run, flow = load_flow_and_flow_run(flow_run.id)
assert loaded_flow_run.id == flow_run.id
assert flow.fn() == "woof!"

async def test_load_flow_from_script_with_module_level_sync_compatible_call(
self, prefect_client: PrefectClient, tmp_path
):
Expand Down

0 comments on commit 7d3ead6

Please sign in to comment.