jax.scipy.special.sph_harm#
- jax.scipy.special.sph_harm(m, n, theta, phi, n_max=None)[source]#
Computes the spherical harmonics.
The JAX version has one extra argument n_max, the maximum value in n.
The spherical harmonic of degree n and order m can be written as \(Y_n^m(\theta, \phi) = N_n^m * P_n^m(\cos \phi) * \exp(i m \theta)\), where \(N_n^m = \sqrt{\frac{\left(2n+1\right) \left(n-m\right)!} {4 \pi \left(n+m\right)!}}\) is the normalization factor and \(\phi\) and \(\theta\) are the colatitude and longitude, respectively. \(N_n^m\) is chosen in the way that the spherical harmonics form a set of orthonormal basis functions of \(L^2(S^2)\).
- Parameters:
m (
Array
) – The order of the harmonic; must have |m| <= n. Return values for |m| > n ara undefined.n (
Array
) – The degree of the harmonic; must have n >= 0. The standard notation for degree in descriptions of spherical harmonics is l (lower case L). We use n here to be consistent with scipy.special.sph_harm. Return values for n < 0 are undefined.theta (
Array
) – The azimuthal (longitudinal) coordinate; must be in [0, 2*pi].phi (
Array
) – The polar (colatitudinal) coordinate; must be in [0, pi].n_max (
Optional
[int
]) – The maximum degree max(n). If the supplied n_max is not the true maximum value of n, the results are clipped to n_max. For example, sph_harm(m=jnp.array([2]), n=jnp.array([10]), theta, phi, n_max=6) acutually returns sph_harm(m=jnp.array([2]), n=jnp.array([6]), theta, phi, n_max=6)
- Return type:
- Returns:
A 1D array containing the spherical harmonics at (m, n, theta, phi).