jax.numpy.expand_dims#
- jax.numpy.expand_dims(a, axis)[source]#
Insert dimensions of length 1 into array
JAX implementation of
numpy.expand_dims()
, implemented viajax.lax.expand_dims()
.- Parameters:
- Returns:
Copy of
a
with added dimensions.- Return type:
Notes
Unlike
numpy.expand_dims()
,jax.numpy.expand_dims()
will return a copy rather than a view of the input array. However, under JIT, the compiler will optimize away such copies when possible, so this doesn’t have performance impacts in practice.See also
jax.numpy.squeeze()
: inverse of this operation, i.e. remove length-1 dimensions.jax.lax.expand_dims()
: XLA version of this functionality.
Examples
>>> x = jnp.array([1, 2, 3]) >>> x.shape (3,)
Expand the leading dimension:
>>> jnp.expand_dims(x, 0) Array([[1, 2, 3]], dtype=int32) >>> _.shape (1, 3)
Expand the trailing dimension:
>>> jnp.expand_dims(x, 1) Array([[1], [2], [3]], dtype=int32) >>> _.shape (3, 1)
Expand multiple dimensions:
>>> jnp.expand_dims(x, (0, 1, 3)) Array([[[[1], [2], [3]]]], dtype=int32) >>> _.shape (1, 1, 3, 1)
Dimensions can also be expanded more succinctly by indexing with
None
:>>> x[None] # equivalent to jnp.expand_dims(x, 0) Array([[1, 2, 3]], dtype=int32) >>> x[:, None] # equivalent to jnp.expand_dims(x, 1) Array([[1], [2], [3]], dtype=int32) >>> x[None, None, :, None] # equivalent to jnp.expand_dims(x, (0, 1, 3)) Array([[[[1], [2], [3]]]], dtype=int32)