jax.numpy.atleast_3d#

jax.numpy.atleast_3d(*arys)[source]#

Convert inputs to arrays with at least 3 dimensions.

JAX implementation of numpy.atleast_3d().

Parameters:
  • arguments. (zero or more arraylike)

  • arys (ArrayLike)

Returns:

an array or list of arrays corresponding to the input values. Arrays of shape () are converted to shape (1, 1, 1), 1D arrays of shape (N,) are converted to shape (1, N, 1), 2D arrays of shape (M, N) are converted to shape (M, N, 1), and arrays of all other shapes are returned unchanged.

Return type:

Array | list[Array]

Examples

Scalar arguments are converted to 3D, size-1 arrays:

>>> x = jnp.float32(1.0)
>>> jnp.atleast_3d(x)
Array([[[1.]]], dtype=float32)

1D arrays have a unit dimension prepended and appended:

>>> y = jnp.arange(4)
>>> jnp.atleast_3d(y).shape
(1, 4, 1)

2D arrays have a unit dimension appended:

>>> z = jnp.ones((2, 3))
>>> jnp.atleast_3d(z).shape
(2, 3, 1)

Multiple arguments can be passed to the function at once, in which case a list of results is returned:

>>> x3, y3 = jnp.atleast_3d(x, y)
>>> print(x3)
[[[1.]]]
>>> print(y3)
[[[0]
  [1]
  [2]
  [3]]]