jax.numpy.cumproductΒΆ

jax.numpy.cumproduct(a, axis=None, dtype=None)ΒΆ

Return the cumulative product of elements along a given axis.

LAX-backend implementation of cumprod(). Original docstring below.

Parameters
  • a (array_like) – Input array.

  • axis (int, optional) – Axis along which the cumulative product is computed. By default the input is flattened.

  • dtype (dtype, optional) – Type of the returned array, as well as of the accumulator in which the elements are multiplied. If dtype is not specified, it defaults to the dtype of a, unless a has an integer dtype with a precision less than that of the default platform integer. In that case, the default platform integer is used instead.

Returns

cumprod – A new array holding the result is returned unless out is specified, in which case a reference to out is returned.

Return type

ndarray

See also

ufuncs-output-type()

Notes

Arithmetic is modular when using integer types, and no error is raised on overflow.

Examples

>>> a = np.array([1,2,3])
>>> np.cumprod(a) # intermediate results 1, 1*2
...               # total product 1*2*3 = 6
array([1, 2, 6])
>>> a = np.array([[1, 2, 3], [4, 5, 6]])
>>> np.cumprod(a, dtype=float) # specify type of output
array([   1.,    2.,    6.,   24.,  120.,  720.])

The cumulative product for each column (i.e., over the rows) of a:

>>> np.cumprod(a, axis=0)
array([[ 1,  2,  3],
       [ 4, 10, 18]])

The cumulative product for each row (i.e. over the columns) of a:

>>> np.cumprod(a,axis=1)
array([[  1,   2,   6],
       [  4,  20, 120]])