jax.Array.squeeze# abstract Array.squeeze(axis=None)[source]# Remove one or more length-1 axes from array. Refer to jax.numpy.squeeze() for full documentation. Parameters: self (Array) axis (reductions.Axis) Return type: Array