jax.random.multivariate_normal

jax.random.multivariate_normal(key, mean, cov, shape=None, dtype=<class 'numpy.float64'>, method='cholesky')[source]

Sample multivariate normal random values with given mean and covariance.

Parameters
  • key (Union[Any, PRNGKeyArray]) – a PRNG key used as the random key.

  • mean (Any) – a mean vector of shape (..., n).

  • cov (Any) – a positive definite covariance matrix of shape (..., n, n). The batch shape ... must be broadcast-compatible with that of mean.

  • shape (Optional[Sequence[int]]) – optional, a tuple of nonnegative integers specifying the result batch shape; that is, the prefix of the result shape excluding the last axis. Must be broadcast-compatible with mean.shape[:-1] and cov.shape[:-2]. The default (None) produces a result batch shape by broadcasting together the batch shapes of mean and cov.

  • dtype (Any) – optional, a float dtype for the returned values (default float64 if jax_enable_x64 is true, otherwise float32).

  • method (str) – optional, a method to compute the factor of cov. Must be one of ‘svd’, eigh, and ‘cholesky’. Default ‘cholesky’.

Return type

ndarray

Returns

A random array with the specified dtype and shape given by shape + mean.shape[-1:] if shape is not None, or else broadcast_shapes(mean.shape[:-1], cov.shape[:-2]) + mean.shape[-1:].