Skip to content

Commit

Permalink
[Symforce] Added "length" metadata reader to timestep sub problem
Browse files Browse the repository at this point in the history
I have found myself often wanting to specify an input to a subproblem as
a sequence of objects (scalars, geo.V3, geo.Rot3, etc.), so I think we
should support this. I also removed `sequence_field` because there was a
comment saying we should probably remove it, and I think we should
just stick to one method for specifying field metadata instead of two,
as I found the fact that there were two ways to do the same thing a tad
bit confusing (I'm open to discussing though).

Topic: symforce_timestep_subproblem_length
GitOrigin-RevId: 6e2dfe469f3a883d6c366e7647940a504e7a7413
  • Loading branch information
nathan-skydio authored and aaron-skydio committed Apr 1, 2022
1 parent 523b5a1 commit e4834f0
Showing 1 changed file with 12 additions and 31 deletions.
43 changes: 12 additions & 31 deletions symforce/opt/timestep_sub_problem.py
Original file line number Diff line number Diff line change
Expand Up @@ -8,6 +8,7 @@
from symforce import ops
from symforce.opt.sub_problem import SubProblem
from symforce import typing as T
from symforce.python_util import get_sequence_from_dataclass_sequence_field


class TimestepSubProblem(SubProblem):
Expand All @@ -32,20 +33,15 @@ def build_inputs(self) -> None:
"""
Build the inputs block of the subproblem, and store in self.inputs.
The default implementation works for Dataclasses where all sequences are of length
self.timesteps; to customize this, override this function. Each field in the subproblem
Inputs that's meant to be a sequence of length `self.timesteps` should be marked with
`"timestepped": True` in the field metadata, e.g.
Each field in the subproblem Inputs that's meant to be a sequence of length `self.timesteps`
should be marked with `"timestepped": True` in the field metadata. Other sequences of known
length should be marked with the `"length": <sequence length>` in the field metadata, where
`<sequence length>` is the length of the sequence. For example:
@dataclass
class Inputs:
my_field: T.Sequence[T.Scalar] = field(metadata={"timestepped": True})
or
@dataclass
class Inputs:
my_field: T.Sequence[T.Scalar] = TimestepSubProblem.sequence_field()
my_timestepped_field: T.Sequence[T.Scalar] = field(metadata={"timestepped": True})
my_sequence_field: T.Sequence[T.Scalar] = field(metadata={"length": 3})
Any remaining fields of unknown size will cause an exception.
"""
Expand All @@ -61,6 +57,11 @@ class Inputs:
ops.StorageOps.symbolic(field_type, f"{self.name}.{field.name}[{i}]")
for i in range(self.timesteps)
]
elif field.metadata.get("length", False):
sequence_instance = get_sequence_from_dataclass_sequence_field(field, field_type)
constructed_fields[field.name] = ops.StorageOps.symbolic(
sequence_instance, f"{self.name}.{field.name}"
)
else:
try:
constructed_fields[field.name] = ops.StorageOps.symbolic(
Expand All @@ -74,23 +75,3 @@ class Inputs:
) from ex

self.inputs = self.Inputs(**constructed_fields)

@staticmethod
def sequence_field(*args: T.Any, **kwargs: T.Any) -> dataclasses.Field:
"""
Replacement for dataclasses.field(metadata={"timestepped": True})
Seemed cleaner to me, but mypy doesn't recognize it and gets mad about creating defaults for
some fields and not others (it thinks this is creating a default, instead of a Field). So
probably best to just delete, and have the user write out
my_field: MyType = field(metadata={"timestepped": True})
In py3.9, we could probably do T.Annotated[MyType, ...] instead of dataclasses.field, which
would be super nice.
"""
if "metadata" in kwargs:
kwargs["metadata"]["timestepped"] = True
else:
kwargs["metadata"] = {"timestepped": True}

return dataclasses.field(*args, **kwargs)

0 comments on commit e4834f0

Please sign in to comment.