jax.scipy.cluster.vq.vq

Contents

jax.scipy.cluster.vq.vq#

jax.scipy.cluster.vq.vq(obs, code_book, check_finite=True)[source]#

Assign codes from a code book to a set of observations.

JAX implementation of scipy.cluster.vq.vq().

Assigns each observation vector in obs to a code from code_book based on the nearest Euclidean distance.

Parameters:
  • obs (jax.typing.ArrayLike) – array of observation vectors of shape (M, N). Each row represents a single observation. If obs is one-dimensional, then each entry is treated as a length-1 observation.

  • code_book (jax.typing.ArrayLike) – array of codes with shape (K, N). Each row represents a single code vector. If code_book is one-dimensional, then each entry is treated as a length-1 code.

  • check_finite (bool) – unused in JAX

Returns:

A tuple of arrays (code, dist)

  • code is an integer array of shape (M,) containing indices 0 <= i < K of the closest entry in code_book for the given entry in obs.

  • dist is a float array of shape (M,) containing the euclidean distance between each observation and the nearest code.

Return type:

tuple[Array, Array]

Examples

>>> obs = jnp.array([[1.1, 2.1, 3.1],
...                  [5.9, 4.8, 6.2]])
>>> code_book = jnp.array([[1., 2., 3.],
...                        [2., 3., 4.],
...                        [3., 4., 5.],
...                        [4., 5., 6.]])
>>> codes, distances = jax.scipy.cluster.vq.vq(obs, code_book)
>>> print(codes)
[0 3]
>>> print(distances)
[0.17320499 1.9209373 ]