Skip to content

Commit 4a079ce

Browse files
nathan-skydioaaron-skydio
authored andcommitted
[Symforce] Added Group and Lie Group Ops for Dataclasses
Implements and tests group + lie group ops for dataclasses. Topic: dataclass_lie_group_ops GitOrigin-RevId: 4e875af9caec7292fed2572aa219ba6bec957632
1 parent c5ee7aa commit 4a079ce

9 files changed

+203
-18
lines changed

symforce/ops/__init__.py

+2-2
Original file line numberDiff line numberDiff line change
@@ -21,6 +21,7 @@
2121
from .impl.scalar_lie_group_ops import ScalarLieGroupOps
2222
from .impl.sequence_lie_group_ops import SequenceLieGroupOps
2323
from .impl.array_lie_group_ops import ArrayLieGroupOps
24+
from .impl.dataclass_lie_group_ops import DataclassLieGroupOps
2425

2526
LieGroupOps.register(float, ScalarLieGroupOps)
2627
LieGroupOps.register(np.float32, ScalarLieGroupOps)
@@ -38,9 +39,8 @@
3839
LieGroupOps.register(np.ndarray, ArrayLieGroupOps)
3940

4041
from symforce import typing as T
41-
from .impl.dataclass_storage_ops import DataclassStorageOps
4242

43-
StorageOps.register(T.Dataclass, DataclassStorageOps)
43+
LieGroupOps.register(T.Dataclass, DataclassLieGroupOps)
4444

4545
# TODO(hayk): Are these okay here or where can we put them? In theory we could just have this
4646
# be automatic that if the given type has the methods that it gets registered automatically.

symforce/ops/impl/array_lie_group_ops.py

+4-2
Original file line numberDiff line numberDiff line change
@@ -27,7 +27,9 @@ def tangent_dim(a: T.ArrayElement) -> int:
2727
return sum([LieGroupOps.tangent_dim(v) for v in a])
2828

2929
@staticmethod
30-
def from_tangent(a: T.ArrayElement, vec: T.List[T.Scalar], epsilon: T.Scalar) -> T.ArrayElement:
30+
def from_tangent(
31+
a: T.ArrayElement, vec: T.Sequence[T.Scalar], epsilon: T.Scalar
32+
) -> T.ArrayElement:
3133
assert len(vec) == ArrayLieGroupOps.tangent_dim(a)
3234
new_a = []
3335
inx = 0
@@ -54,7 +56,7 @@ def tangent_D_storage(a: T.ArrayElement) -> geo.Matrix:
5456
return geo.M.eye(StorageOps.storage_dim(a), LieGroupOps.tangent_dim(a))
5557

