Rank promotion warning

NumPy broadcasting rules allow automatic promotion of arguments from one rank (number of array axes) to another. This behavior can be convenient when intended but can also lead to surprising bugs where a silent rank promotion masks an underlying shape error.

Here’s an example of rank promotion:

>>> import numpy as np
>>> x = np.arange(12).reshape(4, 3)
>>> y = np.array([0, 1, 0])
>>> x + y
array([[ 0,  2,  2],
       [ 3,  5,  5],
       [ 6,  8,  8],
       [ 9, 11, 11]])

To avoid potential surprises, jax.numpy is configurable so that expressions requiring rank promotion can lead to a warning, error, or can be allowed just like regular NumPy. The configuration option is named jax_numpy_rank_promotion and it can take on string values allow, warn, and raise. The default setting is warn, which raises a warning on the first occurrence of rank promotion. The raise setting raises an error on rank promotion, and allow allows rank promotion without warning or error.

As with most other JAX configuration options, you can set this option in several ways. One is by using jax.config in your code:

from jax.config import config
config.update("jax_numpy_rank_promotion", "allow")

You can also set the option using the environment variable JAX_NUMPY_RANK_PROMOTION, for example as JAX_NUMPY_RANK_PROMOTION='raise'. Finally, when using absl-py the option can be set with a command-line flag.