Skip to content

Commit

Permalink
Run template CLI command and bugfix (#3225)
Browse files Browse the repository at this point in the history
  • Loading branch information
schustmi authored Nov 28, 2024
1 parent 1b16f16 commit 34f3fe9
Show file tree
Hide file tree
Showing 2 changed files with 86 additions and 2 deletions.
80 changes: 80 additions & 0 deletions src/zenml/cli/pipeline.py
Original file line number Diff line number Diff line change
Expand Up @@ -315,6 +315,86 @@ def run_pipeline(
pipeline_instance()


@pipeline.command(
"create-run-template",
help="Create a run template for a pipeline. The SOURCE argument needs to "
"be an importable source path resolving to a ZenML pipeline instance, e.g. "
"`my_module.my_pipeline_instance`.",
)
@click.argument("source")
@click.option(
"--name",
"-n",
type=str,
required=True,
help="Name for the template",
)
@click.option(
"--config",
"-c",
"config_path",
type=click.Path(exists=True, dir_okay=False),
required=False,
help="Path to configuration file for the build.",
)
@click.option(
"--stack",
"-s",
"stack_name_or_id",
type=str,
required=False,
help="Name or ID of the stack to use for the build.",
)
def create_run_template(
source: str,
name: str,
config_path: Optional[str] = None,
stack_name_or_id: Optional[str] = None,
) -> None:
"""Create a run template for a pipeline.
Args:
source: Importable source resolving to a pipeline instance.
name: Name of the run template.
config_path: Path to pipeline configuration file.
stack_name_or_id: Name or ID of the stack for which the template should
be created.
"""
if not Client().root:
cli_utils.warning(
"You're running the `zenml pipeline create-run-template` command "
"without a ZenML repository. Your current working directory will "
"be used as the source root relative to which the registered step "
"classes will be resolved. To silence this warning, run `zenml "
"init` at your source code root."
)

try:
pipeline_instance = source_utils.load(source)
except ModuleNotFoundError as e:
source_root = source_utils.get_source_root()
cli_utils.error(
f"Unable to import module `{e.name}`. Make sure the source path is "
f"relative to your source root `{source_root}`."
)
except AttributeError as e:
cli_utils.error("Unable to load attribute from module: " + str(e))

if not isinstance(pipeline_instance, Pipeline):
cli_utils.error(
f"The given source path `{source}` does not resolve to a pipeline "
"object."
)

with cli_utils.temporary_active_stack(stack_name_or_id=stack_name_or_id):
pipeline_instance = pipeline_instance.with_options(
config_path=config_path
)
template = pipeline_instance.create_run_template(name=name)

cli_utils.declare(f"Created run template `{template.id}`.")


@pipeline.command("list", help="List all registered pipelines.")
@list_options(PipelineFilter)
def list_pipelines(**kwargs: Any) -> None:
Expand Down
8 changes: 6 additions & 2 deletions src/zenml/pipelines/pipeline_definition.py
Original file line number Diff line number Diff line change
Expand Up @@ -583,6 +583,7 @@ def _create_deployment(
config_path: Optional[str] = None,
unlisted: bool = False,
prevent_build_reuse: bool = False,
skip_schedule_registration: bool = False,
) -> PipelineDeploymentResponse:
"""Create a pipeline deployment.
Expand All @@ -609,6 +610,7 @@ def _create_deployment(
to any pipeline).
prevent_build_reuse: DEPRECATED: Use
`DockerSettings.prevent_build_reuse` instead.
skip_schedule_registration: Whether to skip schedule registration.
Returns:
The pipeline deployment.
Expand Down Expand Up @@ -649,7 +651,7 @@ def _create_deployment(
stack.validate()

schedule_id = None
if schedule:
if schedule and not skip_schedule_registration:
if not stack.orchestrator.config.is_schedulable:
raise ValueError(
f"Stack {stack.name} does not support scheduling. "
Expand Down Expand Up @@ -1445,7 +1447,9 @@ def create_run_template(
The created run template.
"""
self._prepare_if_possible()
deployment = self._create_deployment(**self._run_args)
deployment = self._create_deployment(
**self._run_args, skip_schedule_registration=True
)

return Client().create_run_template(
name=name, deployment_id=deployment.id, **kwargs
Expand Down

0 comments on commit 34f3fe9

Please sign in to comment.