jax.numpy.broadcast_shapes#
- jax.numpy.broadcast_shapes(*shapes)[source]#
Broadcast input shapes to a common output shape.
JAX implementation of
numpy.broadcast_shapes()
. JAX uses NumPy-style broadcasting rules, which you can read more about at NumPy broadcasting.- Parameters:
shapes – 0 or more shapes specified as sequences of integers
- Returns:
The broadcasted shape as a tuple of integers.
See also
jax.numpy.broadcast_arrays()
: broadcast arrays to a common shape.jax.numpy.broadcast_to()
: broadcast an array to a specified shape.
Examples
Some compatible shapes:
>>> jnp.broadcast_shapes((1,), (4,)) (4,) >>> jnp.broadcast_shapes((3, 1), (4,)) (3, 4) >>> jnp.broadcast_shapes((3, 1), (1, 4), (5, 1, 1)) (5, 3, 4)
Incompatible shapes:
>>> jnp.broadcast_shapes((3, 1), (4, 1)) Traceback (most recent call last): ValueError: Incompatible shapes for broadcasting: shapes=[(3, 1), (4, 1)]