Change Log

jax 0.2.9 (Unreleased)

  • GitHub commits.

  • New features:

  • Bug fixes:

  • Breaking changes:

    • ``jax.ops.segment_sum` now drops segment IDs that are out of range rather than wrapping them into the segment ID space. This was done for performance reasons.

These are the release notes for JAX.

next version

  • GitHub commits.

  • New features:

    • Extend the jax.experimental.loops module with support for pytrees. Improved error checking and error messages.

jax 0.2.8 (January 12 2021)

  • GitHub commits.

  • New features:

  • Bug fixes:

    • jax.numpy.arccosh now returns the same branch as numpy.arccosh for complex inputs (#5156)

    • host_callback.id_tap now works for jax.pmap also. There is a optional parameter for id_tap and id_print to request that the device from which the value is tapped be passed as a keyword argument to the tap function (#5182).

  • Breaking changes:

jax 0.2.7 (Dec 4 2020)

  • GitHub commits.

  • New features:

    • Add jax.device_put_replicated

    • Add multi-host support to jax.experimental.sharded_jit

    • Add support for differentiating eigenvaleus computed by jax.numpy.linalg.eig

    • Add support for building on Windows platforms

    • Add support for general in_axes and out_axes in jax.pmap

    • Add complex support for jax.numpy.linalg.slogdet

  • Bug fixes:

    • Fix higher-than-second order derivatives of jax.numpy.sinc at zero

    • Fix some hard-to-hit bugs around symbolic zeros in transpose rules

  • Breaking changes:

    • jax.experimental.optix has been deleted, in favor of the standalone optax Python package.

    • indexing of JAX arrays with non-tuple sequences now raises a TypeError. This type of indexing has been deprecated in Numpy since v1.16, and in JAX since v0.2.4. See #4564.

jax 0.2.6 (Nov 18 2020)

  • GitHub commits.

  • New Features:

    • Add support for shape-polymorphic tracing for the jax.experimental.jax2tf converter. See

  • Breaking change cleanup

    • Raise an error on non-hashable static arguments for jax.jit and xla_computation. See cb48f42.

    • Improve consistency of type promotion behavior (#4744):

      • Adding a complex Python scalar to a JAX floating point number respects the precision of the JAX float. For example, jnp.float32(1) + 1j now returns complex64, where previously it returned complex128.

      • Results of type promotion with 3 or more terms involving uint64, a signed int, and a third type are now independent of the order of arguments. For example: jnp.result_type(jnp.uint64, jnp.int64, jnp.float16) and jnp.result_type(jnp.float16, jnp.uint64, jnp.int64) both return float16, where previously the first returned float64 and the second returned float16.

    • The contents of the (undocumented) jax.lax_linalg linear algebra module are now exposed publicly as jax.lax.linalg.

    • jax.random.PRNGKey now produces the same results in and out of JIT compilation (#4877). This required changing the result for a given seed in a few particular cases:

      • With jax_enable_x64=False, negative seeds passed as Python integers now return a different result outside JIT mode. For example, jax.random.PRNGKey(-1) previously returned [4294967295, 4294967295], and now returns [0, 4294967295]. This matches the behavior in JIT.

      • Seeds outside the range representable by int64 outside JIT now result in an OverflowError rather than a TypeError. This matches the behavior in JIT.

      To recover the keys returned previously for negative integers with jax_enable_x64=False outside JIT, you can use:

      key = random.PRNGKey(-1).at[0].set(0xFFFFFFFF)
    • DeviceArray now raises RuntimeError instead of ValueError when trying to access its value while it has been deleted.

jaxlib 0.1.58 (Unreleased)

  • Fixed a bug that meant JAX sometimes return platform-specific types (e.g., np.cint) instead of standard types (e.g., np.int32). (#4903)

  • Fixed a crash when constant-folding certain int16 operations. (#4971)

  • Added an is_leaf predicate to pytree.flatten.

jaxlib 0.1.57 (November 12 2020)

  • Fixed manylinux2010 compliance issues in GPU wheels.

  • Switched the CPU FFT implementation from Eigen to PocketFFT.

  • Fixed a bug where the hash of bfloat16 values was not correctly initialized and could change (#4651).

  • Add support for retaining ownership when passing arrays to DLPack (#4636).

  • Fixed a bug for batched triangular solves with sizes greater than 128 but not a multiple of 128.

  • Fixed a bug when performing concurrent FFTs on multiple GPUs (#3518).

  • Fixed a bug in profiler where tools are missing (#4427).

  • Dropped support for CUDA 10.0.

jax 0.2.5 (October 27 2020)

jax 0.2.4 (October 19 2020)

  • GitHub commits.

  • Improvements:

    • Add support for remat to jax.experimental.host_callback. See #4608.

  • Deprecations

    • Indexing with non-tuple sequences is now deprecated, following a similar deprecation in Numpy. In a future release, this will result in a TypeError. See #4564.

jaxlib 0.1.56 (October 14, 2020)

jax 0.2.3 (October 14 2020)

  • GitHub commits.

  • The reason for another release so soon is we need to temporarily roll back a new jit fastpath while we look into a performance degradation

jax 0.2.2 (October 13 2020)

jax 0.2.1 (October 6 2020)

jax (0.2.0) (September 23 2020)

jax (0.1.77) (September 15 2020)

jaxlib 0.1.55 (September 8, 2020)

  • Update XLA:

    • Fix bug in DLPackManagedTensorToBuffer (#4196)

jax 0.1.76 (September 8, 2020)

jax 0.1.75 (July 30, 2020)

  • GitHub commits.

  • Bug Fixes:

    • make jnp.abs() work for unsigned inputs (#3914)

  • Improvements:

    • “Omnistaging” behavior added behind a flag, disabled by default (#3370)

jax 0.1.74 (July 29, 2020)

  • GitHub commits.

  • New Features:

    • BFGS (#3101)

    • TPU suppot for half-precision arithmetic (#3878)

  • Bug Fixes:

    • Prevent some accidental dtype warnings (#3874)

    • Fix a multi-threading bug in custom derivatives (#3845, #3869)

  • Improvements:

    • Faster searchsorted implementation (#3873)

    • Better test coverage for jax.numpy sorting algorithms (#3836)

jaxlib 0.1.52 (July 22, 2020)

  • Update XLA.

jax 0.1.73 (July 22, 2020)

  • GitHub commits.

  • The minimum jaxlib version is now 0.1.51.

  • New Features:

    • jax.image.resize. (#3703)

    • hfft and ihfft (#3664)

    • jax.numpy.intersect1d (#3726)

    • jax.numpy.lexsort (#3812)

    • lax.scan and the scan primitive support an unroll parameter for loop unrolling when lowering to XLA (#3738).

  • Bug Fixes:

    • Fix reduction repeated axis error (#3618)

    • Fix shape rule for lax.pad for input dimensions of size 0. (#3608)

    • make psum transpose handle zero cotangents (#3653)

    • Fix shape error when taking JVP of reduce-prod over size 0 axis. (#3729)

    • Support differentiation through jax.lax.all_to_all (#3733)

    • address nan issue in jax.scipy.special.zeta (#3777)

  • Improvements:

    • Many improvements to jax2tf

    • Reimplement argmin/argmax using a single pass variadic reduction. (#3611)

    • Enable XLA SPMD partitioning by default. (#3151)

    • Add support for 0d transpose convolution (#3643)

    • Make LU gradient work for low-rank matrices (#3610)

    • support multiple_results and custom JVPs in jet (#3657)

    • Generalize reduce-window padding to support (lo, hi) pairs. (#3728)

    • Implement complex convolutions on CPU and GPU. (#3735)

    • Make jnp.take work for empty slices of empty arrays. (#3751)

    • Relax dimension ordering rules for dot_general. (#3778)

    • Enable buffer donation for GPU. (#3800)

    • Add support for base dilation and window dilation to reduce window op… (#3803)

jaxlib 0.1.51 (July 2, 2020)

  • Update XLA.

  • Add new runtime support for host_callback.

jax 0.1.72 (June 28, 2020)

jax 0.1.71 (June 25, 2020)

  • GitHub commits.

  • The minimum jaxlib version is now 0.1.48.

  • Bug fixes:

    • Allow jax.experimental.ode.odeint dynamics functions to close over values with respect to which we’re differentiating #3562.

jaxlib 0.1.50 (June 25, 2020)

  • Add support for CUDA 11.0.

  • Drop support for CUDA 9.2 (we only maintain support for the last four CUDA versions.)

  • Update XLA.

jaxlib 0.1.49 (June 19, 2020)

jaxlib 0.1.48 (June 12, 2020)

  • New features:

    • Adds support for fast traceback collection.

    • Adds preliminary support for on-device heap profiling.

    • Implements np.nextafter for bfloat16 types.

    • Complex128 support for FFTs on CPU and GPU.

  • Bugfixes:

    • Improved float64 tanh accuracy on GPU.

    • float64 scatters on GPU are much faster.

    • Complex matrix multiplication on CPU should be much faster.

    • Stable sorts on CPU should actually be stable now.

    • Concurrency bug fix in CPU backend.

jax 0.1.70 (June 8, 2020)

  • GitHub commits.

  • New features:

    • lax.switch introduces indexed conditionals with multiple branches, together with a generalization of the cond primitive #3318.

jax 0.1.69 (June 3, 2020)

jax 0.1.68 (May 21, 2020)

  • GitHub commits.

  • New features:

    • lax.cond supports a single-operand form, taken as the argument to both branches #2993.

  • Notable changes:

    • The format of the transforms keyword for the lax.experimental.host_callback.id_tap primitive has changed #3132.

jax 0.1.67 (May 12, 2020)

  • GitHub commits.

  • New features:

    • Support for reduction over subsets of a pmapped axis using axis_index_groups #2382.

    • Experimental support for printing and calling host-side Python function from compiled code. See id_print and id_tap (#3006).

  • Notable changes:

    • The visibility of names exported from :py:module:`jax.numpy` has been tightened. This may break code that was making use of names that were previously exported accidentally.

jaxlib 0.1.47 (May 8, 2020)

  • Fixes crash for outfeed.

jax 0.1.66 (May 5, 2020)

jaxlib 0.1.46 (May 5, 2020)

  • Fixes crash for linear algebra functions on Mac OS X (#432).

  • Fixes an illegal instruction crash caused by using AVX512 instructions when an operating system or hypervisor disabled them (#2906).

jax 0.1.65 (April 30, 2020)

  • GitHub commits.

  • New features:

    • Differentiation of determinants of singular matrices #2809.

  • Bug fixes:

    • Fix odeint() differentiation with respect to time of ODEs with time-dependent dynamics #2817, also add ODE CI testing.

    • Fix lax_linalg.qr() differentiation #2867.

jaxlib 0.1.45 (April 21, 2020)

jax 0.1.64 (April 21, 2020)

jaxlib 0.1.44 (April 16, 2020)

  • Fixes a bug where if multiple GPUs of different models were present, JAX would only compile programs suitable for the first GPU.

  • Bugfix for batch_group_count convolutions.

  • Added precompiled SASS for more GPU versions to avoid startup PTX compilation hang.

jax 0.1.63 (April 12, 2020)

  • GitHub commits.

  • Added jax.custom_jvp and jax.custom_vjp from #2026, see the tutorial notebook. Deprecated jax.custom_transforms and removed it from the docs (though it still works).

  • Add #2566.

  • Changed how Tracers are printed to show more useful information for debugging #2591.

  • Made jax.numpy.isclose handle nan and inf correctly #2501.

  • Added several new rules for jax.experimental.jet #2537.

  • Fixed jax.experimental.stax.BatchNorm when scale/center isn’t provided.

  • Fix some missing cases of broadcasting in jax.numpy.einsum #2512.

  • Implement jax.numpy.cumsum and jax.numpy.cumprod in terms of a parallel prefix scan #2596 and make reduce_prod differentiable to arbitray order #2597.

  • Add batch_group_count to conv_general_dilated #2635.

  • Add docstring for test_util.check_grads #2656.

  • Add callback_transform #2665.

  • Implement rollaxis, convolve/correlate 1d & 2d, copysign, trunc, roots, and quantile/percentile interpolation options.

jaxlib 0.1.43 (March 31, 2020)

  • Fixed a performance regression for Resnet-50 on GPU.

jax 0.1.62 (March 21, 2020)

  • GitHub commits.

  • JAX has dropped support for Python 3.5. Please upgrade to Python 3.6 or newer.

  • Removed the internal function lax._safe_mul, which implemented the convention 0. * nan == 0.. This change means some programs when differentiated will produce nans when they previously produced correct values, though it ensures nans rather than silently incorrect results are produced for other programs. See #2447 and #1052 for details.

  • Added an all_gather parallel convenience function.

  • More type annotations in core code.

jaxlib 0.1.42 (March 19, 2020)

  • jaxlib 0.1.41 broke cloud TPU support due to an API incompatibility. This release fixes it again.

  • JAX has dropped support for Python 3.5. Please upgrade to Python 3.6 or newer.

jax 0.1.61 (March 17, 2020)

  • GitHub commits.

  • Fixes Python 3.5 support. This will be the last JAX or jaxlib release that supports Python 3.5.

jax 0.1.60 (March 17, 2020)

  • GitHub commits.

  • New features:

    • jax.pmap() has static_broadcast_argnums argument which allows the user to specify arguments that should be treated as compile-time constants and should be broadcasted to all devices. It works analogously to static_argnums in jax.jit().

    • Improved error messages for when tracers are mistakenly saved in global state.

    • Added jax.nn.one_hot() utility function.

    • Added :py:module:`jax.experimental.jet` for exponentially faster higher-order automatic differentiation.

    • Added more correctness checking to arguments of jax.lax.broadcast_in_dim().

  • The minimum jaxlib version is now 0.1.41.

jaxlib 0.1.40 (March 4, 2020)

  • Adds experimental support in Jaxlib for TensorFlow profiler, which allows tracing of CPU and GPU computations from TensorBoard.

  • Includes prototype support for multihost GPU computations that communicate via NCCL.

  • Improves performance of NCCL collectives on GPU.

  • Adds TopK, CustomCallWithoutLayout, CustomCallWithLayout, IGammaGradA and RandomGamma implementations.

  • Supports device assignments known at XLA compilation time.

jax 0.1.59 (February 11, 2020)

  • GitHub commits.

  • Breaking changes

    • The minimum jaxlib version is now 0.1.38.

    • Simplified Jaxpr by removing the Jaxpr.freevars and Jaxpr.bound_subjaxprs. The call primitives (xla_call, xla_pmap, sharded_call, and remat_call) get a new parameter call_jaxpr with a fully-closed (no constvars) jaxpr. Also, added a new field call_primitive to primitives.

  • New features:

    • Reverse-mode automatic differentiation (e.g. grad) of lax.cond, making it now differentiable in both modes (

    • JAX now supports DLPack, which allows sharing CPU and GPU arrays in a zero-copy way with other libraries, such as PyTorch.

    • JAX GPU DeviceArrays now support __cuda_array_interface__, which is another zero-copy protocol for sharing GPU arrays with other libraries such as CuPy and Numba.

    • JAX CPU device buffers now implement the Python buffer protocol, which allows zero-copy buffer sharing between JAX and NumPy.

    • Added JAX_SKIP_SLOW_TESTS environment variable to skip tests known as slow.

jaxlib 0.1.39 (February 11, 2020)

  • Updates XLA.

jaxlib 0.1.38 (January 29, 2020)

  • CUDA 9.0 is no longer supported.

  • CUDA 10.2 wheels are now built by default.

jax 0.1.58 (January 28, 2020)

Notable bug fixes

  • With the Python 3 upgrade, JAX no longer depends on fastcache, which should help with installation.