jax.numpy.squeeze

Contents

jax.numpy.squeeze#

jax.numpy.squeeze(a, axis=None)[source]#

Remove one or more length-1 axes from array

JAX implementation of numpy.sqeeze(), implemented via jax.lax.squeeze().

Parameters:
  • a (jax.typing.ArrayLike) – input array

  • axis (int | Sequence[int] | None) – integer or sequence of integers specifying axes to remove. If any specified axis does not have a length of 1, an error is raised. If not specified, squeeze all length-1 axes in a.

Returns:

copy of a with length-1 axes removed.

Return type:

Array

Notes

Unlike numpy.squeeze(), jax.numpy.squeeze() will return a copy rather than a view of the input array. However, under JIT, the compiler will optimize-away such copies when possible, so this doesn’t have performance impacts in practice.

See also

Examples

>>> x = jnp.array([[[0]], [[1]], [[2]]])
>>> x.shape
(3, 1, 1)

Squeeze all length-1 dimensions:

>>> jnp.squeeze(x)
Array([0, 1, 2], dtype=int32)
>>> _.shape
(3,)

Eqivalent while specifying the axes explicitly:

>>> jnp.squeeze(x, axis=(1, 2))
Array([0, 1, 2], dtype=int32)

Attempting to squeeze a non-unit axis results in an error:

>>> jnp.squeeze(x, axis=0)  
Traceback (most recent call last):
  ...
ValueError: cannot select an axis to squeeze out which has size not equal to one, got shape=(3, 1, 1) and dimensions=(0,)

For convenience, this functionality is also available via the jax.Array.squeeze() method:

>>> x.squeeze()
Array([0, 1, 2], dtype=int32)