Skip to content

Implement blackman function in keras.ops #21235

New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Merged
merged 5 commits into from
May 4, 2025
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
1 change: 1 addition & 0 deletions keras/api/_tf_keras/keras/ops/__init__.py
Original file line number Diff line number Diff line change
Expand Up @@ -144,6 +144,7 @@
from keras.src.ops.numpy import bitwise_or as bitwise_or
from keras.src.ops.numpy import bitwise_right_shift as bitwise_right_shift
from keras.src.ops.numpy import bitwise_xor as bitwise_xor
from keras.src.ops.numpy import blackman as blackman
from keras.src.ops.numpy import broadcast_to as broadcast_to
from keras.src.ops.numpy import ceil as ceil
from keras.src.ops.numpy import clip as clip
Expand Down
1 change: 1 addition & 0 deletions keras/api/_tf_keras/keras/ops/numpy/__init__.py
Original file line number Diff line number Diff line change
Expand Up @@ -36,6 +36,7 @@
from keras.src.ops.numpy import bitwise_or as bitwise_or
from keras.src.ops.numpy import bitwise_right_shift as bitwise_right_shift
from keras.src.ops.numpy import bitwise_xor as bitwise_xor
from keras.src.ops.numpy import blackman as blackman
from keras.src.ops.numpy import broadcast_to as broadcast_to
from keras.src.ops.numpy import ceil as ceil
from keras.src.ops.numpy import clip as clip
Expand Down
1 change: 1 addition & 0 deletions keras/api/ops/__init__.py
Original file line number Diff line number Diff line change
Expand Up @@ -144,6 +144,7 @@
from keras.src.ops.numpy import bitwise_or as bitwise_or
from keras.src.ops.numpy import bitwise_right_shift as bitwise_right_shift
from keras.src.ops.numpy import bitwise_xor as bitwise_xor
from keras.src.ops.numpy import blackman as blackman
from keras.src.ops.numpy import broadcast_to as broadcast_to
from keras.src.ops.numpy import ceil as ceil
from keras.src.ops.numpy import clip as clip
Expand Down
1 change: 1 addition & 0 deletions keras/api/ops/numpy/__init__.py
Original file line number Diff line number Diff line change
Expand Up @@ -36,6 +36,7 @@
from keras.src.ops.numpy import bitwise_or as bitwise_or
from keras.src.ops.numpy import bitwise_right_shift as bitwise_right_shift
from keras.src.ops.numpy import bitwise_xor as bitwise_xor
from keras.src.ops.numpy import blackman as blackman
from keras.src.ops.numpy import broadcast_to as broadcast_to
from keras.src.ops.numpy import ceil as ceil
from keras.src.ops.numpy import clip as clip
Expand Down
5 changes: 5 additions & 0 deletions keras/src/backend/jax/numpy.py
Original file line number Diff line number Diff line change
Expand Up @@ -474,6 +474,11 @@ def right_shift(x, y):
return bitwise_right_shift(x, y)


def blackman(x):
x = convert_to_tensor(x)
return jnp.blackman(x)


def broadcast_to(x, shape):
x = convert_to_tensor(x)
return jnp.broadcast_to(x, shape)
Expand Down
5 changes: 5 additions & 0 deletions keras/src/backend/numpy/numpy.py
Original file line number Diff line number Diff line change
Expand Up @@ -390,6 +390,11 @@ def right_shift(x, y):
return bitwise_right_shift(x, y)


def blackman(x):
x = convert_to_tensor(x)
return np.blackman(x).astype(config.floatx())


def broadcast_to(x, shape):
return np.broadcast_to(x, shape)

Expand Down
3 changes: 3 additions & 0 deletions keras/src/backend/openvino/excluded_concrete_tests.txt
Original file line number Diff line number Diff line change
Expand Up @@ -9,6 +9,7 @@ NumpyDtypeTest::test_any
NumpyDtypeTest::test_argpartition
NumpyDtypeTest::test_array
NumpyDtypeTest::test_bartlett
NumpyDtypeTest::test_blackman
NumpyDtypeTest::test_bitwise
NumpyDtypeTest::test_ceil
NumpyDtypeTest::test_concatenate
Expand Down Expand Up @@ -77,6 +78,7 @@ NumpyOneInputOpsCorrectnessTest::test_any
NumpyOneInputOpsCorrectnessTest::test_argpartition
NumpyOneInputOpsCorrectnessTest::test_array
NumpyOneInputOpsCorrectnessTest::test_bartlett
NumpyOneInputOpsCorrectnessTest::test_blackman
NumpyOneInputOpsCorrectnessTest::test_bitwise_invert
NumpyOneInputOpsCorrectnessTest::test_conj
NumpyOneInputOpsCorrectnessTest::test_correlate
Expand Down Expand Up @@ -154,4 +156,5 @@ NumpyTwoInputOpsCorrectnessTest::test_vdot
NumpyTwoInputOpsCorrectnessTest::test_where
NumpyOneInputOpsDynamicShapeTest::test_angle
NumpyOneInputOpsDynamicShapeTest::test_bartlett
NumpyOneInputOpsDynamicShapeTest::test_blackman
NumpyOneInputOpsStaticShapeTest::test_angle
6 changes: 6 additions & 0 deletions keras/src/backend/openvino/numpy.py
Original file line number Diff line number Diff line change
Expand Up @@ -509,6 +509,12 @@ def bincount(x, weights=None, minlength=0, sparse=False):
return OpenVINOKerasTensor(final_output)


