jax.lax.expand_dims

Contents

jax.lax.expand_dims#

jax.lax.expand_dims(array, dimensions)[source]#

Insert any number of size 1 dimensions into an array.

Parameters:
  • array (jax.typing.ArrayLike)

  • dimensions (Sequence[int])

Return type:

Array