Change Log¶
These are the release notes for JAX.
jax 0.1.63 (unreleased)¶
jax 0.1.62 (March 21, 2020)¶
- JAX has dropped support for Python 3.5. Please upgrade to Python 3.6 or newer.
- Removed the internal function
lax._safe_mul
, which implemented the convention0. * nan == 0.
. This change means some programs when differentiated will produce nans when they previously produced correct values, though it ensures nans rather than silently incorrect results are produced for other programs. See #2447 and #1052 for details. - Added an
all_gather
parallel convenience function. - More type annotations in core code.
jaxlib 0.1.42 (March 19, 2020)¶
- jaxlib 0.1.41 broke cloud TPU support due to an API incompatibility. This release fixes it again.
- JAX has dropped support for Python 3.5. Please upgrade to Python 3.6 or newer.
jax 0.1.61 (March 17, 2020)¶
- Fixes Python 3.5 support. This will be the last JAX or jaxlib release that supports Python 3.5.
jax 0.1.60 (March 17, 2020)¶
- GitHub commits.
- New features:
jax.pmap()
hasstatic_broadcast_argnums
argument which allows the user to specify arguments that should be treated as compile-time constants and should be broadcasted to all devices. It works analogously tostatic_argnums
injax.jit()
.- Improved error messages for when tracers are mistakenly saved in global state.
- Added
jax.nn.one_hot()
utility function. - Added :py:module:`jax.experimental.jet` for exponentially faster higher-order automatic differentiation.
- Added more sanity checking to arguments of
jax.lax.broadcast_in_dim()
.
- The minimum jaxlib version is now 0.1.41.
jaxlib 0.1.40 (March 4, 2020)¶
- Adds experimental support in Jaxlib for TensorFlow profiler, which allows tracing of CPU and GPU computations from TensorBoard.
- Includes prototype support for multihost GPU computations that communicate via NCCL.
- Improves performance of NCCL collectives on GPU.
- Adds TopK, CustomCallWithoutLayout, CustomCallWithLayout, IGammaGradA and RandomGamma implementations.
- Supports device assignments known at XLA compilation time.
jax 0.1.59 (February 11, 2020)¶
- GitHub commits.
- Breaking changes
- The minimum jaxlib version is now 0.1.38.
- Simplified
Jaxpr
by removing theJaxpr.freevars
andJaxpr.bound_subjaxprs
. The call primitives (xla_call
,xla_pmap
,sharded_call
, andremat_call
) get a new parametercall_jaxpr
with a fully-closed (noconstvars
) jaxpr. Also, added a new fieldcall_primitive
to primitives.
- New features:
- Reverse-mode automatic differentiation (e.g.
grad
) oflax.cond
, making it now differentiable in both modes (https://github.com/google/jax/pull/2091) - JAX now supports DLPack, which allows sharing CPU and GPU arrays in a zero-copy way with other libraries, such as PyTorch.
- JAX GPU DeviceArrays now support
__cuda_array_interface__
, which is another zero-copy protocol for sharing GPU arrays with other libraries such as CuPy and Numba. - JAX CPU device buffers now implement the Python buffer protocol, which allows zero-copy buffer sharing between JAX and NumPy.
- Added JAX_SKIP_SLOW_TESTS environment variable to skip tests known as slow.
- Reverse-mode automatic differentiation (e.g.
jaxlib 0.1.39 (February 11, 2020)¶
- Updates XLA.
jaxlib 0.1.38 (January 29, 2020)¶
- CUDA 9.0 is no longer supported.
- CUDA 10.2 wheels are now built by default.
jax 0.1.58 (January 28, 2020)¶
Breaking changes
- JAX has dropped Python 2 support, because Python 2 reached its end of life on January 1, 2020. Please update to Python 3.5 or newer.
New features
- Forward-mode automatic differentiation (jvp) of while loop (https://github.com/google/jax/pull/1980)
- New NumPy and SciPy functions:
jax.numpy.fft.fft2()
jax.numpy.fft.ifft2()
jax.numpy.fft.rfft()
jax.numpy.fft.irfft()
jax.numpy.fft.rfft2()
jax.numpy.fft.irfft2()
jax.numpy.fft.rfftn()
jax.numpy.fft.irfftn()
jax.numpy.fft.fftfreq()
jax.numpy.fft.rfftfreq()
jax.numpy.linalg.matrix_rank()
jax.numpy.linalg.matrix_power()
jax.scipy.special.betainc()
- Batched Cholesky decomposition on GPU now uses a more efficient batched kernel.
Notable bug fixes¶
- With the Python 3 upgrade, JAX no longer depends on
fastcache
, which should help with installation.