jax.numpy.trunc

jax.numpy.trunc(x)[source]

Return the truncated value of the input, element-wise.

LAX-backend implementation of trunc(). Original docstring below.

trunc(x, /, out=None, *, where=True, casting=’same_kind’, order=’K’, dtype=None, subok=True[, signature, extobj])

The truncated value of the scalar x is the nearest integer i which is closer to zero than x is. In short, the fractional part of the signed number x is discarded.

Parameters

x (array_like) – Input data.

Returns

y – The truncated value of each element in x. This is a scalar if x is a scalar.

Return type

ndarray or scalar

See also

ceil(), floor(), rint()

Notes

New in version 1.3.0.

Examples

>>> a = np.array([-1.7, -1.5, -0.2, 0.2, 1.5, 1.7, 2.0])
>>> np.trunc(a)
array([-1., -1., -0.,  0.,  1.,  1.,  2.])