diff --git a/docs/conf.py b/docs/conf.py index 0a7a23f4b..33ceb7467 100644 --- a/docs/conf.py +++ b/docs/conf.py @@ -149,4 +149,24 @@ doctest_global_setup = """ import jax import jax.numpy as jnp + +import logging as slog +from absl import logging as alog + +# Avoid certain absl logging messages to break doctest +filtered_message = [ + 'SaveArgs.aggregate is deprecated', + '', +] + +class _CustomLogFilter(slog.Formatter): + def format(self, record): + message = super(_CustomLogFilter, self).format(record) + for m in filtered_message: + if m in message: + return '' + return message + +alog.use_absl_handler() +alog.get_absl_handler().setFormatter(_CustomLogFilter()) """ diff --git a/docs/guides/converting_and_upgrading/orbax_upgrade_guide.rst b/docs/guides/converting_and_upgrading/orbax_upgrade_guide.rst index e33d364a1..91f3cbb0f 100644 --- a/docs/guides/converting_and_upgrading/orbax_upgrade_guide.rst +++ b/docs/guides/converting_and_upgrading/orbax_upgrade_guide.rst @@ -93,7 +93,7 @@ For example: # At the top level mgr_options = orbax.checkpoint.CheckpointManagerOptions( - create=True, max_to_keep=3, keep_period=2, step_prefix='test_') + create=True, max_to_keep=3, keep_period=2, step_prefix='test') ckpt_mgr = orbax.checkpoint.CheckpointManager( CKPT_DIR, orbax.checkpoint.Checkpointer(orbax.checkpoint.PyTreeCheckpointHandler()), mgr_options) diff --git a/flax/metrics/tensorboard.py b/flax/metrics/tensorboard.py index b3d72b3bc..96bb25e83 100644 --- a/flax/metrics/tensorboard.py +++ b/flax/metrics/tensorboard.py @@ -75,7 +75,7 @@ def _as_default(summary_writer: tf.summary.SummaryWriter, auto_flush: bool): old_flush = summary_writer.flush new_flush = old_flush if auto_flush else lambda: None summary_writer.flush = new_flush - context_manager.__exit__() + context_manager.__exit__(None, None, None) summary_writer.flush = old_flush