jax.numpy.atleast_3d#
- jax.numpy.atleast_3d(*arys)[source]#
Convert inputs to arrays with at least 3 dimensions.
JAX implementation of
numpy.atleast_3d()
.- Parameters:
arguments. (zero or more arraylike)
arys (ArrayLike)
- Returns:
an array or list of arrays corresponding to the input values. Arrays of shape
()
are converted to shape(1, 1, 1)
, 1D arrays of shape(N,)
are converted to shape(1, N, 1)
, 2D arrays of shape(M, N)
are converted to shape(M, N, 1)
, and arrays of all other shapes are returned unchanged.- Return type:
Examples
Scalar arguments are converted to 3D, size-1 arrays:
>>> x = jnp.float32(1.0) >>> jnp.atleast_3d(x) Array([[[1.]]], dtype=float32)
1D arrays have a unit dimension prepended and appended:
>>> y = jnp.arange(4) >>> jnp.atleast_3d(y).shape (1, 4, 1)
2D arrays have a unit dimension appended:
>>> z = jnp.ones((2, 3)) >>> jnp.atleast_3d(z).shape (2, 3, 1)
Multiple arguments can be passed to the function at once, in which case a list of results is returned:
>>> x3, y3 = jnp.atleast_3d(x, y) >>> print(x3) [[[1.]]] >>> print(y3) [[[0] [1] [2] [3]]]