jax.numpy.multiply#

jax.numpy.multiply = <jnp.ufunc 'multiply'>#

Multiply two arrays element-wise.

JAX implementation of numpy.multiply. This is a universal function, and supports the additional APIs described at jax.numpy.ufunc. This function provides the implementation of the * operator for JAX arrays.

Parameters:
  • x – arrays to multiply. Must be broadcastable to a common shape.

  • y – arrays to multiply. Must be broadcastable to a common shape.

  • args (ArrayLike)

  • out (None)

  • where (None)

Returns:

Array containing the result of the element-wise multiplication.

Return type:

Any

Examples

Calling multiply explicitly:

>>> x = jnp.arange(4)
>>> jnp.multiply(x, 10)
Array([ 0, 10, 20, 30], dtype=int32)

Calling multiply via the * operator:

>>> x * 10
Array([ 0, 10, 20, 30], dtype=int32)