jax.numpy.squeeze#
- jax.numpy.squeeze(a, axis=None)[source]#
Remove axes of length one from a.
LAX-backend implementation of
numpy.squeeze()
.The JAX version of this function may in some cases return a copy rather than a view of the input.
Original docstring below.
- Parameters
- Returns
squeezed – The input array, but with all or a subset of the dimensions of length 1 removed. This is always a itself or a view into a. Note that if all axes are squeezed, the result is a 0d array and not a scalar.
- Return type
ndarray