# jax.numpy.squeeze#

jax.numpy.squeeze(a, axis=None)[source]#

Remove one or more length-1 axes from array

JAX implementation of `numpy.sqeeze()`, implemented via `jax.lax.squeeze()`.

Parameters:
• a (jax.typing.ArrayLike) â€“ input array

• axis (int | Sequence[int] | None) â€“ integer or sequence of integers specifying axes to remove. If any specified axis does not have a length of 1, an error is raised. If not specified, squeeze all length-1 axes in `a`.

Returns:

copy of `a` with length-1 axes removed.

Return type:

Array

Notes

Unlike `numpy.squeeze()`, `jax.numpy.squeeze()` will return a copy rather than a view of the input array. However, under JIT, the compiler will optimize-away such copies when possible, so this doesnâ€™t have performance impacts in practice.

Examples

```>>> x = jnp.array([[[0]], [[1]], [[2]]])
>>> x.shape
(3, 1, 1)
```

Squeeze all length-1 dimensions:

```>>> jnp.squeeze(x)
Array([0, 1, 2], dtype=int32)
>>> _.shape
(3,)
```

Equivalent while specifying the axes explicitly:

```>>> jnp.squeeze(x, axis=(1, 2))
Array([0, 1, 2], dtype=int32)
```

Attempting to squeeze a non-unit axis results in an error:

```>>> jnp.squeeze(x, axis=0)
Traceback (most recent call last):
...
ValueError: cannot select an axis to squeeze out which has size not equal to one, got shape=(3, 1, 1) and dimensions=(0,)
```

For convenience, this functionality is also available via the `jax.Array.squeeze()` method:

```>>> x.squeeze()
Array([0, 1, 2], dtype=int32)
```