Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Refactor to enable RayGraphAdapter and HamiltonTracker to work well together #1103

Merged
merged 33 commits into from
Sep 9, 2024
Merged
Show file tree
Hide file tree
Changes from 27 commits
Commits
Show all changes
33 commits
Select commit Hold shift + click to select a range
74b152b
Update graph_functions.py
elijahbenizzy Aug 15, 2024
41decaa
Adds comments to lifecycle base
elijahbenizzy Aug 15, 2024
5b73b8f
Update h_ray.py with comments for ray tracking compatibility
elijahbenizzy Aug 15, 2024
aa3ac05
Replicate previous error
Aug 19, 2024
e519180
Inline function, unsure if catching errors and exceptions to be handa…
Aug 20, 2024
2dca334
BaseDoRemoteExecute has the added Callable function that snadwisched …
Aug 20, 2024
04f1a1b
method fails, says AssertionError about ray.remote decorator
Aug 20, 2024
b77860e
simple script for now to check telemetry, execution yield the ray.rem…
Aug 20, 2024
c8358f8
passing pointer through and arguments to lifecycle wrapper into ray.r…
Aug 20, 2024
e77f6f7
post-execute hook for node not called
Aug 21, 2024
f7e81a0
finally executed only when exception occurs, hamilton tracker not exe…
Aug 21, 2024
3a1cccd
atexit.register does not work, node keeps running inui
Aug 23, 2024
09b47de
added stop() method, but doesn't get called
Aug 23, 2024
933991f
Ray telemtry works for single node, problem with connected nodes
Aug 23, 2024
d1e6ea0
Ray telemtry works for single node, problem with connected nodes
Aug 23, 2024
b528a45
Ray telemtry works for single node, problem with connected nodes
Aug 23, 2024
7acdd38
Fixes ray object dereferencing
skrawcz Aug 25, 2024
377dc71
ray works checkpoint, pre-commit fixed
Aug 25, 2024
8e3eacd
fixed graph level telemtry proposal
Aug 25, 2024
5aa592b
pinned ruff
Aug 26, 2024
36e8fcb
Correct output, added option to start ray cluster
Aug 28, 2024
f17d4f2
Unit test mimicks the DoNodeExecute unit test
Aug 31, 2024
abd87f7
Refactored driver so all tests pass
Aug 31, 2024
4e28677
Workaround to not break ray by calling init on an open cluster
Aug 31, 2024
d9e86a9
raw_execute does not have post_graph_execute and is private now
Sep 1, 2024
6953417
Correct version for depraction warning
Sep 1, 2024
e988db5
all tests work
Sep 1, 2024
51c8544
this looks better
Sep 2, 2024
3acd95c
ruff version comment
Sep 2, 2024
a556558
Refactored pre- and post-graph-execute hooks outside of raw_execute w…
Sep 3, 2024
c1b55ee
added readme, notebook and made script cli interactive
Sep 6, 2024
089d1de
made cluster init optional through inserting config dict
Sep 6, 2024
1985ef7
User has option to shutdown ray cluster
jernejfrank Sep 7, 2024
File filter

Filter by extension

Filter by extension


Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
2 changes: 1 addition & 1 deletion .pre-commit-config.yaml
Original file line number Diff line number Diff line change
Expand Up @@ -14,7 +14,7 @@ repos:
args: [ --fix ]
# Run the formatter.
- id: ruff-format
# args: [ --diff ] # Use for previewing changes
# args: [ --diff ] # Use for previewing changes
- repo: https://github.com/pre-commit/pre-commit-hooks
rev: v4.6.0
hooks:
Expand Down
Binary file added dag_example_module.png
jernejfrank marked this conversation as resolved.
Show resolved Hide resolved
Loading
Sorry, something went wrong. Reload?
Sorry, we cannot display this file.
Sorry, this file is invalid so it cannot be displayed.
6 changes: 3 additions & 3 deletions hamilton/dev_utils/deprecation.py
Original file line number Diff line number Diff line change
Expand Up @@ -48,8 +48,8 @@ class deprecated:
@deprecate(
warn_starting=(1,10,0)
fail_starting=(2,0,0),
use_instead=parameterize_values,
reason='We have redefined the parameterization decorators to consist of `parametrize`, `parametrize_inputs`, and `parametrize_values`
use_this=parameterize_values,
explanation='We have redefined the parameterization decorators to consist of `parametrize`, `parametrize_inputs`, and `parametrize_values`
migration_guide="https://github.com/dagworks-inc/hamilton/..."
)
class parameterized(...):
Expand All @@ -66,7 +66,7 @@ class parameterized(...):
explanation: str
migration_guide: Optional[
str
] # If this is None, this means that the use_instead is a drop in replacement
] # If this is None, this means that the use_this is a drop in replacement
current_version: Union[Tuple[int, int, int], Version] = dataclasses.field(
default_factory=lambda: CURRENT_VERSION
)
Expand Down
112 changes: 80 additions & 32 deletions hamilton/driver.py
Original file line number Diff line number Diff line change
Expand Up @@ -19,6 +19,7 @@
import pandas as pd