5658
@staticmethod
57-
def retract(a: T.ArrayElement, vec: T.List[T.Scalar], epsilon: T.Scalar) -> T.ArrayElement:
59+
def retract(a: T.ArrayElement, vec: T.Sequence[T.Scalar], epsilon: T.Scalar) -> T.ArrayElement:
5860
assert len(vec) == ArrayLieGroupOps.tangent_dim(a)
5961
new_a = []
6062
inx = 0
+44
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,44 @@
1+
# ----------------------------------------------------------------------------
2+
# SymForce - Copyright 2022, Skydio, Inc.
3+
# This source code is under the Apache 2.0 license found in the LICENSE file.
4+
# ----------------------------------------------------------------------------
5+
6+
import dataclasses
7+
8+
from symforce.ops import GroupOps
9+
from symforce.python_util import get_type
10+
from symforce import typing as T
11+
12+
from .dataclass_storage_ops import DataclassStorageOps
13+
14+
15+
class DataclassGroupOps(DataclassStorageOps):
16+
@staticmethod
17+
def identity(a: T.DataclassOrType) -> T.Dataclass:
18+
constructed_fields = {}
19+
if isinstance(a, type):
20+
type_hints_map = T.get_type_hints(a)
21+
for field in dataclasses.fields(a):
22+
constructed_fields[field.name] = GroupOps.identity(type_hints_map[field.name])
23+
return a(**constructed_fields)
24+
else:
25+
for field in dataclasses.fields(a):
26+
constructed_fields[field.name] = GroupOps.identity(getattr(a, field.name))
27+
return get_type(a)(**constructed_fields)
28+
29+
@staticmethod
30+
def compose(a: T.Dataclass, b: T.Dataclass) -> T.Dataclass:
31+
assert get_type(a) == get_type(b)
32+
constructed_fields = {}
33+
for field in dataclasses.fields(a):
34+
constructed_fields[field.name] = GroupOps.compose(
35+
getattr(a, field.name), getattr(b, field.name)
36+
)
37+
return get_type(a)(**constructed_fields)
38+
39+
@staticmethod
40+
def inverse(a: T.Dataclass) -> T.Dataclass:
41+
constructed_fields = {}
42+
for field in dataclasses.fields(a):
43+
constructed_fields[field.name] = GroupOps.inverse(getattr(a, field.name))
44+
return get_type(a)(**constructed_fields)
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,128 @@
1+
# ----------------------------------------------------------------------------
2+
# SymForce - Copyright 2022, Skydio, Inc.
3+
# This source code is under the Apache 2.0 license found in the LICENSE file.
4+
# ----------------------------------------------------------------------------
5+
from __future__ import annotations
6+
7+
import dataclasses
8+
9+
from symforce.ops import StorageOps
10+
from symforce.ops import LieGroupOps
11+
from symforce.python_util import get_type
12+
from symforce import typing as T
13+
14+
from .dataclass_group_ops import DataclassGroupOps
15+
16+
if T.TYPE_CHECKING:
17+
from symforce import geo
18+
19+
20+
class DataclassLieGroupOps(DataclassGroupOps):
21+
@staticmethod
22+
def tangent_dim(a: T.DataclassOrType) -> int:
23+
if isinstance(a, type):
24+
count = 0
25+
type_hints_map = T.get_type_hints(a)
26+
for field in dataclasses.fields(a):
27+
count += LieGroupOps.tangent_dim(type_hints_map[field.name])
28+
return count
29+
else:
30+
count = 0
31+
for field in dataclasses.fields(a):
32+
count += LieGroupOps.tangent_dim(getattr(a, field.name))
33+
return count
34+
35+
@staticmethod
36+
def from_tangent(
37+
a: T.DataclassOrType, vec: T.Sequence[T.Scalar], epsilon: T.Scalar
38+
) -> T.Dataclass:
39+
if isinstance(a, type):
40+
constructed_fields = {}
41+
offset = 0
42+
type_hints_map = T.get_type_hints(a)
43+
for field in dataclasses.fields(a):
44+
field_type = type_hints_map[field.name]
45+
tangent_dim = LieGroupOps.tangent_dim(field_type)
46+
constructed_fields[field.name] = LieGroupOps.from_tangent(
47+
field_type, vec[offset : offset + tangent_dim], epsilon
48+
)
49+
offset += tangent_dim
50+
return a(**constructed_fields)
51+
else:
52+
constructed_fields = {}
53+
offset = 0
54+
for field in dataclasses.fields(a):
55+
field_instance = getattr(a, field.name)
56+
tangent_dim = LieGroupOps.tangent_dim(field_instance)
57+
constructed_fields[field.name] = LieGroupOps.from_tangent(
58+
field_instance, vec[offset : offset + tangent_dim], epsilon
59+
)
60+
offset += tangent_dim
61+
return get_type(a)(**constructed_fields)
62+
63+
@staticmethod
64+
def to_tangent(a: T.Dataclass, epsilon: T.Scalar) -> T.List[T.Scalar]:
65+
tangent = []
66+
for field in dataclasses.fields(a):
67+
tangent.extend(LieGroupOps.to_tangent(getattr(a, field.name), epsilon))
68+
return tangent
69+
70+
@staticmethod
71+
def storage_D_tangent(a: T.Dataclass) -> geo.Matrix:
72+
from symforce import geo
73+
74+
mat = geo.Matrix(StorageOps.storage_dim(a), LieGroupOps.tangent_dim(a))
75+
s_inx = 0
76+
t_inx = 0
77+
for field in dataclasses.fields(a):
78+
field_instance = getattr(a, field.name)
79+
s_dim = StorageOps.storage_dim(field_instance)
80+
t_dim = LieGroupOps.tangent_dim(field_instance)
81+
mat[s_inx : s_inx + s_dim, t_inx : t_inx + t_dim] = LieGroupOps.storage_D_tangent(
82+
field_instance
83+
)
84+
s_inx += s_dim
85+
t_inx += t_dim
86+
return mat
87+
88+
@staticmethod
89+
def tangent_D_storage(a: T.Dataclass) -> geo.Matrix:
90+
from symforce import geo
91+
92+
mat = geo.Matrix(LieGroupOps.tangent_dim(a), StorageOps.storage_dim(a))
93+
s_inx = 0
94+
t_inx = 0
95+
for field in dataclasses.fields(a):
96+
field_instance = getattr(a, field.name)
97+
s_dim = StorageOps.storage_dim(field_instance)
98+
t_dim = LieGroupOps.tangent_dim(field_instance)
99+
mat[t_inx : t_inx + t_dim, s_inx : s_inx + s_dim] = LieGroupOps.tangent_D_storage(
100+
field_instance
101+
)
102+
s_inx += s_dim
103+
t_inx += t_dim
104+
return mat
105+
106+
@staticmethod
107+
def retract(a: T.Dataclass, vec: T.Sequence[T.Scalar], epsilon: T.Scalar) -> T.Dataclass:
108+
constructed_fields = {}
109+
offset = 0
110+
for field in dataclasses.fields(a):
111+
field_instance = getattr(a, field.name)
112+
tangent_dim = LieGroupOps.tangent_dim(field_instance)
113+
constructed_fields[field.name] = LieGroupOps.retract(
114+
field_instance, vec[offset : offset + tangent_dim], epsilon
115+
)
116+
offset += tangent_dim
117+
return get_type(a)(**constructed_fields)
118+
119+
@staticmethod
120+
def local_coordinates(a: T.Dataclass, b: T.Dataclass, epsilon: T.Scalar) -> T.List[T.Scalar]:
121+
assert get_type(a) == get_type(b)
122+
return [
123+
x
124+
for field in dataclasses.fields(a)
125+
for x in LieGroupOps.local_coordinates(
126+
getattr(a, field.name), getattr(b, field.name), epsilon
127+
)
128+
]

