jax.lax.broadcast_in_dim

jax.lax.broadcast_in_dim#

jax.lax.broadcast_in_dim(operand, shape, broadcast_dimensions)[source]#

Wraps XLA’s BroadcastInDim operator.

Parameters:
  • operand (Union[Array, ndarray, bool_, number, bool, int, float, complex]) – an array

  • shape (Sequence[Union[int, Any]]) – the shape of the target array

  • broadcast_dimensions (Sequence[int]) – to which dimension in the target shape each dimension of the operand shape corresponds to. That is, dimension i of the operand becomes dimension broadcast_dimensions[i] of the result.

Return type:

Array

Returns:

An array containing the result.

See also

jax.lax.broadcast : simpler interface to add new leading dimensions.