jax.numpy.broadcast_arrays#

jax.numpy.broadcast_arrays(*args)[source]#

Broadcast arrays to a common shape.

JAX implementation of numpy.broadcast_arrays(). JAX uses NumPy-style broadcasting rules, which you can read more about at NumPy broadcasting.

Parameters:

args (ArrayLike) – zero or more array-like objects to be broadcasted.

Returns:

a list of arrays containing broadcasted copies of the inputs.

Return type:

list[Array]

See also

Examples

>>> x = jnp.arange(3)
>>> y = jnp.int32(1)
>>> jnp.broadcast_arrays(x, y)
[Array([0, 1, 2], dtype=int32), Array([1, 1, 1], dtype=int32)]
>>> x = jnp.array([[1, 2, 3]])
>>> y = jnp.array([[10],
...                [20]])
>>> x2, y2 = jnp.broadcast_arrays(x, y)
>>> x2
Array([[1, 2, 3],
       [1, 2, 3]], dtype=int32)
>>> y2
Array([[10, 10, 10],
       [20, 20, 20]], dtype=int32)