Skip to content

Commit

Permalink
Add bitwise ops (keras-team#20126)
Browse files Browse the repository at this point in the history
  • Loading branch information
james77777778 authored Aug 15, 2024
1 parent 4cef332 commit d72a0ea
Show file tree
Hide file tree
Showing 10 changed files with 747 additions and 0 deletions.
9 changes: 9 additions & 0 deletions keras/api/_tf_keras/keras/ops/__init__.py
Original file line number Diff line number Diff line change
Expand Up @@ -113,6 +113,13 @@
from keras.src.ops.numpy import array
from keras.src.ops.numpy import average
from keras.src.ops.numpy import bincount
from keras.src.ops.numpy import bitwise_and
from keras.src.ops.numpy import bitwise_invert
from keras.src.ops.numpy import bitwise_left_shift
from keras.src.ops.numpy import bitwise_not
from keras.src.ops.numpy import bitwise_or
from keras.src.ops.numpy import bitwise_right_shift
from keras.src.ops.numpy import bitwise_xor
from keras.src.ops.numpy import broadcast_to
from keras.src.ops.numpy import ceil
from keras.src.ops.numpy import clip
Expand Down Expand Up @@ -156,6 +163,7 @@
from keras.src.ops.numpy import isfinite
from keras.src.ops.numpy import isinf
from keras.src.ops.numpy import isnan
from keras.src.ops.numpy import left_shift
from keras.src.ops.numpy import less
from keras.src.ops.numpy import less_equal
from keras.src.ops.numpy import linspace
Expand Down Expand Up @@ -197,6 +205,7 @@
from keras.src.ops.numpy import reciprocal
from keras.src.ops.numpy import repeat
from keras.src.ops.numpy import reshape
from keras.src.ops.numpy import right_shift
from keras.src.ops.numpy import roll
from keras.src.ops.numpy import round
from keras.src.ops.numpy import searchsorted
Expand Down
9 changes: 9 additions & 0 deletions keras/api/_tf_keras/keras/ops/numpy/__init__.py
Original file line number Diff line number Diff line change
Expand Up @@ -27,6 +27,13 @@
from keras.src.ops.numpy import array
from keras.src.ops.numpy import average
from keras.src.ops.numpy import bincount
from keras.src.ops.numpy import bitwise_and
from keras.src.ops.numpy import bitwise_invert
from keras.src.ops.numpy import bitwise_left_shift
from keras.src.ops.numpy import bitwise_not
from keras.src.ops.numpy import bitwise_or
from keras.src.ops.numpy import bitwise_right_shift
from keras.src.ops.numpy import bitwise_xor
from keras.src.ops.numpy import broadcast_to
from keras.src.ops.numpy import ceil
from keras.src.ops.numpy import clip
Expand Down Expand Up @@ -70,6 +77,7 @@
from keras.src.ops.numpy import isfinite
from keras.src.ops.numpy import isinf
from keras.src.ops.numpy import isnan
from keras.src.ops.numpy import left_shift
from keras.src.ops.numpy import less
from keras.src.ops.numpy import less_equal
from keras.src.ops.numpy import linspace
Expand Down Expand Up @@ -111,6 +119,7 @@
from keras.src.ops.numpy import reciprocal
from keras.src.ops.numpy import repeat
from keras.src.ops.numpy import reshape
from keras.src.ops.numpy import right_shift
from keras.src.ops.numpy import roll
from keras.src.ops.numpy import round
from keras.src.ops.numpy import select
Expand Down
9 changes: 9 additions & 0 deletions keras/api/ops/__init__.py
Original file line number Diff line number Diff line change
Expand Up @@ -113,6 +113,13 @@
from keras.src.ops.numpy import array
from keras.src.ops.numpy import average
from keras.src.ops.numpy import bincount
from keras.src.ops.numpy import bitwise_and
from keras.src.ops.numpy import bitwise_invert
from keras.src.ops.numpy import bitwise_left_shift
from keras.src.ops.numpy import bitwise_not
from keras.src.ops.numpy import bitwise_or
from keras.src.ops.numpy import bitwise_right_shift
from keras.src.ops.numpy import bitwise_xor
from keras.src.ops.numpy import broadcast_to
from keras.src.ops.numpy import ceil
from keras.src.ops.numpy import clip
Expand Down Expand Up @@ -156,6 +163,7 @@
from keras.src.ops.numpy import isfinite
from keras.src.ops.numpy import isinf
from keras.src.ops.numpy import isnan
from keras.src.ops.numpy import left_shift
from keras.src.ops.numpy import less
from keras.src.ops.numpy import less_equal
from keras.src.ops.numpy import linspace
Expand Down Expand Up @@ -197,6 +205,7 @@
from keras.src.ops.numpy import reciprocal
from keras.src.ops.numpy import repeat
from keras.src.ops.numpy import reshape
from keras.src.ops.numpy import right_shift
from keras.src.ops.numpy import roll
from keras.src.ops.numpy import round
from keras.src.ops.numpy import searchsorted
Expand Down
9 changes: 9 additions & 0 deletions keras/api/ops/numpy/__init__.py
Original file line number Diff line number Diff line change
Expand Up @@ -27,6 +27,13 @@
from keras.src.ops.numpy import array
from keras.src.ops.numpy import average
from keras.src.ops.numpy import bincount
from keras.src.ops.numpy import bitwise_and
from keras.src.ops.numpy import bitwise_invert
from keras.src.ops.numpy import bitwise_left_shift
from keras.src.ops.numpy import bitwise_not
from keras.src.ops.numpy import bitwise_or
from keras.src.ops.numpy import bitwise_right_shift
from keras.src.ops.numpy import bitwise_xor
from keras.src.ops.numpy import broadcast_to
from keras.src.ops.numpy import ceil
from keras.src.ops.numpy import clip
Expand Down Expand Up @@ -70,6 +77,7 @@
from keras.src.ops.numpy import isfinite
from keras.src.ops.numpy import isinf
from keras.src.ops.numpy import isnan
from keras.src.ops.numpy import left_shift
from keras.src.ops.numpy import less
from keras.src.ops.numpy import less_equal
from keras.src.ops.numpy import linspace
Expand Down Expand Up @@ -111,6 +119,7 @@
from keras.src.ops.numpy import reciprocal
from keras.src.ops.numpy import repeat
from keras.src.ops.numpy import reshape
from keras.src.ops.numpy import right_shift
from keras.src.ops.numpy import roll
from keras.src.ops.numpy import round
from keras.src.ops.numpy import select
Expand Down
47 changes: 47 additions & 0 deletions keras/src/backend/jax/numpy.py
Original file line number Diff line number Diff line change
Expand Up @@ -369,6 +369,53 @@ def average(x, axis=None, weights=None):
return jnp.average(x, weights=weights, axis=axis)


def bitwise_and(x, y):
x = convert_to_tensor(x)
y = convert_to_tensor(y)
return jnp.bitwise_and(x, y)


def bitwise_invert(x):
x = convert_to_tensor(x)
return jnp.bitwise_invert(x)


def bitwise_not(x):
return bitwise_invert(x)


def bitwise_or(x, y):
x = convert_to_tensor(x)
y = convert_to_tensor(y)
return jnp.bitwise_or(x, y)


def bitwise_xor(x, y):
x = convert_to_tensor(x)
y = convert_to_tensor(y)
return jnp.bitwise_xor(x, y)


def bitwise_left_shift(x, y):
x = convert_to_tensor(x)
y = convert_to_tensor(y)
return jnp.bitwise_left_shift(x, y)


def left_shift(x, y):
return bitwise_left_shift(x, y)


def bitwise_right_shift(x, y):
x = convert_to_tensor(x)
y = convert_to_tensor(y)
return jnp.bitwise_right_shift(x, y)


def right_shift(x, y):
return bitwise_right_shift(x, y)


def broadcast_to(x, shape):
x = convert_to_tensor(x)
return jnp.broadcast_to(x, shape)
Expand Down
47 changes: 47 additions & 0 deletions keras/src/backend/numpy/numpy.py
Original file line number Diff line number Diff line change
Expand Up @@ -293,6 +293,53 @@ def bincount_fn(arr_w):
return np.bincount(x, weights, minlength).astype(dtype)


def bitwise_and(x, y):
x = convert_to_tensor(x)
y = convert_to_tensor(y)
return np.bitwise_and(x, y)


def bitwise_invert(x):
x = convert_to_tensor(x)
return np.bitwise_not(x)


def bitwise_not(x):
return bitwise_invert(x)


def bitwise_or(x, y):
x = convert_to_tensor(x)
y = convert_to_tensor(y)
return np.bitwise_or(x, y)


def bitwise_xor(x, y):
x = convert_to_tensor(x)
y = convert_to_tensor(y)
return np.bitwise_xor(x, y)


def bitwise_left_shift(x, y):
x = convert_to_tensor(x)
y = convert_to_tensor(y)
return np.left_shift(x, y)


def left_shift(x, y):
return bitwise_left_shift(x, y)


def bitwise_right_shift(x, y):
x = convert_to_tensor(x)
y = convert_to_tensor(y)
return np.right_shift(x, y)


def right_shift(x, y):
return bitwise_right_shift(x, y)


def broadcast_to(x, shape):
return np.broadcast_to(x, shape)

Expand Down
62 changes: 62 additions & 0 deletions keras/src/backend/tensorflow/numpy.py
Original file line number Diff line number Diff line change
Expand Up @@ -836,6 +836,68 @@ def _rank_not_equal_case():
return avg


def bitwise_and(x, y):
x = convert_to_tensor(x)
y = convert_to_tensor(y)
dtype = dtypes.result_type(x.dtype, y.dtype)
x = tf.cast(x, dtype)
y = tf.cast(y, dtype)
return tf.bitwise.bitwise_and(x, y)


def bitwise_invert(x):
x = convert_to_tensor(x)
return tf.bitwise.invert(x)


def bitwise_not(x):
return bitwise_invert(x)


def bitwise_or(x, y):
x = convert_to_tensor(x)
y = convert_to_tensor(y)
dtype = dtypes.result_type(x.dtype, y.dtype)
x = tf.cast(x, dtype)
y = tf.cast(y, dtype)
return tf.bitwise.bitwise_or(x, y)


def bitwise_xor(x, y):
x = convert_to_tensor(x)
y = convert_to_tensor(y)
dtype = dtypes.result_type(x.dtype, y.dtype)
x = tf.cast(x, dtype)
y = tf.cast(y, dtype)
return tf.bitwise.bitwise_xor(x, y)


def bitwise_left_shift(x, y):
x = convert_to_tensor(x)
y = convert_to_tensor(y)
dtype = dtypes.result_type(x.dtype, y.dtype)
x = tf.cast(x, dtype)
y = tf.cast(y, dtype)
return tf.bitwise.left_shift(x, y)


def left_shift(x, y):
return bitwise_left_shift(x, y)


def bitwise_right_shift(x, y):
x = convert_to_tensor(x)
y = convert_to_tensor(y)
dtype = dtypes.result_type(x.dtype, y.dtype)
x = tf.cast(x, dtype)
y = tf.cast(y, dtype)
return tf.bitwise.right_shift(x, y)


def right_shift(x, y):
return bitwise_right_shift(x, y)


def broadcast_to(x, shape):
return tf.broadcast_to(x, shape)

Expand Down
47 changes: 47 additions & 0 deletions keras/src/backend/torch/numpy.py
Original file line number Diff line number Diff line change
Expand Up @@ -411,6 +411,53 @@ def bincount_fn(arr_w):
return cast(torch.bincount(x, weights, minlength), dtype)


def bitwise_and(x, y):
x = convert_to_tensor(x)
y = convert_to_tensor(y)
return torch.bitwise_and(x, y)


def bitwise_invert(x):
x = convert_to_tensor(x)
return torch.bitwise_not(x)


def bitwise_not(x):
return bitwise_invert(x)


def bitwise_or(x, y):
x = convert_to_tensor(x)
y = convert_to_tensor(y)
return torch.bitwise_or(x, y)


def bitwise_xor(x, y):
x = convert_to_tensor(x)
y = convert_to_tensor(y)
return torch.bitwise_xor(x, y)


def bitwise_left_shift(x, y):
x = convert_to_tensor(x)
y = convert_to_tensor(y)
return torch.bitwise_left_shift(x, y)


def left_shift(x, y):
return bitwise_left_shift(x, y)


def bitwise_right_shift(x, y):
x = convert_to_tensor(x)
y = convert_to_tensor(y)
return torch.bitwise_right_shift(x, y)


def right_shift(x, y):
return bitwise_right_shift(x, y)


def broadcast_to(x, shape):
x = convert_to_tensor(x)
return torch.broadcast_to(x, shape)
Expand Down
Loading

0 comments on commit d72a0ea

Please sign in to comment.