- jax.lax.approx_max_k(operand, k, reduction_dimension=- 1, recall_target=0.95, reduction_input_size_override=- 1, aggregate_to_topk=True)#
kvalues and their indices of the
operandin an approximate manner.
See https://arxiv.org/abs/2206.14286 for the algorithm details.
Any) – Array to search for max-k. Must be a floating number type.
int) – Specifies the number of max-k.
int) – Integer dimension along which to search. Default: -1.
float) – Recall target for the approximation.
int) – When set to a positive value, it overrides the size determined by
operand[reduction_dim]for evaluating the recall. This option is useful when the given
operandis only a subset of the overall computation in SPMD or distributed pipelines, where the true input size cannot be deferred by the operand shape.
bool) – When true, aggregates approximate results to the top-k in sorted order. When false, returns the approximate results unsorted. In this case, the number of the approximate results is implementation defined and is greater or equal to the specified
- Return type
Tuple of two arrays. The arrays are the max
kvalues and the corresponding indices along the
reduction_dimensionof the input
operand. The arrays’ dimensions are the same as the input
operandexcept for the
aggregate_to_topkis true, the reduction dimension is
k; otherwise, it is greater equals to
kwhere the size is implementation-defined.
We encourage users to wrap
approx_max_kwith jit. See the following example for maximal inner production search (MIPS):
>>> import functools >>> import jax >>> import numpy as np >>> @functools.partial(jax.jit, static_argnames=["k", "recall_target"]) ... def mips(qy, db, k=10, recall_target=0.95): ... dists = jax.lax.dot(qy, db.transpose()) ... # returns (f32[qy_size, k], i32[qy_size, k]) ... return jax.lax.approx_max_k(dists, k=k, recall_target=recall_target) >>> >>> qy = jax.numpy.array(np.random.rand(50, 64)) >>> db = jax.numpy.array(np.random.rand(1024, 64)) >>> dot_products, neighbors = mips(qy, db, k=10)