jax.numpy.fft.fft2#
- jax.numpy.fft.fft2(a, s=None, axes=(-2, -1), norm=None)[source]#
Compute the 2-dimensional discrete Fourier Transform.
LAX-backend implementation of
numpy.fft.fft2()
.Original docstring below.
This function computes the n-dimensional discrete Fourier Transform over any axes in an M-dimensional array by means of the Fast Fourier Transform (FFT). By default, the transform is computed over the last two axes of the input array, i.e., a 2-dimensional FFT.
- Parameters:
a (array_like) – Input array, can be complex
s (sequence of ints, optional) – Shape (length of each transformed axis) of the output (
s[0]
refers to axis 0,s[1]
to axis 1, etc.). This corresponds ton
forfft(x, n)
. Along each axis, if the given shape is smaller than that of the input, the input is cropped. If it is larger, the input is padded with zeros. if s is not given, the shape of the input along the axes specified by axes is used.axes (sequence of ints, optional) – Axes over which to compute the FFT. If not given, the last two axes are used. A repeated index in axes means the transform over that axis is performed multiple times. A one-element sequence means that a one-dimensional FFT is performed.
norm ({"backward", "ortho", "forward"}, optional) –
- Returns:
out – The truncated or zero-padded input, transformed along the axes indicated by axes, or the last two axes if axes is not given.
- Return type:
complex ndarray