jax.lax.broadcast_shapes

jax.lax.broadcast_shapes#

jax.lax.broadcast_shapes(*shapes: tuple[int, ...]) tuple[int, ...][source]#
jax.lax.broadcast_shapes(*shapes: tuple[int | Tracer, ...]) tuple[int | Tracer, ...]

Returns the shape that results from NumPy broadcasting of shapes.