jax.scipy.signal.istft(Zxx, fs=1.0, window='hann', nperseg=None, noverlap=None, nfft=None, input_onesided=True, boundary=True, time_axis=-1, freq_axis=-2)[source]#

Perform the inverse short-time Fourier transform (ISTFT).

JAX implementation of scipy.signal.istft(); computes the inverse of jax.scipy.signal.stft().

  • Zxx (Array) – STFT of the signal to be reconstructed.

  • fs (jax.typing.ArrayLike) – Sampling frequency of the time series (default: 1.0)

  • window (str) – Data tapering window to apply to each segment. Can be a window function name, a tuple specifying a window length and function, or an array (default: 'hann').

  • nperseg (int | None) – Number of data points per segment in the STFT. If None (default), the value is determined from the size of Zxx.

  • noverlap (int | None) – Number of points to overlap between segments (default: nperseg // 2).

  • nfft (int | None) – Number of FFT points used in the STFT. If None (default), the value is determined from the size of Zxx.

  • input_onesided (bool) – If Tru` (default), interpret the input as a one-sided STFT (positive frequencies only). If False, interpret the input as a two-sided STFT.

  • boundary (bool) – If True (default), it is assumed that the input signal was extended at its boundaries by stft. If False, the input signal is assumed to have been truncated at the boundaries by stft.

  • time_axis (int) – Axis in Zxx corresponding to time segments (default: -1).

  • freq_axis (int) – Axis in Zxx corresponding to frequency bins (default: -2).


A length-2 tuple of arrays (t, x). t is the Array of signal times, and x is the reconstructed time series.

Return type:

tuple[Array, Array]

See also

jax.scipy.signal.stft(): short-time Fourier transform.


Demonstrate that this gives the inverse of stft():

>>> x = jnp.array([1., 2., 3., 2., 1., 0., 1., 2.])
>>> f, t, Zxx = jax.scipy.signal.stft(x, nperseg=4)
>>> print(Zxx)  
[[ 1. +0.j   2.5+0.j   1. +0.j   1. +0.j   0.5+0.j ]
 [-0.5+0.5j -1.5+0.j  -0.5-0.5j -0.5+0.5j  0. -0.5j]
 [ 0. +0.j   0.5+0.j   0. +0.j   0. +0.j  -0.5+0.j ]]
>>> t, x_reconstructed = jax.scipy.signal.istft(Zxx)
>>> print(x_reconstructed)
[1. 2. 3. 2. 1. 0. 1. 2.]