jax.numpy.squeeze#
- jax.numpy.squeeze(a, axis=None)[source]#
Remove one or more length-1 axes from array
JAX implementation of
numpy.sqeeze()
, implemented viajax.lax.squeeze()
.- Parameters:
- Returns:
copy of
a
with length-1 axes removed.- Return type:
Notes
Unlike
numpy.squeeze()
,jax.numpy.squeeze()
will return a copy rather than a view of the input array. However, under JIT, the compiler will optimize-away such copies when possible, so this doesn’t have performance impacts in practice.See also
jax.numpy.expand_dims()
: the inverse ofsqueeze
: add dimensions of length 1.jax.Array.squeeze()
: equivalent functionality via an array method.jax.lax.squeeze()
: equivalent XLA API.jax.numpy.ravel()
: flatten an array into a 1D shape.jax.numpy.reshape()
: general array reshape.
Examples
>>> x = jnp.array([[[0]], [[1]], [[2]]]) >>> x.shape (3, 1, 1)
Squeeze all length-1 dimensions:
>>> jnp.squeeze(x) Array([0, 1, 2], dtype=int32) >>> _.shape (3,)
Equivalent while specifying the axes explicitly:
>>> jnp.squeeze(x, axis=(1, 2)) Array([0, 1, 2], dtype=int32)
Attempting to squeeze a non-unit axis results in an error:
>>> jnp.squeeze(x, axis=0) Traceback (most recent call last): ... ValueError: cannot select an axis to squeeze out which has size not equal to one, got shape=(3, 1, 1) and dimensions=(0,)
For convenience, this functionality is also available via the
jax.Array.squeeze()
method:>>> x.squeeze() Array([0, 1, 2], dtype=int32)