jax.Array.prod#
- abstract Array.prod(axis=None, dtype=None, out=None, keepdims=False, initial=None, where=None, promote_integers=True)[source]#
Return product of the array elements over a given axis.
Refer to
jax.numpy.prod()
for the full documentation.