Best viewed here.
jaxlib 0.1.75 (Unreleased)¶
jax 0.2.26 (Unreleased)¶
Out-of-bounds indices to
jax.ops.segment_sumwill now be handled with
FILL_OR_DROPsemantics, as documented. This primarily afects the reverse-mode derivative, where gradients corresponding to out-of-bounds indices will now be returned as 0. (#8634).
jax2tf will force the converted code to use XLA for the code fragments under jax.jit, e.g., most jax.numpy functions (#7839).
jaxlib 0.1.74 (Nov 17, 2021)¶
Enabled peer-to-peer copies between GPUs. Previously, GPU copies were bounced via the host, which is usually slower.
Added experimental MLIR Python bindings for use by JAX.
jax 0.2.25 (Nov 10, 2021)¶
jax.distributed.initializeexposes multi-host GPU backend.
independentkeyword argument (#8430)
jax 0.2.24 (Oct 19, 2021)¶
jaxlib 0.1.73 (Oct 18, 2021)¶
Multiple cuDNN versions are now supported for jaxlib GPU
cuDNN 8.2 or newer. We recommend using the cuDNN 8.2 wheel if your cuDNN installation is new enough, since it supports additional functionality.
cuDNN 8.0.5 or newer.
The install commands for GPU jaxlib are as follows:
pip install --upgrade pip # Installs the wheel compatible with CUDA 11 and cuDNN 8.2 or newer. pip install --upgrade "jax[cuda]" -f https://storage.googleapis.com/jax-releases/jax_releases.html # Installs the wheel compatible with Cuda 11 and cudnn 8.2 or newer. pip install jax[cuda11_cudnn82] -f https://storage.googleapis.com/jax-releases/jax_releases.html # Installs the wheel compatible with Cuda 11 and cudnn 8.0.5 or newer. pip install jax[cuda11_cudnn805] -f https://storage.googleapis.com/jax-releases/jax_releases.html
jax 0.2.22 (Oct 12, 2021)¶
Static arguments to
jax.pmapmust now be hashable.
Unhashable static arguments have long been disallowed on
jax.jit, but they were still permitted on
jax.pmapcompared unhashable static arguments using object identity.
This behavior is a footgun, since comparing arguments using object identity leads to recompilation each time the object identity changes. Instead, we now ban unhashable arguments: if a user of
jax.pmapwants to compare static arguments by object identity, they can define
__eq__methods on their objects that do that, or wrap their objects in an object that has those operations with object identity semantics. Another option is to use
functools.partialto encapsulate the unhashable static arguments into the function object.
jax.util.partialwas an accidental export that has now been removed. Use
functools.partialfrom the Python standard library instead.
jax.ops.index_addetc. are deprecated and will be removed in a future JAX release. Please use the
.atproperty on JAX arrays instead, e.g.,
x.at[idx].set(y). For now, these functions produce a
An optimized C++ code-path improving the dispatch time for
pmapis now the default when using jaxlib 0.1.72 or newer. The feature can be disabled using the
jax.numpy.uniquenow supports an optional
jaxlib 0.1.72 (Oct 12, 2021)¶
Support for CUDA 10.2 and CUDA 10.1 has been dropped. Jaxlib now supports CUDA 11.1+.
Fixes https://github.com/google/jax/issues/7461, which caused wrong outputs on all platforms due to incorrect buffer aliasing inside the XLA compiler.
jax 0.2.21 (Sept 23, 2021)¶
jax.apihas been removed. Functions that were available as
jax.api.*were aliases for functions in
jax.*; please use the functions in
jax.lax.partialwere accidental exports that have now been removed. Use
functools.partialfrom the Python standard library instead.
Boolean scalar indices now raise a
TypeError; previously this silently returned wrong results (#7925).
When inside a transformation such as
jax.numpy.arrayalways stages the array it produces into the traced computation. Previously
jax.numpy.arraywould sometimes produce a on-device array, even under a
jax.jitdecorator. This change may break code that used JAX arrays to perform shape or index computations that must be known statically; the workaround is to perform such computations using classic NumPy arrays instead.
jnp.ndarrayis now a true base-class for JAX arrays. In particular, this means that for a standard numpy array
isinstance(x, jnp.ndarray)will now return
jax 0.2.20 (Sept 2, 2021)¶
jaxlib 0.1.71 (Sep 1, 2021)¶
Support for CUDA 11.0 and CUDA 10.1 has been dropped. Jaxlib now supports CUDA 10.2 and CUDA 11.1+.
jax 0.2.19 (Aug 12, 2021)¶
Support for NumPy 1.17 has been dropped, per the deprecation policy. Please upgrade to a supported NumPy version.
jitdecorator has been added around the implementation of a number of operators on JAX arrays. This speeds up dispatch times for common operators such as
This change should largely be transparent to most users. However, there is one known behavioral change, which is that large integer constants may now produce an error when passed directly to a JAX operator (e.g.,
x + 2**40). The workaround is to cast the constant to an explicit type (e.g.,
Improved the support for shape polymorphism in jax2tf for operations that need to use a dimension size in array computation, e.g.,
Some leaked trace errors from the previous release (#7613)
jaxlib 0.1.70 (Aug 9, 2021)¶
Support for Python 3.6 has been dropped, per the deprecation policy. Please upgrade to a supported Python version.
The host_callback mechanism now uses one thread per local device for making the calls to the Python callbacks. Previously there was a single thread for all devices. This means that the callbacks may now be called interleaved. The callbacks corresponding to one device will still be called in sequence.
jax 0.2.18 (July 21 2021)¶
The minimum jaxlib version is now 0.1.69.
jax.dlpack.from_dlpack()has been removed.
Added a polar decomposition (
Tightened the checks for lax.argmin and lax.argmax to ensure they are not used with an invalid
axisvalue, or with an empty reduction dimension. (#7196)
jaxlib 0.1.69 (July 9 2021)¶
Fix bugs in TFRT CPU backend that results in incorrect results.
jax 0.2.17 (July 9 2021)¶
Default to the older “stream_executor” CPU runtime for jaxlib <= 0.1.68 to work around #7229, which caused wrong outputs on CPU due to a concurrency problem.
New SciPy function
Reverse-mode autodiff functions (
jax.linear_transpose()) support a parameter that indicates which named axes should be summed over in the backward pass if they were broadcasted over in the forward pass. This enables use of these APIs in a non-per-example way inside maps (initially only
jax 0.2.15 (June 23 2021)¶
Support for NumPy 1.16 has been dropped, per the deprecation policy.
Fixed bug that prevented round-tripping from JAX to TF and back:
jaxlib 0.1.68 (June 23 2021)¶
Fixed bug in TFRT CPU backend that gets nans when transfer TPU buffer to CPU.
jax 0.2.14 (June 10 2021)¶
jax2tf.convert()now has support for
A new configuration option JAX_TRACEBACK_FILTERING controls how JAX filters tracebacks.
A new traceback filtering mode using
__tracebackhide__is now enabled by default in sufficiently recent versions of IPython.
jax2tf.convert()supports shape polymorphism even when the unknown dimensions are used in arithmetic operations, e.g.,
jax2tf.convert()generates custom attributes with location information in TF ops. The code that XLA generates after jax2tf has the same location information as JAX/XLA.
New SciPy function
jax2tf.convert()now ensures that it uses the same typing rules for Python scalars and for choosing 32-bit vs. 64-bit computations as JAX (#6883).
jax2tf.convert()now scopes the
enable_xlaconversion parameter properly to apply only during the just-in-time conversion (#6720).
XlaDotTensorFlow op, for better fidelity w.r.t. JAX numerical precision (#6717).
jax2tf.convert()now has support for inequality comparisons and min/max for complex numbers (#6892).
jaxlib 0.1.67 (May 17 2021)¶
jaxlib 0.1.66 (May 11 2021)¶
CUDA 11.1 wheels are now supported on all CUDA 11 versions 11.1 or higher.
NVidia now promises compatibility between CUDA minor releases starting with CUDA 11.1. This means that JAX can release a single CUDA 11.1 wheel that is compatible with CUDA 11.2 and 11.3.
There is no longer a separate jaxlib release for CUDA 11.2 (or higher); use the CUDA 11.1 wheel for those versions (cuda111).
Jaxlib now bundles
libdevice.10.bcin CUDA wheels. There should be no need to point JAX to a CUDA installation to find this file.
Added automatic support for static keyword arguments to the
Added support for pretransformation exception traces.
Initial support for pruning unused arguments from
jit()-transformed computations. Pruning is still a work in progress.
Improved the string representation of
Added support for XLA’s variadic ReduceWindow.
Fixed a bug in the remote cloud TPU support when large numbers of arguments are passed to a computation.
Fix a bug that meant that JAX garbage collection was not triggered by
jax 0.2.13 (May 3 2021)¶
When combined with jaxlib 0.1.66,
jax.jit()now supports static keyword arguments. A new
static_argnamesoption has been added to specify keyword arguments as static.
jax.nonzero()has a new optional
sizeargument that allows it to be used within
jax.scipy.linalg.eigh_tridiagonal()that computes the eigenvalues of a tridiagonal matrix. Only eigenvalues are supported at present.
The order of the filtered and unfiltered stack traces in exceptions has been changed. The traceback attached to an exception thrown from JAX-transformed code is now filtered, with an
UnfilteredStackTraceexception containing the original trace as the
__cause__of the filtered exception. Filtered stack traces now also work with Python 3.6.
If an exception is thrown by code that has been transformed by reverse-mode automatic differentiation, JAX now attempts to attach as a
__cause__of the exception a
JaxStackTraceBeforeTransformationobject that contains the stack trace that created the original operation in the forward pass. Requires jaxlib 0.1.66.
The following function names have changed. There are still aliases, so this should not break existing code, but the aliases will eventually be removed so please change your code.
Similarly, the argument to
local_devices()has been renamed from
jax.jit()other than the function are now marked as keyword-only. This change is to prevent accidental breakage when arguments are added to
jaxlib 0.1.65 (April 7 2021)¶
jax 0.2.12 (April 1 2021)¶
The minimum jaxlib version is now 0.1.64.
Some profiler APIs names have been changed. There are still aliases, so this should not break existing code, but the aliases will eventually be removed so please change your code.
Omnistaging can no longer be disabled. See omnistaging for more information.
Python integers larger than the maximum
int64value will now lead to an overflow in all cases, rather than being silently converted to
uint64in some cases (#6047).
Outside X64 mode, Python integers outside the range representable by
int32will now lead to an
OverflowErrorrather than having their value silently truncated.
jax 0.2.11 (March 23 2021)¶
jax.flatten_util.ravel_pytreeto handle integer dtypes.
#6129 fixed a bug with handling some constants like
#6145 fixed batching issues with incomplete beta functions
#6014 fixed H2D transfers during tracing
#6165 avoids OverflowErrors when converting some large Python integers to floats
The minimum jaxlib version is now 0.1.62.
jaxlib 0.1.64 (March 18 2021)¶
jaxlib 0.1.63 (March 17 2021)¶
jax 0.2.10 (March 5 2021)¶
jax.scipy.stats.chi2()is now available as a distribution with logpdf and pdf methods.
jax.scipy.stats.betabinom()is now available as a distribution with logpmf and pmf methods.
Extended the batching rule for
lax.padto support batching of the padding values.
JAX’s promotion rules were adjusted to make promotion more consistent and invariant to JIT. In particular, binary operations can now result in weakly-typed values when appropriate. The main user-visible effect of the change is that some operations result in outputs of different precision than before; for example the expression
jnp.bfloat16(1) + 0.1 * jnp.arange(10)previously returned a
float64array, and now returns a
bfloat16array. JAX’s type promotion behavior is described at Type promotion semantics.
jax.numpy.linspace()now computes the floor of integer values, i.e., rounding towards -inf rather than 0. This change was made to match NumPy 1.20.0.
jax.numpy.i0()no longer accepts complex numbers. Previously the function computed the absolute value of complex arguments. This change was made to match the semantics of NumPy 1.20.0.
jax.numpyfunctions no longer accept tuples or lists in place of array arguments:
jax.numpy.reshape(). In general,
jax.numpyfunctions should be used with scalars or array arguments.
jaxlib 0.1.62 (March 9 2021)¶
jaxlib wheels are now built to require AVX instructions on x86-64 machines by default. If you want to use JAX on a machine that doesn’t support AVX, you can build a jaxlib from source using the
jaxlib 0.1.61 (February 12 2021)¶
jaxlib 0.1.60 (Febuary 3 2021)¶
Fixed a memory leak when converting CPU DeviceArrays to NumPy arrays. The memory leak was present in jaxlib releases 0.1.58 and 0.1.59.
uint8are now considered safe to cast to
bfloat16NumPy extension type.
jax 0.2.9 (January 26 2021)¶
jax.ops.segment_sum()now drops segment IDs that are out of range rather than wrapping them into the segment ID space. This was done for performance reasons.
jaxlib 0.1.59 (January 15 2021)¶
jax 0.2.8 (January 12 2021)¶
jax.numpy.arccoshnow returns the same branch as
numpy.arccoshfor complex inputs (#5156)
host_callback.id_tapnow works for
jax.pmapalso. There is an optional parameter for
id_printto request that the device from which the value is tapped be passed as a keyword argument to the tap function (#5182).
jax.numpy.padnow takes keyword arguments. Positional argument
constant_valueshas been removed. In addition, passing unsupported keyword arguments raises an error.
Removed support for
jax.experimental.host_callback.id_tap(). (This support has been deprecated for a few months.)
Changed the printing of tuples for
jax.experimental.host_callback.id_print()to use ‘(‘ instead of ‘[‘.
jax.experimental.host_callback.id_print()in presence of JVP to print a pair of primal and tangent. Previously, there were two separate print operations for the primals and the tangent.
host_callback.outfeed_receiverhas been removed (it is not necessary, and was deprecated a few months ago).
New flag for debugging
inf, analagous to that for
jax 0.2.7 (Dec 4 2020)¶
Add multi-host support to
Add support for differentiating eigenvalues computed by
Add support for building on Windows platforms
Add support for general in_axes and out_axes in
Add complex support for
Fix higher-than-second order derivatives of
Fix some hard-to-hit bugs around symbolic zeros in transpose rules
jax.experimental.optixhas been deleted, in favor of the standalone
indexing of JAX arrays with non-tuple sequences now raises a
TypeError. This type of indexing has been deprecated in Numpy since v1.16, and in JAX since v0.2.4. See #4564.
jax 0.2.6 (Nov 18 2020)¶
Add support for shape-polymorphic tracing for the jax.experimental.jax2tf converter. See README.md.
Breaking change cleanup
Raise an error on non-hashable static arguments for jax.jit and xla_computation. See cb48f42.
Improve consistency of type promotion behavior (#4744):
Adding a complex Python scalar to a JAX floating point number respects the precision of the JAX float. For example,
jnp.float32(1) + 1jnow returns
complex64, where previously it returned
Results of type promotion with 3 or more terms involving uint64, a signed int, and a third type are now independent of the order of arguments. For example:
jnp.result_type(jnp.uint64, jnp.int64, jnp.float16)and
jnp.result_type(jnp.float16, jnp.uint64, jnp.int64)both return
float16, where previously the first returned
float64and the second returned
The contents of the (undocumented)
jax.lax_linalglinear algebra module are now exposed publicly as
jax.random.PRNGKeynow produces the same results in and out of JIT compilation (#4877). This required changing the result for a given seed in a few particular cases:
jax_enable_x64=False, negative seeds passed as Python integers now return a different result outside JIT mode. For example,
[4294967295, 4294967295], and now returns
[0, 4294967295]. This matches the behavior in JIT.
Seeds outside the range representable by
int64outside JIT now result in an
OverflowErrorrather than a
TypeError. This matches the behavior in JIT.
To recover the keys returned previously for negative integers with
jax_enable_x64=Falseoutside JIT, you can use:
key = random.PRNGKey(-1).at.set(0xFFFFFFFF)
DeviceArray now raises
ValueErrorwhen trying to access its value while it has been deleted.
jaxlib 0.1.58 (January 12ish 2021)¶
Fixed a bug that meant JAX sometimes return platform-specific types (e.g.,
np.cint) instead of standard types (e.g.,
Fixed a crash when constant-folding certain int16 operations. (#4971)
jaxlib 0.1.57 (November 12 2020)¶
Fixed manylinux2010 compliance issues in GPU wheels.
Switched the CPU FFT implementation from Eigen to PocketFFT.
Fixed a bug where the hash of bfloat16 values was not correctly initialized and could change (#4651).
Add support for retaining ownership when passing arrays to DLPack (#4636).
Fixed a bug for batched triangular solves with sizes greater than 128 but not a multiple of 128.
Fixed a bug when performing concurrent FFTs on multiple GPUs (#3518).
Fixed a bug in profiler where tools are missing (#4427).
Dropped support for CUDA 10.0.
jax 0.2.5 (October 27 2020)¶
jax 0.2.4 (October 19 2020)¶
jaxlib 0.1.56 (October 14, 2020)¶
jax 0.2.3 (October 14 2020)¶
The reason for another release so soon is we need to temporarily roll back a new jit fastpath while we look into a performance degradation
jax 0.2.1 (October 6 2020)¶
jax (0.2.0) (September 23 2020)¶
jax (0.1.77) (September 15 2020)¶
New simplified interface for
jaxlib 0.1.55 (September 8, 2020)¶
Fix bug in DLPackManagedTensorToBuffer (#4196)
jax 0.1.75 (July 30, 2020)¶
make jnp.abs() work for unsigned inputs (#3914)
“Omnistaging” behavior added behind a flag, disabled by default (#3370)
jax 0.1.74 (July 29, 2020)¶
TPU support for half-precision arithmetic (#3878)
Prevent some accidental dtype warnings (#3874)
Fix a multi-threading bug in custom derivatives (#3845, #3869)
Faster searchsorted implementation (#3873)
Better test coverage for jax.numpy sorting algorithms (#3836)
jaxlib 0.1.52 (July 22, 2020)¶
jax 0.1.73 (July 22, 2020)¶
The minimum jaxlib version is now 0.1.51.
hfft and ihfft (#3664)
scanprimitive support an
unrollparameter for loop unrolling when lowering to XLA (#3738).
Fix reduction repeated axis error (#3618)
Fix shape rule for lax.pad for input dimensions of size 0. (#3608)
make psum transpose handle zero cotangents (#3653)
Fix shape error when taking JVP of reduce-prod over size 0 axis. (#3729)
Support differentiation through jax.lax.all_to_all (#3733)
address nan issue in jax.scipy.special.zeta (#3777)
Many improvements to jax2tf
Reimplement argmin/argmax using a single pass variadic reduction. (#3611)
Enable XLA SPMD partitioning by default. (#3151)
Add support for 0d transpose convolution (#3643)
Make LU gradient work for low-rank matrices (#3610)
support multiple_results and custom JVPs in jet (#3657)
Generalize reduce-window padding to support (lo, hi) pairs. (#3728)
Implement complex convolutions on CPU and GPU. (#3735)
Make jnp.take work for empty slices of empty arrays. (#3751)
Relax dimension ordering rules for dot_general. (#3778)
Enable buffer donation for GPU. (#3800)
Add support for base dilation and window dilation to reduce window op… (#3803)
jaxlib 0.1.51 (July 2, 2020)¶
Add new runtime support for host_callback.
jax 0.1.72 (June 28, 2020)¶
jax 0.1.71 (June 25, 2020)¶
jaxlib 0.1.50 (June 25, 2020)¶
Add support for CUDA 11.0.
Drop support for CUDA 9.2 (we only maintain support for the last four CUDA versions.)
jaxlib 0.1.49 (June 19, 2020)¶
Fix build issue that could result in slow compiles (https://github.com/tensorflow/tensorflow/commit/f805153a25b00d12072bd728e91bb1621bfcf1b1)
jaxlib 0.1.48 (June 12, 2020)¶
Adds support for fast traceback collection.
Adds preliminary support for on-device heap profiling.
Complex128 support for FFTs on CPU and GPU.
tanhaccuracy on GPU.
float64 scatters on GPU are much faster.
Complex matrix multiplication on CPU should be much faster.
Stable sorts on CPU should actually be stable now.
Concurrency bug fix in CPU backend.
jax 0.1.70 (June 8, 2020)¶
jax 0.1.68 (May 21, 2020)¶
jax 0.1.67 (May 12, 2020)¶
The visibility of names exported from
jax.numpyhas been tightened. This may break code that was making use of names that were previously exported accidentally.
jaxlib 0.1.47 (May 8, 2020)¶
Fixes crash for outfeed.
jaxlib 0.1.46 (May 5, 2020)¶
Fixes crash for linear algebra functions on Mac OS X (#432).
Fixes an illegal instruction crash caused by using AVX512 instructions when an operating system or hypervisor disabled them (#2906).
jax 0.1.65 (April 30, 2020)¶
jaxlib 0.1.45 (April 21, 2020)¶
Fixes segfault: #2755
Plumb is_stable option on Sort HLO through to Python.
jax 0.1.64 (April 21, 2020)¶
Improves error message for reverse-mode differentiation of
jaxlib 0.1.44 (April 16, 2020)¶
Fixes a bug where if multiple GPUs of different models were present, JAX would only compile programs suitable for the first GPU.
Added precompiled SASS for more GPU versions to avoid startup PTX compilation hang.
jax 0.1.63 (April 12, 2020)¶
Changed how Tracers are printed to show more useful information for debugging #2591.
Added several new rules for
Fix some missing cases of broadcasting in
Add docstring for
correlate1d & 2d,
jaxlib 0.1.43 (March 31, 2020)¶
Fixed a performance regression for Resnet-50 on GPU.
jax 0.1.62 (March 21, 2020)¶
JAX has dropped support for Python 3.5. Please upgrade to Python 3.6 or newer.
Removed the internal function
lax._safe_mul, which implemented the convention
0. * nan == 0.. This change means some programs when differentiated will produce nans when they previously produced correct values, though it ensures nans rather than silently incorrect results are produced for other programs. See #2447 and #1052 for details.
all_gatherparallel convenience function.
More type annotations in core code.
jaxlib 0.1.42 (March 19, 2020)¶
jaxlib 0.1.41 broke cloud TPU support due to an API incompatibility. This release fixes it again.
JAX has dropped support for Python 3.5. Please upgrade to Python 3.6 or newer.
jax 0.1.61 (March 17, 2020)¶
Fixes Python 3.5 support. This will be the last JAX or jaxlib release that supports Python 3.5.
jax 0.1.60 (March 17, 2020)¶
static_broadcast_argnumsargument which allows the user to specify arguments that should be treated as compile-time constants and should be broadcasted to all devices. It works analogously to
Improved error messages for when tracers are mistakenly saved in global state.
jax.experimental.jetfor exponentially faster higher-order automatic differentiation.
Added more correctness checking to arguments of
The minimum jaxlib version is now 0.1.41.
jaxlib 0.1.40 (March 4, 2020)¶
Adds experimental support in Jaxlib for TensorFlow profiler, which allows tracing of CPU and GPU computations from TensorBoard.
Includes prototype support for multihost GPU computations that communicate via NCCL.
Improves performance of NCCL collectives on GPU.
Adds TopK, CustomCallWithoutLayout, CustomCallWithLayout, IGammaGradA and RandomGamma implementations.
Supports device assignments known at XLA compilation time.
jax 0.1.59 (February 11, 2020)¶
The minimum jaxlib version is now 0.1.38.
Jaxprby removing the
Jaxpr.bound_subjaxprs. The call primitives (
remat_call) get a new parameter
call_jaxprwith a fully-closed (no
constvars) jaxpr. Also, added a new field
Reverse-mode automatic differentiation (e.g.
lax.cond, making it now differentiable in both modes (#2091)
JAX now supports DLPack, which allows sharing CPU and GPU arrays in a zero-copy way with other libraries, such as PyTorch.
JAX GPU DeviceArrays now support
__cuda_array_interface__, which is another zero-copy protocol for sharing GPU arrays with other libraries such as CuPy and Numba.
JAX CPU device buffers now implement the Python buffer protocol, which allows zero-copy buffer sharing between JAX and NumPy.
Added JAX_SKIP_SLOW_TESTS environment variable to skip tests known as slow.
jaxlib 0.1.39 (February 11, 2020)¶
jaxlib 0.1.38 (January 29, 2020)¶
CUDA 9.0 is no longer supported.
CUDA 10.2 wheels are now built by default.
jax 0.1.58 (January 28, 2020)¶
JAX has dropped Python 2 support, because Python 2 reached its end of life on January 1, 2020. Please update to Python 3.5 or newer.
Forward-mode automatic differentiation (
jvp) of while loop (#1980)
New NumPy and SciPy functions:
Batched Cholesky decomposition on GPU now uses a more efficient batched kernel.
Notable bug fixes¶
With the Python 3 upgrade, JAX no longer depends on
fastcache, which should help with installation.