jax.numpy.cross#

jax.numpy.cross(a, b, axisa=-1, axisb=-1, axisc=-1, axis=None)[source]#

Compute the (batched) cross product of two arrays.

JAX implementation of numpy.cross().

This computes the 2-dimensional or 3-dimensional cross product,

\[c = a \times b\]

In 3 dimensions, c is a length-3 array. In 2 dimensions, c is a scalar.

Parameters:
  • a – N-dimensional array. a.shape[axisa] indicates the dimension of the cross product, and must be 2 or 3.

  • b – N-dimensional array. Must have b.shape[axisb] == a.shape[axisb], and other dimensions of a and b must be broadcast compatible.

  • axisa (int) – specicy the axis of a along which to compute the cross product.

  • axisb (int) – specicy the axis of b along which to compute the cross product.

  • axisc (int) – specicy the axis of c along which the cross product result will be stored.

  • axis (int | None) – if specified, this overrides axisa, axisb, and axisc with a single value.

Returns:

The array c containing the (batched) cross product of a and b along the specified axes.

See also

Examples

A 2-dimensional cross product returns a scalar:

>>> a = jnp.array([1, 2])
>>> b = jnp.array([3, 4])
>>> jnp.cross(a, b)
Array(-2, dtype=int32)

A 3-dimensional cross product returns a length-3 vector:

>>> a = jnp.array([1, 2, 3])
>>> b = jnp.array([4, 5, 6])
>>> jnp.cross(a, b)
Array([-3,  6, -3], dtype=int32)

With multi-dimensional inputs, the cross-product is computed along the last axis by default. Here’s a batched 3-dimensional cross product, operating on the rows of the inputs:

>>> a = jnp.array([[1, 2, 3],
...                [3, 4, 3]])
>>> b = jnp.array([[2, 3, 2],
...                [4, 5, 6]])
>>> jnp.cross(a, b)
Array([[-5,  4, -1],
       [ 9, -6, -1]], dtype=int32)

Specifying axis=0 makes this a batched 2-dimensional cross product, operating on the columns of the inputs:

>>> jnp.cross(a, b, axis=0)
Array([-2, -2, 12], dtype=int32)

Equivalently, we can independently specify the axis of the inputs a and b and the output c:

>>> jnp.cross(a, b, axisa=0, axisb=0, axisc=0)
Array([-2, -2, 12], dtype=int32)