diff --git a/README.md b/README.md index efcb12c..72c81f6 100644 --- a/README.md +++ b/README.md @@ -73,7 +73,7 @@ One important difference between `tree_math` and `jax.numpy` is that dot products in `tree_math` default to full precision on all platforms, rather than defaulting to bfloat16 precision on TPUs. This is useful for writing most numerical algorithms, and will likely be JAX's default behavior -[in the future](https://github.com/google/jax/pull/7859). +[in the future](https://github.com/jax-ml/jax/pull/7859). It would be nice to have a `Matrix` class to make it possible to use tree-math for numerical algorithms such as @@ -86,7 +86,7 @@ feature, please comment on [this GitHub issue](https://github.com/google/tree-ma Here is how we could write the preconditioned conjugate gradient method. Notice how similar the implementation is to the [pseudocode from Wikipedia](https://en.wikipedia.org/wiki/Conjugate_gradient_method#The_preconditioned_conjugate_gradient_method), -unlike the [implementation in JAX](https://github.com/google/jax/blob/b5aea7bc2da4fb5ef96c87a59bfd1486d8958dd7/jax/_src/scipy/sparse/linalg.py#L111-L121). +unlike the [implementation in JAX](https://github.com/jax-ml/jax/blob/b5aea7bc2da4fb5ef96c87a59bfd1486d8958dd7/jax/_src/scipy/sparse/linalg.py#L111-L121). Both versions support arbitrary pytrees as input: ```python diff --git a/tree_math/_src/vector.py b/tree_math/_src/vector.py index 66fb195..b9857fc 100644 --- a/tree_math/_src/vector.py +++ b/tree_math/_src/vector.py @@ -104,7 +104,7 @@ def dot(left, right, *, precision="highest"): Note that unlike jax.numpy.dot, tree_math.dot defaults to full (highest) precision. This is more useful for numerical algorithms and will be the default for jax.numpy in the future: - https://github.com/google/jax/pull/7859 + https://github.com/jax-ml/jax/pull/7859 Args: left: left argument.