Skip to content

Commit

Permalink
[resolver] Add options to give resolvers access to the config
Browse files Browse the repository at this point in the history
Fixes omry#266
  • Loading branch information
odelalleau committed Oct 27, 2020
1 parent 9f30916 commit 0d08356
Show file tree
Hide file tree
Showing 4 changed files with 135 additions and 6 deletions.
6 changes: 4 additions & 2 deletions omegaconf/base.py
Original file line number Diff line number Diff line change
Expand Up @@ -305,7 +305,8 @@ def _select_impl(
)
if value is None:
return root, last_key, value
value = root._resolve_interpolation(
value = root.resolve_interpolation(
parent=root,
key=last_key,
value=value,
throw_on_missing=False,
Expand Down Expand Up @@ -361,6 +362,7 @@ def _resolve_complex_interpolation(
def resolve_simple_interpolation(
self,
key: Any,
parent: Optional["Container"],
inter_type: str,
inter_key: Tuple[Any, ...],
throw_on_missing: bool,
Expand Down Expand Up @@ -399,7 +401,7 @@ def resolve_simple_interpolation(
if resolver is not None:
root_node = self._get_root()
try:
value = resolver(root_node, inter_key, inputs_str)
value = resolver(root_node, parent, inter_key, inputs_str)
return ValueNode(
value=value,
parent=self,
Expand Down
1 change: 1 addition & 0 deletions omegaconf/basecontainer.py
Original file line number Diff line number Diff line change
Expand Up @@ -60,6 +60,7 @@ def is_mandatory_missing(val: Any) -> bool:
return default_value

resolved = self.resolve_interpolation(
parent=self,
key=key,
value=value,
throw_on_missing=not has_default,
Expand Down
51 changes: 47 additions & 4 deletions omegaconf/omegaconf.py
Original file line number Diff line number Diff line change
Expand Up @@ -378,7 +378,9 @@ def register_resolver(
name: str,
resolver: Resolver,
args_as_strings: bool = True,
use_cache: Optional[bool] = True,
config_arg: Optional[str] = None,
parent_arg: Optional[str] = None,
use_cache: Optional[bool] = None,
) -> None:
"""
The `args_as_strings` flag was introduced to preserve backward compatibility
Expand All @@ -388,13 +390,48 @@ def register_resolver(
of its inputs), and triggers a warning
- `False` is the new behavior (the resolver can take non-string inputs), and
will become the default in the future
If provided, `config_arg` should be the name of a keyword (typically keyword-only)
argument of `resolver` of type `BaseContainer`, that will be bound to the config
root when the resolver is called. This allows performing arbitrary operations on
the config from within the resolver. See `env()` for an example.
Similarly, `parent_arg` can be used to bind the corresponding keyword argument
of `resolver` (of type `Optional[Container]`) to the parent of the key being
processed when the resolver is called. This can be useful for operations involving
other config options relative to the current key.
`use_cache` indicates whether the resolver's outputs should be cached. When not
provided, it defaults to `True` unless either `config_arg` or `parent_arg` is
used. In such situations it defaults to `False` and the user is warned to
explicitly set `use_cache=False` to make it clear that no caching is done
(currently caching is not supported when using `config_arg` or `parent_arg`).
"""
assert callable(resolver), "resolver must be callable"
# noinspection PyProtectedMember
assert (
name not in BaseContainer._resolvers
), "resolved {} is already registered".format(name)
def caching(

if use_cache is None:
if config_arg is not None or parent_arg is not None:
warnings.warn(
f"You are using either `config_arg` or `parent_arg` to register "
f"resolver `{name}`: caching is not supported in such a case, and "
f"you must explicitly set `use_cache=False` to disable this warning.",
stacklevel=2,
)
use_cache = False
else:
use_cache = True
elif use_cache and (config_arg is not None or parent_arg is not None):
raise NotImplementedError(
f"Caching is not supported when using either `config_arg` or "
f"`parent_arg`, please set `use_cache=False` when registering "
f"resolver `{name}`",
)

def resolver_wrapper(
config: BaseContainer,
parent: Optional[Container],
key: Tuple[Any, ...],
Expand Down Expand Up @@ -432,12 +469,18 @@ def caching(
pass

# Call resolver.
ret = resolver(*inputs)
optional_args: Dict[str, Optional[Container]] = {}
if config_arg is not None:
optional_args[config_arg] = config
if parent_arg is not None:
optional_args[parent_arg] = parent
ret = resolver(*inputs, **optional_args)
if use_cache:
cache[hashable_key] = ret
return ret

# noinspection PyProtectedMember
BaseContainer._resolvers[name] = caching
BaseContainer._resolvers[name] = resolver_wrapper

@staticmethod
def get_resolver(name: str) -> Optional[Callable[[Container, Any], Any]]:
Expand Down
83 changes: 83 additions & 0 deletions tests/test_interpolation.py
Original file line number Diff line number Diff line change
Expand Up @@ -292,6 +292,89 @@ def test_register_resolver_1(restore_resolvers: Any) -> None:
assert c.k == 1000


def test_register_resolver_access_config(restore_resolvers: Any) -> None:
OmegaConf.register_resolver(
"len",
lambda value, *, root: len(OmegaConf.select(root, value)),
config_arg="root",
use_cache=False,
)
c = OmegaConf.create({"list": [1, 2, 3], "list_len": "${len:list}"})
assert c.list_len == 3


def test_register_resolver_access_parent(restore_resolvers: Any) -> None:
OmegaConf.register_resolver(
"get_sibling",
lambda sibling, *, parent: getattr(parent, sibling),
parent_arg="parent",
use_cache=False,
)
c = OmegaConf.create(
"""
root:
foo:
bar:
baz1: "${get_sibling:baz2}"
baz2: useful data
"""
)
assert c.root.foo.bar.baz1 == "useful data"


def test_register_resolver_access_parent_no_cache(restore_resolvers: Any) -> None:
OmegaConf.register_resolver(
"add_noise_to_sibling",
lambda sibling, *, parent: random.uniform(0, 1) + getattr(parent, sibling),
parent_arg="parent",
use_cache=False,
)
c = OmegaConf.create(
"""
root:
foo:
baz1: "${add_noise_to_sibling:baz2}"
baz2: 1
bar:
baz1: "${add_noise_to_sibling:baz2}"
baz2: 1
"""
)
assert c.root.foo.baz2 == c.root.bar.baz2 # make sure we test what we want to test
assert c.root.foo.baz1 != c.root.foo.baz1 # same node (regular "no cache" behavior)
assert c.root.foo.baz1 != c.root.bar.baz1 # same args but different parents


def test_register_resolver_cache_warnings(restore_resolvers: Any) -> None:
with pytest.warns(UserWarning):
OmegaConf.register_resolver(
"test_warning_parent", lambda *, parent: None, parent_arg="parent"
)

with pytest.warns(UserWarning):
OmegaConf.register_resolver(
"test_warning_config", lambda *, config: None, config_arg="config"
)


def test_register_resolver_cache_errors(restore_resolvers: Any) -> None:
with pytest.raises(NotImplementedError):
OmegaConf.register_resolver(
"test_error_parent",
lambda *, parent: None,
parent_arg="parent",
use_cache=True,
)

with pytest.raises(NotImplementedError):
OmegaConf.register_resolver(
"test_error_config",
lambda *, config: None,
config_arg="config",
use_cache=True,
)


def test_resolver_cache_1(restore_resolvers: Any) -> None:
# resolvers are always converted to stateless idempotent functions
# subsequent calls to the same function with the same argument will always return the same value.
Expand Down

0 comments on commit 0d08356

Please sign in to comment.