diff --git a/src/zenml/integrations/aws/orchestrators/sagemaker_orchestrator.py b/src/zenml/integrations/aws/orchestrators/sagemaker_orchestrator.py index ffccd70daa..dfb748c001 100644 --- a/src/zenml/integrations/aws/orchestrators/sagemaker_orchestrator.py +++ b/src/zenml/integrations/aws/orchestrators/sagemaker_orchestrator.py @@ -478,7 +478,9 @@ def prepare_or_run_pipeline( ) # Yield metadata based on the generated execution object - yield from self.compute_metadata(execution=execution) + yield from self.compute_metadata( + execution=execution, settings=settings + ) # mainly for testing purposes, we wait for the pipeline to finish if settings.synchronous: @@ -577,12 +579,15 @@ def fetch_status(self, run: "PipelineRunResponse") -> ExecutionStatus: raise ValueError("Unknown status for the pipeline execution.") def compute_metadata( - self, execution: Any + self, + execution: Any, + settings: SagemakerOrchestratorSettings, ) -> Iterator[Dict[str, MetadataType]]: """Generate run metadata based on the generated Sagemaker Execution. Args: execution: The corresponding _PipelineExecution object. + settings: The Sagemaker orchestrator settings. Yields: A dictionary of metadata related to the pipeline run. @@ -599,7 +604,9 @@ def compute_metadata( metadata[METADATA_ORCHESTRATOR_URL] = Uri(orchestrator_url) # URL to the corresponding CloudWatch page - if logs_url := self._compute_orchestrator_logs_url(execution): + if logs_url := self._compute_orchestrator_logs_url( + execution, settings + ): metadata[METADATA_ORCHESTRATOR_LOGS_URL] = Uri(logs_url) yield metadata @@ -643,11 +650,13 @@ def _compute_orchestrator_url( @staticmethod def _compute_orchestrator_logs_url( pipeline_execution: Any, + settings: SagemakerOrchestratorSettings, ) -> Optional[str]: """Generate the CloudWatch URL upon pipeline execution. Args: pipeline_execution: The corresponding _PipelineExecution object. + settings: The Sagemaker orchestrator settings. Returns: the URL querying the pipeline logs in CloudWatch on AWS. @@ -657,10 +666,16 @@ def _compute_orchestrator_logs_url( pipeline_execution.arn ) + use_training_jobs = True + if settings.use_training_step is not None: + use_training_jobs = settings.use_training_step + + job_type = "Training" if use_training_jobs else "Processing" + return ( f"https://{region_name}.console.aws.amazon.com/" f"cloudwatch/home?region={region_name}#logsV2:log-groups/log-group" - f"/$252Faws$252Fsagemaker$252FTrainingJobs$3FlogStreamNameFilter" + f"/$252Faws$252Fsagemaker$252F{job_type}Jobs$3FlogStreamNameFilter" f"$3Dpipelines-{execution_id}-" ) except Exception as e: