jax.lax.broadcast_in_dim

jax.lax.broadcast_in_dim#

jax.lax.broadcast_in_dim(operand, shape, broadcast_dimensions)[source]#

Wraps XLA’s BroadcastInDim operator.

Parameters:
  • operand (jax.typing.ArrayLike) – an array

  • shape (Sequence[int | Any]) – the shape of the target array

  • broadcast_dimensions (Sequence[int]) – to which dimension in the target shape each dimension of the operand shape corresponds to. That is, dimension i of the operand becomes dimension broadcast_dimensions[i] of the result.

Returns:

An array containing the result.

Return type:

Array

See also

jax.lax.broadcast : simpler interface to add new leading dimensions.