jax.numpy.fft.fftn#
- jax.numpy.fft.fftn(a, s=None, axes=None, norm=None)[source]#
Compute the N-dimensional discrete Fourier Transform.
LAX-backend implementation of
numpy.fft.fftn()
.Original docstring below.
This function computes the N-dimensional discrete Fourier Transform over any number of axes in an M-dimensional array by means of the Fast Fourier Transform (FFT).
- Parameters:
a (array_like) – Input array, can be complex.
s (sequence of ints, optional) – Shape (length of each transformed axis) of the output (
s[0]
refers to axis 0,s[1]
to axis 1, etc.). This corresponds ton
forfft(x, n)
. Along any axis, if the given shape is smaller than that of the input, the input is cropped. If it is larger, the input is padded with zeros. if s is not given, the shape of the input along the axes specified by axes is used.axes (sequence of ints, optional) – Axes over which to compute the FFT. If not given, the last
len(s)
axes are used, or all axes if s is also not specified. Repeated indices in axes means that the transform over that axis is performed multiple times.norm ({"backward", "ortho", "forward"}, optional) –
- Returns:
out – The truncated or zero-padded input, transformed along the axes indicated by axes, or by a combination of s and a, as explained in the parameters section above.
- Return type:
complex ndarray