Skip to content

Commit

Permalink
Resolvers can now access the parent node as well
Browse files Browse the repository at this point in the history
  • Loading branch information
odelalleau committed Aug 12, 2020
1 parent 1dfe7b9 commit 70d6f4b
Show file tree
Hide file tree
Showing 3 changed files with 55 additions and 10 deletions.
4 changes: 3 additions & 1 deletion omegaconf/base.py
Original file line number Diff line number Diff line change
Expand Up @@ -313,6 +313,7 @@ def _resolve_complex_interpolation(
resolve_func=partial(
self._resolve_simple_interpolation,
key=key,
parent=parent,
throw_on_missing=throw_on_missing,
throw_on_resolution_failure=throw_on_resolution_failure,
)
Expand All @@ -336,6 +337,7 @@ def _resolve_complex_interpolation(
def _resolve_simple_interpolation(
self,
key: Any,
parent: Optional["Container"],
inter_type: str,
inter_key: Tuple[Any],
throw_on_missing: bool,
Expand Down Expand Up @@ -373,7 +375,7 @@ def _resolve_simple_interpolation(
resolver = OmegaConf.get_resolver(inter_type)
if resolver is not None:
try:
value = resolver(root_node, inter_key, inputs_str)
value = resolver(root_node, parent, inter_key, inputs_str)
return ValueNode(
value=value,
parent=self,
Expand Down
33 changes: 24 additions & 9 deletions omegaconf/omegaconf.py
Original file line number Diff line number Diff line change
Expand Up @@ -332,6 +332,7 @@ def register_resolver(
resolver: Resolver,
variables_as_strings: bool = True,
config_arg: Optional[str] = None,
parent_arg: Optional[str] = None,
) -> None:
"""
The `variables_as_strings` flag was introduced to preserve backward compatibility
Expand All @@ -344,8 +345,13 @@ def register_resolver(
If provided, `config_arg` should be the name of a keyword (typically keyword-only)
argument of `resolver` of type `BaseContainer`, that will be bound to the config
root when the resolver is called. This allows accessing arbitrary config elements
from within the resolver. See `env()` for an example.
root when the resolver is called. This allows performing arbitrary operations on
the config from within the resolver. See `env()` for an example.
Similarly, `parent_arg` can be used to bind the corresponding keyword argument
of `resolver` (of type `Optional[Container]`) to the parent of the key being
processed when the resolver is called. This can be useful for operations involving
other config options relative to the current key.
"""
assert callable(resolver), "resolver must be callable"
# noinspection PyProtectedMember
Expand All @@ -354,7 +360,10 @@ def register_resolver(
), "resolver {} is already registered".format(name)

def resolver_wrapper(
config: BaseContainer, key: Tuple[Any, ...], inputs_str: Tuple[str, ...]
config: BaseContainer,
parent: Optional[Container],
key: Tuple[Any, ...],
inputs_str: Tuple[str, ...],
) -> Any:
# The `variables_as_strings` warning is triggered when the resolver is
# called instead of when it is defined, so as to limit the amount of
Expand All @@ -377,11 +386,13 @@ def resolver_wrapper(
try:
val = cache[hashable_key]
except KeyError:
val = cache[hashable_key] = (
resolver(*inputs)
if config_arg is None
else resolver(*inputs, **{config_arg: config})
)
# Call resolver.
optional_args: Dict[str, Optional[Container]] = {}
if config_arg is not None:
optional_args[config_arg] = config
if parent_arg is not None:
optional_args[parent_arg] = parent
val = cache[hashable_key] = resolver(*inputs, **optional_args)
return val

# noinspection PyProtectedMember
Expand All @@ -390,7 +401,11 @@ def resolver_wrapper(
@staticmethod
def get_resolver(
name: str,
) -> Optional[Callable[[Container, Tuple[Any, ...], Tuple[str, ...]], Any]]:
) -> Optional[
Callable[
[Container, Optional[Container], Tuple[Any, ...], Tuple[str, ...]], Any
]
]:
# noinspection PyProtectedMember
return (
BaseContainer._resolvers[name] if name in BaseContainer._resolvers else None
Expand Down
28 changes: 28 additions & 0 deletions tests/test_interpolation.py
Original file line number Diff line number Diff line change
Expand Up @@ -233,6 +233,34 @@ def test_register_resolver_1(restore_resolvers: Any) -> None:
assert c.k == 1000


def test_register_resolver_access_config(restore_resolvers: Any) -> None:
OmegaConf.register_resolver(
"len",
lambda value, *, root: len(OmegaConf.select(root, value)),
config_arg="root",
)
c = OmegaConf.create({"list": [1, 2, 3], "list_len": "${len:list}"})
assert c.list_len == 3


def test_register_resolver_access_parent(restore_resolvers: Any) -> None:
OmegaConf.register_resolver(
"get_sibling",
lambda sibling, *, parent: getattr(parent, sibling),
parent_arg="parent",
)
c = OmegaConf.create(
"""
root:
foo:
bar:
baz1: "${get_sibling:baz2}"
baz2: useful data
"""
)
assert c.root.foo.bar.baz1 == "useful data"


def test_resolver_cache_1(restore_resolvers: Any) -> None:
# resolvers are always converted to stateless idempotent functions
# subsequent calls to the same function with the same argument will always return the same value.
Expand Down

0 comments on commit 70d6f4b

Please sign in to comment.