Skip to content

Commit

Permalink
lint
Browse files Browse the repository at this point in the history
Signed-off-by: Mecoli1219 <[email protected]>
  • Loading branch information
Mecoli1219 committed Aug 14, 2024
1 parent 57a19cb commit 0bd6edc
Show file tree
Hide file tree
Showing 2 changed files with 82 additions and 253 deletions.
83 changes: 20 additions & 63 deletions flytekit/bin/entrypoint.py
Original file line number Diff line number Diff line change
Expand Up @@ -68,9 +68,7 @@ def _compute_array_job_index():
if os.environ.get("BATCH_JOB_ARRAY_INDEX_OFFSET"):
offset = int(os.environ.get("BATCH_JOB_ARRAY_INDEX_OFFSET"))
if os.environ.get("BATCH_JOB_ARRAY_INDEX_VAR_NAME"):
return offset + int(
os.environ.get(os.environ.get("BATCH_JOB_ARRAY_INDEX_VAR_NAME"))
)
return offset + int(os.environ.get(os.environ.get("BATCH_JOB_ARRAY_INDEX_VAR_NAME")))
return offset


Expand All @@ -95,17 +93,13 @@ def _dispatch_execute(
# Step1
local_inputs_file = os.path.join(ctx.execution_state.working_dir, "inputs.pb")
ctx.file_access.get_data(inputs_path, local_inputs_file)
input_proto = utils.load_proto_from_file(
_literals_pb2.LiteralMap, local_inputs_file
)
input_proto = utils.load_proto_from_file(_literals_pb2.LiteralMap, local_inputs_file)
idl_input_literals = _literal_models.LiteralMap.from_flyte_idl(input_proto)

# Step2
# Decorate the dispatch execute function before calling it, this wraps all exceptions into one
# of the FlyteScopedExceptions
outputs = _scoped_exceptions.system_entry_point(task_def.dispatch_execute)(
ctx, idl_input_literals
)
outputs = _scoped_exceptions.system_entry_point(task_def.dispatch_execute)(ctx, idl_input_literals)
if inspect.iscoroutine(outputs):
# Handle eager-mode (async) tasks
logger.info("Output is a coroutine")
Expand All @@ -114,9 +108,7 @@ def _dispatch_execute(
# Step3a
if isinstance(outputs, VoidPromise):
logger.warning("Task produces no outputs")
output_file_dict = {
_constants.OUTPUT_FILE_NAME: _literal_models.LiteralMap(literals={})
}
output_file_dict = {_constants.OUTPUT_FILE_NAME: _literal_models.LiteralMap(literals={})}
elif isinstance(outputs, _literal_models.LiteralMap):
output_file_dict = {_constants.OUTPUT_FILE_NAME: outputs}
elif isinstance(outputs, _dynamic_job.DynamicJobSpec):
Expand All @@ -135,9 +127,7 @@ def _dispatch_execute(
# Handle user-scoped errors
except _scoped_exceptions.FlyteScopedUserException as e:
if isinstance(e.value, IgnoreOutputs):
logger.warning(
f"User-scoped IgnoreOutputs received! Outputs.pb will not be uploaded. reason {e}!!"
)
logger.warning(f"User-scoped IgnoreOutputs received! Outputs.pb will not be uploaded. reason {e}!!")
return
output_file_dict[_constants.ERROR_FILE_NAME] = _error_models.ErrorDocument(
_error_models.ContainerError(
Expand All @@ -154,9 +144,7 @@ def _dispatch_execute(
# Handle system-scoped errors
except _scoped_exceptions.FlyteScopedSystemException as e:
if isinstance(e.value, IgnoreOutputs):
logger.warning(
f"System-scoped IgnoreOutputs received! Outputs.pb will not be uploaded. reason {e}!!"
)
logger.warning(f"System-scoped IgnoreOutputs received! Outputs.pb will not be uploaded. reason {e}!!")
return
output_file_dict[_constants.ERROR_FILE_NAME] = _error_models.ErrorDocument(
_error_models.ContainerError(
Expand All @@ -183,34 +171,23 @@ def _dispatch_execute(
_execution_models.ExecutionError.ErrorKind.SYSTEM,
)
)
logger.error(
f"Exception when executing task {task_def.name or task_def.id.name}, reason {str(e)}"
)
logger.error(f"Exception when executing task {task_def.name or task_def.id.name}, reason {str(e)}")
logger.error("!! Begin Unknown System Error Captured by Flyte !!")
logger.error(exc_str)
logger.error("!! End Error Captured by Flyte !!")

for k, v in output_file_dict.items():
utils.write_proto_to_file(
v.to_flyte_idl(), os.path.join(ctx.execution_state.engine_dir, k)
)
utils.write_proto_to_file(v.to_flyte_idl(), os.path.join(ctx.execution_state.engine_dir, k))

ctx.file_access.put_data(
ctx.execution_state.engine_dir, output_prefix, is_multipart=True
)
logger.info(
f"Engine folder written successfully to the output prefix {output_prefix}"
)
ctx.file_access.put_data(ctx.execution_state.engine_dir, output_prefix, is_multipart=True)
logger.info(f"Engine folder written successfully to the output prefix {output_prefix}")

if not getattr(task_def, "disable_deck", True):
_output_deck(task_def.name.split(".")[-1], ctx.user_space_params)

logger.debug("Finished _dispatch_execute")

if (
os.environ.get("FLYTE_FAIL_ON_ERROR", "").lower() == "true"
and _constants.ERROR_FILE_NAME in output_file_dict
):
if os.environ.get("FLYTE_FAIL_ON_ERROR", "").lower() == "true" and _constants.ERROR_FILE_NAME in output_file_dict:
# This env is set by the flytepropeller
# AWS batch job get the status from the exit code, so once we catch the error,
# we should return the error code here
Expand Down Expand Up @@ -285,12 +262,8 @@ def setup_execution(

checkpointer = None
if checkpoint_path is not None:
checkpointer = SyncCheckpoint(
checkpoint_dest=checkpoint_path, checkpoint_src=prev_checkpoint
)
logger.debug(
f"Checkpointer created with source {prev_checkpoint} and dest {checkpoint_path}"
)
checkpointer = SyncCheckpoint(checkpoint_dest=checkpoint_path, checkpoint_src=prev_checkpoint)
logger.debug(f"Checkpointer created with source {prev_checkpoint} and dest {checkpoint_path}")

execution_parameters = ExecutionParameters(
execution_id=_identifier.WorkflowExecutionIdentifier(
Expand Down Expand Up @@ -318,9 +291,7 @@ def setup_execution(
raw_output_prefix=raw_output_data_prefix,
output_metadata_prefix=output_metadata_prefix,
checkpoint=checkpointer,
task_id=_identifier.Identifier(
_identifier.ResourceType.TASK, tk_project, tk_domain, tk_name, tk_version
),
task_id=_identifier.Identifier(_identifier.ResourceType.TASK, tk_project, tk_domain, tk_name, tk_version),
)

metadata = {
Expand All @@ -337,9 +308,7 @@ def setup_execution(
execution_metadata=metadata,
)
except TypeError: # would be thrown from DataPersistencePlugins.find_plugin
logger.error(
f"No data plugin found for raw output prefix {raw_output_data_prefix}"
)
logger.error(f"No data plugin found for raw output prefix {raw_output_data_prefix}")
raise

ctx = ctx.new_builder().with_file_access(file_access).build()
Expand Down Expand Up @@ -511,9 +480,7 @@ def _execute_map_task(
map_task = cloudpickle.load(f)
else:
mtr = load_object_from_module(resolver)()
map_task = mtr.load_task(
loader_args=resolver_args, max_concurrency=max_concurrency
)
map_task = mtr.load_task(loader_args=resolver_args, max_concurrency=max_concurrency)
# Special case for the map task resolver, we need to append the task index to the output prefix.
# TODO: (https://github.com/flyteorg/flyte/issues/5011) Remove legacy map task
if mtr.name() == "flytekit.core.legacy_map_task.MapTaskResolver":
Expand Down Expand Up @@ -542,11 +509,7 @@ def normalize_inputs(
raw_output_data_prefix = None
if checkpoint_path == "{{.checkpointOutputPrefix}}":
checkpoint_path = None
if (
prev_checkpoint == "{{.prevCheckpointPrefix}}"
or prev_checkpoint == ""
or prev_checkpoint == '""'
):
if prev_checkpoint == "{{.prevCheckpointPrefix}}" or prev_checkpoint == "" or prev_checkpoint == '""':
prev_checkpoint = None

return raw_output_data_prefix, checkpoint_path, prev_checkpoint
Expand All @@ -573,9 +536,7 @@ def _pass_through():
default=False,
help="Use this to mark if the distribution is pickled.",
)
@click.option(
"--pkl-file", required=False, help="Location where pickled file can be found."
)
@click.option("--pkl-file", required=False, help="Location where pickled file can be found.")
@click.argument(
"resolver-args",
type=click.UNPROCESSED,
Expand Down Expand Up @@ -648,9 +609,7 @@ def fast_execute_task_cmd(
if pickled:
click.secho("Received pickled object")
dest_file = os.path.join(os.getcwd(), "pickled.tar.gz")
FlyteContextManager.current_context().file_access.get_data(
additional_distribution, dest_file
)
FlyteContextManager.current_context().file_access.get_data(additional_distribution, dest_file)
cmd_extend = ["--pickled", "--pkl-file", dest_file]
else:
if not dest_dir:
Expand Down Expand Up @@ -700,9 +659,7 @@ def handle_sigterm(signum, frame):
default=False,
help="Use this to mark if the distribution is pickled.",
)
@click.option(
"--pkl-file", required=False, help="Location where pickled file can be found."
)
@click.option("--pkl-file", required=False, help="Location where pickled file can be found.")
@click.argument(
"resolver-args",
type=click.UNPROCESSED,
Expand Down
Loading

0 comments on commit 0bd6edc

Please sign in to comment.