- jax.lax.approx_min_k(operand, k, reduction_dimension=- 1, recall_target=0.95, reduction_input_size_override=- 1, aggregate_to_topk=True)#
kvalues and their indices of the
operandin an approximate manner.
See https://arxiv.org/abs/2206.14286 for the algorithm details.
Any) – Array to search for min-k. Must be a floating number type.
int) – Specifies the number of min-k.
int) – Integer dimension along which to search. Default: -1.
float) – Recall target for the approximation.
int) – When set to a positive value, it overrides the size determined by
operand[reduction_dim]for evaluating the recall. This option is useful when the given operand is only a subset of the overall computation in SPMD or distributed pipelines, where the true input size cannot be deferred by the
bool) – When true, aggregates approximate results to the top-k in sorted order. When false, returns the approximate results unsorted. In this case, the number of the approximate results is implementation defined and is greater or equal to the specified
- Return type
Tuple of two arrays. The arrays are the least
kvalues and the corresponding indices along the
reduction_dimensionof the input
operand. The arrays’ dimensions are the same as the input
operandexcept for the
aggregate_to_topkis true, the reduction dimension is
k; otherwise, it is greater equals to
kwhere the size is implementation-defined.
We encourage users to wrap
approx_min_kwith jit. See the following example for nearest neighbor search over the squared l2 distance:
>>> import functools >>> import jax >>> import numpy as np >>> @functools.partial(jax.jit, static_argnames=["k", "recall_target"]) ... def l2_ann(qy, db, half_db_norms, k=10, recall_target=0.95): ... dists = half_db_norms - jax.lax.dot(qy, db.transpose()) ... return jax.lax.approx_min_k(dists, k=k, recall_target=recall_target) >>> >>> qy = jax.numpy.array(np.random.rand(50, 64)) >>> db = jax.numpy.array(np.random.rand(1024, 64)) >>> half_db_norm_sq = jax.numpy.linalg.norm(db, axis=1)**2 / 2 >>> dists, neighbors = l2_ann(qy, db, half_db_norm_sq, k=10)
In the example above, we compute
db^2/2 - dot(qy, db^T)instead of
qy^2 - 2 dot(qy, db^T) + db^2for performance reason. The former uses less arithmetics and produces the same set of neighbors.