Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Parameters, scheduling and python deployment execution #19

Merged
merged 3 commits into from
Feb 26, 2024
Merged
Show file tree
Hide file tree
Changes from 2 commits
Commits
File filter

Filter by extension

Filter by extension


Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
22 changes: 22 additions & 0 deletions examples/simple/example_cron.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,22 @@
from giza_actions.action import Action, action
from giza_actions.model import GizaModel
from giza_actions.task import task

@task
def preprocess():
print(f"Preprocessing...")


@task
def transform():
print(f"Transforming...")


@action(log_prints=True)
def inference():
preprocess()
transform()

if __name__ == '__main__':
action_deploy = Action(entrypoint=inference, name="inference-local-action")
action_deploy.serve(name="inference-local-action", cron="* * * * *")
22 changes: 22 additions & 0 deletions examples/simple/example_interval.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,22 @@
from giza_actions.action import Action, action
from giza_actions.model import GizaModel
from giza_actions.task import task

@task
def preprocess():
print(f"Preprocessing...")


@task
def transform():
print(f"Transforming...")


@action(log_prints=True)
def inference():
preprocess()
transform()

if __name__ == '__main__':
action_deploy = Action(entrypoint=inference, name="inference-local-action")
action_deploy.serve(name="inference-local-action", interval=10)
25 changes: 25 additions & 0 deletions examples/simple/example_parameters.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,25 @@
from giza_actions.action import Action, action
from giza_actions.model import GizaModel
from giza_actions.task import task

@task
def preprocess(example_parameter: bool = False):
print(f"Preprocessing with example={example_parameter}")
print(f"Preprocessing...")


@task
def transform(example_parameter: bool = False):
print(f"Transforming with example={example_parameter}")
print(f"Transforming...")


@action(log_prints=True)
def inference(example_parameter: bool = False):
print(f"Running inference with example={example_parameter}")
preprocess(example_parameter=example_parameter)
transform(example_parameter=example_parameter)

if __name__ == '__main__':
action_deploy = Action(entrypoint=inference, name="inference-local-action")
action_deploy.serve(name="inference-local-action", parameters={"example_parameter": False})
28 changes: 23 additions & 5 deletions giza_actions/action.py
Original file line number Diff line number Diff line change
@@ -1,6 +1,7 @@
import os
from functools import partial, wraps
from pathlib import Path
from typing import Optional

from giza_actions.utils import get_workspace_uri # noqa: E402

Expand All @@ -9,6 +10,7 @@

from prefect import Flow # noqa: E402
from prefect import flow as _flow # noqa: E402
from prefect.client.schemas.schedules import construct_schedule # noqa: E402
from prefect.settings import PREFECT_API_URL # noqa: E402
from prefect.settings import ( # noqa: E402
PREFECT_LOGGING_SETTINGS_PATH,
Expand Down Expand Up @@ -80,6 +82,9 @@ def get_flow(self):
async def serve(
self,
name: str,
cron: Optional[str] = None,
interval: Optional[str] = None,
parameters: Optional[dict] = None,
print_starting_message: bool = True,
):
"""
Expand All @@ -88,6 +93,9 @@ async def serve(
Args:
name (str): The name to assign to the runner. If a file path is provided, it uses the file name without the extension.
print_starting_message (bool, optional): Whether to print a starting message. Defaults to True.
interval: An interval on which to schedule runs. Accepts either a number
or a timedelta object. If a number is given, it will be interpreted as seconds.
cron: A cron schedule for runs.
"""

workspace_url = get_workspace_uri()
Expand All @@ -103,10 +111,20 @@ async def serve(
# Non filepath strings will pass through unchanged
name = Path(name).stem

schedule = None

if interval or cron:
schedule = construct_schedule(
interval=interval,
cron=cron,
)

runner = Runner(name=name, pause_on_shutdown=False)
deployment_id = await runner.add_flow(
self._flow,
name=name,
schedule=schedule,
parameters=parameters,
)
if print_starting_message:
help_message = (
Expand All @@ -124,7 +142,7 @@ async def serve(
await runner.start(webserver=False)


def action(func=None, **task_init_kwargs):
def action(func=None, *task_init_args, **task_init_kwargs):
"""
Decorator to convert a function into a Prefect flow.

Expand All @@ -136,10 +154,10 @@ def action(func=None, **task_init_kwargs):
Flow: The Prefect flow created from the function.
"""
if func is None:
return partial(action, **task_init_kwargs)
return partial(action, *task_init_args, **task_init_kwargs)

@wraps(func)
def safe_func(**kwargs):
def safe_func(*args, **kwargs):
"""
A wrapper function that calls the original function with its arguments.

Expand All @@ -149,7 +167,7 @@ def safe_func(**kwargs):
Returns:
The return value of the original function.
"""
return func(**kwargs)
return func(*args, **kwargs)

safe_func.__name__ = func.__name__
return _flow(safe_func, **task_init_kwargs)
return _flow(safe_func, *task_init_args, **task_init_kwargs)
16 changes: 16 additions & 0 deletions giza_actions/deployments.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,16 @@
import os # noqa: E402
Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Does this one need noqa?


from giza_actions.utils import get_workspace_uri # noqa: E402

os.environ["PREFECT_API_URL"] = f"{get_workspace_uri()}/api"
os.environ["PREFECT_UI_URL"] = get_workspace_uri()

from prefect.deployments import run_deployment # noqa: E402


def run_action_deployment(name: str, parameters: dict = None):
deployment_run = run_deployment(name=name, parameters=parameters)
print(
f"Deployment run name: {deployment_run.name} exited with state: {deployment_run.state_name}"
)
return deployment_run
2 changes: 1 addition & 1 deletion pyproject.toml
Original file line number Diff line number Diff line change
@@ -1,6 +1,6 @@
[tool.poetry]
name = "giza-actions"
version = "0.1.3"
version = "0.2.0"
description = "A Python SDK for Giza platform"
authors = [
"Francisco Algaba <[email protected]>",
Expand Down
Loading