Skip to content

Commit

Permalink
Flax: avoid key reuse in tests
Browse files Browse the repository at this point in the history
Detected via running tests with configuration `jax_enable_key_reuse_checks=true`, available since jax-ml/jax#19795

PiperOrigin-RevId: 612604299
  • Loading branch information
Jake VanderPlas authored and Flax Authors committed Mar 4, 2024
1 parent f195a8b commit a307d47
Show file tree
Hide file tree
Showing 4 changed files with 35 additions and 32 deletions.
31 changes: 17 additions & 14 deletions tests/linen/linen_attention_test.py
Original file line number Diff line number Diff line change
Expand Up @@ -272,22 +272,23 @@ def get_receptive_field_1d(pos):
def test_multihead_kv_args(self):
key1, key2 = random.split(random.key(0), 2)
query = random.uniform(key1, (3, 5))
key_value = random.uniform(key1, (9, 5))
key_value = random.uniform(key2, (9, 5))
module = nn.MultiHeadDotProductAttention(
num_heads=8,
qkv_features=16,
kernel_init=initializers.ones,
bias_init=initializers.zeros,
deterministic=False,
)
key = lambda: random.key(43279)
y0, v0 = module.init_with_output(
key2, query, inputs_k=key_value, inputs_v=key_value
key(), query, inputs_k=key_value, inputs_v=key_value
)
y1, v1 = module.init_with_output(key2, query, inputs_k=key_value)
y1, v1 = module.init_with_output(key(), query, inputs_k=key_value)
with self.assertWarnsRegex(
DeprecationWarning, 'The inputs_kv arg will be deprecated soon.'
):
y2, v2 = module.init_with_output(key2, query, inputs_kv=key_value)
y2, v2 = module.init_with_output(key(), query, inputs_kv=key_value)
self.assertTrue((y0 == y1).all() and (y1 == y2).all())
self.assertTrue(
jax.tree_util.tree_all(
Expand All @@ -300,20 +301,20 @@ def test_multihead_kv_args(self):
with self.assertRaisesRegex(
ValueError, '`inputs_k` cannot be None if `inputs_v` is not None.'
):
y3, v3 = module.init_with_output(key2, query, inputs_v=key_value)
y3, v3 = module.init_with_output(key(), query, inputs_v=key_value)
with self.assertRaisesRegex(
ValueError,
'If either `inputs_k` or `inputs_v` is not None, `inputs_kv` must be None.',
):
y3, v3 = module.init_with_output(
key2, query, inputs_kv=key_value, inputs_v=key_value
key(), query, inputs_kv=key_value, inputs_v=key_value
)
with self.assertRaisesRegex(
ValueError,
'If either `inputs_k` or `inputs_v` is not None, `inputs_kv` must be None.',
):
y3, v3 = module.init_with_output(
key2, query, key_value, key_value, inputs_kv=key_value
key(), query, key_value, key_value, inputs_kv=key_value
)

def test_multihead_mask_warning(self):
Expand Down Expand Up @@ -423,7 +424,7 @@ def test_autoregressive_decode_with_x64(self):
def test_attention_alias_equivalence(self):
key1, key2 = random.split(random.key(0), 2)
query = random.uniform(key1, (3, 5))
key_value = random.uniform(key1, (9, 5))
key_value = random.uniform(key2, (9, 5))
attention_kwargs = dict(
num_heads=8,
qkv_features=16,
Expand All @@ -433,8 +434,9 @@ def test_attention_alias_equivalence(self):
)
module1 = nn.MultiHeadDotProductAttention(**attention_kwargs)
module2 = nn.MultiHeadAttention(**attention_kwargs)
out1, v1 = module1.init_with_output(key2, query, key_value)
out2, v2 = module2.init_with_output(key2, query, key_value, key_value)
key = lambda: random.key(43279)
out1, v1 = module1.init_with_output(key(), query, key_value)
out2, v2 = module2.init_with_output(key(), query, key_value, key_value)
self.assertTrue((out1 == out2).all())
self.assertTrue(
jax.tree_util.tree_all(
Expand All @@ -445,7 +447,7 @@ def test_attention_alias_equivalence(self):
def test_attention_alias_submodule(self):
key1, key2 = random.split(random.key(0), 2)
query = random.uniform(key1, (3, 5))
key_value = random.uniform(key1, (9, 5))
key_value = random.uniform(key2, (9, 5))
attention_kwargs = dict(
num_heads=8,
qkv_features=16,
Expand All @@ -470,10 +472,11 @@ class Foo2(nn.Module):
def __call__(self, query, key, value):
return nn.MultiHeadAttention(**self.attention_kwargs)(query, key, value)

key = lambda: random.key(5478392)
module1 = Foo1(attention_kwargs)
module2 = Foo2(attention_kwargs)
out1, v1 = module1.init_with_output(key2, query, key_value)
out2, v2 = module2.init_with_output(key2, query, key_value, key_value)
out1, v1 = module1.init_with_output(key(), query, key_value)
out2, v2 = module2.init_with_output(key(), query, key_value, key_value)

# test different output and variables if layer names are different
self.assertTrue((out1 != out2).all())
Expand Down Expand Up @@ -507,7 +510,7 @@ def __call__(self, query, key, value):
)(query, key, value)

module2 = Foo2(attention_kwargs)
out2, v2 = module2.init_with_output(key2, query, key_value, key_value)
out2, v2 = module2.init_with_output(key(), query, key_value, key_value)
self.assertTrue((out1 == out2).all())
self.assertTrue(
jax.tree_util.tree_all(
Expand Down
6 changes: 3 additions & 3 deletions tests/linen/linen_combinators_test.py
Original file line number Diff line number Diff line change
Expand Up @@ -160,10 +160,10 @@ def test_dict_output(self):
]
)

key1, key2 = random.split(random.key(0), 2)
key1, key2, key3 = random.split(random.key(0), 3)
query = random.uniform(key1, (3, 5))
key_value = random.uniform(key1, (9, 5))
params_1 = sequential.init(key2, query, key_value)
key_value = random.uniform(key2, (9, 5))
params_1 = sequential.init(key3, query, key_value)
outputs = sequential.apply(params_1, query, key_value)
np.testing.assert_equal(len(outputs), 2)
out_query, out_key_value = outputs['query'], outputs['key_value']
Expand Down
12 changes: 6 additions & 6 deletions tests/linen/linen_test.py
Original file line number Diff line number Diff line change
Expand Up @@ -997,7 +997,7 @@ class RecurrentTest(parameterized.TestCase):
def test_lstm(self):
lstm = nn.LSTMCell(features=4)
rng = random.key(0)
key1, key2 = random.split(rng)
rng, key1, key2 = random.split(rng, 3)
x = random.normal(key1, (2, 3))
c0, h0 = lstm.initialize_carry(rng, x.shape)
self.assertEqual(c0.shape, (2, 4))
Expand Down Expand Up @@ -1053,7 +1053,7 @@ def test_lstm(self):
def test_gated_units(self, module_cls, expected_param_shapes):
module = module_cls(features=4)
rng = random.key(0)
key1, key2 = random.split(rng)
rng, key1, key2 = random.split(rng, 3)
x = random.normal(key1, (2, 3))
carry0 = module.initialize_carry(rng, x.shape)
self.assertEqual(carry0.shape, (2, 4))
Expand Down Expand Up @@ -1084,7 +1084,7 @@ def test_gated_units(self, module_cls, expected_param_shapes):
def test_complex_input_gated_units(self, module_cls):
module_instance = module_cls(features=4)
rng = random.key(0)
key1, key2 = random.split(rng)
rng, key1, key2 = random.split(rng, 3)
x = random.normal(key1, (2, 3), dtype=jnp.complex64)
carry0 = module_instance.initialize_carry(rng, x.shape)
self.assertEqual(carry0.shape, (2, 4))
Expand All @@ -1095,7 +1095,7 @@ def test_complex_input_gated_units(self, module_cls):
def test_convlstm(self):
lstm = nn.ConvLSTMCell(features=6, kernel_size=(3, 3))
rng = random.key(0)
key1, key2 = random.split(rng)
rng, key1, key2 = random.split(rng, 3)
x = random.normal(key1, (2, 4, 4, 3))
c0, h0 = lstm.initialize_carry(rng, x.shape)
self.assertEqual(c0.shape, (2, 4, 4, 6))
Expand All @@ -1117,7 +1117,7 @@ def test_optimized_lstm_cell_matches_regular(self):
# Create regular LSTMCell.
lstm = nn.LSTMCell(features=4)
rng = random.key(0)
key1, key2 = random.split(rng)
rng, key1, key2 = random.split(rng, 3)
x = random.normal(key1, (2, 3))
c0, h0 = lstm.initialize_carry(rng, x.shape)
self.assertEqual(c0.shape, (2, 4))
Expand All @@ -1127,7 +1127,7 @@ def test_optimized_lstm_cell_matches_regular(self):
# Create OptimizedLSTMCell.
lstm_opt = nn.OptimizedLSTMCell(features=4)
rng = random.key(0)
key1, key2 = random.split(rng)
rng, key1, key2 = random.split(rng, 3)
x = random.normal(key1, (2, 3))
c0, h0 = lstm_opt.initialize_carry(rng, x.shape)
self.assertEqual(c0.shape, (2, 4))
Expand Down
18 changes: 9 additions & 9 deletions tests/linen/linen_transforms_test.py
Original file line number Diff line number Diff line change
Expand Up @@ -221,9 +221,9 @@ def __call__(self, inputs, train: bool):
self.assertEqual(y.shape, (1, 3))

def test_vmap(self):
key1, key2 = random.split(random.key(3), 2)
key1, key2, key3 = random.split(random.key(3), 3)
x = random.uniform(key1, (4, 4))
x2 = random.uniform(key1, (5, 4, 4))
x2 = random.uniform(key2, (5, 4, 4))

def vmap(cls):
return nn.vmap(
Expand All @@ -235,7 +235,7 @@ def vmap(cls):

normal_model = TransformedMLP(features=[3, 4, 5])
vmap_model = TransformedMLP(features=[3, 4, 5], transform=vmap)
init_variables = normal_model.init(key2, x)
init_variables = normal_model.init(key3, x)
# simulate vmap in python for comparison:
y1 = jnp.vstack([
normal_model.apply(init_variables, x2[i])[None, ...]
Expand All @@ -245,9 +245,9 @@ def vmap(cls):
np.testing.assert_allclose(y1, y2, atol=1e-7)

def test_vmap_decorated(self):
key1, key2 = random.split(random.key(3), 2)
key1, key2, key3 = random.split(random.key(3), 3)
x = random.uniform(key1, (4, 4))
x2 = random.uniform(key1, (5, 4, 4))
x2 = random.uniform(key2, (5, 4, 4))

def vmap(fn):
return nn.vmap(
Expand All @@ -259,7 +259,7 @@ def vmap(fn):

normal_model = decorated_MLP()(features=[3, 4, 5])
vmap_model = decorated_MLP(vmap)(features=[3, 4, 5])
init_variables = normal_model.init(key2, x)
init_variables = normal_model.init(key3, x)
# simulate vmap in python for comparison:
y1 = jnp.vstack([
normal_model.apply(init_variables, x2[i])[None, ...]
Expand All @@ -269,9 +269,9 @@ def vmap(fn):
np.testing.assert_allclose(y1, y2, atol=1e-7)

def test_vmap_batchnorm(self):
key1, key2 = random.split(random.key(3), 2)
key1, key2, key3 = random.split(random.key(3), 3)
x = random.uniform(key1, (4, 4))
x2 = random.uniform(key1, (5, 4, 4))
x2 = random.uniform(key2, (5, 4, 4))

def vmap(cls):
return nn.vmap(
Expand All @@ -293,7 +293,7 @@ def __call__(self, x):

normal_model = MlpBn()
vmap_model = vmap(MlpBn)(axis_name='batch')
init_variables = normal_model.init(key2, x)
init_variables = normal_model.init(key3, x)
y1 = normal_model.apply(
init_variables, x2.reshape((-1, 4)), mutable=['batch_stats']
)[0]
Expand Down

0 comments on commit a307d47

Please sign in to comment.