Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Add path_aware_map function #2371

Merged
merged 1 commit into from
Aug 30, 2022
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
2 changes: 2 additions & 0 deletions docs/api_reference/flax.traverse_util.rst
Original file line number Diff line number Diff line change
Expand Up @@ -43,6 +43,8 @@ Dict utils

.. autofunction:: unflatten_dict

.. autofunction:: path_aware_map


Model parameter traversal
--------------------------
Expand Down
29 changes: 29 additions & 0 deletions flax/traverse_util.py
Original file line number Diff line number Diff line change
Expand Up @@ -43,13 +43,16 @@
import abc
import copy
import dataclasses
from typing import Any, Callable, Dict, Tuple
import warnings

import jax
import flax
from flax.core.scope import VariableDict

from . import struct

Path = Tuple[str, ...]

# the empty node is a struct.dataclass to be compatible with JAX.
@struct.dataclass
Expand Down Expand Up @@ -159,6 +162,32 @@ def unflatten_dict(xs, sep=None):
cursor[path[-1]] = value
return result

def path_aware_map(
f: Callable[[Path, Any], Any], nested_dict: VariableDict) -> VariableDict:
"""A map function that operates over nested dictionary structures while taking
the path to each leaf into account.

Example::

>>> import jax.numpy as jnp
>>> from flax import traverse_util
...
>>> params = {'a': {'x': 10, 'y': 3}, 'b': {'x': 20}}
>>> f = lambda path, x: x + 5 if 'x' in path else -x
>>> traverse_util.path_aware_map(f, params)
{'a': {'x': 15, 'y': -3}, 'b': {'x': 25}}

Args:
f: A callable that takes in ``(path, value)`` arguments and maps them
to a new value. Here ``path`` is a tuple of strings.
nested_dict: A nested dictionary structure.

Returns:
A new nested dictionary structure with the mapped values.
"""
flat = flatten_dict(nested_dict, keep_empty_nodes=True)
return unflatten_dict({
k: f(k, v) if v is not empty_node else v for k, v in flat.items()})

class Traversal(abc.ABC):
"""Base class for all traversals."""
Expand Down
54 changes: 54 additions & 0 deletions tests/traverse_util_test.py
Original file line number Diff line number Diff line change
Expand Up @@ -17,10 +17,13 @@

import collections
from absl.testing import absltest
import numpy as np
import optax
import flax
from flax.core import freeze
from flax import traverse_util
import jax
import jax.numpy as jnp

# Parse absl flags test_srcdir and test_tmpdir.
jax.config.parse_flags_with_absl()
Expand Down Expand Up @@ -253,6 +256,57 @@ def filter_fn(name, _):
new_model = traversal.update(lambda x: x + x, model)
self.assertEqual(new_model, expected_model)

def test_path_value(self):
params_in = {'a': {'b': 10, 'c': 2}}
params_out = traverse_util.path_aware_map(
lambda path, x: x + 1 if 'b' in path else -x, params_in)

self.assertEqual(params_out, {'a': {'b': 11, 'c': -2}})

def test_path_aware_map_with_multi_transform(self):
params = {'linear_1': {'w': jnp.zeros((5, 6)), 'b': jnp.zeros(5)},
'linear_2': {'w': jnp.zeros((6, 1)), 'b': jnp.zeros(1)}}
gradients = jax.tree_util.tree_map(jnp.ones_like, params) # dummy gradients

param_labels = traverse_util.path_aware_map(
lambda path, x: 'kernel' if 'w' in path else 'bias', params)
tx = optax.multi_transform(
{'kernel': optax.sgd(1.0), 'bias': optax.set_to_zero()}, param_labels)
state = tx.init(params)
updates, new_state = tx.update(gradients, state, params)
new_params = optax.apply_updates(params, updates)


self.assertTrue(np.allclose(new_params['linear_1']['b'], params['linear_1']['b']))
self.assertTrue(np.allclose(new_params['linear_2']['b'], params['linear_2']['b']))
self.assertFalse(np.allclose(new_params['linear_1']['w'], params['linear_1']['w']))
self.assertFalse(np.allclose(new_params['linear_2']['w'], params['linear_2']['w']))

def test_path_aware_map_with_masked(self):
params = {'linear_1': {'w': jnp.zeros((5, 6)), 'b': jnp.zeros(5)},
'linear_2': {'w': jnp.zeros((6, 1)), 'b': jnp.zeros(1)}}
gradients = jax.tree_util.tree_map(jnp.ones_like, params) # dummy gradients

params_mask = traverse_util.path_aware_map(
lambda path, x: 'w' in path, params)
tx = optax.masked(optax.sgd(1.0), params_mask)
state = tx.init(params)
updates, new_state = tx.update(gradients, state, params)
new_params = optax.apply_updates(params, updates)


self.assertTrue(np.allclose(new_params['linear_1']['b'], gradients['linear_1']['b']))
self.assertTrue(np.allclose(new_params['linear_2']['b'], gradients['linear_2']['b']))
self.assertTrue(np.allclose(new_params['linear_1']['w'], -gradients['linear_1']['w']))
self.assertTrue(np.allclose(new_params['linear_2']['w'], -gradients['linear_2']['w']))

def test_path_aware_map_with_empty_nodes(self):
params_in = {'a': {'b': 10, 'c': 2}, 'b': {}}
params_out = traverse_util.path_aware_map(
lambda path, x: x + 1 if 'b' in path else -x, params_in)

self.assertEqual(params_out, {'a': {'b': 11, 'c': -2}, 'b': {}})


if __name__ == '__main__':
absltest.main()