jax.numpy.cbrt#
- jax.numpy.cbrt(x, /)[source]#
Calculates element-wise cube root of the input array.
JAX implementation of
numpy.cbrt
.- Parameters:
x (ArrayLike) – input array or scalar.
complex
dtypes are not supported.- Returns:
An array containing the cube root of the elements of
x
.- Return type:
See also
jax.numpy.sqrt()
: Calculates the element-wise non-negative square root of the input.jax.numpy.square()
: Calculates the element-wise square of the input.
Examples
>>> x = jnp.array([[216, 125, 64], ... [-27, -8, -1]]) >>> with jnp.printoptions(precision=3, suppress=True): ... jnp.cbrt(x) Array([[ 6., 5., 4.], [-3., -2., -1.]], dtype=float32)