symforce/ops/impl/dataclass_storage_ops.py

+4-4
Original file line numberDiff line numberDiff line change
@@ -23,7 +23,7 @@ class DataclassStorageOps:
2323
"""
2424

2525
@staticmethod
26-
def storage_dim(a: T.ElementOrType) -> int:
26+
def storage_dim(a: T.DataclassOrType) -> int:
2727
if isinstance(a, type):
2828
count = 0
2929
type_hints_map = T.get_type_hints(a)
@@ -37,14 +37,14 @@ def storage_dim(a: T.ElementOrType) -> int:
3737
return count
3838

3939
@staticmethod
40-
def to_storage(a: T.Element) -> T.List[T.Scalar]:
40+
def to_storage(a: T.Dataclass) -> T.List[T.Scalar]:
4141
storage = []
4242
for field in dataclasses.fields(a):
4343
storage.extend(StorageOps.to_storage(getattr(a, field.name)))
4444
return storage
4545

4646
@staticmethod
47-
def from_storage(a: T.ElementOrType, elements: T.List[T.Scalar]) -> T.Element:
47+
def from_storage(a: T.DataclassOrType, elements: T.Sequence[T.Scalar]) -> T.Dataclass:
4848
if isinstance(a, type):
4949
constructed_fields = {}
5050
offset = 0
@@ -70,7 +70,7 @@ def from_storage(a: T.ElementOrType, elements: T.List[T.Scalar]) -> T.Element:
7070
return get_type(a)(**constructed_fields)
7171

7272
@staticmethod
73-
def symbolic(a: T.ElementOrType, name: T.Optional[str], **kwargs: T.Dict) -> T.Element:
73+
def symbolic(a: T.DataclassOrType, name: T.Optional[str], **kwargs: T.Dict) -> T.Dataclass:
7474
"""
7575
Return a symbolic instance of a Dataclass
7676

symforce/ops/impl/sequence_lie_group_ops.py

+2-2
Original file line numberDiff line numberDiff line change
@@ -23,7 +23,7 @@ def tangent_dim(a: T.SequenceElement) -> int:
2323

