# jax.numpy.expand_dimsΒΆ

jax.numpy.expand_dims(a, axis)[source]ΒΆ

Expand the shape of an array.

LAX-backend implementation of expand_dims(). Original docstring below.

Insert a new axis that will appear at the axis position in the expanded array shape.

Parameters
• a (array_like) β Input array.

• axis (int or tuple of ints) β Position in the expanded axes where the new axis (or axes) is placed.

Returns

result β View of a with the number of dimensions increased.

Return type

ndarray

squeeze()

The inverse operation, removing singleton dimensions

reshape()

Insert, remove, and combine dimensions, and resize existing ones

doc.indexing(), atleast_1d(), atleast_2d(), atleast_3d()

Examples

>>> x = np.array([1, 2])
>>> x.shape
(2,)


The following is equivalent to x[np.newaxis, :] or x[np.newaxis]:

>>> y = np.expand_dims(x, axis=0)
>>> y
array([[1, 2]])
>>> y.shape
(1, 2)


The following is equivalent to x[:, np.newaxis]:

>>> y = np.expand_dims(x, axis=1)
>>> y
array([[1],
[2]])
>>> y.shape
(2, 1)


axis may also be a tuple:

>>> y = np.expand_dims(x, axis=(0, 1))
>>> y
array([[[1, 2]]])

>>> y = np.expand_dims(x, axis=(2, 0))
>>> y
array([[[1],
[2]]])


Note that some examples may use None instead of np.newaxis. These are the same objects:

>>> np.newaxis is None
True