jax.numpy.allclose#

jax.numpy.allclose(a, b, rtol=1e-05, atol=1e-08, equal_nan=False)[source]#

Check if two arrays are element-wise approximately equal within a tolerance.

JAX implementation of numpy.allclose().

Essentially this function evaluates the following condition:

\[|a - b| \le \mathtt{atol} + \mathtt{rtol} * |b|\]

jnp.inf in a will be considered equal to jnp.inf in b.

Parameters:
  • a (ArrayLike) – first input array to compare.

  • b (ArrayLike) – second input array to compare.

  • rtol (ArrayLike) – relative tolerance used for approximate equality. Default = 1e-05.

  • atol (ArrayLike) – absolute tolerance used for approximate equality. Default = 1e-08.

  • equal_nan (bool) – Boolean. If True, NaNs in a will be considered equal to NaNs in b. Default is False.

Returns:

Boolean scalar array indicating whether the input arrays are element-wise approximately equal within the specified tolerances.

Return type:

Array

Examples

>>> jnp.allclose(jnp.array([1e6, 2e6, 3e6]), jnp.array([1e6, 2e6, 3e7]))
Array(False, dtype=bool)
>>> jnp.allclose(jnp.array([1e6, 2e6, 3e6]),
...              jnp.array([1.00008e6, 2.00008e7, 3.00008e8]), rtol=1e3)
Array(True, dtype=bool)
>>> jnp.allclose(jnp.array([1e6, 2e6, 3e6]),
...              jnp.array([1.00001e6, 2.00002e6, 3.00009e6]), atol=1e3)
Array(True, dtype=bool)
>>> jnp.allclose(jnp.array([jnp.nan, 1, 2]),
...              jnp.array([jnp.nan, 1, 2]), equal_nan=True)
Array(True, dtype=bool)