jax.numpy.squeeze

Contents

jax.numpy.squeeze#

jax.numpy.squeeze(a, axis=None)[source]#

Remove axes of length one from a.

LAX-backend implementation of numpy.squeeze().

The JAX version of this function may in some cases return a copy rather than a view of the input.

Original docstring below.

Parameters:
  • a (array_like) – Input data.

  • axis (None or int or tuple of ints, optional)

Returns:

squeezed – The input array, but with all or a subset of the dimensions of length 1 removed. This is always a itself or a view into a. Note that if all axes are squeezed, the result is a 0d array and not a scalar.

Return type:

ndarray