jax.numpy.atleast_1d

jax.numpy.atleast_1d(*arys)[source]

Convert inputs to arrays with at least one dimension.

LAX-backend implementation of atleast_1d(). Original docstring below.

Scalar inputs are converted to 1-dimensional arrays, whilst higher-dimensional inputs are preserved.

arys1, arys2, …array_like

One or more input arrays.

retndarray

An array, or list of arrays, each with a.ndim >= 1. Copies are made only if necessary.

atleast_2d, atleast_3d

>>> np.atleast_1d(1.0)
array([1.])
>>> x = np.arange(9.0).reshape(3,3)
>>> np.atleast_1d(x)
array([[0., 1., 2.],
       [3., 4., 5.],
       [6., 7., 8.]])
>>> np.atleast_1d(x) is x
True
>>> np.atleast_1d(1, [3, 4])
[array([1]), array([3, 4])]