Skip to content

Commit 0c67bb8

Browse files
nathan-skydioaaron-skydio
authored andcommitted
[Symforce] Allow specific explicit template instantiations
Previously we allowed users to set a flag to indicate if they wanted generated C++ functions to be explicitly instantiated, and would explicitly instantiate a float and a double version of the function. However, compliling very large functions can be slow because both the float and double versions of the function must be compiled, even if only one is used. This PR allows the user to do explicit template instantiation on whatever types they like, meaning they can choose to only do explicit template instantiation on float or double if they like. Topic: symforce_explicit_template_instantiation_list GitOrigin-RevId: 27e5b54441a80928b85019df44fc605ecbdf75aa
1 parent 0c2f9f2 commit 0c67bb8

File tree

5 files changed

+15
-16
lines changed

5 files changed

+15
-16
lines changed

symforce/codegen/codegen.py

+1-1
Original file line numberDiff line numberDiff line change
@@ -431,7 +431,7 @@ def generate_function(
431431
template_data,
432432
)
433433

434-
if self.config.use_explicit_template_instantiation:
434+
if self.config.explicit_template_instantiation_types is not None:
435435
templates.add(
436436
Path(template_util.CPP_TEMPLATE_DIR) / "function" / "FUNCTION.cc.jinja",
437437
cpp_function_dir / f"{generated_file_name}.cc",

symforce/codegen/codegen_config.py

+4-4
Original file line numberDiff line numberDiff line change
@@ -48,9 +48,9 @@ class CppConfig(CodegenConfig):
4848
which we'll initialize an output matrix to 0, so we
4949
don't have to generate a line to set each zero
5050
element to 0 individually
51-
use_explicit_template_instantiation: Explicity instantiate templated functions in a `.cc`
52-
file so that generated function can be compiled in its own translation unit. Useful
53-
for large functions which take a long time to compile.
51+
explicit_template_instantiation_types: Explicity instantiates templated functions in a `.cc`
52+
file for each given type. This allows the generated function to be compiled in its own
53+
translation unit. Useful for large functions which take a long time to compile.
5454
"""
5555

5656
doc_comment_line_prefix: str = " * "
@@ -59,7 +59,7 @@ class CppConfig(CodegenConfig):
5959
support_complex: bool = False
6060
force_no_inline: bool = False
6161
zero_initialization_sparsity_threshold: float = 0.5
62-
use_explicit_template_instantiation: bool = False
62+
explicit_template_instantiation_types: T.Optional[T.Sequence[str]] = None
6363

6464

6565
@dataclass

symforce/codegen/cpp_templates/function/FUNCTION.cc.jinja

+4-4
Original file line numberDiff line numberDiff line change
@@ -9,11 +9,11 @@
99

1010
namespace {{ spec.namespace }} {
1111

12+
{% for type in spec.config.explicit_template_instantiation_types %}
1213
{% set name = python_util.snakecase_to_camelcase(spec.name) %}
13-
template {{ util.get_return_type(spec, scalar_type="double") }} {{ name }}<double>(
14-
{{- util.input_args_declaration(spec, is_declaration=False, scalar_type="double") -}});
14+
template {{ util.get_return_type(spec, scalar_type=type) }} {{ name }}<{{ type }}>(
15+
{{- util.input_args_declaration(spec, is_declaration=False, scalar_type=type) -}});
1516

16-
template {{ util.get_return_type(spec, scalar_type="float") }} {{ name }}<float>(
17-
{{- util.input_args_declaration(spec, is_declaration=False, scalar_type="float") -}});
17+
{% endfor -%}
1818

1919
} // namespace {{ spec.namespace }}

symforce/codegen/cpp_templates/function/FUNCTION.h.jinja

+5-6
Original file line numberDiff line numberDiff line change
@@ -46,14 +46,13 @@ __attribute__((noinline))
4646
{{ util.expr_code(spec) -}}
4747
} // NOLINT(readability/fn_size)
4848

49-
{% if spec.config.use_explicit_template_instantiation %}
49+
{% if spec.config.explicit_template_instantiation_types is not none %}
50+
{% for type in spec.config.explicit_template_instantiation_types %}
5051
{% set name = python_util.snakecase_to_camelcase(spec.name) %}
51-
extern template {{ util.get_return_type(spec, scalar_type="double") }} {{ name }}<double>(
52-
{{- util.input_args_declaration(spec, is_declaration=False, scalar_type="double") -}});
52+
extern template {{ util.get_return_type(spec, scalar_type=type) }} {{ name }}<{{ type }}>(
53+
{{- util.input_args_declaration(spec, is_declaration=False, scalar_type=type) -}});
5354

54-
extern template {{ util.get_return_type(spec, scalar_type="float") }} {{ name }}<float>(
55-
{{- util.input_args_declaration(spec, is_declaration=False, scalar_type="float") -}});
56-
55+
{% endfor -%}
5756
{% endif -%}
5857

5958
// NOLINTNEXTLINE(readability/fn_size)

test/symforce_codegen_test.py

+1-1
Original file line numberDiff line numberDiff line change
@@ -837,7 +837,7 @@ def test_function_explicit_template_instantiation(self) -> None:
837837
cpp_func = codegen.Codegen(
838838
inputs,
839839
outputs,
840-
codegen.CppConfig(use_explicit_template_instantiation=True),
840+
codegen.CppConfig(explicit_template_instantiation_types=["double", "float"]),
841841
"codegen_explicit_template_instantiation_test",
842842
)
843843
shared_types = {

0 commit comments

Comments
 (0)