jax.numpy.atleast_2d#
- jax.numpy.atleast_2d(*arys)[source]#
Convert inputs to arrays with at least 2 dimensions.
JAX implementation of
numpy.atleast_2d()
.- Parameters:
arguments. (zero or more arraylike)
arys (ArrayLike)
- Returns:
an array or list of arrays corresponding to the input values. Arrays of shape
()
are converted to shape(1, 1)
, 1D arrays of shape(N,)
are converted to shape(1, N)
, and arrays of all other shapes are returned unchanged.- Return type:
Examples
Scalar arguments are converted to 2D, size-1 arrays:
>>> x = jnp.float32(1.0) >>> jnp.atleast_2d(x) Array([[1.]], dtype=float32)
One-dimensional arguments have a unit dimension prepended to the shape:
>>> y = jnp.arange(4) >>> jnp.atleast_2d(y) Array([[0, 1, 2, 3]], dtype=int32)
Higher dimensional inputs are returned unchanged:
>>> z = jnp.ones((2, 3)) >>> jnp.atleast_2d(z) Array([[1., 1., 1.], [1., 1., 1.]], dtype=float32)
Multiple arguments can be passed to the function at once, in which case a list of results is returned:
>>> jnp.atleast_2d(x, y) [Array([[1.]], dtype=float32), Array([[0, 1, 2, 3]], dtype=int32)]