jax.numpy.expand_dimsΒΆ
-
jax.numpy.
expand_dims
(a, axis)[source]ΒΆ Expand the shape of an array.
LAX-backend implementation of
expand_dims()
. Original docstring below.Insert a new axis that will appear at the axis position in the expanded array shape.
- Parameters
a (array_like) β Input array.
axis (int or tuple of ints) β Position in the expanded axes where the new axis (or axes) is placed.
- Returns
result β View of a with the number of dimensions increased.
- Return type
See also
squeeze()
The inverse operation, removing singleton dimensions
reshape()
Insert, remove, and combine dimensions, and resize existing ones
doc.indexing()
,atleast_1d()
,atleast_2d()
,atleast_3d()
Examples
>>> x = np.array([1, 2]) >>> x.shape (2,)
The following is equivalent to
x[np.newaxis, :]
orx[np.newaxis]
:>>> y = np.expand_dims(x, axis=0) >>> y array([[1, 2]]) >>> y.shape (1, 2)
The following is equivalent to
x[:, np.newaxis]
:>>> y = np.expand_dims(x, axis=1) >>> y array([[1], [2]]) >>> y.shape (2, 1)
axis
may also be a tuple:>>> y = np.expand_dims(x, axis=(0, 1)) >>> y array([[[1, 2]]])
>>> y = np.expand_dims(x, axis=(2, 0)) >>> y array([[[1], [2]]])
Note that some examples may use
None
instead ofnp.newaxis
. These are the same objects:>>> np.newaxis is None True