Change log
Contents
Change log#
Best viewed here.
jax 0.3.14 (Unreleased)#
Changes
jax.numpy.linalg.slogdet()
now accepts an optionalmethod
argument that allows selection between an LU-decomposition based implementation and an implementation based on QR decomposition.jax.numpy.linalg.qr()
now supportsmode="raw"
.pickle
,copy.copy
, andcopy.deepcopy
now have more complete support when used on jax arrays (#10659). In particular:pickle
anddeepcopy
previously returnednp.ndarray
objects when used on aDeviceArray
; nowDeviceArray
objects are returned. Fordeepcopy
, the copied array is on the same device as the original. Forpickle
the deserialized array will be on the default device.Within function transformations (i.e. traced code),
deepcopy
andcopy
previously were no-ops. Now they use the same mechanism asDeviceArray.copy()
.Calling
pickle
on a traced array now results in an explicitConcretizationTypeError
.
jaxlib 0.3.11 (Unreleased)#
jax 0.3.13 (May 16, 2022)#
jax 0.3.11 (May 15, 2022)#
Changes
jax.lax.eigh()
now accepts an optionalsort_eigenvalues
argument that allows users to opt out of eigenvalue sorting on TPU.
Deprecations
Non-array arguments to functions in
jax.lax.linalg
are now marked keyword-only. As a backward-compatibility step passing keyword-only arguments positionally yields a warning, but in a future JAX release passing keyword-only arguments positionally will fail. However, most users should prefer to usejax.numpy.linalg
instead.jax.scipy.linalg.polar_unitary()
, which was a JAX extension to the scipy API, is deprecated. Usejax.scipy.linalg.polar()
instead.
jax 0.3.10 (May 3, 2022)#
jaxlib 0.3.10 (May 3, 2022)#
Changes
TF commit fixes an issue in the MHLO canonicalizer that caused constant folding to take a long time or crash for certain programs.
jax 0.3.9 (May 2, 2022)#
Changes
Added support for fully asynchronous checkpointing for GlobalDeviceArray.
jax 0.3.8 (April 29 2022)#
Changes
jax.numpy.linalg.svd()
on TPUs uses a qdwh-svd solver.jax.numpy.linalg.cond()
on TPUs now accepts complex input.jax.numpy.linalg.pinv()
on TPUs now accepts complex input.jax.numpy.linalg.matrix_rank()
on TPUs now accepts complex input.jax.scipy.cluster.vq.vq()
has been added.jax.experimental.maps.mesh
has been deleted. Please usejax.experimental.maps.Mesh
. Please see https://jax.readthedocs.io/en/latest/_autosummary/jax.experimental.maps.Mesh.html#jax.experimental.maps.Mesh for more information.jax.scipy.linalg.qr()
now returns a length-1 tuple rather than the raw array whenmode='r'
, in order to match the behavior ofscipy.linalg.qr
(#10452)jax.numpy.take_along_axis()
now takes an optionalmode
parameter that specifies the behavior of out-of-bounds indexing. By default, invalid values (e.g., NaN) will be returned for out-of-bounds indices. In previous versions of JAX, invalid indices were clamped into range. The previous behavior can be restored by passingmode="clip"
.jax.numpy.take()
now defaults tomode="fill"
, which returns invalid values (e.g., NaN) for out-of-bounds indices.Scatter operations, such as
x.at[...].set(...)
, now have"drop"
semantics. This has no effect on the scatter operation itself, but it means that when differentiated the gradient of a scatter will yield zero cotangents for out-of-bounds indices. Previously out-of-bounds indices were clamped into range for the gradient, which was not mathematically correct.jax.numpy.take_along_axis()
now raises aTypeError
if its indices are not of an integer type, matching the behavior ofnumpy.take_along_axis()
. Previously non-integer indices were silently cast to integers.jax.numpy.ravel_multi_index()
now raises aTypeError
if itsdims
argument is not of an integer type, matching the behavior ofnumpy.ravel_multi_index()
. Previously non-integerdims
was silently cast to integers.jax.numpy.split()
now raises aTypeError
if itsaxis
argument is not of an integer type, matching the behavior ofnumpy.split()
. Previously non-integeraxis
was silently cast to integers.jax.numpy.indices()
now raises aTypeError
if its dimensions are not of an integer type, matching the behavior ofnumpy.indices()
. Previously non-integer dimensions were silently cast to integers.jax.numpy.diag()
now raises aTypeError
if itsk
argument is not of an integer type, matching the behavior ofnumpy.diag()
. Previously non-integerk
was silently cast to integers.Added
jax.random.orthogonal()
.
Deprecations
Many functions and objects available in
jax.test_util
are now deprecated and will raise a warning on import. This includescases_from_list
,check_close
,check_eq
,device_under_test
,format_shape_dtype_string
,rand_uniform
,skip_on_devices
,with_config
,xla_bridge
, and_default_tolerance
(#10389). These, along with previously-deprecatedJaxTestCase
,JaxTestLoader
, andBufferDonationTestCase
, will be removed in a future JAX release. Most of these utilites can be replaced by calls to standard python & numpy testing utilities found in e.g.unittest
,absl.testing
,numpy.testing
, etc. JAX-specific functionality such as device checking can be replaced through the use of public APIs such asjax.devices()
. Many of the deprecated utilities will still exist injax._src.test_util
, but these are not public APIs and as such may be changed or removed without notice in future releases.
jax 0.3.7 (April 15, 2022)#
Changes:
Fixed a performance problem if the indices passed to
jax.numpy.take_along_axis()
were broadcasted (#10281).jax.scipy.special.expit()
andjax.scipy.special.logit()
now require their arguments to be scalars or JAX arrays. They also now promote integer arguments to floating point.The
DeviceArray.tile()
method is deprecated, because numpy arrays do not have atile()
method. As a replacement for this, usejax.numpy.tile()
(#10266).
jaxlib 0.3.7 (April 15, 2022)#
Changes:
Linux wheels are now built conforming to the
manylinux2014
standard, instead ofmanylinux2010
.
jax 0.3.6 (April 12, 2022)#
Changes:
Upgraded libtpu wheel to a version that fixes a hang when initializing a TPU pod. Fixes #10218.
Deprecations:
jax.experimental.loops
is being deprecated. See #10278 for an alternative API.
jax 0.3.5 (April 7, 2022)#
Changes:
added
jax.random.loggamma()
& improved behavior ofjax.random.beta()
andjax.random.dirichlet()
for small parameter values (#9906).the private
lax_numpy
submodule is no longer exposed in thejax.numpy
namespace (#10029).added array creation routines
jax.numpy.frombuffer()
,jax.numpy.fromfunction()
, andjax.numpy.fromstring()
(#10049).DeviceArray.copy()
now returns aDeviceArray
rather than anp.ndarray
(#10069)jax.experimental.sharded_jit
has been deprecated and will be removed soon.
Deprecations:
jax.nn.normalize()
is being deprecated. Usejax.nn.standardize()
instead (#9899).jax.tree_util.tree_multimap()
is deprecated. Usejax.tree_util.tree_map()
instead (#5746).jax.experimental.sharded_jit
is deprecated. Usepjit
instead.
jaxlib 0.3.5 (April 7, 2022)#
jax 0.3.4 (March 18, 2022)#
jax 0.3.3 (March 17, 2022)#
jax 0.3.2 (March 16, 2022)#
Changes:
The functions
jax.ops.index_update
,jax.ops.index_add
, which were deprecated in 0.2.22, have been removed. Please use the.at
property on JAX arrays instead, e.g.,x.at[idx].set(y)
.Moved
jax.experimental.ann.approx_*_k
intojax.lax
. These functions are optimized alternatives tojax.lax.top_k
.jax.numpy.broadcast_arrays()
andjax.numpy.broadcast_to()
now require scalar or array-like inputs, and will fail if they are passed lists (part of #7737).The standard jax[tpu] install can now be used with Cloud TPU v4 VMs.
pjit
now works on CPU (in addition to previous TPU and GPU support).
jaxlib 0.3.2 (March 16, 2022)#
Changes
XlaComputation.as_hlo_text()
now supports printing large constants by passing boolean flagprint_large_constants=True
.
Deprecations:
The
.block_host_until_ready()
method on JAX arrays has been deprecated. Use.block_until_ready()
instead.
jax 0.3.1 (Feb 18, 2022)#
Changes:
jax.test_util.JaxTestCase
andjax.test_util.JaxTestLoader
are now deprecated. The suggested replacement is to useparametrized.TestCase
directly. For tests that rely on custom asserts such asJaxTestCase.assertAllClose()
, the suggested replacement is to use standard numpy testing utilities such asnumpy.testing.assert_allclose()
, which work directly with JAX arrays (#9620).jax.test_util.JaxTestCase
now setsjax_numpy_rank_promotion='raise'
by default (#9562). To recover the previous behavior, use the newjax.test_util.with_config
decorator:@jtu.with_config(jax_numpy_rank_promotion='allow') class MyTestCase(jtu.JaxTestCase): ...
Added
jax.scipy.linalg.schur()
,jax.scipy.linalg.sqrtm()
,jax.scipy.signal.csd()
,jax.scipy.signal.stft()
,jax.scipy.signal.welch()
.
jax 0.3.0 (Feb 10, 2022)#
Changes
jax version has been bumped to 0.3.0. Please see the design doc for the explanation.
jaxlib 0.3.0 (Feb 10, 2022)#
Changes
Bazel 5.0.0 is now required to build jaxlib.
jaxlib version has been bumped to 0.3.0. Please see the design doc for the explanation.
jax 0.2.28 (Feb 1, 2022)#
-
jax.jit(f).lower(...).compiler_ir()
now defaults to the MHLO dialect if nodialect=
is passed.The
jax.jit(f).lower(...).compiler_ir(dialect='mhlo')
now returns an MLIRir.Module
object instead of its string representation.
jaxlib 0.1.76 (Jan 27, 2022)#
New features
Includes precompiled SASS for NVidia compute capability 8.0 GPUS (e.g. A100). Removes precompiled SASS for compute capability 6.1 so as not to increase the number of compute capabilities: GPUs with compute capability 6.1 can use the 6.0 SASS.
With jaxlib 0.1.76, JAX uses the MHLO MLIR dialect as its primary target compiler IR by default.
Breaking changes
Support for NumPy 1.18 has been dropped, per the deprecation policy. Please upgrade to a supported NumPy version.
Bug fixes
Fixed a bug where apparently identical pytreedef objects constructed by different routes do not compare as equal (#9066).
The JAX jit cache requires two static arguments to have identical types for a cache hit (#9311).
jax 0.2.27 (Jan 18 2022)#
Breaking changes:
Support for NumPy 1.18 has been dropped, per the deprecation policy. Please upgrade to a supported NumPy version.
The host_callback primitives have been simplified to drop the special autodiff handling for hcb.id_tap and id_print. From now on, only the primals are tapped. The old behavior can be obtained (for a limited time) by setting the
JAX_HOST_CALLBACK_AD_TRANSFORMS
environment variable, or the--flax_host_callback_ad_transforms
flag. Additionally, added documentation for how to implement the old behavior using JAX custom AD APIs (#8678).Sorting now matches the behavior of NumPy for
0.0
andNaN
regardless of the bit representation. In particular,0.0
and-0.0
are now treated as equivalent, where previously-0.0
was treated as less than0.0
. Additionally allNaN
representations are now treated as equivalent and sorted to the end of the array. Previously negativeNaN
values were sorted to the front of the array, andNaN
values with different internal bit representations were not treated as equivalent, and were sorted according to those bit patterns (#9178).jax.numpy.unique()
now treatsNaN
values in the same way asnp.unique
in NumPy versions 1.21 and newer: at most oneNaN
value will appear in the uniquified output (#9184).
Bug fixes:
host_callback now supports ad_checkpoint.checkpoint (#8907).
New features:
add
jax.block_until_ready
({jax-issue}`#8941)Added a new debugging flag/environment variable
JAX_DUMP_IR_TO=/path
. If set, JAX dumps the MHLO/HLO IR it generates for each computation to a file under the given path.Added
jax.ensure_compile_time_eval
to the public api (#7987).jax2tf now supports a flag jax2tf_associative_scan_reductions to change the lowering for associative reductions, e.g., jnp.cumsum, to behave like JAX on CPU and GPU (to use an associative scan). See the jax2tf README for more details (#9189).
jaxlib 0.1.75 (Dec 8, 2021)#
New features:
Support for python 3.10.
jax 0.2.26 (Dec 8, 2021)#
Bug fixes:
Out-of-bounds indices to
jax.ops.segment_sum
will now be handled withFILL_OR_DROP
semantics, as documented. This primarily afects the reverse-mode derivative, where gradients corresponding to out-of-bounds indices will now be returned as 0. (#8634).jax2tf will force the converted code to use XLA for the code fragments under jax.jit, e.g., most jax.numpy functions (#7839).
jaxlib 0.1.74 (Nov 17, 2021)#
Enabled peer-to-peer copies between GPUs. Previously, GPU copies were bounced via the host, which is usually slower.
Added experimental MLIR Python bindings for use by JAX.
jax 0.2.25 (Nov 10, 2021)#
New features:
(Experimental)
jax.distributed.initialize
exposes multi-host GPU backend.jax.random.permutation
supports newindependent
keyword argument (#8430)
Breaking changes
Moved
jax.experimental.stax
tojax.example_libraries.stax
Moved
jax.experimental.optimizers
tojax.example_libraries.optimizers
New features:
Added
jax.lax.linalg.qdwh
.
jax 0.2.24 (Oct 19, 2021)#
jaxlib 0.1.73 (Oct 18, 2021)#
Multiple cuDNN versions are now supported for jaxlib GPU
cuda11
wheels.cuDNN 8.2 or newer. We recommend using the cuDNN 8.2 wheel if your cuDNN installation is new enough, since it supports additional functionality.
cuDNN 8.0.5 or newer.
Breaking changes:
The install commands for GPU jaxlib are as follows:
pip install --upgrade pip # Installs the wheel compatible with CUDA 11 and cuDNN 8.2 or newer. pip install --upgrade "jax[cuda]" -f https://storage.googleapis.com/jax-releases/jax_releases.html # Installs the wheel compatible with Cuda 11 and cudnn 8.2 or newer. pip install jax[cuda11_cudnn82] -f https://storage.googleapis.com/jax-releases/jax_releases.html # Installs the wheel compatible with Cuda 11 and cudnn 8.0.5 or newer. pip install jax[cuda11_cudnn805] -f https://storage.googleapis.com/jax-releases/jax_releases.html
jax 0.2.22 (Oct 12, 2021)#
Breaking Changes
Static arguments to
jax.pmap
must now be hashable.Unhashable static arguments have long been disallowed on
jax.jit
, but they were still permitted onjax.pmap
;jax.pmap
compared unhashable static arguments using object identity.This behavior is a footgun, since comparing arguments using object identity leads to recompilation each time the object identity changes. Instead, we now ban unhashable arguments: if a user of
jax.pmap
wants to compare static arguments by object identity, they can define__hash__
and__eq__
methods on their objects that do that, or wrap their objects in an object that has those operations with object identity semantics. Another option is to usefunctools.partial
to encapsulate the unhashable static arguments into the function object.jax.util.partial
was an accidental export that has now been removed. Usefunctools.partial
from the Python standard library instead.
Deprecations
The functions
jax.ops.index_update
,jax.ops.index_add
etc. are deprecated and will be removed in a future JAX release. Please use the.at
property on JAX arrays instead, e.g.,x.at[idx].set(y)
. For now, these functions produce aDeprecationWarning
.
New features:
An optimized C++ code-path improving the dispatch time for
pmap
is now the default when using jaxlib 0.1.72 or newer. The feature can be disabled using the--experimental_cpp_pmap
flag (orJAX_CPP_PMAP
environment variable).jax.numpy.unique
now supports an optionalfill_value
argument (#8121)
jaxlib 0.1.72 (Oct 12, 2021)#
Breaking changes:
Support for CUDA 10.2 and CUDA 10.1 has been dropped. Jaxlib now supports CUDA 11.1+.
Bug fixes:
Fixes https://github.com/google/jax/issues/7461, which caused wrong outputs on all platforms due to incorrect buffer aliasing inside the XLA compiler.
jax 0.2.21 (Sept 23, 2021)#
Breaking Changes
jax.api
has been removed. Functions that were available asjax.api.*
were aliases for functions injax.*
; please use the functions injax.*
instead.jax.partial
, andjax.lax.partial
were accidental exports that have now been removed. Usefunctools.partial
from the Python standard library instead.Boolean scalar indices now raise a
TypeError
; previously this silently returned wrong results (#7925).Many more
jax.numpy
functions now require array-like inputs, and will error if passed a list (#7747 #7802 #7907). See #7737 for a discussion of the rationale behind this change.When inside a transformation such as
jax.jit
,jax.numpy.array
always stages the array it produces into the traced computation. Previouslyjax.numpy.array
would sometimes produce a on-device array, even under ajax.jit
decorator. This change may break code that used JAX arrays to perform shape or index computations that must be known statically; the workaround is to perform such computations using classic NumPy arrays instead.jnp.ndarray
is now a true base-class for JAX arrays. In particular, this means that for a standard numpy arrayx
,isinstance(x, jnp.ndarray)
will now returnFalse
(#7927).
New features:
Added
jax.numpy.insert()
implementation (#7936).
jax 0.2.20 (Sept 2, 2021)#
Breaking Changes
jaxlib 0.1.71 (Sep 1, 2021)#
Breaking changes:
Support for CUDA 11.0 and CUDA 10.1 has been dropped. Jaxlib now supports CUDA 10.2 and CUDA 11.1+.
jax 0.2.19 (Aug 12, 2021)#
Breaking changes:
Support for NumPy 1.17 has been dropped, per the deprecation policy. Please upgrade to a supported NumPy version.
The
jit
decorator has been added around the implementation of a number of operators on JAX arrays. This speeds up dispatch times for common operators such as+
.This change should largely be transparent to most users. However, there is one known behavioral change, which is that large integer constants may now produce an error when passed directly to a JAX operator (e.g.,
x + 2**40
). The workaround is to cast the constant to an explicit type (e.g.,np.float64(2**40)
).
New features:
Improved the support for shape polymorphism in jax2tf for operations that need to use a dimension size in array computation, e.g.,
jnp.mean
. (#7317)
Bug fixes:
Some leaked trace errors from the previous release (#7613)
jaxlib 0.1.70 (Aug 9, 2021)#
Breaking changes:
Support for Python 3.6 has been dropped, per the deprecation policy. Please upgrade to a supported Python version.
Support for NumPy 1.17 has been dropped, per the deprecation policy. Please upgrade to a supported NumPy version.
The host_callback mechanism now uses one thread per local device for making the calls to the Python callbacks. Previously there was a single thread for all devices. This means that the callbacks may now be called interleaved. The callbacks corresponding to one device will still be called in sequence.
jax 0.2.18 (July 21 2021)#
Breaking changes:
Support for Python 3.6 has been dropped, per the deprecation policy. Please upgrade to a supported Python version.
The minimum jaxlib version is now 0.1.69.
The
backend
argument tojax.dlpack.from_dlpack()
has been removed.
New features:
Added a polar decomposition (
jax.scipy.linalg.polar()
).
Bug fixes:
Tightened the checks for lax.argmin and lax.argmax to ensure they are not used with an invalid
axis
value, or with an empty reduction dimension. (#7196)
jaxlib 0.1.69 (July 9 2021)#
Fix bugs in TFRT CPU backend that results in incorrect results.
jax 0.2.17 (July 9 2021)#
Bug fixes:
Default to the older “stream_executor” CPU runtime for jaxlib <= 0.1.68 to work around #7229, which caused wrong outputs on CPU due to a concurrency problem.
New features:
New SciPy function
jax.scipy.special.sph_harm()
.Reverse-mode autodiff functions (
jax.grad()
,jax.value_and_grad()
,jax.vjp()
, andjax.linear_transpose()
) support a parameter that indicates which named axes should be summed over in the backward pass if they were broadcasted over in the forward pass. This enables use of these APIs in a non-per-example way inside maps (initially onlyjax.experimental.maps.xmap()
) (#6950).
jax 0.2.16 (June 23 2021)#
jax 0.2.15 (June 23 2021)#
New features:
#7042 Turned on TFRT CPU backend with significant dispatch performance improvements on CPU.
The
jax2tf.convert()
supports inequalities and min/max for booleans (#6956).New SciPy function
jax.scipy.special.lpmn_values()
.
Breaking changes:
Support for NumPy 1.16 has been dropped, per the deprecation policy.
Bug fixes:
Fixed bug that prevented round-tripping from JAX to TF and back:
jax2tf.call_tf(jax2tf.convert)
(#6947).
jaxlib 0.1.68 (June 23 2021)#
Bug fixes:
Fixed bug in TFRT CPU backend that gets nans when transfer TPU buffer to CPU.
jax 0.2.14 (June 10 2021)#
New features:
The
jax2tf.convert()
now has support forpjit
andsharded_jit
.A new configuration option JAX_TRACEBACK_FILTERING controls how JAX filters tracebacks.
A new traceback filtering mode using
__tracebackhide__
is now enabled by default in sufficiently recent versions of IPython.The
jax2tf.convert()
supports shape polymorphism even when the unknown dimensions are used in arithmetic operations, e.g.,jnp.reshape(-1)
(#6827).The
jax2tf.convert()
generates custom attributes with location information in TF ops. The code that XLA generates after jax2tf has the same location information as JAX/XLA.New SciPy function
jax.scipy.special.lpmn()
.
Bug fixes:
The
jax2tf.convert()
now ensures that it uses the same typing rules for Python scalars and for choosing 32-bit vs. 64-bit computations as JAX (#6883).The
jax2tf.convert()
now scopes theenable_xla
conversion parameter properly to apply only during the just-in-time conversion (#6720).The
jax2tf.convert()
now convertslax.dot_general
using theXlaDot
TensorFlow op, for better fidelity w.r.t. JAX numerical precision (#6717).The
jax2tf.convert()
now has support for inequality comparisons and min/max for complex numbers (#6892).
jaxlib 0.1.67 (May 17 2021)#
jaxlib 0.1.66 (May 11 2021)#
New features:
CUDA 11.1 wheels are now supported on all CUDA 11 versions 11.1 or higher.
NVidia now promises compatibility between CUDA minor releases starting with CUDA 11.1. This means that JAX can release a single CUDA 11.1 wheel that is compatible with CUDA 11.2 and 11.3.
There is no longer a separate jaxlib release for CUDA 11.2 (or higher); use the CUDA 11.1 wheel for those versions (cuda111).
Jaxlib now bundles
libdevice.10.bc
in CUDA wheels. There should be no need to point JAX to a CUDA installation to find this file.Added automatic support for static keyword arguments to the
jit()
implementation.Added support for pretransformation exception traces.
Initial support for pruning unused arguments from
jit()
-transformed computations. Pruning is still a work in progress.Improved the string representation of
PyTreeDef
objects.Added support for XLA’s variadic ReduceWindow.
Bug fixes:
Fixed a bug in the remote cloud TPU support when large numbers of arguments are passed to a computation.
Fix a bug that meant that JAX garbage collection was not triggered by
jit()
transformed functions.
jax 0.2.13 (May 3 2021)#
New features:
When combined with jaxlib 0.1.66,
jax.jit()
now supports static keyword arguments. A newstatic_argnames
option has been added to specify keyword arguments as static.jax.nonzero()
has a new optionalsize
argument that allows it to be used withinjit
(#6501)jax.numpy.unique()
now supports theaxis
argument (#6532).jax.experimental.host_callback.call()
now supportspjit.pjit
(#6569).Added
jax.scipy.linalg.eigh_tridiagonal()
that computes the eigenvalues of a tridiagonal matrix. Only eigenvalues are supported at present.The order of the filtered and unfiltered stack traces in exceptions has been changed. The traceback attached to an exception thrown from JAX-transformed code is now filtered, with an
UnfilteredStackTrace
exception containing the original trace as the__cause__
of the filtered exception. Filtered stack traces now also work with Python 3.6.If an exception is thrown by code that has been transformed by reverse-mode automatic differentiation, JAX now attempts to attach as a
__cause__
of the exception aJaxStackTraceBeforeTransformation
object that contains the stack trace that created the original operation in the forward pass. Requires jaxlib 0.1.66.
Breaking changes:
The following function names have changed. There are still aliases, so this should not break existing code, but the aliases will eventually be removed so please change your code.
host_id
–>process_index()
host_count
–>process_count()
host_ids
–>range(jax.process_count())
Similarly, the argument to
local_devices()
has been renamed fromhost_id
toprocess_index
.Arguments to
jax.jit()
other than the function are now marked as keyword-only. This change is to prevent accidental breakage when arguments are added tojit
.
Bug fixes:
jaxlib 0.1.65 (April 7 2021)#
jax 0.2.12 (April 1 2021)#
New features
New profiling APIs:
jax.profiler.start_trace()
,jax.profiler.stop_trace()
, andjax.profiler.trace()
jax.lax.reduce()
is now differentiable.
Breaking changes:
The minimum jaxlib version is now 0.1.64.
Some profiler APIs names have been changed. There are still aliases, so this should not break existing code, but the aliases will eventually be removed so please change your code.
TraceContext
–>TraceAnnotation()
StepTraceContext
–>StepTraceAnnotation()
trace_function
–>annotate_function()
Omnistaging can no longer be disabled. See omnistaging for more information.
Python integers larger than the maximum
int64
value will now lead to an overflow in all cases, rather than being silently converted touint64
in some cases (#6047).Outside X64 mode, Python integers outside the range representable by
int32
will now lead to anOverflowError
rather than having their value silently truncated.
Bug fixes:
host_callback
now supports empty arrays in arguments and results (#6262).jax.random.randint()
clips rather than wraps of out-of-bounds limits, and can now generate integers in the full range of the specified dtype (#5868)
jax 0.2.11 (March 23 2021)#
New features:
Bug fixes:
#6136 generalized
jax.flatten_util.ravel_pytree
to handle integer dtypes.#6129 fixed a bug with handling some constants like
enum.IntEnums
#6145 fixed batching issues with incomplete beta functions
#6014 fixed H2D transfers during tracing
#6165 avoids OverflowErrors when converting some large Python integers to floats
Breaking changes:
The minimum jaxlib version is now 0.1.62.
jaxlib 0.1.64 (March 18 2021)#
jaxlib 0.1.63 (March 17 2021)#
jax 0.2.10 (March 5 2021)#
New features:
jax.scipy.stats.chi2()
is now available as a distribution with logpdf and pdf methods.jax.scipy.stats.betabinom()
is now available as a distribution with logpmf and pmf methods.Added
jax.experimental.jax2tf.call_tf()
to call TensorFlow functions from JAX (#5627) and README).Extended the batching rule for
lax.pad
to support batching of the padding values.
Bug fixes:
jax.numpy.take()
properly handles negative indices (#5768)
Breaking changes:
JAX’s promotion rules were adjusted to make promotion more consistent and invariant to JIT. In particular, binary operations can now result in weakly-typed values when appropriate. The main user-visible effect of the change is that some operations result in outputs of different precision than before; for example the expression
jnp.bfloat16(1) + 0.1 * jnp.arange(10)
previously returned afloat64
array, and now returns abfloat16
array. JAX’s type promotion behavior is described at Type promotion semantics.jax.numpy.linspace()
now computes the floor of integer values, i.e., rounding towards -inf rather than 0. This change was made to match NumPy 1.20.0.jax.numpy.i0()
no longer accepts complex numbers. Previously the function computed the absolute value of complex arguments. This change was made to match the semantics of NumPy 1.20.0.Several
jax.numpy
functions no longer accept tuples or lists in place of array arguments:jax.numpy.pad()
, :funcjax.numpy.ravel
,jax.numpy.repeat()
,jax.numpy.reshape()
. In general,jax.numpy
functions should be used with scalars or array arguments.
jaxlib 0.1.62 (March 9 2021)#
New features:
jaxlib wheels are now built to require AVX instructions on x86-64 machines by default. If you want to use JAX on a machine that doesn’t support AVX, you can build a jaxlib from source using the
--target_cpu_features
flag tobuild.py
.--target_cpu_features
also replaces--enable_march_native
.
jaxlib 0.1.61 (February 12 2021)#
jaxlib 0.1.60 (Febuary 3 2021)#
Bug fixes:
Fixed a memory leak when converting CPU DeviceArrays to NumPy arrays. The memory leak was present in jaxlib releases 0.1.58 and 0.1.59.
bool
,int8
, anduint8
are now considered safe to cast tobfloat16
NumPy extension type.
jax 0.2.9 (January 26 2021)#
New features:
Extend the
jax.experimental.loops
module with support for pytrees. Improved error checking and error messages.Add
jax.experimental.enable_x64()
andjax.experimental.disable_x64()
. These are context managers which allow X64 mode to be temporarily enabled/disabled within a session.
Breaking changes:
jax.ops.segment_sum()
now drops segment IDs that are out of range rather than wrapping them into the segment ID space. This was done for performance reasons.
jaxlib 0.1.59 (January 15 2021)#
jax 0.2.8 (January 12 2021)#
New features:
Add
jax.closure_convert()
for use with higher-order custom derivative functions. (#5244)Add
jax.experimental.host_callback.call()
to call a custom Python function on the host and return a result to the device computation. (#5243)
Bug fixes:
jax.numpy.arccosh
now returns the same branch asnumpy.arccosh
for complex inputs (#5156)host_callback.id_tap
now works forjax.pmap
also. There is an optional parameter forid_tap
andid_print
to request that the device from which the value is tapped be passed as a keyword argument to the tap function (#5182).
Breaking changes:
jax.numpy.pad
now takes keyword arguments. Positional argumentconstant_values
has been removed. In addition, passing unsupported keyword arguments raises an error.Changes for
jax.experimental.host_callback.id_tap()
(#5243):Removed support for
kwargs
forjax.experimental.host_callback.id_tap()
. (This support has been deprecated for a few months.)Changed the printing of tuples for
jax.experimental.host_callback.id_print()
to use ‘(‘ instead of ‘[‘.Changed the
jax.experimental.host_callback.id_print()
in presence of JVP to print a pair of primal and tangent. Previously, there were two separate print operations for the primals and the tangent.host_callback.outfeed_receiver
has been removed (it is not necessary, and was deprecated a few months ago).
New features:
New flag for debugging
inf
, analagous to that forNaN
(#5224).
jax 0.2.7 (Dec 4 2020)#
New features:
Add
jax.device_put_replicated
Add multi-host support to
jax.experimental.sharded_jit
Add support for differentiating eigenvalues computed by
jax.numpy.linalg.eig
Add support for building on Windows platforms
Add support for general in_axes and out_axes in
jax.pmap
Add complex support for
jax.numpy.linalg.slogdet
Bug fixes:
Fix higher-than-second order derivatives of
jax.numpy.sinc
at zeroFix some hard-to-hit bugs around symbolic zeros in transpose rules
Breaking changes:
jax.experimental.optix
has been deleted, in favor of the standaloneoptax
Python package.indexing of JAX arrays with non-tuple sequences now raises a
TypeError
. This type of indexing has been deprecated in Numpy since v1.16, and in JAX since v0.2.4. See #4564.
jax 0.2.6 (Nov 18 2020)#
New Features:
Add support for shape-polymorphic tracing for the jax.experimental.jax2tf converter. See README.md.
Breaking change cleanup
Raise an error on non-hashable static arguments for jax.jit and xla_computation. See cb48f42.
Improve consistency of type promotion behavior (#4744):
Adding a complex Python scalar to a JAX floating point number respects the precision of the JAX float. For example,
jnp.float32(1) + 1j
now returnscomplex64
, where previously it returnedcomplex128
.Results of type promotion with 3 or more terms involving uint64, a signed int, and a third type are now independent of the order of arguments. For example:
jnp.result_type(jnp.uint64, jnp.int64, jnp.float16)
andjnp.result_type(jnp.float16, jnp.uint64, jnp.int64)
both returnfloat16
, where previously the first returnedfloat64
and the second returnedfloat16
.
The contents of the (undocumented)
jax.lax_linalg
linear algebra module are now exposed publicly asjax.lax.linalg
.jax.random.PRNGKey
now produces the same results in and out of JIT compilation (#4877). This required changing the result for a given seed in a few particular cases:With
jax_enable_x64=False
, negative seeds passed as Python integers now return a different result outside JIT mode. For example,jax.random.PRNGKey(-1)
previously returned[4294967295, 4294967295]
, and now returns[0, 4294967295]
. This matches the behavior in JIT.Seeds outside the range representable by
int64
outside JIT now result in anOverflowError
rather than aTypeError
. This matches the behavior in JIT.
To recover the keys returned previously for negative integers with
jax_enable_x64=False
outside JIT, you can use:key = random.PRNGKey(-1).at[0].set(0xFFFFFFFF)
DeviceArray now raises
RuntimeError
instead ofValueError
when trying to access its value while it has been deleted.
jaxlib 0.1.58 (January 12ish 2021)#
Fixed a bug that meant JAX sometimes return platform-specific types (e.g.,
np.cint
) instead of standard types (e.g.,np.int32
). (#4903)Fixed a crash when constant-folding certain int16 operations. (#4971)
Added an
is_leaf
predicate topytree.flatten()
.
jaxlib 0.1.57 (November 12 2020)#
Fixed manylinux2010 compliance issues in GPU wheels.
Switched the CPU FFT implementation from Eigen to PocketFFT.
Fixed a bug where the hash of bfloat16 values was not correctly initialized and could change (#4651).
Add support for retaining ownership when passing arrays to DLPack (#4636).
Fixed a bug for batched triangular solves with sizes greater than 128 but not a multiple of 128.
Fixed a bug when performing concurrent FFTs on multiple GPUs (#3518).
Fixed a bug in profiler where tools are missing (#4427).
Dropped support for CUDA 10.0.
jax 0.2.5 (October 27 2020)#
Improvements:
Ensure that
check_jaxpr
does not perform FLOPS. See #4650.Expanded the set of JAX primitives converted by jax2tf. See primitives_with_limited_support.md.
jax 0.2.4 (October 19 2020)#
jaxlib 0.1.56 (October 14, 2020)#
jax 0.2.3 (October 14 2020)#
The reason for another release so soon is we need to temporarily roll back a new jit fastpath while we look into a performance degradation
jax 0.2.2 (October 13 2020)#
jax 0.2.1 (October 6 2020)#
Improvements:
As a benefit of omnistaging, the host_callback functions are executed (in program order) even if the result of the
jax.experimental.host_callback.id_print()
/jax.experimental.host_callback.id_tap()
is not used in the computation.
jax (0.2.0) (September 23 2020)#
Improvements:
Omnistaging on by default. See #3370 and omnistaging
jax (0.1.77) (September 15 2020)#
Breaking changes:
New simplified interface for
jax.experimental.host_callback.id_tap()
(#4101)
jaxlib 0.1.55 (September 8, 2020)#
Update XLA:
Fix bug in DLPackManagedTensorToBuffer (#4196)
jax 0.1.76 (September 8, 2020)#
jax 0.1.75 (July 30, 2020)#
Bug Fixes:
make jnp.abs() work for unsigned inputs (#3914)
Improvements:
“Omnistaging” behavior added behind a flag, disabled by default (#3370)
jax 0.1.74 (July 29, 2020)#
New Features:
BFGS (#3101)
TPU support for half-precision arithmetic (#3878)
Bug Fixes:
Prevent some accidental dtype warnings (#3874)
Fix a multi-threading bug in custom derivatives (#3845, #3869)
Improvements:
Faster searchsorted implementation (#3873)
Better test coverage for jax.numpy sorting algorithms (#3836)
jaxlib 0.1.52 (July 22, 2020)#
Update XLA.
jax 0.1.73 (July 22, 2020)#
The minimum jaxlib version is now 0.1.51.
New Features:
jax.image.resize. (#3703)
hfft and ihfft (#3664)
jax.numpy.intersect1d (#3726)
jax.numpy.lexsort (#3812)
lax.scan
and thescan
primitive support anunroll
parameter for loop unrolling when lowering to XLA (#3738).
Bug Fixes:
Fix reduction repeated axis error (#3618)
Fix shape rule for lax.pad for input dimensions of size 0. (#3608)
make psum transpose handle zero cotangents (#3653)
Fix shape error when taking JVP of reduce-prod over size 0 axis. (#3729)
Support differentiation through jax.lax.all_to_all (#3733)
address nan issue in jax.scipy.special.zeta (#3777)
Improvements:
Many improvements to jax2tf
Reimplement argmin/argmax using a single pass variadic reduction. (#3611)
Enable XLA SPMD partitioning by default. (#3151)
Add support for 0d transpose convolution (#3643)
Make LU gradient work for low-rank matrices (#3610)
support multiple_results and custom JVPs in jet (#3657)
Generalize reduce-window padding to support (lo, hi) pairs. (#3728)
Implement complex convolutions on CPU and GPU. (#3735)
Make jnp.take work for empty slices of empty arrays. (#3751)
Relax dimension ordering rules for dot_general. (#3778)
Enable buffer donation for GPU. (#3800)
Add support for base dilation and window dilation to reduce window op… (#3803)
jaxlib 0.1.51 (July 2, 2020)#
Update XLA.
Add new runtime support for host_callback.
jax 0.1.72 (June 28, 2020)#
Bug fixes:
Fix an odeint bug introduced in the previous release, see #3587.
jax 0.1.71 (June 25, 2020)#
The minimum jaxlib version is now 0.1.48.
Bug fixes:
Allow
jax.experimental.ode.odeint
dynamics functions to close over values with respect to which we’re differentiating #3562.
jaxlib 0.1.50 (June 25, 2020)#
Add support for CUDA 11.0.
Drop support for CUDA 9.2 (we only maintain support for the last four CUDA versions.)
Update XLA.
jaxlib 0.1.49 (June 19, 2020)#
Bug fixes:
Fix build issue that could result in slow compiles (https://github.com/tensorflow/tensorflow/commit/f805153a25b00d12072bd728e91bb1621bfcf1b1)
jaxlib 0.1.48 (June 12, 2020)#
New features:
Adds support for fast traceback collection.
Adds preliminary support for on-device heap profiling.
Implements
np.nextafter
forbfloat16
types.Complex128 support for FFTs on CPU and GPU.
Bugfixes:
Improved float64
tanh
accuracy on GPU.float64 scatters on GPU are much faster.
Complex matrix multiplication on CPU should be much faster.
Stable sorts on CPU should actually be stable now.
Concurrency bug fix in CPU backend.
jax 0.1.70 (June 8, 2020)#
New features:
lax.switch
introduces indexed conditionals with multiple branches, together with a generalization of thecond
primitive #3318.
jax 0.1.69 (June 3, 2020)#
jax 0.1.68 (May 21, 2020)#
New features:
lax.cond()
supports a single-operand form, taken as the argument to both branches #2993.
Notable changes:
The format of the
transforms
keyword for thejax.experimental.host_callback.id_tap()
primitive has changed #3132.
jax 0.1.67 (May 12, 2020)#
New features:
Support for reduction over subsets of a pmapped axis using
axis_index_groups
#2382.Experimental support for printing and calling host-side Python function from compiled code. See id_print and id_tap (#3006).
Notable changes:
The visibility of names exported from
jax.numpy
has been tightened. This may break code that was making use of names that were previously exported accidentally.
jaxlib 0.1.47 (May 8, 2020)#
Fixes crash for outfeed.
jaxlib 0.1.46 (May 5, 2020)#
Fixes crash for linear algebra functions on Mac OS X (#432).
Fixes an illegal instruction crash caused by using AVX512 instructions when an operating system or hypervisor disabled them (#2906).
jax 0.1.65 (April 30, 2020)#
New features:
Differentiation of determinants of singular matrices #2809.
Bug fixes:
jaxlib 0.1.45 (April 21, 2020)#
Fixes segfault: #2755
Plumb is_stable option on Sort HLO through to Python.
jax 0.1.64 (April 21, 2020)#
New features:
Add syntactic sugar for functional indexed updates #2684.
Add
jax.numpy.unique()
#2760.Add
jax.numpy.rint()
#2724.Add
jax.numpy.rint()
#2724.Add more primitive rules for
jax.experimental.jet()
.
Bug fixes:
Better errors:
Improves error message for reverse-mode differentiation of
lax.while_loop()
#2129.
jaxlib 0.1.44 (April 16, 2020)#
Fixes a bug where if multiple GPUs of different models were present, JAX would only compile programs suitable for the first GPU.
Bugfix for
batch_group_count
convolutions.Added precompiled SASS for more GPU versions to avoid startup PTX compilation hang.
jax 0.1.63 (April 12, 2020)#
Added
jax.custom_jvp
andjax.custom_vjp
from #2026, see the tutorial notebook. Deprecatedjax.custom_transforms
and removed it from the docs (though it still works).Add
scipy.sparse.linalg.cg
#2566.Changed how Tracers are printed to show more useful information for debugging #2591.
Made
jax.numpy.isclose
handlenan
andinf
correctly #2501.Added several new rules for
jax.experimental.jet
#2537.Fixed
jax.experimental.stax.BatchNorm
whenscale
/center
isn’t provided.Fix some missing cases of broadcasting in
jax.numpy.einsum
#2512.Implement
jax.numpy.cumsum
andjax.numpy.cumprod
in terms of a parallel prefix scan #2596 and makereduce_prod
differentiable to arbitray order #2597.Add
batch_group_count
toconv_general_dilated
#2635.Add docstring for
test_util.check_grads
#2656.Add
callback_transform
#2665.Implement
rollaxis
,convolve
/correlate
1d & 2d,copysign
,trunc
,roots
, andquantile
/percentile
interpolation options.
jaxlib 0.1.43 (March 31, 2020)#
Fixed a performance regression for Resnet-50 on GPU.
jax 0.1.62 (March 21, 2020)#
JAX has dropped support for Python 3.5. Please upgrade to Python 3.6 or newer.
Removed the internal function
lax._safe_mul
, which implemented the convention0. * nan == 0.
. This change means some programs when differentiated will produce nans when they previously produced correct values, though it ensures nans rather than silently incorrect results are produced for other programs. See #2447 and #1052 for details.Added an
all_gather
parallel convenience function.More type annotations in core code.
jaxlib 0.1.42 (March 19, 2020)#
jaxlib 0.1.41 broke cloud TPU support due to an API incompatibility. This release fixes it again.
JAX has dropped support for Python 3.5. Please upgrade to Python 3.6 or newer.
jax 0.1.61 (March 17, 2020)#
Fixes Python 3.5 support. This will be the last JAX or jaxlib release that supports Python 3.5.
jax 0.1.60 (March 17, 2020)#
New features:
jax.pmap()
hasstatic_broadcast_argnums
argument which allows the user to specify arguments that should be treated as compile-time constants and should be broadcasted to all devices. It works analogously tostatic_argnums
injax.jit()
.Improved error messages for when tracers are mistakenly saved in global state.
Added
jax.nn.one_hot()
utility function.Added
jax.experimental.jet
for exponentially faster higher-order automatic differentiation.Added more correctness checking to arguments of
jax.lax.broadcast_in_dim()
.
The minimum jaxlib version is now 0.1.41.
jaxlib 0.1.40 (March 4, 2020)#
Adds experimental support in Jaxlib for TensorFlow profiler, which allows tracing of CPU and GPU computations from TensorBoard.
Includes prototype support for multihost GPU computations that communicate via NCCL.
Improves performance of NCCL collectives on GPU.
Adds TopK, CustomCallWithoutLayout, CustomCallWithLayout, IGammaGradA and RandomGamma implementations.
Supports device assignments known at XLA compilation time.
jax 0.1.59 (February 11, 2020)#
Breaking changes
The minimum jaxlib version is now 0.1.38.
Simplified
Jaxpr
by removing theJaxpr.freevars
andJaxpr.bound_subjaxprs
. The call primitives (xla_call
,xla_pmap
,sharded_call
, andremat_call
) get a new parametercall_jaxpr
with a fully-closed (noconstvars
) jaxpr. Also, added a new fieldcall_primitive
to primitives.
New features:
Reverse-mode automatic differentiation (e.g.
grad
) oflax.cond
, making it now differentiable in both modes (#2091)JAX now supports DLPack, which allows sharing CPU and GPU arrays in a zero-copy way with other libraries, such as PyTorch.
JAX GPU DeviceArrays now support
__cuda_array_interface__
, which is another zero-copy protocol for sharing GPU arrays with other libraries such as CuPy and Numba.JAX CPU device buffers now implement the Python buffer protocol, which allows zero-copy buffer sharing between JAX and NumPy.
Added JAX_SKIP_SLOW_TESTS environment variable to skip tests known as slow.
jaxlib 0.1.39 (February 11, 2020)#
Updates XLA.
jaxlib 0.1.38 (January 29, 2020)#
CUDA 9.0 is no longer supported.
CUDA 10.2 wheels are now built by default.
jax 0.1.58 (January 28, 2020)#
Breaking changes
JAX has dropped Python 2 support, because Python 2 reached its end of life on January 1, 2020. Please update to Python 3.5 or newer.
New features
Forward-mode automatic differentiation (
jvp
) of while loop (#1980)
New NumPy and SciPy functions:
Batched Cholesky decomposition on GPU now uses a more efficient batched kernel.
Notable bug fixes#
With the Python 3 upgrade, JAX no longer depends on
fastcache
, which should help with installation.