jax.numpy.squeezeΒΆ

jax.numpy.squeeze(a, axis=None)[source]ΒΆ

Remove single-dimensional entries from the shape of an array.

LAX-backend implementation of squeeze().

The JAX version of this function may in some cases return a copy rather than a view of the input.

Original docstring below.

Parameters
  • a (array_like) – Input data.

  • axis (None or int or tuple of ints, optional) –

Returns

squeezed – The input array, but with all or a subset of the dimensions of length 1 removed. This is always a itself or a view into a. Note that if all axes are squeezed, the result is a 0d array and not a scalar.

Return type

ndarray