jax.scipy.fft.dctn

Contents

jax.scipy.fft.dctn#

jax.scipy.fft.dctn(x, type=2, s=None, axes=None, norm=None)[source]#

Computes the multidimensional discrete cosine transform of the input

JAX implementation of scipy.fft.dctn().

Parameters:
  • x (Array) – array

  • type (int) – integer, default = 2. Currently only type 2 is supported.

  • s (Sequence[int] | None) – integer or sequence of integers. Specifies the shape of the result. If not specified, it will default to the shape of x along the specified axes.

  • axes (Sequence[int] | None) – integer or sequence of integers. Specifies the axes along which the transform will be computed.

  • norm (str | None) – string. The normalization mode. Currently only "ortho" is supported.

Returns:

array containing the discrete cosine transform of x

Return type:

Array

See also

Example

jax.scipy.fft.dctn computes the transform along both the axes by default when axes argument is None.

>>> x = jax.random.normal(jax.random.key(0), (3, 3))
>>> with jnp.printoptions(precision=2, suppress=True):
...   print(jax.scipy.fft.dctn(x))
[[-5.04 -7.54 -3.26]
 [ 0.83  3.64 -4.03]
 [ 0.12 -0.73  3.74]]

When s=[2], dimension of the transform along axis 0 will be 2 and dimension along axis 1 will be same as that of input.

>>> with jnp.printoptions(precision=2, suppress=True):
...   print(jax.scipy.fft.dctn(x, s=[2]))
[[-2.92 -2.68 -5.74]
 [ 0.42  0.97  1.  ]]

When s=[2] and axes=[1], dimension of the transform along axis 1 will be 2 and dimension along axis 0 will be same as that of input. Also when axes=[1], transform will be computed only along axis 1.

>>> with jnp.printoptions(precision=2, suppress=True):
...   print(jax.scipy.fft.dctn(x, s=[2], axes=[1]))
[[-0.22 -0.9 ]
 [-0.57 -1.68]
 [-2.52 -0.11]]

When s=[2, 4], shape of the transform will be (2, 4).

>>> with jnp.printoptions(precision=2, suppress=True):
...   print(jax.scipy.fft.dctn(x, s=[2, 4]))
[[-2.92 -2.49 -4.21 -5.57]
 [ 0.42  0.79  1.16  0.8 ]]