diff --git a/flytekit/bin/entrypoint.py b/flytekit/bin/entrypoint.py index 0ba28ee3334..16439969e0d 100644 --- a/flytekit/bin/entrypoint.py +++ b/flytekit/bin/entrypoint.py @@ -68,9 +68,7 @@ def _compute_array_job_index(): if os.environ.get("BATCH_JOB_ARRAY_INDEX_OFFSET"): offset = int(os.environ.get("BATCH_JOB_ARRAY_INDEX_OFFSET")) if os.environ.get("BATCH_JOB_ARRAY_INDEX_VAR_NAME"): - return offset + int( - os.environ.get(os.environ.get("BATCH_JOB_ARRAY_INDEX_VAR_NAME")) - ) + return offset + int(os.environ.get(os.environ.get("BATCH_JOB_ARRAY_INDEX_VAR_NAME"))) return offset @@ -95,17 +93,13 @@ def _dispatch_execute( # Step1 local_inputs_file = os.path.join(ctx.execution_state.working_dir, "inputs.pb") ctx.file_access.get_data(inputs_path, local_inputs_file) - input_proto = utils.load_proto_from_file( - _literals_pb2.LiteralMap, local_inputs_file - ) + input_proto = utils.load_proto_from_file(_literals_pb2.LiteralMap, local_inputs_file) idl_input_literals = _literal_models.LiteralMap.from_flyte_idl(input_proto) # Step2 # Decorate the dispatch execute function before calling it, this wraps all exceptions into one # of the FlyteScopedExceptions - outputs = _scoped_exceptions.system_entry_point(task_def.dispatch_execute)( - ctx, idl_input_literals - ) + outputs = _scoped_exceptions.system_entry_point(task_def.dispatch_execute)(ctx, idl_input_literals) if inspect.iscoroutine(outputs): # Handle eager-mode (async) tasks logger.info("Output is a coroutine") @@ -114,9 +108,7 @@ def _dispatch_execute( # Step3a if isinstance(outputs, VoidPromise): logger.warning("Task produces no outputs") - output_file_dict = { - _constants.OUTPUT_FILE_NAME: _literal_models.LiteralMap(literals={}) - } + output_file_dict = {_constants.OUTPUT_FILE_NAME: _literal_models.LiteralMap(literals={})} elif isinstance(outputs, _literal_models.LiteralMap): output_file_dict = {_constants.OUTPUT_FILE_NAME: outputs} elif isinstance(outputs, _dynamic_job.DynamicJobSpec): @@ -135,9 +127,7 @@ def _dispatch_execute( # Handle user-scoped errors except _scoped_exceptions.FlyteScopedUserException as e: if isinstance(e.value, IgnoreOutputs): - logger.warning( - f"User-scoped IgnoreOutputs received! Outputs.pb will not be uploaded. reason {e}!!" - ) + logger.warning(f"User-scoped IgnoreOutputs received! Outputs.pb will not be uploaded. reason {e}!!") return output_file_dict[_constants.ERROR_FILE_NAME] = _error_models.ErrorDocument( _error_models.ContainerError( @@ -154,9 +144,7 @@ def _dispatch_execute( # Handle system-scoped errors except _scoped_exceptions.FlyteScopedSystemException as e: if isinstance(e.value, IgnoreOutputs): - logger.warning( - f"System-scoped IgnoreOutputs received! Outputs.pb will not be uploaded. reason {e}!!" - ) + logger.warning(f"System-scoped IgnoreOutputs received! Outputs.pb will not be uploaded. reason {e}!!") return output_file_dict[_constants.ERROR_FILE_NAME] = _error_models.ErrorDocument( _error_models.ContainerError( @@ -183,34 +171,23 @@ def _dispatch_execute( _execution_models.ExecutionError.ErrorKind.SYSTEM, ) ) - logger.error( - f"Exception when executing task {task_def.name or task_def.id.name}, reason {str(e)}" - ) + logger.error(f"Exception when executing task {task_def.name or task_def.id.name}, reason {str(e)}") logger.error("!! Begin Unknown System Error Captured by Flyte !!") logger.error(exc_str) logger.error("!! End Error Captured by Flyte !!") for k, v in output_file_dict.items(): - utils.write_proto_to_file( - v.to_flyte_idl(), os.path.join(ctx.execution_state.engine_dir, k) - ) + utils.write_proto_to_file(v.to_flyte_idl(), os.path.join(ctx.execution_state.engine_dir, k)) - ctx.file_access.put_data( - ctx.execution_state.engine_dir, output_prefix, is_multipart=True - ) - logger.info( - f"Engine folder written successfully to the output prefix {output_prefix}" - ) + ctx.file_access.put_data(ctx.execution_state.engine_dir, output_prefix, is_multipart=True) + logger.info(f"Engine folder written successfully to the output prefix {output_prefix}") if not getattr(task_def, "disable_deck", True): _output_deck(task_def.name.split(".")[-1], ctx.user_space_params) logger.debug("Finished _dispatch_execute") - if ( - os.environ.get("FLYTE_FAIL_ON_ERROR", "").lower() == "true" - and _constants.ERROR_FILE_NAME in output_file_dict - ): + if os.environ.get("FLYTE_FAIL_ON_ERROR", "").lower() == "true" and _constants.ERROR_FILE_NAME in output_file_dict: # This env is set by the flytepropeller # AWS batch job get the status from the exit code, so once we catch the error, # we should return the error code here @@ -285,12 +262,8 @@ def setup_execution( checkpointer = None if checkpoint_path is not None: - checkpointer = SyncCheckpoint( - checkpoint_dest=checkpoint_path, checkpoint_src=prev_checkpoint - ) - logger.debug( - f"Checkpointer created with source {prev_checkpoint} and dest {checkpoint_path}" - ) + checkpointer = SyncCheckpoint(checkpoint_dest=checkpoint_path, checkpoint_src=prev_checkpoint) + logger.debug(f"Checkpointer created with source {prev_checkpoint} and dest {checkpoint_path}") execution_parameters = ExecutionParameters( execution_id=_identifier.WorkflowExecutionIdentifier( @@ -318,9 +291,7 @@ def setup_execution( raw_output_prefix=raw_output_data_prefix, output_metadata_prefix=output_metadata_prefix, checkpoint=checkpointer, - task_id=_identifier.Identifier( - _identifier.ResourceType.TASK, tk_project, tk_domain, tk_name, tk_version - ), + task_id=_identifier.Identifier(_identifier.ResourceType.TASK, tk_project, tk_domain, tk_name, tk_version), ) metadata = { @@ -337,9 +308,7 @@ def setup_execution( execution_metadata=metadata, ) except TypeError: # would be thrown from DataPersistencePlugins.find_plugin - logger.error( - f"No data plugin found for raw output prefix {raw_output_data_prefix}" - ) + logger.error(f"No data plugin found for raw output prefix {raw_output_data_prefix}") raise ctx = ctx.new_builder().with_file_access(file_access).build() @@ -511,9 +480,7 @@ def _execute_map_task( map_task = cloudpickle.load(f) else: mtr = load_object_from_module(resolver)() - map_task = mtr.load_task( - loader_args=resolver_args, max_concurrency=max_concurrency - ) + map_task = mtr.load_task(loader_args=resolver_args, max_concurrency=max_concurrency) # Special case for the map task resolver, we need to append the task index to the output prefix. # TODO: (https://github.com/flyteorg/flyte/issues/5011) Remove legacy map task if mtr.name() == "flytekit.core.legacy_map_task.MapTaskResolver": @@ -542,11 +509,7 @@ def normalize_inputs( raw_output_data_prefix = None if checkpoint_path == "{{.checkpointOutputPrefix}}": checkpoint_path = None - if ( - prev_checkpoint == "{{.prevCheckpointPrefix}}" - or prev_checkpoint == "" - or prev_checkpoint == '""' - ): + if prev_checkpoint == "{{.prevCheckpointPrefix}}" or prev_checkpoint == "" or prev_checkpoint == '""': prev_checkpoint = None return raw_output_data_prefix, checkpoint_path, prev_checkpoint @@ -573,9 +536,7 @@ def _pass_through(): default=False, help="Use this to mark if the distribution is pickled.", ) -@click.option( - "--pkl-file", required=False, help="Location where pickled file can be found." -) +@click.option("--pkl-file", required=False, help="Location where pickled file can be found.") @click.argument( "resolver-args", type=click.UNPROCESSED, @@ -648,9 +609,7 @@ def fast_execute_task_cmd( if pickled: click.secho("Received pickled object") dest_file = os.path.join(os.getcwd(), "pickled.tar.gz") - FlyteContextManager.current_context().file_access.get_data( - additional_distribution, dest_file - ) + FlyteContextManager.current_context().file_access.get_data(additional_distribution, dest_file) cmd_extend = ["--pickled", "--pkl-file", dest_file] else: if not dest_dir: @@ -700,9 +659,7 @@ def handle_sigterm(signum, frame): default=False, help="Use this to mark if the distribution is pickled.", ) -@click.option( - "--pkl-file", required=False, help="Location where pickled file can be found." -) +@click.option("--pkl-file", required=False, help="Location where pickled file can be found.") @click.argument( "resolver-args", type=click.UNPROCESSED, diff --git a/flytekit/core/workflow.py b/flytekit/core/workflow.py index 7f2d9449864..a1420baefb2 100644 --- a/flytekit/core/workflow.py +++ b/flytekit/core/workflow.py @@ -107,12 +107,9 @@ class WorkflowMetadata(object): def __post_init__(self): if ( self.on_failure != WorkflowFailurePolicy.FAIL_IMMEDIATELY - and self.on_failure - != WorkflowFailurePolicy.FAIL_AFTER_EXECUTABLE_NODES_COMPLETE + and self.on_failure != WorkflowFailurePolicy.FAIL_AFTER_EXECUTABLE_NODES_COMPLETE ): - raise FlyteValidationException( - f"Failure policy {self.on_failure} not acceptable" - ) + raise FlyteValidationException(f"Failure policy {self.on_failure} not acceptable") def to_flyte_model(self): if self.on_failure == WorkflowFailurePolicy.FAIL_IMMEDIATELY: @@ -135,21 +132,15 @@ class WorkflowMetadataDefaults(object): def __post_init__(self): # TODO: Get mypy working so we don't have to worry about these checks if self.interruptible is not True and self.interruptible is not False: - raise FlyteValidationException( - f"Interruptible must be boolean, {self.interruptible} invalid" - ) + raise FlyteValidationException(f"Interruptible must be boolean, {self.interruptible} invalid") def to_flyte_model(self): - return _workflow_model.WorkflowMetadataDefaults( - interruptible=self.interruptible - ) + return _workflow_model.WorkflowMetadataDefaults(interruptible=self.interruptible) def construct_input_promises(inputs: List[str]) -> Dict[str, Promise]: return { - input_name: Promise( - var=input_name, val=NodeOutput(node=GLOBAL_START_NODE, var=input_name) - ) + input_name: Promise(var=input_name, val=NodeOutput(node=GLOBAL_START_NODE, var=input_name)) for input_name in inputs } @@ -171,9 +162,7 @@ def get_promise( # binding_data.promise.var is the name of the upstream node's output we want return outputs_cache[binding_data.promise.node][binding_data.promise.var] elif binding_data.scalar is not None: - return Promise( - var="placeholder", val=_literal_models.Literal(scalar=binding_data.scalar) - ) + return Promise(var="placeholder", val=_literal_models.Literal(scalar=binding_data.scalar)) elif binding_data.collection is not None: literals = [] for bd in binding_data.collection.bindings: @@ -181,9 +170,7 @@ def get_promise( literals.append(p.val) return Promise( var="placeholder", - val=_literal_models.Literal( - collection=_literal_models.LiteralCollection(literals=literals) - ), + val=_literal_models.Literal(collection=_literal_models.LiteralCollection(literals=literals)), ) elif binding_data.map is not None: literals = {} # type: ignore @@ -192,9 +179,7 @@ def get_promise( literals[k] = p.val return Promise( var="placeholder", - val=_literal_models.Literal( - map=_literal_models.LiteralMap(literals=literals) - ), + val=_literal_models.Literal(map=_literal_models.LiteralMap(literals=literals)), ) raise FlyteValidationException("Binding type unrecognized.") @@ -245,19 +230,15 @@ def __init__( if self.docs is None: self._docs = Documentation( short_description=self._python_interface.docstring.short_description, - long_description=Description( - value=self._python_interface.docstring.long_description - ), + long_description=Description(value=self._python_interface.docstring.long_description), ) else: if self._python_interface.docstring.short_description: - cast(Documentation, self._docs).short_description = ( - self._python_interface.docstring.short_description - ) + cast( + Documentation, self._docs + ).short_description = self._python_interface.docstring.short_description if self._python_interface.docstring.long_description: - self._docs = Description( - value=self._python_interface.docstring.long_description - ) + self._docs = Description(value=self._python_interface.docstring.long_description) FlyteEntities.entities.append(self) super().__init__(**kwargs) @@ -322,9 +303,7 @@ def construct_node_metadata(self) -> _workflow_model.NodeMetadata: interruptible=self.workflow_metadata_defaults.interruptible, ) - def __call__( - self, *args, **kwargs - ) -> Union[Tuple[Promise], Promise, VoidPromise, Tuple, Coroutine, None]: + def __call__(self, *args, **kwargs) -> Union[Tuple[Promise], Promise, VoidPromise, Tuple, Coroutine, None]: """ Workflow needs to fill in default arguments before invoking the call handler. """ @@ -336,14 +315,9 @@ def __call__( return flyte_entity_call_handler(self, *args, **input_kwargs) except Exception as exc: if self.on_failure: - if ( - self.on_failure.python_interface - and "err" in self.on_failure.python_interface.inputs - ): + if self.on_failure.python_interface and "err" in self.on_failure.python_interface.inputs: id = self.failure_node.id if self.failure_node else "" - input_kwargs["err"] = FlyteError( - failed_node_id=id, message=str(exc) - ) + input_kwargs["err"] = FlyteError(failed_node_id=id, message=str(exc)) self.on_failure(**input_kwargs) raise exc @@ -367,9 +341,7 @@ def remote(self, options: Optional[Options] = None, **kwargs) -> FlyteFuture: def compile(self, **kwargs): pass - def local_execute( - self, ctx: FlyteContext, **kwargs - ) -> Union[Tuple[Promise], Promise, VoidPromise, None]: + def local_execute(self, ctx: FlyteContext, **kwargs) -> Union[Tuple[Promise], Promise, VoidPromise, None]: # This is done to support the invariant that Workflow local executions always work with Promise objects # holding Flyte literal values. Even in a wf, a user can call a sub-workflow with a Python native value. literal_map = translate_inputs_to_literals( @@ -404,9 +376,7 @@ def local_execute( # Because we should've already returned in the above check, we just raise an error here. if len(self.python_interface.outputs) == 0: - raise FlyteValueException( - function_outputs, "Interface output should've been VoidPromise or None." - ) + raise FlyteValueException(function_outputs, "Interface output should've been VoidPromise or None.") expected_output_names = list(self.python_interface.outputs.keys()) if len(expected_output_names) == 1: @@ -414,17 +384,12 @@ def local_execute( # length one. That convention is used for naming outputs - and single-length-NamedTuples are # particularly troublesome but elegant handling of them is not a high priority # Again, we're using the output_tuple_name as a proxy. - if self.python_interface.output_tuple_name and isinstance( - function_outputs, tuple - ): + if self.python_interface.output_tuple_name and isinstance(function_outputs, tuple): wf_outputs_as_map = {expected_output_names[0]: function_outputs[0]} else: wf_outputs_as_map = {expected_output_names[0]: function_outputs} else: - wf_outputs_as_map = { - expected_output_names[i]: function_outputs[i] - for i, _ in enumerate(function_outputs) - } + wf_outputs_as_map = {expected_output_names[i]: function_outputs[i] for i, _ in enumerate(function_outputs)} # Basically we need to repackage the promises coming from the tasks into Promises that match the workflow's # interface. We do that by extracting out the literals, and creating new Promises @@ -435,10 +400,7 @@ def local_execute( native_types=self.python_interface.outputs, ) # Recreate new promises that use the workflow's output names. - new_promises = [ - Promise(var, wf_outputs_as_literal_dict[var]) - for var in expected_output_names - ] + new_promises = [Promise(var, wf_outputs_as_literal_dict[var]) for var in expected_output_names] return create_task_output(new_promises, self.python_interface) @@ -487,9 +449,7 @@ def __init__( failure_policy: Optional[WorkflowFailurePolicy] = None, interruptible: bool = False, ): - metadata = WorkflowMetadata( - on_failure=failure_policy or WorkflowFailurePolicy.FAIL_IMMEDIATELY - ) + metadata = WorkflowMetadata(on_failure=failure_policy or WorkflowFailurePolicy.FAIL_IMMEDIATELY) workflow_metadata_defaults = WorkflowMetadataDefaults(interruptible) self._compilation_state = CompilationState(prefix="") self._inputs = {} @@ -526,10 +486,7 @@ def inputs(self) -> Dict[str, Promise]: return self._inputs def __repr__(self): - return ( - super().__repr__() - + f"Nodes ({len(self.compilation_state.nodes)}): {self.compilation_state.nodes}" - ) + return super().__repr__() + f"Nodes ({len(self.compilation_state.nodes)}): {self.compilation_state.nodes}" def execute(self, **kwargs): """ @@ -542,9 +499,7 @@ def execute(self, **kwargs): After all nodes are run, we fill in workflow level outputs the same way as any other previous node. """ if not self.ready(): - raise FlyteValidationException( - f"Workflow not ready, wf is currently {self}" - ) + raise FlyteValidationException(f"Workflow not ready, wf is currently {self}") # Create a map that holds the outputs of each node. intermediate_node_outputs: Dict[Node, Dict[str, Promise]] = {GLOBAL_START_NODE: {}} # type: ignore @@ -573,26 +528,18 @@ def execute(self, **kwargs): # Because we should've already returned in the above check, we just raise an Exception here. if len(entity.python_interface.outputs) == 0: - raise FlyteValueException( - results, "Interface output should've been VoidPromise or None." - ) + raise FlyteValueException(results, "Interface output should've been VoidPromise or None.") # if there's only one output, if len(expected_output_names) == 1: - if entity.python_interface.output_tuple_name and isinstance( - results, tuple - ): - intermediate_node_outputs[node][expected_output_names[0]] = results[ - 0 - ] + if entity.python_interface.output_tuple_name and isinstance(results, tuple): + intermediate_node_outputs[node][expected_output_names[0]] = results[0] else: intermediate_node_outputs[node][expected_output_names[0]] = results else: if len(results) != len(expected_output_names): - raise FlyteValueException( - results, f"Different lengths {results} {expected_output_names}" - ) + raise FlyteValueException(results, f"Different lengths {results} {expected_output_names}") for idx, r in enumerate(results): intermediate_node_outputs[node][expected_output_names[idx]] = r @@ -611,29 +558,16 @@ def execute(self, **kwargs): # Again use presence of output_tuple_name to understand that we're dealing with a one-element # named tuple if self.python_interface.output_tuple_name: - return ( - get_promise( - self.output_bindings[0].binding, intermediate_node_outputs - ), - ) + return (get_promise(self.output_bindings[0].binding, intermediate_node_outputs),) # Just a normal single element - return get_promise( - self.output_bindings[0].binding, intermediate_node_outputs - ) - return tuple( - [ - get_promise(b.binding, intermediate_node_outputs) - for b in self.output_bindings - ] - ) + return get_promise(self.output_bindings[0].binding, intermediate_node_outputs) + return tuple([get_promise(b.binding, intermediate_node_outputs) for b in self.output_bindings]) def create_conditional(self, name: str) -> ConditionalSection: ctx = FlyteContext.current_context() if ctx.compilation_state is not None: raise RuntimeError("Can't already be compiling") - FlyteContextManager.with_context( - ctx.with_compilation_state(self.compilation_state) - ) + FlyteContextManager.with_context(ctx.with_compilation_state(self.compilation_state)) return conditional(name=name) def add_entity( @@ -650,9 +584,7 @@ def add_entity( ctx = FlyteContext.current_context() if ctx.compilation_state is not None: raise RuntimeError("Can't already be compiling") - with FlyteContextManager.with_context( - ctx.with_compilation_state(self.compilation_state) - ) as ctx: + with FlyteContextManager.with_context(ctx.with_compilation_state(self.compilation_state)) as ctx: n = create_node(entity=entity, **kwargs) def get_input_values(input_value): @@ -673,9 +605,7 @@ def get_input_values(input_value): # values but we're only interested in the ones that are Promises so let's filter for those. # There's probably a way to clean this up, maybe key off of the name instead of value? all_input_values = get_input_values(kwargs) - for input_value in filter( - lambda x: isinstance(x, Promise), all_input_values - ): + for input_value in filter(lambda x: isinstance(x, Promise), all_input_values): if input_value in self._unbound_inputs: self._unbound_inputs.remove(input_value) return n # type: ignore @@ -685,16 +615,10 @@ def add_workflow_input(self, input_name: str, python_type: Type) -> Promise: Adds an input to the workflow. """ if input_name in self._inputs: - raise FlyteValidationException( - f"Input {input_name} has already been specified for wf {self.name}." - ) - self._python_interface = self._python_interface.with_inputs( - extra_inputs={input_name: python_type} - ) + raise FlyteValidationException(f"Input {input_name} has already been specified for wf {self.name}.") + self._python_interface = self._python_interface.with_inputs(extra_inputs={input_name: python_type}) self._interface = transform_interface_to_typed_interface(self._python_interface) - self._inputs[input_name] = Promise( - var=input_name, val=NodeOutput(node=GLOBAL_START_NODE, var=input_name) - ) + self._inputs[input_name] = Promise(var=input_name, val=NodeOutput(node=GLOBAL_START_NODE, var=input_name)) self._unbound_inputs.add(self._inputs[input_name]) return self._inputs[input_name] @@ -708,9 +632,7 @@ def add_workflow_output( Add an output with the given name from the given node output. """ if output_name in self._python_interface.outputs: - raise FlyteValidationException( - f"Output {output_name} already exists in workflow {self.name}" - ) + raise FlyteValidationException(f"Output {output_name} already exists in workflow {self.name}") if python_type is None: if type(p) == list or type(p) == dict: @@ -719,21 +641,15 @@ def add_workflow_output( f" starting with the container type (e.g. List[int]" ) promise = cast(Promise, p) - python_type = promise.ref.node.flyte_entity.python_interface.outputs[ - promise.var - ] - logger.debug( - f"Inferring python type for wf output {output_name} from Promise provided {python_type}" - ) + python_type = promise.ref.node.flyte_entity.python_interface.outputs[promise.var] + logger.debug(f"Inferring python type for wf output {output_name} from Promise provided {python_type}") flyte_type = TypeEngine.to_literal_type(python_type=python_type) ctx = FlyteContext.current_context() if ctx.compilation_state is not None: raise RuntimeError("Can't already be compiling") - with FlyteContextManager.with_context( - ctx.with_compilation_state(self.compilation_state) - ) as ctx: + with FlyteContextManager.with_context(ctx.with_compilation_state(self.compilation_state)) as ctx: b, _ = binding_from_python_std( ctx, output_name, @@ -742,19 +658,13 @@ def add_workflow_output( t_value_type=python_type, ) self._output_bindings.append(b) - self._python_interface = self._python_interface.with_outputs( - extra_outputs={output_name: python_type} - ) - self._interface = transform_interface_to_typed_interface( - self._python_interface - ) + self._python_interface = self._python_interface.with_outputs(extra_outputs={output_name: python_type}) + self._interface = transform_interface_to_typed_interface(self._python_interface) def add_task(self, task: PythonTask, **kwargs) -> Node: return self.add_entity(task, **kwargs) - def add_launch_plan( - self, launch_plan: _annotated_launch_plan.LaunchPlan, **kwargs - ) -> Node: + def add_launch_plan(self, launch_plan: _annotated_launch_plan.LaunchPlan, **kwargs) -> Node: return self.add_entity(launch_plan, **kwargs) def add_subwf(self, sub_wf: WorkflowBase, **kwargs) -> Node: @@ -796,9 +706,7 @@ def __init__( ): name, _, _, _ = extract_task_module(workflow_function) self._workflow_function = workflow_function - native_interface = transform_function_to_interface( - workflow_function, docstring=docstring - ) + native_interface = transform_function_to_interface(workflow_function, docstring=docstring) # TODO do we need this - can this not be in launchplan only? # This can be in launch plan only, but is here only so that we don't have to re-evaluate. Or @@ -821,29 +729,20 @@ def function(self): def task_name(self, t: PythonAutoContainerTask) -> str: # type: ignore return f"{self.name}.{t.__module__}.{t.name}" - def _validate_add_on_failure_handler( - self, ctx: FlyteContext, prefix: str, wf_args: Dict[str, Promise] - ): + def _validate_add_on_failure_handler(self, ctx: FlyteContext, prefix: str, wf_args: Dict[str, Promise]): # Compare with FlyteContextManager.with_context( - ctx.with_compilation_state( - CompilationState(prefix=prefix, task_resolver=self) - ) + ctx.with_compilation_state(CompilationState(prefix=prefix, task_resolver=self)) ) as inner_comp_ctx: # Now lets compile the failure-node if it exists if self.on_failure: c = wf_args.copy() exception_scopes.user_entry_point(self.on_failure)(**c) inner_nodes = None - if ( - inner_comp_ctx.compilation_state - and inner_comp_ctx.compilation_state.nodes - ): + if inner_comp_ctx.compilation_state and inner_comp_ctx.compilation_state.nodes: inner_nodes = inner_comp_ctx.compilation_state.nodes if not inner_nodes or len(inner_nodes) > 1: - raise AssertionError( - "Unable to compile failure node, only either a task or a workflow can be used" - ) + raise AssertionError("Unable to compile failure node, only either a task or a workflow can be used") self._failure_node = inner_nodes[0] def compile(self, **kwargs): @@ -857,23 +756,15 @@ def compile(self, **kwargs): self.compiled = True ctx = FlyteContextManager.current_context() all_nodes = [] - prefix = ( - ctx.compilation_state.prefix if ctx.compilation_state is not None else "" - ) + prefix = ctx.compilation_state.prefix if ctx.compilation_state is not None else "" with FlyteContextManager.with_context( - ctx.with_compilation_state( - CompilationState(prefix=prefix, task_resolver=self) - ) + ctx.with_compilation_state(CompilationState(prefix=prefix, task_resolver=self)) ) as comp_ctx: # Construct the default input promise bindings, but then override with the provided inputs, if any - input_kwargs = construct_input_promises( - [k for k in self.interface.inputs.keys()] - ) + input_kwargs = construct_input_promises([k for k in self.interface.inputs.keys()]) input_kwargs.update(kwargs) - workflow_outputs = exception_scopes.user_entry_point( - self._workflow_function - )(**input_kwargs) + workflow_outputs = exception_scopes.user_entry_point(self._workflow_function)(**input_kwargs) all_nodes.extend(comp_ctx.compilation_state.nodes) # This little loop was added as part of the task resolver change. The task resolver interface itself is @@ -882,16 +773,11 @@ def compile(self, **kwargs): # does store state. This loop adds Tasks that are defined within the body of the workflow to the workflow # object itself. for n in comp_ctx.compilation_state.nodes: - if ( - isinstance(n.flyte_entity, PythonAutoContainerTask) - and n.flyte_entity.task_resolver == self - ): + if isinstance(n.flyte_entity, PythonAutoContainerTask) and n.flyte_entity.task_resolver == self: logger.debug(f"WF {self.name} saving task {n.flyte_entity.name}") self.add(n.flyte_entity) - self._validate_add_on_failure_handler( - comp_ctx, comp_ctx.compilation_state.prefix + "f", input_kwargs - ) + self._validate_add_on_failure_handler(comp_ctx, comp_ctx.compilation_state.prefix + "f", input_kwargs) # Iterate through the workflow outputs bindings = [] @@ -926,18 +812,12 @@ def compile(self, **kwargs): ) from e elif len(output_names) > 1: if not isinstance(workflow_outputs, tuple): - raise AssertionError( - "The Workflow specification indicates multiple return values, received only one" - ) + raise AssertionError("The Workflow specification indicates multiple return values, received only one") if len(output_names) != len(workflow_outputs): - raise ValueError( - f"Length mismatch {len(output_names)} vs {len(workflow_outputs)}" - ) + raise ValueError(f"Length mismatch {len(output_names)} vs {len(workflow_outputs)}") for i, out in enumerate(output_names): if isinstance(workflow_outputs[i], ConditionalSection): - raise AssertionError( - "A Conditional block (if-else) should always end with an `else_()` clause" - ) + raise AssertionError("A Conditional block (if-else) should always end with an `else_()` clause") t = self.python_interface.outputs[out] try: b, _ = binding_from_python_std( @@ -949,9 +829,7 @@ def compile(self, **kwargs): ) bindings.append(b) except Exception as e: - raise FlyteValidationException( - f"Failed to bind output {out} for function {self.name}: {e}" - ) from e + raise FlyteValidationException(f"Failed to bind output {out} for function {self.name}: {e}") from e # Save all the things necessary to create an WorkflowTemplate, except for the missing project and domain self._nodes = all_nodes @@ -1036,9 +914,7 @@ def workflow( """ def wrapper(fn: Callable[P, FuncOut]) -> PythonFunctionWorkflow: - workflow_metadata = WorkflowMetadata( - on_failure=failure_policy or WorkflowFailurePolicy.FAIL_IMMEDIATELY - ) + workflow_metadata = WorkflowMetadata(on_failure=failure_policy or WorkflowFailurePolicy.FAIL_IMMEDIATELY) workflow_metadata_defaults = WorkflowMetadataDefaults(interruptible) @@ -1075,9 +951,7 @@ def __init__( inputs: Dict[str, Type], outputs: Dict[str, Type], ): - super().__init__( - WorkflowReference(project, domain, name, version), inputs, outputs - ) + super().__init__(WorkflowReference(project, domain, name, version), inputs, outputs) def reference_workflow( @@ -1099,8 +973,6 @@ def reference_workflow( def wrapper(fn) -> ReferenceWorkflow: interface = transform_function_to_interface(fn) - return ReferenceWorkflow( - project, domain, name, version, interface.inputs, interface.outputs - ) + return ReferenceWorkflow(project, domain, name, version, interface.inputs, interface.outputs) return wrapper