From f1eb799f3ab1435581895b2156377d050ab48887 Mon Sep 17 00:00:00 2001 From: divyashreepathihalli Date: Mon, 13 Jan 2025 23:27:23 +0000 Subject: [PATCH] add remat wrapper to layer --- keras/src/layers/layer.py | 24 +++++++++++++++++++++++- keras/src/layers/layer_test.py | 18 ++++++++++++++++++ 2 files changed, 41 insertions(+), 1 deletion(-) diff --git a/keras/src/layers/layer.py b/keras/src/layers/layer.py index a4f830912d5d..31b00dcc48ba 100644 --- a/keras/src/layers/layer.py +++ b/keras/src/layers/layer.py @@ -32,11 +32,13 @@ from keras.src.backend import KerasTensor from keras.src.backend.common import global_state from keras.src.backend.common.name_scope import current_path +from keras.src.backend.common.remat_scope import get_current_remat_mode from keras.src.backend.common.symbolic_scope import in_symbolic_scope from keras.src.distribution import distribution_lib from keras.src.dtype_policies import DTypePolicyMap from keras.src.layers import input_spec from keras.src.metrics.metric import Metric +from keras.src.ops.core import remat from keras.src.ops.operation import Operation from keras.src.saving.keras_saveable import KerasSaveable from keras.src.utils import python_utils @@ -60,6 +62,26 @@ ) +def remat_wrapper(layer_call): + """Wrap the layer's call method to enable rematerialization dynamically. + + Args: + layer_call: The original `call` method of a layer. + + Returns: + callable: The wrapped method with rematerialization logic applied. + """ + + def wrapped_call(*args, **kwargs): + remat_mode = get_current_remat_mode() + if remat_mode["mode"] == "full": + return remat(layer_call)(*args, **kwargs) + # TODO: implement other modes + return layer_call(*args, **kwargs) + + return wrapped_call + + @keras_export(["keras.Layer", "keras.layers.Layer"]) class Layer(BackendLayer, Operation, KerasSaveable): """This is the class from which all layers inherit. @@ -1040,7 +1062,7 @@ def stateless_call( if self.dtype_policy.quantization_mode is not None: outputs = self.quantized_call(*args, **kwargs) else: - outputs = self.call(*args, **kwargs) + outputs = remat_wrapper(self.call)(*args, **kwargs) if return_losses: losses = self.losses diff --git a/keras/src/layers/layer_test.py b/keras/src/layers/layer_test.py index a74018b007c2..f443555ae102 100644 --- a/keras/src/layers/layer_test.py +++ b/keras/src/layers/layer_test.py @@ -12,6 +12,7 @@ from keras.src import ops from keras.src import testing from keras.src.backend.common import global_state +from keras.src.backend.common.remat_scope import RematScope class LayerTest(testing.TestCase): @@ -165,6 +166,23 @@ def test_not_implemented_error(self, method, args): else: getattr(layer, method)(args) + def test_layer_with_remat(self): + # Create some layer + class SomeLayer(layers.Layer): + def call(self, x): + return x + 1 + + input_tensor = backend.random.uniform((2, 4)) + layer = SomeLayer() + # Case 1: Without rematerialization + output_no_remat = layer(input_tensor) + + # Case 2: With rematerialization + with RematScope(mode="full"): + output_with_remat = layer(input_tensor) + + self.assertAllClose(output_no_remat, output_with_remat) + def test_rng_seed_tracking(self): class RNGLayer(layers.Layer): def __init__(self):