Skip to content

Commit

Permalink
Resolve TODO
Browse files Browse the repository at this point in the history
  • Loading branch information
NeilGirdhar committed Dec 25, 2024
1 parent 711d60e commit 4b190a2
Showing 1 changed file with 3 additions and 3 deletions.
6 changes: 3 additions & 3 deletions tjax/_src/math_tools.py
Original file line number Diff line number Diff line change
Expand Up @@ -17,11 +17,11 @@ def abs_square(x: JaxComplexArray) -> JaxRealArray: ...
def abs_square(x: ComplexArray) -> RealArray: ...
def abs_square(x: ComplexArray) -> RealArray:
xp = get_namespace(x)
# TODO: remove workaround when Jax is 0.4.27.
return xp.square(x.real) + xp.square(xp.asarray(x.imag))
return xp.square(x.real) + xp.square(x.imag)


# TODO: Remove this when the Array API has it with broadcasting under xp.linalg.norm.
# TODO: Remove when it's added to the Array API:
# https://github.com/data-apis/array-api/issues/242
@overload
def outer_product(x: JaxRealArray, y: JaxRealArray) -> JaxRealArray: ...
@overload
Expand Down

0 comments on commit 4b190a2

Please sign in to comment.