def blackman(x):
raise NotImplementedError(
"`blackman` is not supported with openvino backend"
)


def broadcast_to(x, shape):
assert isinstance(shape, (tuple, list)), (
"`broadcast_to` is supported only for tuple and list `shape`"
Expand Down
12 changes: 12 additions & 0 deletions keras/src/backend/tensorflow/numpy.py
Original file line number Diff line number Diff line change
Expand Up @@ -1058,6 +1058,18 @@ def right_shift(x, y):
return bitwise_right_shift(x, y)


def blackman(x):
dtype = config.floatx()
x = tf.cast(x, dtype)
n = tf.range(x, dtype=dtype)
n_minus_1 = tf.cast(x - 1, dtype)
term1 = 0.42
term2 = -0.5 * tf.cos(2 * np.pi * n / n_minus_1)
term3 = 0.08 * tf.cos(4 * np.pi * n / n_minus_1)
window = term1 + term2 + term3
return window


def broadcast_to(x, shape):
return tf.broadcast_to(x, shape)

Expand Down
5 changes: 5 additions & 0 deletions keras/src/backend/torch/numpy.py
Original file line number Diff line number Diff line change
Expand Up @@ -515,6 +515,11 @@ def right_shift(x, y):
return bitwise_right_shift(x, y)


def blackman(x):
x = convert_to_tensor(x)
return torch.signal.windows.blackman(x)


def broadcast_to(x, shape):
x = convert_to_tensor(x)
return torch.broadcast_to(x, shape)
Expand Down
30 changes: 30 additions & 0 deletions keras/src/ops/numpy.py
Original file line number Diff line number Diff line change
Expand Up @@ -1617,6 +1617,36 @@ def right_shift(x, y):
return backend.numpy.right_shift(x, y)


class Blackman(Operation):
def call(self, x):
return backend.numpy.blackman(x)

def compute_output_spec(self, x):
return KerasTensor(x.shape, dtype=backend.floatx())


@keras_export(["keras.ops.blackman", "keras.ops.numpy.blackman"])
def blackman(x):
"""Blackman window function.
The Blackman window is a taper formed by using a weighted cosine.

Args:
x: Scalar or 1D Tensor. Window length.

Returns:
A 1D tensor containing the Blackman window values.

Example:
>>> x = keras.ops.convert_to_tensor(5)
>>> keras.ops.blackman(x)
array([-1.3877788e-17, 3.4000000e-01, 1.0000000e+00, 3.4000000e-01,
-1.3877788e-17], dtype=float32)
"""
if any_symbolic_tensors((x,)):
return Blackman().symbolic_call(x)
return backend.numpy.blackman(x)


class BroadcastTo(Operation):
def __init__(self, shape):
super().__init__()
Expand Down
26 changes: 26 additions & 0 deletions keras/src/ops/numpy_test.py
Original file line number Diff line number Diff line change
Expand Up @@ -1224,6 +1224,10 @@ def test_bartlett(self):
x = np.random.randint(1, 100 + 1)
self.assertEqual(knp.bartlett(x).shape[0], x)

def test_blackman(self):
x = np.random.randint(1, 100 + 1)
self.assertEqual(knp.blackman(x).shape[0], x)

def test_bitwise_invert(self):
x = KerasTensor((None, 3))
self.assertEqual(knp.bitwise_invert(x).shape, (None, 3))
Expand Down Expand Up @@ -3600,6 +3604,12 @@ def test_bartlett(self):

self.assertAllClose(knp.Bartlett()(x), np.bartlett(x))

def test_blackman(self):
x = np.random.randint(1, 100 + 1)
self.assertAllClose(knp.blackman(x), np.blackman(x))

self.assertAllClose(knp.Blackman()(x), np.blackman(x))

@parameterized.named_parameters(
named_product(sparse_input=(False, True), sparse_arg=(False, True))
)
Expand Down Expand Up @@ -5579,6 +5589,22 @@ def test_bartlett(self, dtype):
expected_dtype,
)

@parameterized.named_parameters(named_product(dtype=ALL_DTYPES))
def test_blackman(self, dtype):
import jax.numpy as jnp

x = knp.ones((), dtype=dtype)
x_jax = jnp.ones((), dtype=dtype)
expected_dtype = standardize_dtype(jnp.blackman(x_jax).dtype)

self.assertEqual(
standardize_dtype(knp.blackman(x).dtype), expected_dtype
)
self.assertEqual(
standardize_dtype(knp.Blackman().symbolic_call(x).dtype),
expected_dtype,
)

@parameterized.named_parameters(named_product(dtype=INT_DTYPES))
def test_bincount(self, dtype):
import jax.numpy as jnp
Expand Down