Sourced from jax's releases.
JAX v0.5.0
As of this release, JAX now uses effort-based versioning. Since this release makes a breaking change to PRNG key semantics that may require users to update their code, we are bumping the "meso" version of JAX to signify this.
Breaking changes
Enable
jax_threefry_partitionable
by default (see the update note).This release drops support for Mac x86 wheels. Mac ARM of course remains supported. For a recent discussion, see https://github.com/jax-ml/jax/discussions/22936.
Two key factors motivated this decision:
- The Mac x86 build (only) has a number of test failures and crashes. We would prefer to ship no release than a broken release.
- Mac x86 hardware is end-of-life and cannot be easily obtained for developers at this point. So it is difficult for us to fix this kind of problem even if we wanted to.
We are open to readding support for Mac x86 if the community is willing to help support that platform: in particular, we would need the JAX test suite to pass cleanly on Mac x86 before we could ship releases again.
Changes:
- The minimum NumPy version is now 1.25. NumPy 1.25 will remain the minimum supported version until June 2025.
- The minimum SciPy version is now 1.11. SciPy 1.11 will remain the minimum supported version until June 2025.
jax.numpy.einsum
now defaults tooptimize='auto'
rather thanoptimize='optimal'
. This avoids exponentially-scaling trace-time in the case of many arguments ([#25214](https://github.com/jax-ml/jax/issues/25214)
).jax.numpy.linalg.solve
no longer supports batched 1D arguments on the right hand side. To recover the previous behavior in these cases, usesolve(a, b[..., None]).squeeze(-1)
.New Features
jax.numpy.fft.fftn
,jax.numpy.fft.rfftn
,jax.numpy.fft.ifftn
, andjax.numpy.fft.irfftn
now support transforms in more than 3 dimensions, which was previously the limit. See[#25606](https://github.com/jax-ml/jax/issues/25606)
for more details.- Support added for user defined state in the FFI via the new
jax.ffi.register_ffi_type_id
function.- The AOT lowering
.as_text()
method now supports thedebug_info
option to include debugging information, e.g., source location, in the output.Deprecations
- From
jax.interpreters.xla
,abstractify
andpytype_aval_mappings
are now deprecated, having been replaced by symbols of the same name injax.core
.
... (truncated)
Sourced from jax's changelog.
jax 0.5.0 (Jan 17, 2025)
As of this release, JAX now uses effort-based versioning. Since this release makes a breaking change to PRNG key semantics that may require users to update their code, we are bumping the "meso" version of JAX to signify this.
Breaking changes
Enable
jax_threefry_partitionable
by default (see the update note).This release drops support for Mac x86 wheels. Mac ARM of course remains supported. For a recent discussion, see https://github.com/jax-ml/jax/discussions/22936.
Two key factors motivated this decision:
- The Mac x86 build (only) has a number of test failures and crashes. We would prefer to ship no release than a broken release.
- Mac x86 hardware is end-of-life and cannot be easily obtained for developers at this point. So it is difficult for us to fix this kind of problem even if we wanted to.
We are open to readding support for Mac x86 if the community is willing to help support that platform: in particular, we would need the JAX test suite to pass cleanly on Mac x86 before we could ship releases again.
Changes:
- The minimum NumPy version is now 1.25. NumPy 1.25 will remain the minimum supported version until June 2025.
- The minimum SciPy version is now 1.11. SciPy 1.11 will remain the minimum supported version until June 2025.
- {func}
jax.numpy.einsum
now defaults tooptimize='auto'
rather thanoptimize='optimal'
. This avoids exponentially-scaling trace-time in the case of many arguments ({jax-issue}[#25214](https://github.com/jax-ml/jax/issues/25214)
).- {func}
jax.numpy.linalg.solve
no longer supports batched 1D arguments on the right hand side. To recover the previous behavior in these cases, usesolve(a, b[..., None]).squeeze(-1)
.New Features
- {func}
jax.numpy.fft.fftn
, {func}jax.numpy.fft.rfftn
, {func}jax.numpy.fft.ifftn
, and {func}jax.numpy.fft.irfftn
now support transforms in more than 3 dimensions, which was previously the limit. See {jax-issue}[#25606](https://github.com/jax-ml/jax/issues/25606)
for more details.- Support added for user defined state in the FFI via the new {func}
jax.ffi.register_ffi_type_id
function.- The AOT lowering
.as_text()
method now supports thedebug_info
option to include debugging information, e.g., source location, in the output.Deprecations
... (truncated)
c25fb92
Release JAX 0.5.0a527aba
Reverts f1b894d14a28ac22a037fb79177b991275c75a18ce85b89
[sharding_in_types] Error out for reshape for splits like this:
(4, 6, 8)
-...7cac76d
Update XLA dependency to use revisiond3be190
[Mosaic GPU] Delete unused declarations of
mosaic_gpu_memcpy_async_h2d
.d34c40f
[mosaic_gpu] Added a serialization passaf66719
[sharding_in_types] Rename .at[...].get(out_spec)
to
`.at[...].get(out_shar...97cd748
Rename out_type -> out_sharding parameter on einsum49224d6
Replace Auto/User/Collective AxisTypes names with
Hidden/Visible/Collective.bd22bfe
[Mosaic TPU] Use large to compact 2nd minor retiling for conversions
going bo...