jax.numpy.ix_

Contents

jax.numpy.ix_#

jax.numpy.ix_(*args)[source]#

Construct an open mesh from multiple sequences.

LAX-backend implementation of numpy.ix_().

Original docstring below.

This function takes N 1-D sequences and returns N outputs with N dimensions each, such that the shape is 1 in all but one dimension and the dimension with the non-unit shape value cycles through all N dimensions.

Using ix_ one can quickly construct index arrays that will index the cross product. a[np.ix_([1,3],[2,5])] returns the array [[a[1,2] a[1,5]], [a[3,2] a[3,5]]].

Parameters:

args (1-D sequences) – Each sequence should be of integer or boolean type. Boolean sequences will be interpreted as boolean masks for the corresponding dimension (equivalent to passing in np.nonzero(boolean_sequence)).

Returns:

out – N arrays with N dimensions each, with N the number of input sequences. Together these arrays form an open mesh.

Return type:

tuple of ndarrays