Skip to content

Commit

Permalink
Merge pull request #3753 from IvyZX:orbax-quickfix
Browse files Browse the repository at this point in the history
PiperOrigin-RevId: 615279397
  • Loading branch information
Flax Authors committed Mar 13, 2024
2 parents 6f2b08e + 3b23dc8 commit ce8a3c7
Show file tree
Hide file tree
Showing 3 changed files with 22 additions and 2 deletions.
20 changes: 20 additions & 0 deletions docs/conf.py
Original file line number Diff line number Diff line change
Expand Up @@ -149,4 +149,24 @@
doctest_global_setup = """
import jax
import jax.numpy as jnp
import logging as slog
from absl import logging as alog
# Avoid certain absl logging messages to break doctest
filtered_message = [
'SaveArgs.aggregate is deprecated',
'',
]
class _CustomLogFilter(slog.Formatter):
def format(self, record):
message = super(_CustomLogFilter, self).format(record)
for m in filtered_message:
if m in message:
return ''
return message
alog.use_absl_handler()
alog.get_absl_handler().setFormatter(_CustomLogFilter())
"""
Original file line number Diff line number Diff line change
Expand Up @@ -93,7 +93,7 @@ For example:

# At the top level
mgr_options = orbax.checkpoint.CheckpointManagerOptions(
create=True, max_to_keep=3, keep_period=2, step_prefix='test_')
create=True, max_to_keep=3, keep_period=2, step_prefix='test')
ckpt_mgr = orbax.checkpoint.CheckpointManager(
CKPT_DIR,
orbax.checkpoint.Checkpointer(orbax.checkpoint.PyTreeCheckpointHandler()), mgr_options)
Expand Down
2 changes: 1 addition & 1 deletion flax/metrics/tensorboard.py
Original file line number Diff line number Diff line change
Expand Up @@ -75,7 +75,7 @@ def _as_default(summary_writer: tf.summary.SummaryWriter, auto_flush: bool):
old_flush = summary_writer.flush
new_flush = old_flush if auto_flush else lambda: None
summary_writer.flush = new_flush
context_manager.__exit__()
context_manager.__exit__(None, None, None)
summary_writer.flush = old_flush


Expand Down

0 comments on commit ce8a3c7

Please sign in to comment.