From 7d3ead6fe2090bbfba8527b03f07058dbe814e4a Mon Sep 17 00:00:00 2001 From: Alex Streed Date: Fri, 22 Nov 2024 16:21:01 -0600 Subject: [PATCH] Add async engine specific tests --- tests/test_flow_engine.py | 82 +++++++++++++++++++++++++++++++-------- 1 file changed, 65 insertions(+), 17 deletions(-) diff --git a/tests/test_flow_engine.py b/tests/test_flow_engine.py index c291804285f8..34b05363eac0 100644 --- a/tests/test_flow_engine.py +++ b/tests/test_flow_engine.py @@ -32,6 +32,7 @@ Pause, ) from prefect.flow_engine import ( + AsyncFlowRunEngine, FlowRunEngine, load_flow_and_flow_run, run_flow, @@ -57,24 +58,24 @@ async def foo(): class TestFlowRunEngine: - async def test_basic_init(self): + def test_basic_init(self): engine = FlowRunEngine(flow=foo) assert isinstance(engine.flow, Flow) assert engine.flow.name == "foo" assert engine.parameters == {} - async def test_empty_init(self): + def test_empty_init(self): with pytest.raises( TypeError, match="missing 1 required positional argument: 'flow'" ): FlowRunEngine() - async def test_client_attr_raises_informative_error(self): + def test_client_attr_raises_informative_error(self): engine = FlowRunEngine(flow=foo) with pytest.raises(RuntimeError, match="not started"): engine.client - async def test_client_attr_returns_client_after_starting(self): + def test_client_attr_returns_client_after_starting(self): engine = FlowRunEngine(flow=foo) with engine.initialize_run(): client = engine.client @@ -83,21 +84,33 @@ async def test_client_attr_returns_client_after_starting(self): with pytest.raises(RuntimeError, match="not started"): engine.client - async def test_load_flow_from_entrypoint(self, monkeypatch, tmp_path, flow_run): - flow_code = """ - from prefect import flow - @flow - def dog(): - return "woof!" - """ - fpath = tmp_path / "f.py" - fpath.write_text(dedent(flow_code)) +class TestAsyncFlowRunEngine: + def test_basic_init(self): + engine = AsyncFlowRunEngine(flow=foo) + assert isinstance(engine.flow, Flow) + assert engine.flow.name == "foo" + assert engine.parameters == {} - monkeypatch.setenv("PREFECT__FLOW_ENTRYPOINT", f"{fpath}:dog") - loaded_flow_run, flow = load_flow_and_flow_run(flow_run.id) - assert loaded_flow_run.id == flow_run.id - assert flow.fn() == "woof!" + def test_empty_init(self): + with pytest.raises( + TypeError, match="missing 1 required positional argument: 'flow'" + ): + AsyncFlowRunEngine() + + def test_client_attr_raises_informative_error(self): + engine = AsyncFlowRunEngine(flow=foo) + with pytest.raises(RuntimeError, match="not started"): + engine.client + + async def test_client_attr_returns_client_after_starting(self): + engine = AsyncFlowRunEngine(flow=foo) + async with engine.initialize_run(): + client = engine.client + assert isinstance(client, SyncPrefectClient) + + with pytest.raises(RuntimeError, match="not started"): + engine.client class TestStartFlowRunEngine: @@ -119,6 +132,25 @@ def flow_with_retries(): engine.begin_run() +class TestStartAsyncFlowRunEngine: + async def test_start_updates_empirical_policy_on_provided_flow_run( + self, prefect_client: PrefectClient + ): + @flow(retries=3, retry_delay_seconds=10) + def flow_with_retries(): + pass + + flow_run = await prefect_client.create_flow_run(flow_with_retries) + + engine = AsyncFlowRunEngine(flow=flow_with_retries, flow_run=flow_run) + async with engine.start(): + assert engine.flow_run.empirical_policy.retries == 3 + assert engine.flow_run.empirical_policy.retry_delay == 10 + + # avoid error on teardown + await engine.begin_run() + + class TestFlowRunsAsync: async def test_basic(self): @flow @@ -1744,6 +1776,22 @@ def g(required: str, model: TheModel = {"x": [1, 2, 3]}): # type: ignore class TestLoadFlowAndFlowRun: + def test_load_flow_from_entrypoint(self, monkeypatch, tmp_path, flow_run): + flow_code = """ + from prefect import flow + + @flow + def dog(): + return "woof!" + """ + fpath = tmp_path / "f.py" + fpath.write_text(dedent(flow_code)) + + monkeypatch.setenv("PREFECT__FLOW_ENTRYPOINT", f"{fpath}:dog") + loaded_flow_run, flow = load_flow_and_flow_run(flow_run.id) + assert loaded_flow_run.id == flow_run.id + assert flow.fn() == "woof!" + async def test_load_flow_from_script_with_module_level_sync_compatible_call( self, prefect_client: PrefectClient, tmp_path ):