jax.numpy.productΒΆ

jax.numpy.product(a, axis=None, dtype=None, out=None, keepdims=None, initial=None, where=None)ΒΆ

Return the product of array elements over a given axis.

LAX-backend implementation of prod().

Original docstring below.

Parameters
  • a (array_like) – Input data.

  • axis (None or int or tuple of ints, optional) – Axis or axes along which a product is performed. The default, axis=None, will calculate the product of all the elements in the input array. If axis is negative it counts from the last to the first axis.

  • dtype (dtype, optional) – The type of the returned array, as well as of the accumulator in which the elements are multiplied. The dtype of a is used by default unless a has an integer dtype of less precision than the default platform integer. In that case, if a is signed then the platform integer is used while if a is unsigned then an unsigned integer of the same precision as the platform integer is used.

  • keepdims (bool, optional) –

    If this is set to True, the axes which are reduced are left in the result as dimensions with size one. With this option, the result will broadcast correctly against the input array.

    If the default value is passed, then keepdims will not be passed through to the prod method of sub-classes of ndarray, however any non-default value will be. If the sub-class’ method does not implement keepdims any exceptions will be raised.

  • initial (scalar, optional) – The starting value for this product. See ~numpy.ufunc.reduce for details.

  • where (array_like of bool, optional) – Elements to include in the product. See ~numpy.ufunc.reduce for details.

Returns

product_along_axis – An array shaped as a but with the specified axis removed. Returns a reference to out if specified.

Return type

ndarray, see dtype parameter above.