Source code for tensorflow3d_transforms.rotation_conversions

"""This module contains functions to convert between rotation representations.

The transformation matrices returned from the functions in this file
assume the points on which the transformation will be applied are column
vectors that is the R matrix is structured as

.. code-block:: python

    R = [
            [Rxx, Rxy, Rxz],
            [Ryx, Ryy, Ryz],
            [Rzx, Rzy, Rzz],
        ]  # (3, 3)

Furthermore, we will assume for any functions in this module that these are
quaternions with real part first that is a tensor of shape (..., 4).
"""

from typing import Optional

import tensorflow as tf


[docs]def quaternion_to_matrix(quaternions: tf.Tensor) -> tf.Tensor: """Convert rotations given as quaternions to rotation matrices. Example: .. code-block:: python quaternion = tf.constant([0.0, 0.0, 0.0, 4.0]) output = tensorflow3d_transforms.quaternion_to_matrix(quaternions=quaternion) # <tf.Tensor: shape=(3, 3), dtype=float32, numpy= # array([[-1., 0., 0.], # [ 0., -1., 0.], # [ 0., 0., 1.]], dtype=float32)> Args: quaternions (tf.Tensor): A tensor of shape (..., 4) representing quaternions with real part first. Returns: tf.Tensor: A tensor of shape (..., 3, 3) representing rotation matrices. """ r, i, j, k = tf.unstack(quaternions, axis=-1) two_s = 2.0 / tf.reduce_sum(quaternions * quaternions, axis=-1) o = tf.stack( [ 1 - two_s * (j * j + k * k), two_s * (i * j - k * r), two_s * (i * k + j * r), two_s * (i * j + k * r), 1 - two_s * (i * i + k * k), two_s * (j * k - i * r), two_s * (i * k - j * r), two_s * (j * k + i * r), 1 - two_s * (i * i + j * j), ], -1, ) return tf.reshape(o, quaternions.shape[:-1] + (3, 3))
[docs]def matrix_to_quaternion(matrix: tf.Tensor) -> tf.Tensor: """Convert rotations given as rotation matrices to quaternions. Example: .. code-block:: python matrix = tf.constant( [ [ [0.15885946, -0.56794965, -0.48926896], [-1.0064808, -0.39120296, 1.6047943], [0.05503756, 0.817741, 0.4543775], ] ] ) matrix_to_quaternion(matrix) # <tf.Tensor: shape=(1, 4), dtype=float32, numpy= # array([[-0.1688297 , -0.16717434, 0.9326495 , 0.6493691 ]], # dtype=float32)> Args: matrix (tf.Tensor): A tensor of shape (..., 3, 3) representing rotation matrices. Returns: tf.Tensor: A tensor of shape (..., 4) representing quaternions with real part first. Raises: ValueError: If the shape of the input matrix is invalid that is does not have the shape (..., 3, 3). """ if matrix.shape[-1] != 3 or matrix.shape[-2] != 3: raise ValueError(f"Invalid rotation matrix shape {matrix.shape}.") batch_dim = matrix.shape[:-2] m00, m01, m02, m10, m11, m12, m20, m21, m22 = tf.unstack( tf.reshape(matrix, batch_dim + (9,)), axis=-1 ) q_abs = _sqrt_positive_part( tf.stack( [ 1.0 + m00 + m11 + m22, 1.0 + m00 - m11 - m22, 1.0 - m00 + m11 - m22, 1.0 - m00 - m11 + m22, ], axis=-1, ) ) quat_by_rijk = tf.stack( [ tf.stack([q_abs[..., 0] ** 2, m21 - m12, m02 - m20, m10 - m01], axis=-1), tf.stack([m21 - m12, q_abs[..., 1] ** 2, m10 + m01, m02 + m20], axis=-1), tf.stack([m02 - m20, m10 + m01, q_abs[..., 2] ** 2, m12 + m21], axis=-1), tf.stack([m10 - m01, m20 + m02, m21 + m12, q_abs[..., 3] ** 2], axis=-1), ], axis=-2, ) flr = tf.convert_to_tensor(0, dtype=tf.int32) quat_candidates = quat_by_rijk / ( 2.0 * tf.reduce_max(q_abs[..., None], flr, keepdims=True) ) max_indices = tf.argmax(q_abs, axis=-1) one_hot = tf.one_hot(max_indices, depth=4) selected = tf.boolean_mask(quat_candidates, one_hot > 0.5) return tf.reshape(selected, batch_dim + [4])
def _sqrt_positive_part(x: tf.Tensor) -> tf.Tensor: """Returns the square root of all positive elements of x and 0 for others. Args: x (tf.Tensor): A tensor Returns: tf.Tensor: A tensor with the same shape as x """ ret = tf.zeros_like(x) positive_mask = x > 0 ret = tf.where(positive_mask, tf.math.sqrt(x), ret) return ret def _axis_angle_rotation(axis: str, angle: tf.Tensor) -> tf.Tensor: """Return the rotation matrices for one of the rotations about an axis of which Euler angles describe, for each value of the angle given. Args: axis (str): The axis about which the rotation is performed. Must be one of 'X', 'Y', 'Z'. angle (tf.Tensor): Any shape tensor of Euler angles in radians Returns: tf.Tensor: A tensor of shape (..., 3, 3) representing rotation matrices. Raises: ValueError: If the axis is not one of 'X', 'Y', 'Z'. """ if axis not in ("X", "Y", "Z"): raise ValueError("letter must be either X, Y or Z.") cos = tf.math.cos(angle) sin = tf.math.sin(angle) one = tf.ones_like(angle) zero = tf.zeros_like(angle) if axis == "X": R_flat = (one, zero, zero, zero, cos, -sin, zero, sin, cos) elif axis == "Y": R_flat = (cos, zero, sin, zero, one, zero, -sin, zero, cos) elif axis == "Z": R_flat = (cos, -sin, zero, sin, cos, zero, zero, zero, one) return tf.reshape(tf.stack(R_flat, axis=-1), angle.shape + (3, 3))
[docs]def euler_angles_to_matrix(euler_angles: tf.Tensor, convention: str) -> tf.Tensor: """Convert rotations given as euler angles to rotation matrices. Example: .. code-block:: python euler_angles = tf.constant( [ [ [0.0, 0.0, 0.0], [0.0, 0.0, 0.0], [0.0, 0.0, 0.0], ] ] ) euler_angles_to_matrix(euler_angles=euler_angles, convention="XYZ") # <tf.Tensor: shape=(1, 3, 3, 3), dtype=float32, numpy= # array([[[[1., 0., 0.], # [0., 1., 0.], # [0., 0., 1.]], # # [[1., 0., 0.], # [0., 1., 0.], # [0., 0., 1.]], # # [[1., 0., 0.], # [0., 1., 0.], # [0., 0., 1.]]]], dtype=float32)> Args: euler_angles (tf.Tensor): A tensor of shape (..., 3) representing euler angles. convention (str): The euler angle convention. A string containing a combination of three uppercase letters from {"X", "Y", and "Z"}. Returns: tf.Tensor: A tensor of shape (..., 3, 3) representing rotation matrices. Raises: ValueError: If the shape of the input euler angles is invalid that is does not have the shape (..., 3). ValueError: If the convention string is invalid that is does not have the length 3. ValueError: If the second character of the convention string is the same as the first or third. ValueError: If the convention string contains characters other than {"X", "Y", and "Z"}. """ if euler_angles.shape[-1] != 3: raise ValueError( f"Invalid euler angle shape {euler_angles.shape}, last dimension should" " be 3." ) if len(convention) != 3: raise ValueError( f"Invalid euler angle convention {convention}, should be a string of" " length 3." ) if convention[1] in (convention[0], convention[2]): raise ValueError( f"Invalid euler angle convention {convention}, second character should be" " different from first and third." ) for letter in convention: if letter not in ("X", "Y", "Z"): raise ValueError(f"Invalid letter {letter} in convention string.") matrices = [ _axis_angle_rotation(c, e) for c, e in zip(convention, tf.unstack(euler_angles, axis=-1)) ] return tf.linalg.matmul(tf.linalg.matmul(matrices[0], matrices[1]), matrices[2])
def _index_from_letter(letter: str) -> int: """Return the index of the axis corresponding to the letter. Args: letter (str): The letter corresponding to the axis. Must be one of 'X', 'Y', 'Z'. Returns: int: The index of the axis. Raises: ValueError: If the letter is not one of 'X', 'Y', 'Z'. """ if letter == "X": return 0 if letter == "Y": return 1 if letter == "Z": return 2 raise ValueError("letter must be either X, Y or Z.") def _angle_from_tan( axis: str, other_axis: str, data: tf.Tensor, horizontal: bool, tait_bryan: bool ) -> tf.Tensor: """Extract the first or third Euler angle from the two members of the matrix which are positive constant times its sine and cosine. Args: axis (str): Axis label "X" or "Y or "Z" for the angle we are finding. other_axis (str): Axis label "X" or "Y or "Z" for the middle axis in the convention. data (tf.Tensor): Rotation matrices as tensor of shape (..., 3, 3). horizontal (bool): Whether we are looking for the angle for the third axis, which means the relevant entries are in the same row of the rotation matrix. If not, they are in the same column. tait_bryan (bool): Whether the first and third axes in the convention differ. Returns: tf.Tensor: Euler Angles in radians for each matrix in data as a tensor of shape (...). """ i1, i2 = {"X": (2, 1), "Y": (0, 2), "Z": (1, 0)}[axis] if horizontal: i2, i1 = i1, i2 even = (axis + other_axis) in ["XY", "YZ", "ZX"] if horizontal == even: return tf.math.atan2(data[..., i1], data[..., i2]) if tait_bryan: return tf.math.atan2(-data[..., i2], data[..., i1]) return tf.math.atan2(data[..., i2], -data[..., i1])
[docs]def matrix_to_euler_angles(matrix: tf.Tensor, convention: str) -> tf.Tensor: """Convert rotation matrices to euler angles in radians. Example: .. code-block:: python matrix = tf.constant( [ [ [[1.0, 0.0, 0.0], [0.0, 1.0, 0.0], [0.0, 0.0, 1.0]], [[1.0, 0.0, 0.0], [0.0, 1.0, 0.0], [0.0, 0.0, 1.0]], [[1.0, 0.0, 0.0], [0.0, 1.0, 0.0], [0.0, 0.0, 1.0]], ] ] ) matrix_to_euler_angles(matrix=matrix, convention="XYZ") # <tf.Tensor: shape=(1, 3, 3), dtype=float32, numpy= # array([[[-0., 0., -0.], # [-0., 0., -0.], # [-0., 0., -0.]]], dtype=float32)> Args: matrix (tf.Tensor): A tensor of shape (..., 3, 3) representing rotation matrices. convention (str): The euler angle convention. A string containing a combination of three uppercase letters from {"X", "Y", and "Z"}. Returns: tf.Tensor: A tensor of shape (..., 3) representing euler angles. Raises: ValueError: If the shape of the input matrix is invalid that is does not have the shape (..., 3, 3). ValueError: If the convention string is invalid that is does not have the length 3. ValueError: If the second character of the convention string is the same as the first or third. ValueError: If the convention string contains characters other than {"X", "Y", and "Z"}. """ if len(convention) != 3: raise ValueError( f"Invalid euler angle convention {convention}, should be a string of" " length 3." ) if convention[1] in (convention[0], convention[2]): raise ValueError( f"Invalid euler angle convention {convention}, second character should be" " different from first and third." ) if matrix.shape[-2:] != (3, 3): raise ValueError( f"Invalid matrix shape {matrix.shape}, last two dimensions should be 3, 3." ) for letter in convention: if letter not in ("X", "Y", "Z"): raise ValueError(f"Invalid letter {letter} in convention string.") i0 = _index_from_letter(convention[0]) i2 = _index_from_letter(convention[2]) tait_bryan = i0 != i2 if tait_bryan: if i0 - i2 in [-1, 2]: central_angle = tf.math.asin(-1 * matrix[..., i0, i2]) else: central_angle = tf.math.asin(matrix[..., i0, i2]) else: central_angle = tf.math.acos(matrix[..., i0, i0]) o = ( _angle_from_tan( convention[0], convention[1], matrix[..., i2], False, tait_bryan ), central_angle, _angle_from_tan( convention[2], convention[1], matrix[..., i0, :], True, tait_bryan ), ) return tf.stack(o, axis=-1)
def _copysign(a: tf.Tensor, b: tf.Tensor) -> tf.Tensor: """Return a tensor where each element has the absolute value taken from the, corresponding element of a, with sign taken from the corresponding element of b. This is like the standard copysign floating-point operation, but is not careful about negative 0 and NaN. Args: a (tf.Tensor): Source tensor. b (tf.Tensor): Tensor whose signs will be used, of the same shape as a. Returns: tf.Tensor: Tensor of the same shape as a with the signs of b. Raises: ValueError: If the shapes of a and b do not match. """ if a.shape != b.shape: raise ValueError(f"Shapes of a and b do not match: {a.shape} and {b.shape}.") signs_differ = (a < 0) != (b < 0) return tf.where(signs_differ, -a, a)
[docs]def random_quaternions( n: int, dtype: Optional[tf.dtypes.DType] = tf.float32, ) -> tf.Tensor: """Generate random quaternions representing rotations, i.e. versors with nonnegative real part. Example: .. code-block:: python random_quaternions(2) # <tf.Tensor: shape=(2, 4), dtype=float32, numpy=...> Args: n (int): Number of quaternions to generate. dtype (Optional[tf.dtype], optional): Data type of the returned tensor, defaults to tf.float32. Returns: tf.Tensor: Tensor of shape (n, 4) representing quaternions. """ o = tf.random.normal((n, 4), dtype=dtype) s = tf.reduce_sum(o * o, axis=1) o = o / _copysign(tf.math.sqrt(s), o[:, 0])[:, None] return o
[docs]def random_rotations( n: int, dtype: Optional[tf.dtypes.DType] = tf.float32, ) -> tf.Tensor: """Generate random rotations as 3x3 rotation matrices. Example: .. code-block:: python random_rotations(2) # <tf.Tensor: shape=(2, 3, 3), dtype=float32, numpy=...> Args: n (int): Number of rotation matrices to generate. dtype (Optional[tf.dtype], optional): Data type of the returned tensor, defaults to tf.float32. Returns: tf.Tensor: Tensor of shape (n, 3, 3) representing rotation matrices. """ quaternions = random_quaternions(n, dtype=dtype) return quaternion_to_matrix(quaternions)
[docs]def standardize_quaternion(quaternions: tf.Tensor) -> tf.Tensor: """Convert a unit quaternion to a standard form: one in which the real part is non negative. Example: .. code-block:: python quaternions = tf.constant((-1.,-2.,-1.,-1.)) standardize_quaternion(quaternions=quaternions) # <tf.Tensor: shape=(4,), dtype=float32, numpy=array([ 1., 2., 1., 1.], dtype=float32)> Args: quaternions (tf.Tensor): Quaternions with real part first, as tensor of shape (..., 4). Returns: tf.Tensor: Standardized quaternions as tensor of shape (..., 4). """ return tf.where(quaternions[..., 0:1] < 0, -quaternions, quaternions)
[docs]def quaternion_raw_multiply(a: tf.Tensor, b: tf.Tensor) -> tf.Tensor: """Multiply two quaternions. Example: .. code-block:: python a = tf.constant((1.,2.,3.,4.)) b = tf.constant((5.,6.,7.,8.)) quaternion_raw_multiply(a=a, b=b) # <tf.Tensor: shape=(4,), dtype=float32, numpy=array([-60., 12., 30., 24.], dtype=float32)> Args: a (tf.Tensor): First quaternion with real part first, as tensor of shape (..., 4). b (tf.Tensor): Second quaternion with real part first, as tensor of shape (..., 4). Returns: tf.Tensor: Product of a and b as tensor of shape (..., 4). """ aw, ax, ay, az = tf.unstack(a, axis=-1) bw, bx, by, bz = tf.unstack(b, axis=-1) ow = aw * bw - ax * bx - ay * by - az * bz ox = aw * bx + ax * bw + ay * bz - az * by oy = aw * by - ax * bz + ay * bw + az * bx oz = aw * bz + ax * by - ay * bx + az * bw return tf.stack([ow, ox, oy, oz], axis=-1)
[docs]def quaternion_multiply(a: tf.Tensor, b: tf.Tensor) -> tf.Tensor: """Multiply two quaternions representing rotations, returning the quaternion representing their composition, i.e. the versor with nonnegative real part. Example: .. code-block:: python a = tf.constant((1.,2.,3.,4.)) b = tf.constant((5.,6.,7.,8.)) quaternion_multiply(a=a, b=b) # <tf.Tensor: shape=(4,), dtype=float32, numpy=array([ 60., -12., -30., -24.], dtype=float32)> Args: a (tf.Tensor): First quaternion with real part first, as tensor of shape (..., 4). b (tf.Tensor): Second quaternion with real part first, as tensor of shape (..., 4). Returns: tf.Tensor: Product of a and b as tensor of shape (..., 4). """ return standardize_quaternion(quaternion_raw_multiply(a, b))
[docs]def quaternion_invert(quaternion: tf.Tensor) -> tf.Tensor: """Given a quaternion representing rotation, get the quaternion representing its inverse. Example: .. code-block:: python quaternion = tf.constant((1.,2.,3.,4.)) quaternion_invert(quaternion=quaternion) # <tf.Tensor: shape=(4,), dtype=float32, numpy=array([ 1., -2., -3., -4.], dtype=float32)> Args: quaternion (tf.Tensor): Quaternions as tensor of shape (..., 4), with real part first, which must be versors (unit quaternions). Returns: tf.Tensor: The inverse, a tensor of quaternions of shape (..., 4). """ scaling = tf.cast(tf.constant([1, -1, -1, -1]), dtype=quaternion.dtype) return quaternion * scaling
[docs]def quaternion_apply(quaternion: tf.Tensor, point: tf.Tensor) -> tf.Tensor: """Apply the rotation given by a quaternion to a 3D point. Example: .. code-block:: python quaternion = tf.constant((1.,1.,1.,4.)) point = tf.constant((1.,1.,1.)) quaternion_apply(quaternion=quaternion, point=point) # <tf.Tensor: shape=(3,), dtype=float32, numpy=array([-11., 1., 31.], dtype=float32)> Args: quaternion (tf.Tensor): Quaternions as tensor of shape (..., 4), with real part first point (tf.Tensor): Points as tensor of shape (..., 3) Returns: tf.Tensor: Tensor of rotated points of shape (..., 3). Raises: ValueError: If the last dimension of point is not 3. """ if point.shape[-1] != 3: raise ValueError("Points must be 3D") real_parts = tf.zeros(tf.shape(point)[:-1] + (1,)) point_as_quaternion = tf.concat([real_parts, point], axis=-1) out = quaternion_raw_multiply( quaternion_raw_multiply(quaternion, point_as_quaternion), quaternion_invert(quaternion), ) return out[..., 1:]
[docs]def axis_angle_to_quaternion(axis_angle: tf.Tensor) -> tf.Tensor: """Convert rotations given as axis/angle to quaternions. Example: .. code-block:: python axis_angle = tf.constant((1.,1.,1.)) axis_angle_to_quaternion(axis_angle=axis_angle) # <tf.Tensor: shape=(4,), dtype=float32, numpy=array([0.64785933, 0.43980235, 0.43980235, 0.43980235], dtype=float32)> Args: axis_angle (tf.Tensor): Rotations given as a vector in axis angle form, as a tensor of shape (..., 3), where the magnitude is the angle turned anticlockwise in radians around the vector's direction. Returns: tf.Tensor: Quaternions as tensor of shape (..., 4), with real part first. """ angles = tf.norm(axis_angle, ord=2, axis=-1, keepdims=True) half_angles = angles * 0.5 sin_half_angles_over_angles = tf.cast(tf.zeros_like(angles), dtype=tf.float32) sin_half_angles_over_angles = tf.math.sin(half_angles) / angles quaternions = tf.concat( [tf.math.cos(half_angles), axis_angle * sin_half_angles_over_angles], axis=-1 ) return quaternions
[docs]def axis_angle_to_matrix(axis_angle: tf.Tensor) -> tf.Tensor: """Convert rotations given as axis/angle to rotation matrices. Example: .. code-block:: python axis_angle = tf.constant((1.,1.,1.)) axis_angle_to_matrix(axis_angle=axis_angle) # <tf.Tensor: shape=(3, 3), dtype=float32, numpy= # array([[ 0.22629571, -0.18300788, 0.9567122 ], # [ 0.9567122 , 0.22629571, -0.18300788], # [-0.18300788, 0.9567122 , 0.22629571]], dtype=float32)> Args: axis_angle (tf.Tensor): Rotations given as a vector in axis angle form, as a tensor of shape (..., 3), where the magnitude is the angle turned anticlockwise in radians around the vector's direction. Returns: tf.Tensor: Rotation matrices as tensor of shape (..., 3, 3). """ return quaternion_to_matrix(axis_angle_to_quaternion(axis_angle))
[docs]def quaternion_to_axis_angle(quaternions: tf.Tensor) -> tf.Tensor: """Convert rotations given as quaternions to axis/angle. Example: .. code-block:: python quaternions = tf.constant((1.,1.,1.,4.)) quaternion_to_axis_angle(quaternions=quaternions) # <tf.Tensor: shape=(3,), dtype=float32, numpy=array([ 2.752039, 2.752039, 11.008156], dtype=float32)> Args: quaternions: Quaternions as tensor of shape (..., 4), with real part first. quaternion (tf.Tensor) Returns: tf.Tensor: Rotations given as a vector in axis angle form, as a tensor of shape (..., 3), where the magnitude is the angle turned anticlockwise in radians around the vector's direction. """ norms = tf.norm(quaternions[..., 1:], ord=2, axis=-1, keepdims=True) half_angles = tf.math.atan2(norms, quaternions[..., :1]) angles = 2 * half_angles sin_half_angles_over_angles = tf.cast(tf.zeros_like(angles), dtype=tf.float32) sin_half_angles_over_angles = tf.math.sin(half_angles) / angles return quaternions[..., 1:] / sin_half_angles_over_angles
[docs]def matrix_to_axis_angle(matrix: tf.Tensor) -> tf.Tensor: """Convert rotations given as rotation matrices to axis/angle. Example: .. code-block:: python matrix = tf.constant( [ [ [0.15885946, -0.56794965, -0.48926896], [-1.0064808, -0.39120296, 1.6047943], [0.05503756, 0.817741, 0.4543775], ] ] ) matrix_to_axis_angle(matrix) # <tf.Tensor: shape=(1, 3), dtype=float32, numpy=array([[-0.5801526, 3.2366152, 2.2535346]], dtype=float32)> :param: matrix: Rotation matrices as tensor of shape (..., 3, 3). Args: matrix (tf.Tensor) Returns: tf.Tensor: Rotations given as a vector in axis angle form, as a tensor of shape (..., 3), where the magnitude is the angle turned anticlockwise in radians around the vector's direction. """ return quaternion_to_axis_angle(matrix_to_quaternion(matrix))
[docs]def rotation_6d_to_matrix(d6: tf.Tensor) -> tf.Tensor: """Converts 6D rotation representation by Zhou et al. [1] to rotation matrix using Gram--Schmidt orthogonalization per Section B of [1]. Example: .. code-block:: python d6 = tf.constant((1.,1.,1.,1.,1.,1.)) rotation_6d_to_matrix(d6) # <tf.Tensor: shape=(3, 3), dtype=float32, numpy= # array([[0.57735026, 0.57735026, 0.57735026], # [0.57735026, 0.57735026, 0.57735026], # [0. , 0. , 0. ]], dtype=float32)> [1] Zhou, Y., Barnes, C., Lu, J., Yang, J., & Li, H. On the Continuity of Rotation Representations in Neural Networks. IEEE Conference on Computer Vision and Pattern Recognition, 2019. Retrieved from http://arxiv.org/abs/1812.07035 Args: d6 (tf.Tensor): 6D rotation representation as tensor of shape (..., 6). Returns: tf.Tensor: Rotation matrices as tensor of shape (..., 3, 3). """ a1, a2 = d6[..., :3], d6[..., 3:] b1 = tf.nn.l2_normalize(a1, axis=-1) b2 = a2 - tf.reduce_sum(b1 * a2, keepdims=True) * b1 b2 = tf.linalg.normalize(b2)[0] b3 = tf.linalg.cross(b1, b2) return tf.stack((b1, b2, b3), axis=-2)
[docs]def matrix_to_rotation_6d(matrix: tf.Tensor) -> tf.Tensor: """Converts rotation matrices to 6D rotation representation by Zhou et al. [1] by dropping the last row. Example: .. code-block:: python matrix = tf.constant([[2.0, 1.0, 1.0], [1.0, 2.0, 1.0], [1.0, 1.0, 2.0]]) matrix_to_rotation_6d(matrix) # <tf.Tensor: shape=(6,), dtype=float32, numpy=array([2., 1., 1., 1., 2., 1.], dtype=float32)> [1] Zhou, Y., Barnes, C., Lu, J., Yang, J., & Li, H. On the Continuity of Rotation Representations in Neural Networks. IEEE Conference on Computer Vision and Pattern Recognition, 2019. Retrieved from http://arxiv.org/abs/1812.07035 Args: matrix (tf.Tensor): Rotation matrices as tensor of shape (..., 3, 3). Returns: tf.Tensor: 6D rotation representation as tensor of shape (..., 6). """ batch_dim = matrix.shape[:-2] return tf.reshape(tf.identity(matrix[..., :2, :]), batch_dim + (6,))