jax.numpy.expand_dims#
- jax.numpy.expand_dims(a, axis)[source]#
Expand the shape of an array.
LAX-backend implementation of
numpy.expand_dims()
.Original docstring below.
Insert a new axis that will appear at the axis position in the expanded array shape.
- Parameters:
a (array_like) – Input array.
axis (int or tuple of ints) –
Position in the expanded axes where the new axis (or axes) is placed.
Deprecated since version 1.13.0: Passing an axis where
axis > a.ndim
will be treated asaxis == a.ndim
, and passingaxis < -a.ndim - 1
will be treated asaxis == 0
. This behavior is deprecated.Changed in version 1.18.0: A tuple of axes is now supported. Out of range axes as described above are now forbidden and raise an AxisError.
- Returns:
result – View of a with the number of dimensions increased.
- Return type:
ndarray