jax.scipy.stats.sem#

jax.scipy.stats.sem(a, axis=0, ddof=1, nan_policy='propagate', *, keepdims=False)[source]#

Compute the standard error of the mean.

JAX implementation of `scipy.stats.sem()`.

Parameters:
• a (jax.typing.ArrayLike) â€“ arraylike

• axis (int | None) â€“ optional integer. If not specified, the input array is flattened.

• ddof (int) â€“ integer, default=1. The degrees of freedom in the SEM computation.

• nan_policy (str) â€“ str, default=â€ťpropagateâ€ť. JAX supports only â€śpropagateâ€ť and â€śomitâ€ť.

• keepdims (bool) â€“ bool, default=False. If true, reduced axes are left in the result with size 1.

Returns:

array

Return type:

Array

Examples

```>>> x = jnp.array([2, 4, 1, 1, 3, 4, 4, 2, 3])
>>> with jnp.printoptions(precision=2, suppress=True):
...   jax.scipy.stats.sem(x)
Array(0.41, dtype=float32)
```

For multi dimensional arrays, `sem` computes standard error of mean along `axis=0`:

```>>> x1 = jnp.array([[1, 2, 1, 3, 2, 1],
...                 [3, 1, 3, 2, 1, 3],
...                 [1, 2, 2, 3, 1, 2]])
>>> with jnp.printoptions(precision=2, suppress=True):
...   jax.scipy.stats.sem(x1)
Array([0.67, 0.33, 0.58, 0.33, 0.33, 0.58], dtype=float32)
```

If `axis=1`, standard error of mean will be computed along `axis 1`.

```>>> with jnp.printoptions(precision=2, suppress=True):
...   jax.scipy.stats.sem(x1, axis=1)
Array([0.33, 0.4 , 0.31], dtype=float32)
```

If `axis=None`, standard error of mean will be computed along all the axes.

```>>> with jnp.printoptions(precision=2, suppress=True):
...   jax.scipy.stats.sem(x1, axis=None)
Array(0.2, dtype=float32)
```

By default, `sem` reduces the dimension of the result. To keep the dimensions same as that of the input array, the argument `keepdims` must be set to `True`.

```>>> with jnp.printoptions(precision=2, suppress=True):
...   jax.scipy.stats.sem(x1, axis=1, keepdims=True)
Array([[0.33],
[0.4 ],
[0.31]], dtype=float32)
```

Since, by default, `nan_policy='propagate'`, `sem` propagates the `nan` values in the result.

```>>> nan = jnp.nan
>>> x2 = jnp.array([[1, 2, 3, nan, 4, 2],
...                 [4, 5, 4, 3, nan, 1],
...                 [7, nan, 8, 7, 9, nan]])
>>> with jnp.printoptions(precision=2, suppress=True):
...   jax.scipy.stats.sem(x2)
Array([1.73,  nan, 1.53,  nan,  nan,  nan], dtype=float32)
```

If `nan_policy='omit``, `sem` omits the `nan` values and computes the error for the remainging values along the specified axis.

```>>> with jnp.printoptions(precision=2, suppress=True):
...   jax.scipy.stats.sem(x2, nan_policy='omit')
Array([1.73, 1.5 , 1.53, 2.  , 2.5 , 0.5 ], dtype=float32)
```