Skip to content

Commit

Permalink
Add test for divide_where
Browse files Browse the repository at this point in the history
  • Loading branch information
NeilGirdhar committed Jan 5, 2025
1 parent 1cca5f2 commit 3addc35
Showing 1 changed file with 24 additions and 1 deletion.
25 changes: 24 additions & 1 deletion tests/test_math_tools.py
Original file line number Diff line number Diff line change
@@ -1,10 +1,13 @@
import math
from functools import partial
from typing import Any

import jax.numpy as jnp
import pytest
from jax import Array, vjp
from numpy.testing import assert_allclose

from tjax import JaxRealArray, normalize
from tjax import JaxRealArray, divide_where, normalize


@pytest.mark.parametrize(("x", "axis", "result"),
Expand All @@ -27,3 +30,23 @@ def test_l2(x: Any,
result: JaxRealArray,
) -> None:
assert_allclose(normalize('l2', jnp.asarray(x), axis=axis), jnp.asarray(result))


@pytest.mark.parametrize("k", [0.0, 1.0])
def test_divide_where(k: float) -> None:
s = (5, 3)
w = jnp.ones(s) * k
x = jnp.arange(math.prod(w.shape), dtype='f').reshape(w.shape)
dummy = jnp.ones_like(w[..., 0])

def f(w: Array, x: Array, dummy: Array) -> Array:
total_w = jnp.sum(w, axis=-1)
return divide_where(dividend=jnp.sum(w * x, axis=-1),
divisor=total_w,
where=total_w > 0.0,
otherwise=dummy)

y, vjp_f = vjp(partial(f, dummy=dummy), w, x)
w_bar, x_bar = vjp_f(jnp.ones_like(y))
assert_allclose(w_bar, k / s[-1] * jnp.tile(jnp.asarray([-1, 0, 1]), (*s[:-1], 1)), atol=1e-3)
assert_allclose(x_bar, k / s[-1] * jnp.ones(s))

0 comments on commit 3addc35

Please sign in to comment.