2424
@staticmethod
2525
def from_tangent(
26-
a: T.SequenceElement, vec: T.List[T.Scalar], epsilon: T.Scalar
26+
a: T.SequenceElement, vec: T.Sequence[T.Scalar], epsilon: T.Scalar
2727
) -> T.SequenceElement:
2828
assert len(vec) == SequenceLieGroupOps.tangent_dim(a)
2929
new_a = []
@@ -70,7 +70,7 @@ def tangent_D_storage(a: T.SequenceElement) -> geo.Matrix:
7070

7171
@staticmethod
7272
def retract(
73-
a: T.SequenceElement, vec: T.List[T.Scalar], epsilon: T.Scalar
73+
a: T.SequenceElement, vec: T.Sequence[T.Scalar], epsilon: T.Scalar
7474
) -> T.SequenceElement:
7575
assert len(vec) == SequenceLieGroupOps.tangent_dim(a)
7676
new_a = []

symforce/ops/impl/sequence_storage_ops.py

+1-1
Original file line numberDiff line numberDiff line change
@@ -18,7 +18,7 @@ def to_storage(a: T.SequenceElement) -> T.List[T.Scalar]:
1818
return [scalar for v in a for scalar in StorageOps.to_storage(v)]
1919

2020
@staticmethod
21-
def from_storage(a: T.SequenceElement, elements: T.List[T.Scalar]) -> T.SequenceElement:
21+
def from_storage(a: T.SequenceElement, elements: T.Sequence[T.Scalar]) -> T.SequenceElement:
2222
assert len(elements) == SequenceStorageOps.storage_dim(a)
2323
new_a = []
2424
inx = 0

symforce/typing.py

+2
Original file line numberDiff line numberDiff line change
@@ -67,6 +67,8 @@ def __subclasshook__(cls, subclass: Type) -> bool:
6767
return dataclasses.is_dataclass(subclass) and isinstance(subclass, type)
6868

6969

70+
DataclassOrType = Union[Dataclass, Type[Dataclass]]
71+
7072
# Abstract method helpers
7173
_ReturnType = TypeVar("_ReturnType")
7274

test/symforce_dataclass_ops_test.py

+16-7
Original file line numberDiff line numberDiff line change
@@ -11,6 +11,7 @@
1111

1212
from symforce.test_util import TestCase
1313
from symforce.ops import StorageOps
14+
from symforce.test_util.lie_group_ops_test_mixin import LieGroupOpsTestMixin
1415

1516

1617
@dataclass
@@ -34,11 +35,24 @@ class TestFixedSizeType:
3435
seq: TestSubType
3536

3637

37-
class SymforceDataclassOpsTest(TestCase):
38+
class SymforceDataclassOpsTest(LieGroupOpsTestMixin, TestCase):
3839
"""
3940
Tests ops.impl.dataclass_*_ops.py
4041
"""
4142

43+
@classmethod
44+
def element(cls) -> T.Dataclass:
45+
element = TestDynamicSizeType(
46+
rot=geo.Rot3.identity(),
47+
x=1.0,
48+
subtype=TestSubType(rot=geo.Rot3.from_yaw_pitch_roll(1.0, 2.0, 3.0)),
49+
seq=[
50+
[TestSubType(rot=geo.Rot3.from_yaw_pitch_roll(0.1, 0.2, 0.3)) for _ in range(3)]
51+
for _ in range(5)
52+
],
53+
)
54+
return element
55+
4256
def test_fixed_size_storage_ops(self) -> None:
4357
"""
4458
Tests:
@@ -78,12 +92,7 @@ def test_dynamic_size_storage_ops(self) -> None:
7892
Tests:
7993
DataclassStorageOps, with dynamic size type
8094
"""
81-
empty_instance = TestDynamicSizeType(
82-
rot=geo.Rot3.symbolic("rot"),
83-
x=sm.Symbol("x"),
84-
subtype=TestSubType(rot=geo.Rot3.symbolic("rot")),
85-
seq=[[TestSubType(rot=geo.Rot3.symbolic("rot")) for _ in range(3)] for _ in range(5)],
86-
)
95+
empty_instance = self.element()
8796

8897
expected_size = 1 + (
8998
2 + len(empty_instance.seq) * len(empty_instance.seq[0])

0 commit comments

Comments
 (0)