Skip to content

Commit

Permalink
add remat wrapper to layer
Browse files Browse the repository at this point in the history
  • Loading branch information
divyashreepathihalli committed Jan 13, 2025
1 parent 373a0dc commit f1eb799
Show file tree
Hide file tree
Showing 2 changed files with 41 additions and 1 deletion.
24 changes: 23 additions & 1 deletion keras/src/layers/layer.py
Original file line number Diff line number Diff line change
Expand Up @@ -32,11 +32,13 @@
from keras.src.backend import KerasTensor
from keras.src.backend.common import global_state
from keras.src.backend.common.name_scope import current_path
from keras.src.backend.common.remat_scope import get_current_remat_mode
from keras.src.backend.common.symbolic_scope import in_symbolic_scope
from keras.src.distribution import distribution_lib
from keras.src.dtype_policies import DTypePolicyMap
from keras.src.layers import input_spec
from keras.src.metrics.metric import Metric
from keras.src.ops.core import remat
from keras.src.ops.operation import Operation
from keras.src.saving.keras_saveable import KerasSaveable
from keras.src.utils import python_utils
Expand All @@ -60,6 +62,26 @@
)


def remat_wrapper(layer_call):
"""Wrap the layer's call method to enable rematerialization dynamically.
Args:
layer_call: The original `call` method of a layer.
Returns:
callable: The wrapped method with rematerialization logic applied.
"""

def wrapped_call(*args, **kwargs):
remat_mode = get_current_remat_mode()
if remat_mode["mode"] == "full":
return remat(layer_call)(*args, **kwargs)
# TODO: implement other modes
return layer_call(*args, **kwargs)

return wrapped_call


@keras_export(["keras.Layer", "keras.layers.Layer"])
class Layer(BackendLayer, Operation, KerasSaveable):
"""This is the class from which all layers inherit.
Expand Down Expand Up @@ -1040,7 +1062,7 @@ def stateless_call(
if self.dtype_policy.quantization_mode is not None:
outputs = self.quantized_call(*args, **kwargs)
else:
outputs = self.call(*args, **kwargs)
outputs = remat_wrapper(self.call)(*args, **kwargs)
if return_losses:
losses = self.losses

Expand Down
18 changes: 18 additions & 0 deletions keras/src/layers/layer_test.py
Original file line number Diff line number Diff line change
Expand Up @@ -12,6 +12,7 @@
from keras.src import ops
from keras.src import testing
from keras.src.backend.common import global_state
from keras.src.backend.common.remat_scope import RematScope


class LayerTest(testing.TestCase):
Expand Down Expand Up @@ -165,6 +166,23 @@ def test_not_implemented_error(self, method, args):
else:
getattr(layer, method)(args)

def test_layer_with_remat(self):
# Create some layer
class SomeLayer(layers.Layer):
def call(self, x):
return x + 1

input_tensor = backend.random.uniform((2, 4))
layer = SomeLayer()
# Case 1: Without rematerialization
output_no_remat = layer(input_tensor)

# Case 2: With rematerialization
with RematScope(mode="full"):
output_with_remat = layer(input_tensor)

self.assertAllClose(output_no_remat, output_with_remat)

def test_rng_seed_tracking(self):
class RNGLayer(layers.Layer):
def __init__(self):
Expand Down

0 comments on commit f1eb799

Please sign in to comment.