from hamilton import common, graph_types, htypes
from hamilton.dev_utils import deprecation
from hamilton.execution import executors, graph_functions, grouping, state
from hamilton.graph_types import HamiltonNode
from hamilton.io import materialization
Expand Down Expand Up @@ -580,10 +581,12 @@ def execute(
)
start_time = time.time()
run_successful = True
error = None
telemetry_error = None
execution_error = None
outputs = None
_final_vars = self._create_final_vars(final_vars)
try:
outputs = self.raw_execute(_final_vars, overrides, display_graph, inputs=inputs)
outputs = self.__raw_execute(_final_vars, overrides, display_graph, inputs=inputs)
if self.adapter.does_method("do_build_result", is_async=False):
# Build the result if we have a result builder
return self.adapter.call_lifecycle_method_sync("do_build_result", outputs=outputs)
Expand All @@ -592,12 +595,22 @@ def execute(
except Exception as e:
run_successful = False
logger.error(SLACK_ERROR_MESSAGE)
error = telemetry.sanitize_error(*sys.exc_info())
execution_error = e
telemetry_error = telemetry.sanitize_error(*sys.exc_info())
raise e
finally:
if self.adapter.does_hook("post_graph_execute", is_async=False):
self.adapter.call_all_lifecycle_hooks_sync(
"post_graph_execute",
run_id=self.run_id,
graph=self.function_graph,
success=run_successful,
error=execution_error,
results=outputs,
)
duration = time.time() - start_time
self.capture_execute_telemetry(
error, _final_vars, inputs, overrides, run_successful, duration
telemetry_error, _final_vars, inputs, overrides, run_successful, duration
)

def _create_final_vars(self, final_vars: List[Union[str, Callable, Variable]]) -> List[str]:
Expand Down Expand Up @@ -649,6 +662,13 @@ def capture_execute_telemetry(
if logger.isEnabledFor(logging.DEBUG):
logger.debug(f"Error caught in processing telemetry: \n{e}")

@deprecation.deprecated(
warn_starting=(1, 75, 0),
fail_starting=(2, 0, 0),
use_this=None,
explanation="This has become a private method and does not guarantee that all the adapters work correctly.",
migration_guide="Don't use this entry point for execution directly. Always go through `.execute()`.",
skrawcz marked this conversation as resolved.
Show resolved Hide resolved
)
def raw_execute(
self,
final_vars: List[str],
Expand All @@ -657,24 +677,55 @@ def raw_execute(
inputs: Dict[str, Any] = None,
_fn_graph: graph.FunctionGraph = None,
) -> Dict[str, Any]:
"""Raw execute function that does the meat of execute.

Don't use this entry point for execution directly. Always go through `.execute()`.
"""Don't use this entry point for execution directly. Always go through `.execute()`.
skrawcz marked this conversation as resolved.
Show resolved Hide resolved
In case you are using `.raw_execute()` directly, please switch to `.execute()` using a
`base.DictResult()`. Note: `base.DictResult()` is the default return of execute if you are
using the `driver.Builder()` class to create a `Driver()` object.
"""
success = True
error = None
results = None
try:
return self.__raw_execute(final_vars, overrides, display_graph, inputs=inputs)
except Exception as e:
success = False
logger.error(SLACK_ERROR_MESSAGE)
error = e
raise e
finally:
if self.adapter.does_hook("post_graph_execute", is_async=False):
self.adapter.call_all_lifecycle_hooks_sync(
"post_graph_execute",
run_id=self.run_id,
graph=self.function_graph,
success=success,
error=error,
results=results,
)

def __raw_execute(
self,
final_vars: List[str],
overrides: Dict[str, Any] = None,
display_graph: bool = False,
inputs: Dict[str, Any] = None,
_fn_graph: graph.FunctionGraph = None,
) -> Dict[str, Any]:
"""Raw execute function that does the meat of execute.

Private method since the result building and post_graph_execute lifecycle hooks are performed outside and so this returns an incomplete result.

:param final_vars: Final variables to compute
:param overrides: Overrides to run.
:param display_graph: DEPRECATED. DO NOT USE. Whether or not to display the graph when running it
:param inputs: Runtime inputs to the DAG
:return:
"""
function_graph = _fn_graph if _fn_graph is not None else self.graph
run_id = str(uuid.uuid4())
nodes, user_nodes = function_graph.get_upstream_nodes(final_vars, inputs, overrides)
self.function_graph = _fn_graph if _fn_graph is not None else self.graph
self.run_id = str(uuid.uuid4())
nodes, user_nodes = self.function_graph.get_upstream_nodes(final_vars, inputs, overrides)
Copy link
Collaborator

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

🤔 you're assigning to self to have access to these after the function has run, right?

Might need to think through this a little more since this now means an execute mutates the Driver object... e.g .if you're running in a mutli-threaded situation where you have one driver object and execute concurrently in multiple threads.. I'd be inclined to change the return type instead of mutating self here. @elijahbenizzy thoughts?

Copy link
Contributor Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

I don't like this structure as is.. the pre-execute hooks are in one function and the post-execute hooks are outside. The self was a workaround to get the test to pass and not running into issues with unspecified run_id when an exception occurs.

Maybe we can simplify this and just put the result_builder into __raw_execute? This would make everything neater, planning to code this up tonight.

Copy link
Contributor Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Let me know if this solution is acceptable or if it breaks something somewhere else?

Copy link
Collaborator

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

So yeah, results builder is not in raw execute. The problem here is that it's not thread-safe -- this is something we rely on (E.G. being able to have multiple calls to the driver at once). So I don't think this will work?

IMO it's OK having it in two different places, but I think there are better solution here:

  1. (simple, hackyish) Have __raw_execute return the run_id/function graph -- function graph can actually always be passed in.
  2. Cleaner -- pass run_id into __raw_execute -- why should it generate? Then if we pass that + the graph in we can easily call it at the beginning.
  3. Probably best -- just move some of the logic outside of __raw_execute, always surrounding it by the pre/post graph hooks. Then it look like a sandwich, but has the right tooling. The raw_execute function is the only one that does not call the result builder, the rest call the post graph execute hook after calling the result builder.

One thing to note -- the earlier the pre graph execute and later the post graph execute hooks are called the better. This is because it can capture errors -- thus any configuration problems/hamilton bugs will be caught in the right place and displayed in the UI.

Copy link
Contributor Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

I was 90% sure my suggestion will not work; looked too easy :) I adjusted the .materialize() entyr point in the same way since it also calls __raw_execute and all the test pass now as well.

Copy link
Collaborator

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Nice, thanks! Yeah we should have tests for this, but it's always an issue of how much you can reasonably cover...

Driver.validate_inputs(
function_graph, self.adapter, user_nodes, inputs, nodes
self.function_graph, self.adapter, user_nodes, inputs, nodes
) # TODO -- validate within the function graph itself
if display_graph: # deprecated flow.
logger.warning(
Expand All @@ -683,46 +734,31 @@ def raw_execute(
)
self.visualize_execution(final_vars, "test-output/execute.gv", {"view": True})
if self.has_cycles(
final_vars, function_graph
final_vars, self.function_graph
): # here for backwards compatible driver behavior.
raise ValueError("Error: cycles detected in your graph.")
all_nodes = nodes | user_nodes
self.graph_executor.validate(list(all_nodes))
if self.adapter.does_hook("pre_graph_execute", is_async=False):
self.adapter.call_all_lifecycle_hooks_sync(
"pre_graph_execute",
run_id=run_id,
graph=function_graph,
run_id=self.run_id,
graph=self.function_graph,
final_vars=final_vars,
inputs=inputs,
overrides=overrides,
)
results = None
error = None
success = False
try:
results = self.graph_executor.execute(
function_graph,
self.function_graph,
final_vars,
overrides if overrides is not None else {},
inputs if inputs is not None else {},
run_id,
self.run_id,
)
success = True
except Exception as e:
error = e
success = False
raise e
finally:
if self.adapter.does_hook("post_graph_execute", is_async=False):
self.adapter.call_all_lifecycle_hooks_sync(
"post_graph_execute",
run_id=run_id,
graph=function_graph,
success=success,
error=error,
results=results,
)
return results

@capture_function_usage
Expand Down Expand Up @@ -1517,6 +1553,8 @@ def materialize(
start_time = time.time()
run_successful = True
error = None
execution_error = None
raw_results_output = None

final_vars = self._create_final_vars(additional_vars)
# This is so the finally logging statement does not accidentally die
Expand Down Expand Up @@ -1550,7 +1588,7 @@ def materialize(
Driver.validate_inputs(function_graph, self.adapter, user_nodes, inputs, nodes)
all_nodes = nodes | user_nodes
self.graph_executor.validate(list(all_nodes))
raw_results = self.raw_execute(
raw_results = self.__raw_execute(
final_vars=final_vars + materializer_vars,
inputs=inputs,
overrides=overrides,
Expand All @@ -1563,9 +1601,19 @@ def materialize(
except Exception as e:
run_successful = False
logger.error(SLACK_ERROR_MESSAGE)
execution_error = e
error = telemetry.sanitize_error(*sys.exc_info())
raise e
finally:
if self.adapter.does_hook("post_graph_execute", is_async=False):
self.adapter.call_all_lifecycle_hooks_sync(
"post_graph_execute",
run_id=self.run_id,
graph=self.function_graph,
success=run_successful,
error=execution_error,
results=raw_results_output,
)
duration = time.time() - start_time
self.capture_execute_telemetry(
error, final_vars + materializer_vars, inputs, overrides, run_successful, duration
Expand Down
Loading