Source code for jax._src.lax.fft

# Copyright 2019 The JAX Authors.
#
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at
#
#     https://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.


from functools import partial
import math
from typing import Union, Sequence

import numpy as np

from jax import lax
from jax.interpreters import mlir
from jax.interpreters import xla

from jax._src.api import jit, linear_transpose, ShapeDtypeStruct
from jax._src.core import Primitive, is_constant_shape
from jax._src.interpreters import ad
from jax._src.interpreters import batching
from jax._src.lib.mlir.dialects import hlo
from jax._src.lib import xla_client
from jax._src.lib import ducc_fft
from jax._src.numpy.util import promote_dtypes_complex, promote_dtypes_inexact

__all__ = [
  "fft",
  "fft_p",
]

def _str_to_fft_type(s: str) -> xla_client.FftType:
  if s in ("fft", "FFT"):
    return xla_client.FftType.FFT
  elif s in ("ifft", "IFFT"):
    return xla_client.FftType.IFFT
  elif s in ("rfft", "RFFT"):
    return xla_client.FftType.RFFT
  elif s in ("irfft", "IRFFT"):
    return xla_client.FftType.IRFFT
  else:
    raise ValueError(f"Unknown FFT type '{s}'")

[docs]@partial(jit, static_argnums=(1, 2)) def fft(x, fft_type: Union[xla_client.FftType, str], fft_lengths: Sequence[int]): if isinstance(fft_type, str): typ = _str_to_fft_type(fft_type) elif isinstance(fft_type, xla_client.FftType): typ = fft_type else: raise TypeError(f"Unknown FFT type value '{fft_type}'") if typ == xla_client.FftType.RFFT: if np.iscomplexobj(x): raise ValueError("only real valued inputs supported for rfft") x, = promote_dtypes_inexact(x) else: x, = promote_dtypes_complex(x) if len(fft_lengths) == 0: # XLA FFT doesn't support 0-rank. return x fft_lengths = tuple(fft_lengths) return fft_p.bind(x, fft_type=typ, fft_lengths=fft_lengths)
def _fft_impl(x, fft_type, fft_lengths): return xla.apply_primitive(fft_p, x, fft_type=fft_type, fft_lengths=fft_lengths) _complex_dtype = lambda dtype: (np.zeros((), dtype) + np.zeros((), np.complex64)).dtype _real_dtype = lambda dtype: np.finfo(dtype).dtype _is_even = lambda x: x % 2 == 0 def fft_abstract_eval(x, fft_type, fft_lengths): if len(fft_lengths) > x.ndim: raise ValueError(f"FFT input shape {x.shape} must have at least as many " f"input dimensions as fft_lengths {fft_lengths}.") if fft_type == xla_client.FftType.RFFT: if x.shape[-len(fft_lengths):] != fft_lengths: raise ValueError(f"RFFT input shape {x.shape} minor dimensions must " f"be equal to fft_lengths {fft_lengths}") shape = (x.shape[:-len(fft_lengths)] + fft_lengths[:-1] + (fft_lengths[-1] // 2 + 1,)) dtype = _complex_dtype(x.dtype) elif fft_type == xla_client.FftType.IRFFT: if x.shape[-len(fft_lengths):-1] != fft_lengths[:-1]: raise ValueError(f"IRFFT input shape {x.shape} minor dimensions must " "be equal to all except the last fft_length, got " f"{fft_lengths=}") shape = x.shape[:-len(fft_lengths)] + fft_lengths dtype = _real_dtype(x.dtype) else: if x.shape[-len(fft_lengths):] != fft_lengths: raise ValueError(f"FFT input shape {x.shape} minor dimensions must " f"be equal to fft_lengths {fft_lengths}") shape = x.shape dtype = x.dtype return x.update(shape=shape, dtype=dtype) def _fft_lowering(ctx, x, *, fft_type, fft_lengths): return [ hlo.FftOp(x, hlo.FftTypeAttr.get(fft_type.name), mlir.dense_int_elements(fft_lengths)).result ] def _fft_lowering_cpu(ctx, x, *, fft_type, fft_lengths): if any(not is_constant_shape(a.shape) for a in (ctx.avals_in + ctx.avals_out)): raise NotImplementedError("Shape polymorphism for custom call is not implemented (fft); b/261671778") x_aval, = ctx.avals_in return [ducc_fft.ducc_fft_hlo(x, x_aval.dtype, fft_type=fft_type, fft_lengths=fft_lengths)] def _naive_rfft(x, fft_lengths): y = fft(x, xla_client.FftType.FFT, fft_lengths) n = fft_lengths[-1] return y[..., : n//2 + 1] @partial(jit, static_argnums=1) def _rfft_transpose(t, fft_lengths): # The transpose of RFFT can't be expressed only in terms of irfft. Instead of # manually building up larger twiddle matrices (which would increase the # asymptotic complexity and is also rather complicated), we rely JAX to # transpose a naive RFFT implementation. dummy_shape = t.shape[:-len(fft_lengths)] + fft_lengths dummy_primal = ShapeDtypeStruct(dummy_shape, _real_dtype(t.dtype)) transpose = linear_transpose( partial(_naive_rfft, fft_lengths=fft_lengths), dummy_primal) result, = transpose(t) assert result.dtype == _real_dtype(t.dtype), (result.dtype, t.dtype) return result def _irfft_transpose(t, fft_lengths): # The transpose of IRFFT is the RFFT of the cotangent times a scaling # factor and a mask. The mask scales the cotangent for the Hermitian # symmetric components of the RFFT by a factor of two, since these components # are de-duplicated in the RFFT. x = fft(t, xla_client.FftType.RFFT, fft_lengths) n = x.shape[-1] is_odd = fft_lengths[-1] % 2 full = partial(lax.full_like, t, dtype=x.dtype) mask = lax.concatenate( [full(1.0, shape=(1,)), full(2.0, shape=(n - 2 + is_odd,)), full(1.0, shape=(1 - is_odd,))], dimension=0) scale = 1 / math.prod(fft_lengths) out = scale * lax.expand_dims(mask, range(x.ndim - 1)) * x assert out.dtype == _complex_dtype(t.dtype), (out.dtype, t.dtype) # Use JAX's convention for complex gradients # https://github.com/google/jax/issues/6223#issuecomment-807740707 return lax.conj(out) def _fft_transpose_rule(t, operand, fft_type, fft_lengths): if fft_type == xla_client.FftType.RFFT: result = _rfft_transpose(t, fft_lengths) elif fft_type == xla_client.FftType.IRFFT: result = _irfft_transpose(t, fft_lengths) else: result = fft(t, fft_type, fft_lengths) return result, def _fft_batching_rule(batched_args, batch_dims, fft_type, fft_lengths): x, = batched_args bd, = batch_dims x = batching.moveaxis(x, bd, 0) return fft(x, fft_type, fft_lengths), 0 fft_p = Primitive('fft') fft_p.def_impl(_fft_impl) fft_p.def_abstract_eval(fft_abstract_eval) mlir.register_lowering(fft_p, _fft_lowering) ad.deflinear2(fft_p, _fft_transpose_rule) batching.primitive_batchers[fft_p] = _fft_batching_rule mlir.register_lowering(fft_p, _fft_lowering_cpu, platform='cpu')