jax.lax.approx_max_k#
- jax.lax.approx_max_k(operand, k, reduction_dimension=-1, recall_target=0.95, reduction_input_size_override=-1, aggregate_to_topk=True)[source]#
Returns max
k
values and their indices of theoperand
in an approximate manner.See https://arxiv.org/abs/2206.14286 for the algorithm details.
- Parameters:
operand (
Array
) – Array to search for max-k. Must be a floating number type.k (
int
) – Specifies the number of max-k.reduction_dimension (
int
) – Integer dimension along which to search. Default: -1.recall_target (
float
) – Recall target for the approximation.reduction_input_size_override (
int
) – When set to a positive value, it overrides the size determined byoperand[reduction_dim]
for evaluating the recall. This option is useful when the givenoperand
is only a subset of the overall computation in SPMD or distributed pipelines, where the true input size cannot be deferred by the operand shape.aggregate_to_topk (
bool
) – When true, aggregates approximate results to the top-k in sorted order. When false, returns the approximate results unsorted. In this case, the number of the approximate results is implementation defined and is greater or equal to the specifiedk
.
- Return type:
- Returns:
Tuple of two arrays. The arrays are the max
k
values and the corresponding indices along thereduction_dimension
of the inputoperand
. The arrays’ dimensions are the same as the inputoperand
except for thereduction_dimension
: whenaggregate_to_topk
is true, the reduction dimension isk
; otherwise, it is greater equals tok
where the size is implementation-defined.
We encourage users to wrap
approx_max_k
with jit. See the following example for maximal inner production search (MIPS):>>> import functools >>> import jax >>> import numpy as np >>> @functools.partial(jax.jit, static_argnames=["k", "recall_target"]) ... def mips(qy, db, k=10, recall_target=0.95): ... dists = jax.lax.dot(qy, db.transpose()) ... # returns (f32[qy_size, k], i32[qy_size, k]) ... return jax.lax.approx_max_k(dists, k=k, recall_target=recall_target) >>> >>> qy = jax.numpy.array(np.random.rand(50, 64)) >>> db = jax.numpy.array(np.random.rand(1024, 64)) >>> dot_products, neighbors = mips(qy, db, k=10)