jax.numpy.broadcast_shapes#

jax.numpy.broadcast_shapes(*shapes)[source]#

Broadcast input shapes to a common output shape.

JAX implementation of numpy.broadcast_shapes(). JAX uses NumPy-style broadcasting rules, which you can read more about at NumPy broadcasting.

Parameters:

shapes – 0 or more shapes specified as sequences of integers

Returns:

The broadcasted shape as a tuple of integers.

See also

Examples

Some compatible shapes:

>>> jnp.broadcast_shapes((1,), (4,))
(4,)
>>> jnp.broadcast_shapes((3, 1), (4,))
(3, 4)
>>> jnp.broadcast_shapes((3, 1), (1, 4), (5, 1, 1))
(5, 3, 4)

Incompatible shapes:

>>> jnp.broadcast_shapes((3, 1), (4, 1))  
Traceback (most recent call last):
ValueError: Incompatible shapes for broadcasting: shapes=[(3, 1), (4, 1)]