jax.numpy.nanvar#
- jax.numpy.nanvar(a, axis=None, dtype=None, out=None, ddof=0, keepdims=False, where=None)[source]#
Compute the variance of array elements along a given axis, ignoring NaNs.
JAX implementation of
numpy.nanvar()
.- Parameters:
a (ArrayLike) – input array.
axis (Axis) – optional, int or sequence of ints, default=None. Axis along which the variance is computed. If None, variance is computed along flattened array.
dtype (DTypeLike | None) – The type of the output array. Default=None.
ddof (int) – int, default=0. Degrees of freedom. The divisor in the variance computation is
N-ddof
,N
is number of elements along given axis.keepdims (bool) – bool, default=False. If true, reduced axes are left in the result with size 1.
where (ArrayLike | None) – optional, boolean array, default=None. The elements to be used in the variance. Array should be broadcast compatible to the input.
out (None) – Unused by JAX.
- Returns:
An array containing the variance of array elements along specified axis. If all elements along the given axis are NaNs, returns
nan
.- Return type:
See also
jax.numpy.nanmean()
: Compute the mean of array elements over a given axis, ignoring NaNs.jax.numpy.nanstd()
: Computed the standard deviation of a given axis, ignoring NaNs.jax.numpy.var()
: Compute the variance of array elements along a given axis.
Examples
By default,
jnp.nanvar
computes the variance along all axes.>>> nan = jnp.nan >>> x = jnp.array([[1, nan, 4, 3], ... [nan, 2, nan, 9], ... [4, 8, 6, nan]]) >>> jnp.nanvar(x) Array(6.984375, dtype=float32)
If
axis=1
, variance is computed along axis 1.>>> with jnp.printoptions(precision=2, suppress=True): ... print(jnp.nanvar(x, axis=1)) [ 1.56 12.25 2.67]
To preserve the dimensions of input, you can set
keepdims=True
.>>> with jnp.printoptions(precision=2, suppress=True): ... print(jnp.nanvar(x, axis=1, keepdims=True)) [[ 1.56] [12.25] [ 2.67]]
If
ddof=1
:>>> with jnp.printoptions(precision=2, suppress=True): ... print(jnp.nanvar(x, axis=1, keepdims=True, ddof=1)) [[ 2.33] [24.5 ] [ 4. ]]
To include specific elements of the array to compute variance, you can use
where
.>>> where = jnp.array([[1, 0, 1, 0], ... [0, 1, 1, 0], ... [1, 1, 0, 1]], dtype=bool) >>> jnp.nanvar(x, axis=1, keepdims=True, where=where) Array([[2.25], [0. ], [4. ]], dtype=float32)