jax.numpy.atleast_2d#

jax.numpy.atleast_2d(*arys)[source]#

Convert inputs to arrays with at least 2 dimensions.

JAX implementation of numpy.atleast_2d().

Parameters:
  • arguments. (zero or more arraylike)

  • arys (ArrayLike)

Returns:

an array or list of arrays corresponding to the input values. Arrays of shape () are converted to shape (1, 1), 1D arrays of shape (N,) are converted to shape (1, N), and arrays of all other shapes are returned unchanged.

Return type:

Array | list[Array]

Examples

Scalar arguments are converted to 2D, size-1 arrays:

>>> x = jnp.float32(1.0)
>>> jnp.atleast_2d(x)
Array([[1.]], dtype=float32)

One-dimensional arguments have a unit dimension prepended to the shape:

>>> y = jnp.arange(4)
>>> jnp.atleast_2d(y)
Array([[0, 1, 2, 3]], dtype=int32)

Higher dimensional inputs are returned unchanged:

>>> z = jnp.ones((2, 3))
>>> jnp.atleast_2d(z)
Array([[1., 1., 1.],
       [1., 1., 1.]], dtype=float32)

Multiple arguments can be passed to the function at once, in which case a list of results is returned:

>>> jnp.atleast_2d(x, y)
[Array([[1.]], dtype=float32), Array([[0, 1, 2, 3]], dtype=int32)]