# jax.scipy.fft.idctn#

jax.scipy.fft.idctn(x, type=2, s=None, axes=None, norm=None)[source]#

Computes the multidimensional inverse discrete cosine transform of the input

JAX implementation of `scipy.fft.idctn()`.

Parameters:
• x (Array) â€“ array

• type (int) â€“ integer, default = 2. Currently only type 2 is supported.

• s (Sequence[int] | None) â€“ integer or sequence of integers. Specifies the shape of the result. If not specified, it will default to the shape of `x` along the specified `axes`.

• axes (Sequence[int] | None) â€“ integer or sequence of integers. Specifies the axes along which the transform will be computed.

• norm (str | None) â€“ string. The normalization mode. Currently only `"ortho"` is supported.

Returns:

array containing the inverse discrete cosine transform of x

Return type:

Array

Example

`jax.scipy.fft.idctn` computes the transform along both the axes by default when `axes` argument is `None`.

```>>> x = jax.random.normal(jax.random.key(0), (3, 3))
>>> with jnp.printoptions(precision=2, suppress=True):
...    print(jax.scipy.fft.idctn(x))
[[-0.03 -0.08 -0.08]
[ 0.05  0.12 -0.09]
[-0.02 -0.04  0.08]]
```

When `s=[2]`, dimension of the transform along `axis 0` will be `2` and dimension along `axis 1` will be the same as that of input.

```>>> with jnp.printoptions(precision=2, suppress=True):
...  print(jax.scipy.fft.idctn(x, s=[2]))
[[-0.01 -0.03 -0.14]
[ 0.    0.03  0.06]]
```

When `s=[2]` and `axes=[1]`, dimension of the transform along `axis 1` will be `2` and dimension along `axis 0` will be same as that of input. Also when `axes=[1]`, transform will be computed only along `axis 1`.

```>>> with jnp.printoptions(precision=2, suppress=True):
...  print(jax.scipy.fft.idctn(x, s=[2], axes=[1]))
[[ 0.   -0.19]
[-0.03 -0.34]
[-0.38  0.04]]
```

When `s=[2, 4]`, shape of the transform will be `(2, 4)`

```>>> with jnp.printoptions(precision=2, suppress=True):
...  print(jax.scipy.fft.idctn(x, s=[2, 4]))
[[-0.01 -0.01 -0.05 -0.11]
[ 0.    0.01  0.03  0.04]]
```

`jax.scipy.fft.idctn` can be used to reconstruct `x` from the result of `jax.scipy.fft.dctn`

```>>> x_dctn = jax.scipy.fft.dctn(x)
>>> jnp.allclose(x, jax.scipy.fft.idctn(x_dctn))
Array(True, dtype=bool)
```