jax.numpy.expand_dims

Contents

jax.numpy.expand_dims#

jax.numpy.expand_dims(a, axis)[source]#

Expand the shape of an array.

LAX-backend implementation of numpy.expand_dims().

Original docstring below.

Insert a new axis that will appear at the axis position in the expanded array shape.

Parameters:
  • a (array_like) – Input array.

  • axis (int or tuple of ints) –

    Position in the expanded axes where the new axis (or axes) is placed.

    Deprecated since version 1.13.0: Passing an axis where axis > a.ndim will be treated as axis == a.ndim, and passing axis < -a.ndim - 1 will be treated as axis == 0. This behavior is deprecated.

    Changed in version 1.18.0: A tuple of axes is now supported. Out of range axes as described above are now forbidden and raise an AxisError.

Returns:

result – View of a with the number of dimensions increased.

Return type:

ndarray