diff --git a/scripts/firedrake-install b/scripts/firedrake-install index 73670610eb..433a611904 100755 --- a/scripts/firedrake-install +++ b/scripts/firedrake-install @@ -1304,12 +1304,6 @@ def build_and_install_jax(): log.info("Installing JAX (backend: %s)" % args.jax) # version_name = "jax" if args.jax == "cpu" else "jax[cuda12]" run_pip_install(["jax"] + ["jaxlib"] + ["ml_dtypes"] + ["opt_einsum"]) - # Test if jaxlib is installed correctly - try: - import jax - except ImportError: - print("Failed to install jax") - raise InstallError("Failed to install jax") def build_and_install_slepc():