Skip to content

Commit

Permalink
update scope to return all configs
Browse files Browse the repository at this point in the history
  • Loading branch information
divyashreepathihalli committed Jan 13, 2025
1 parent fef369a commit 373a0dc
Show file tree
Hide file tree
Showing 2 changed files with 18 additions and 7 deletions.
17 changes: 14 additions & 3 deletions keras/src/backend/common/remat_scope.py
Original file line number Diff line number Diff line change
Expand Up @@ -104,13 +104,24 @@ def __exit__(self, *args, **kwargs):


def get_current_remat_mode():
"""Get the current rematerialization mode.
"""Get the current rematerialization mode and associated settings.
Returns:
Keras rematerialization mode.
dict: A dictionary containing the rematerialization mode and other
settings.
Example:
{
"mode": "list_of_layers",
"output_size_threshold": 1024,
"layer_names": ["dense_1", "conv2d_1"]
}
"""
remat_scope_stack = global_state.get_global_attribute("remat_scope_stack")
if remat_scope_stack is None or not remat_scope_stack:
return None
active_scope = remat_scope_stack[-1]
return active_scope.mode
return {
"mode": active_scope.mode,
"output_size_threshold": active_scope.output_size_threshold,
"layer_names": active_scope.layer_names,
}
8 changes: 4 additions & 4 deletions keras/src/backend/common/remat_scope_test.py
Original file line number Diff line number Diff line change
Expand Up @@ -16,7 +16,7 @@ def test_remat_scope_activation(self):

with RematScope(mode="full"):
self.assertEqual(
get_current_remat_mode(), "full"
get_current_remat_mode()["mode"], "full"
) # Mode is set to "full"

self.assertIsNone(
Expand All @@ -27,16 +27,16 @@ def test_remat_scope_nested(self):
"""Test nested scopes with different rematerialization modes."""
with RematScope(mode="full"):
self.assertEqual(
get_current_remat_mode(), "full"
get_current_remat_mode()["mode"], "full"
) # Outer scope is "full"

with RematScope(mode="activations"):
self.assertEqual(
get_current_remat_mode(), "activations"
get_current_remat_mode()["mode"], "activations"
) # Inner scope is "activations"

self.assertEqual(
get_current_remat_mode(), "full"
get_current_remat_mode()["mode"], "full"
) # Back to outer scope

self.assertIsNone(
Expand Down

0 comments on commit 373a0dc

Please sign in to comment.