jax.Array.max#
- abstract Array.max(axis=None, out=None, keepdims=False, initial=None, where=None)[source]#
Return the maximum of array elements along a given axis.
Refer to
jax.numpy.max()
for the full documentation.
Return the maximum of array elements along a given axis.
Refer to jax.numpy.max()
for the full documentation.