Source code for jax._src.scipy.spatial.transform

# Copyright 2023 The JAX Authors.
#
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at
#
#     https://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.

from __future__ import annotations

import functools
import re
import typing

import scipy.spatial.transform

import jax
import jax.numpy as jnp
from jax._src.numpy.util import implements


[docs] @implements(scipy.spatial.transform.Rotation) class Rotation(typing.NamedTuple): """Rotation in 3 dimensions.""" quat: jax.Array @classmethod def concatenate(cls, rotations: typing.Sequence): """Concatenate a sequence of `Rotation` objects.""" return cls(jnp.concatenate([rotation.quat for rotation in rotations])) @classmethod def from_euler(cls, seq: str, angles: jax.Array, degrees: bool = False): """Initialize from Euler angles.""" num_axes = len(seq) if num_axes < 1 or num_axes > 3: raise ValueError("Expected axis specification to be a non-empty " "string of upto 3 characters, got {}".format(seq)) intrinsic = (re.match(r'^[XYZ]{1,3}$', seq) is not None) extrinsic = (re.match(r'^[xyz]{1,3}$', seq) is not None) if not (intrinsic or extrinsic): raise ValueError("Expected axes from `seq` to be from ['x', 'y', " "'z'] or ['X', 'Y', 'Z'], got {}".format(seq)) if any(seq[i] == seq[i+1] for i in range(num_axes - 1)): raise ValueError("Expected consecutive axes to be different, " "got {}".format(seq)) angles = jnp.atleast_1d(angles) axes = jnp.array([_elementary_basis_index(x) for x in seq.lower()]) return cls(_elementary_quat_compose(angles, axes, intrinsic, degrees)) @classmethod def from_matrix(cls, matrix: jax.Array): """Initialize from rotation matrix.""" return cls(_from_matrix(matrix)) @classmethod def from_mrp(cls, mrp: jax.Array): """Initialize from Modified Rodrigues Parameters (MRPs).""" return cls(_from_mrp(mrp)) @classmethod def from_quat(cls, quat: jax.Array): """Initialize from quaternions.""" return cls(_normalize_quaternion(quat)) @classmethod def from_rotvec(cls, rotvec: jax.Array, degrees: bool = False): """Initialize from rotation vectors.""" return cls(_from_rotvec(rotvec, degrees)) @classmethod def identity(cls, num: int | None = None, dtype=float): """Get identity rotation(s).""" assert num is None quat = jnp.array([0., 0., 0., 1.], dtype=dtype) return cls(quat) @classmethod def random(cls, random_key: jax.Array, num: int | None = None): """Generate uniformly distributed rotations.""" # Need to implement scipy.stats.special_ortho_group for this to work... raise NotImplementedError def __getitem__(self, indexer): """Extract rotation(s) at given index(es) from object.""" if self.single: raise TypeError("Single rotation is not subscriptable.") return Rotation(self.quat[indexer]) def __len__(self): """Number of rotations contained in this object.""" if self.single: raise TypeError('Single rotation has no len().') else: return self.quat.shape[0] def __mul__(self, other): """Compose this rotation with the other.""" return Rotation.from_quat(_compose_quat(self.quat, other.quat)) def apply(self, vectors: jax.Array, inverse: bool = False) -> jax.Array: """Apply this rotation to one or more vectors.""" return _apply(self.as_matrix(), vectors, inverse) def as_euler(self, seq: str, degrees: bool = False): """Represent as Euler angles.""" if len(seq) != 3: raise ValueError(f"Expected 3 axes, got {seq}.") intrinsic = (re.match(r'^[XYZ]{1,3}$', seq) is not None) extrinsic = (re.match(r'^[xyz]{1,3}$', seq) is not None) if not (intrinsic or extrinsic): raise ValueError("Expected axes from `seq` to be from " "['x', 'y', 'z'] or ['X', 'Y', 'Z'], " "got {}".format(seq)) if any(seq[i] == seq[i+1] for i in range(2)): raise ValueError("Expected consecutive axes to be different, " "got {}".format(seq)) axes = jnp.array([_elementary_basis_index(x) for x in seq.lower()]) return _compute_euler_from_quat(self.quat, axes, extrinsic, degrees) def as_matrix(self) -> jax.Array: """Represent as rotation matrix.""" return _as_matrix(self.quat) def as_mrp(self) -> jax.Array: """Represent as Modified Rodrigues Parameters (MRPs).""" return _as_mrp(self.quat) def as_rotvec(self, degrees: bool = False) -> jax.Array: """Represent as rotation vectors.""" return _as_rotvec(self.quat, degrees) def as_quat(self) -> jax.Array: """Represent as quaternions.""" return self.quat def inv(self): """Invert this rotation.""" return Rotation(_inv(self.quat)) def magnitude(self) -> jax.Array: """Get the magnitude(s) of the rotation(s).""" return _magnitude(self.quat) def mean(self, weights: jax.Array | None = None): """Get the mean of the rotations.""" w = jnp.ones(self.quat.shape[0], dtype=self.quat.dtype) if weights is None else jnp.asarray(weights, dtype=self.quat.dtype) if w.ndim != 1: raise ValueError("Expected `weights` to be 1 dimensional, got " "shape {}.".format(w.shape)) if w.shape[0] != len(self): raise ValueError("Expected `weights` to have number of values " "equal to number of rotations, got " "{} values and {} rotations.".format(w.shape[0], len(self))) K = jnp.dot(w[jnp.newaxis, :] * self.quat.T, self.quat) _, v = jnp.linalg.eigh(K) return Rotation(v[:, -1]) @property def single(self) -> bool: """Whether this instance represents a single rotation.""" return self.quat.ndim == 1
[docs] @implements(scipy.spatial.transform.Slerp) class Slerp(typing.NamedTuple): """Spherical Linear Interpolation of Rotations.""" times: jnp.ndarray timedelta: jnp.ndarray rotations: Rotation rotvecs: jnp.ndarray @classmethod def init(cls, times: jax.Array, rotations: Rotation): if not isinstance(rotations, Rotation): raise TypeError("`rotations` must be a `Rotation` instance.") if rotations.single or len(rotations) == 1: raise ValueError("`rotations` must be a sequence of at least 2 rotations.") times = jnp.asarray(times, dtype=rotations.quat.dtype) if times.ndim != 1: raise ValueError("Expected times to be specified in a 1 " "dimensional array, got {} " "dimensions.".format(times.ndim)) if times.shape[0] != len(rotations): raise ValueError("Expected number of rotations to be equal to " "number of timestamps given, got {} rotations " "and {} timestamps.".format(len(rotations), times.shape[0])) timedelta = jnp.diff(times) # if jnp.any(timedelta <= 0): # this causes a concretization error... # raise ValueError("Times must be in strictly increasing order.") new_rotations = Rotation(rotations.as_quat()[:-1]) return cls( times=times, timedelta=timedelta, rotations=new_rotations, rotvecs=(new_rotations.inv() * Rotation(rotations.as_quat()[1:])).as_rotvec()) def __call__(self, times: jax.Array): """Interpolate rotations.""" compute_times = jnp.asarray(times, dtype=self.times.dtype) if compute_times.ndim > 1: raise ValueError("`times` must be at most 1-dimensional.") single_time = compute_times.ndim == 0 compute_times = jnp.atleast_1d(compute_times) ind = jnp.maximum(jnp.searchsorted(self.times, compute_times) - 1, 0) alpha = (compute_times - self.times[ind]) / self.timedelta[ind] result = (self.rotations[ind] * Rotation.from_rotvec(self.rotvecs[ind] * alpha[:, None])) if single_time: return result[0] return result
@functools.partial(jnp.vectorize, signature='(m,m),(m),()->(m)') def _apply(matrix: jax.Array, vector: jax.Array, inverse: bool) -> jax.Array: return jnp.where(inverse, matrix.T, matrix) @ vector @functools.partial(jnp.vectorize, signature='(m)->(n,n)') def _as_matrix(quat: jax.Array) -> jax.Array: x = quat[0] y = quat[1] z = quat[2] w = quat[3] x2 = x * x y2 = y * y z2 = z * z w2 = w * w xy = x * y zw = z * w xz = x * z yw = y * w yz = y * z xw = x * w return jnp.array([[+ x2 - y2 - z2 + w2, 2 * (xy - zw), 2 * (xz + yw)], [2 * (xy + zw), - x2 + y2 - z2 + w2, 2 * (yz - xw)], [2 * (xz - yw), 2 * (yz + xw), - x2 - y2 + z2 + w2]]) @functools.partial(jnp.vectorize, signature='(m)->(n)') def _as_mrp(quat: jax.Array) -> jax.Array: sign = jnp.where(quat[3] < 0, -1., 1.) denominator = 1. + sign * quat[3] return sign * quat[:3] / denominator @functools.partial(jnp.vectorize, signature='(m),()->(n)') def _as_rotvec(quat: jax.Array, degrees: bool) -> jax.Array: quat = jnp.where(quat[3] < 0, -quat, quat) # w > 0 to ensure 0 <= angle <= pi angle = 2. * jnp.arctan2(_vector_norm(quat[:3]), quat[3]) angle2 = angle * angle small_scale = 2 + angle2 / 12 + 7 * angle2 * angle2 / 2880 large_scale = angle / jnp.sin(angle / 2) scale = jnp.where(angle <= 1e-3, small_scale, large_scale) scale = jnp.where(degrees, jnp.rad2deg(scale), scale) return scale * jnp.array(quat[:3]) @functools.partial(jnp.vectorize, signature='(n),(n)->(n)') def _compose_quat(p: jax.Array, q: jax.Array) -> jax.Array: cross = jnp.cross(p[:3], q[:3]) return jnp.array([p[3]*q[0] + q[3]*p[0] + cross[0], p[3]*q[1] + q[3]*p[1] + cross[1], p[3]*q[2] + q[3]*p[2] + cross[2], p[3]*q[3] - p[0]*q[0] - p[1]*q[1] - p[2]*q[2]]) @functools.partial(jnp.vectorize, signature='(m),(l),(),()->(n)') def _compute_euler_from_quat(quat: jax.Array, axes: jax.Array, extrinsic: bool, degrees: bool) -> jax.Array: angle_first = jnp.where(extrinsic, 0, 2) angle_third = jnp.where(extrinsic, 2, 0) axes = jnp.where(extrinsic, axes, axes[::-1]) i = axes[0] j = axes[1] k = axes[2] symmetric = i == k k = jnp.where(symmetric, 3 - i - j, k) sign = jnp.array((i - j) * (j - k) * (k - i) // 2, dtype=quat.dtype) eps = 1e-7 a = jnp.where(symmetric, quat[3], quat[3] - quat[j]) b = jnp.where(symmetric, quat[i], quat[i] + quat[k] * sign) c = jnp.where(symmetric, quat[j], quat[j] + quat[3]) d = jnp.where(symmetric, quat[k] * sign, quat[k] * sign - quat[i]) angles = jnp.empty(3, dtype=quat.dtype) angles = angles.at[1].set(2 * jnp.arctan2(jnp.hypot(c, d), jnp.hypot(a, b))) case = jnp.where(jnp.abs(angles[1] - jnp.pi) <= eps, 2, 0) case = jnp.where(jnp.abs(angles[1]) <= eps, 1, case) half_sum = jnp.arctan2(b, a) half_diff = jnp.arctan2(d, c) angles = angles.at[0].set(jnp.where(case == 1, 2 * half_sum, 2 * half_diff * jnp.where(extrinsic, -1, 1))) # any degenerate case angles = angles.at[angle_first].set(jnp.where(case == 0, half_sum - half_diff, angles[angle_first])) angles = angles.at[angle_third].set(jnp.where(case == 0, half_sum + half_diff, angles[angle_third])) angles = angles.at[angle_third].set(jnp.where(symmetric, angles[angle_third], angles[angle_third] * sign)) angles = angles.at[1].set(jnp.where(symmetric, angles[1], angles[1] - jnp.pi / 2)) angles = (angles + jnp.pi) % (2 * jnp.pi) - jnp.pi return jnp.where(degrees, jnp.rad2deg(angles), angles) def _elementary_basis_index(axis: str) -> int: if axis == 'x': return 0 elif axis == 'y': return 1 elif axis == 'z': return 2 raise ValueError(f"Expected axis to be from ['x', 'y', 'z'], got {axis}") @functools.partial(jnp.vectorize, signature=('(m),(m),(),()->(n)')) def _elementary_quat_compose(angles: jax.Array, axes: jax.Array, intrinsic: bool, degrees: bool) -> jax.Array: angles = jnp.where(degrees, jnp.deg2rad(angles), angles) result = _make_elementary_quat(axes[0], angles[0]) for idx in range(1, len(axes)): quat = _make_elementary_quat(axes[idx], angles[idx]) result = jnp.where(intrinsic, _compose_quat(result, quat), _compose_quat(quat, result)) return result @functools.partial(jnp.vectorize, signature=('(m),()->(n)')) def _from_rotvec(rotvec: jax.Array, degrees: bool) -> jax.Array: rotvec = jnp.where(degrees, jnp.deg2rad(rotvec), rotvec) angle = _vector_norm(rotvec) angle2 = angle * angle small_scale = scale = 0.5 - angle2 / 48 + angle2 * angle2 / 3840 large_scale = jnp.sin(angle / 2) / angle scale = jnp.where(angle <= 1e-3, small_scale, large_scale) return jnp.hstack([scale * rotvec, jnp.cos(angle / 2)]) @functools.partial(jnp.vectorize, signature=('(m,m)->(n)')) def _from_matrix(matrix: jax.Array) -> jax.Array: matrix_trace = matrix[0, 0] + matrix[1, 1] + matrix[2, 2] decision = jnp.array([matrix[0, 0], matrix[1, 1], matrix[2, 2], matrix_trace], dtype=matrix.dtype) choice = jnp.argmax(decision) i = choice j = (i + 1) % 3 k = (j + 1) % 3 quat_012 = jnp.empty(4, dtype=matrix.dtype) quat_012 = quat_012.at[i].set(1 - decision[3] + 2 * matrix[i, i]) quat_012 = quat_012.at[j].set(matrix[j, i] + matrix[i, j]) quat_012 = quat_012.at[k].set(matrix[k, i] + matrix[i, k]) quat_012 = quat_012.at[3].set(matrix[k, j] - matrix[j, k]) quat_3 = jnp.empty(4, dtype=matrix.dtype) quat_3 = quat_3.at[0].set(matrix[2, 1] - matrix[1, 2]) quat_3 = quat_3.at[1].set(matrix[0, 2] - matrix[2, 0]) quat_3 = quat_3.at[2].set(matrix[1, 0] - matrix[0, 1]) quat_3 = quat_3.at[3].set(1 + decision[3]) quat = jnp.where(choice != 3, quat_012, quat_3) return _normalize_quaternion(quat) @functools.partial(jnp.vectorize, signature='(m)->(n)') def _from_mrp(mrp: jax.Array) -> jax.Array: mrp_squared_plus_1 = jnp.dot(mrp, mrp) + 1 return jnp.hstack([2 * mrp[:3], (2 - mrp_squared_plus_1)]) / mrp_squared_plus_1 @functools.partial(jnp.vectorize, signature='(n)->(n)') def _inv(quat: jax.Array) -> jax.Array: return quat.at[3].set(-quat[3]) @functools.partial(jnp.vectorize, signature='(n)->()') def _magnitude(quat: jax.Array) -> jax.Array: return 2. * jnp.arctan2(_vector_norm(quat[:3]), jnp.abs(quat[3])) @functools.partial(jnp.vectorize, signature='(),()->(n)') def _make_elementary_quat(axis: int, angle: jax.Array) -> jax.Array: quat = jnp.zeros(4, dtype=angle.dtype) quat = quat.at[3].set(jnp.cos(angle / 2.)) quat = quat.at[axis].set(jnp.sin(angle / 2.)) return quat @functools.partial(jnp.vectorize, signature='(n)->(n)') def _normalize_quaternion(quat: jax.Array) -> jax.Array: return quat / _vector_norm(quat) @functools.partial(jnp.vectorize, signature='(n)->()') def _vector_norm(vector: jax.Array) -> jax.Array: return jnp.sqrt(jnp.dot(vector, vector))