jax.numpy.squeezeΒΆ
-
jax.numpy.
squeeze
(a, axis=None)[source]ΒΆ Remove single-dimensional entries from the shape of an array.
LAX-backend implementation of
squeeze()
. Original docstring below.- Parameters
- Returns
squeezed β The input array, but with all or a subset of the dimensions of length 1 removed. This is always a itself or a view into a. Note that if all axes are squeezed, the result is a 0d array and not a scalar.
- Return type
- Raises
ValueError β If axis is not None, and an axis being squeezed is not of length 1
See also
expand_dims()
The inverse operation, adding singleton dimensions
reshape()
Insert, remove, and combine dimensions, and resize existing ones
Examples
>>> x = np.array([[[0], [1], [2]]]) >>> x.shape (1, 3, 1) >>> np.squeeze(x).shape (3,) >>> np.squeeze(x, axis=0).shape (3, 1) >>> np.squeeze(x, axis=1).shape Traceback (most recent call last): ... ValueError: cannot select an axis to squeeze out which has size not equal to one >>> np.squeeze(x, axis=2).shape (1, 3) >>> x = np.array([[1234]]) >>> x.shape (1, 1) >>> np.squeeze(x) array(1234) # 0d array >>> np.squeeze(x).shape () >>> np.squeeze(x)[()] 1234