jax.numpy.linalg.cross

Contents

jax.numpy.linalg.cross#

jax.numpy.linalg.cross(x1, x2, /, *, axis=-1)[source]#

Compute the corss-product of two 3D vectors

JAX implementation of numpy.linalg.cross()

Parameters:
  • x1 (jax.typing.ArrayLike) – N-dimesional array, with x1.shape[axis] == 3

  • x2 (jax.typing.ArrayLike) – N-dimensional array, with x2.shape[axis] == 3, and other axes broadcast-compatible with x1.

  • axis – axis along which to take the cross product (default: -1).

Returns:

array containing the result of the cross-product

See also

jax.numpy.cross(): more flexible cross-product API.

Example

Showing that \(\hat{x} \times \hat{y} = \hat{z}\):

>>> x = jnp.array([1., 0., 0.])
>>> y = jnp.array([0., 1., 0.])
>>> jnp.linalg.cross(x, y)
Array([0., 0., 1.], dtype=float32)

Cross product of \(\hat{x}\) with all three standard unit vectors, via broadcasting:

>>> xyz = jnp.eye(3)
>>> jnp.linalg.cross(x, xyz, axis=-1)
Array([[ 0.,  0.,  0.],
       [ 0.,  0.,  1.],
       [ 0., -1.,  0.]], dtype=float32)