Skip to content

Commit

Permalink
Standardize layer interface.
Browse files Browse the repository at this point in the history
  • Loading branch information
james-choncholas committed Oct 18, 2024
1 parent ee65e2f commit 1a824f1
Show file tree
Hide file tree
Showing 3 changed files with 4 additions and 4 deletions.
2 changes: 1 addition & 1 deletion tf_shell_ml/dropout.py
Original file line number Diff line number Diff line change
Expand Up @@ -72,6 +72,6 @@ def call(self, inputs, training=False):
self.outputs = inputs * dropout_mask
return self.outputs

def backward(self, dy):
def backward(self, dy, _):
d_x = dy * self._layer_intermediate
return [], d_x
2 changes: 1 addition & 1 deletion tf_shell_ml/globalaveragepool1d.py
Original file line number Diff line number Diff line change
Expand Up @@ -42,7 +42,7 @@ def get_config(self):
)
return config

def backward(self, dy):
def backward(self, dy, _):
dx = tf_shell.expand_dims(dy, axis=1)
dx = tf_shell.broadcast_to(
dx, (dx.shape[0], self._layer_intermediate, dx.shape[2])
Expand Down
4 changes: 2 additions & 2 deletions tf_shell_ml/test/dropout_test.py
Original file line number Diff line number Diff line change
Expand Up @@ -87,10 +87,10 @@ def _test_dropout_back(self, per_batch):
notrain_y = dropout_layer(x, training=True)
dy = tf.ones_like(notrain_y)

dw, dx = dropout_layer.backward(dy)
dw, dx = dropout_layer.backward(dy, None)

enc_dy = tf_shell.to_encrypted(dy, key, context)
enc_dw, enc_dx = dropout_layer.backward(enc_dy)
enc_dw, enc_dx = dropout_layer.backward(enc_dy, None)
dec_dx = tf_shell.to_tensorflow(enc_dx, key)

self.assertEmpty(dw)
Expand Down

0 comments on commit 1a824f1

Please sign in to comment.