jax.numpy.square

jax.numpy.square(x)[source]

Return the element-wise square of the input.

LAX-backend implementation of square(). Original docstring below.

square(x, /, out=None, *, where=True, casting=’same_kind’, order=’K’, dtype=None, subok=True[, signature, extobj])

Parameters

x (array_like) – Input data.

Returns

out – Element-wise x*x, of the same shape and dtype as x. This is a scalar if x is a scalar.

Return type

ndarray or scalar

Examples

>>> np.square([-1j, 1])
array([-1.-0.j,  1.+0.j])