jax.numpy.cbrt#

jax.numpy.cbrt(x, /)[source]#

Calculates element-wise cube root of the input array.

JAX implementation of numpy.cbrt.

Parameters:

x (ArrayLike) – input array or scalar. complex dtypes are not supported.

Returns:

An array containing the cube root of the elements of x.

Return type:

Array

See also

Examples

>>> x = jnp.array([[216, 125, 64],
...                [-27, -8, -1]])
>>> with jnp.printoptions(precision=3, suppress=True):
...   jnp.cbrt(x)
Array([[ 6.,  5.,  4.],
       [-3., -2., -1.]], dtype=float32)