From 9ad954006c13a1a47d188d86340f344c4d7bc618 Mon Sep 17 00:00:00 2001 From: jessegrabowski Date: Sat, 21 Jun 2025 19:50:56 +0200 Subject: [PATCH 1/5] block_diag dot rewrite --- pytensor/tensor/rewriting/math.py | 73 +++++++++++++++++++++++++++-- tests/tensor/rewriting/test_math.py | 73 +++++++++++++++++++++++++++++ 2 files changed, 142 insertions(+), 4 deletions(-) diff --git a/pytensor/tensor/rewriting/math.py b/pytensor/tensor/rewriting/math.py index aef363655e..a0d5a9dc7b 100644 --- a/pytensor/tensor/rewriting/math.py +++ b/pytensor/tensor/rewriting/math.py @@ -29,9 +29,11 @@ cast, constant, get_underlying_scalar_constant_value, + join, moveaxis, ones_like, register_infer_shape, + split, switch, zeros_like, ) @@ -99,6 +101,7 @@ ) from pytensor.tensor.rewriting.elemwise import apply_local_dimshuffle_lift from pytensor.tensor.shape import Shape, Shape_i +from pytensor.tensor.slinalg import BlockDiagonal from pytensor.tensor.subtensor import Subtensor from pytensor.tensor.type import ( complex_dtypes, @@ -167,6 +170,72 @@ def local_0_dot_x(fgraph, node): return [constant_zero] +@register_canonicalize +@register_specialize +@register_stabilize +@node_rewriter([Dot]) +def local_block_diag_dot_to_dot_block_diag(fgraph, node): + r""" + Perform the rewrite ``dot(block_diag(A, B), C) -> block_diag(dot(A, C), dot(B, C))`` + + BlockDiag results in the creation of a matrix of shape ``(n1 * n2, m1 * m2)``. Because dot has complexity + of approximately O(n^3), it's always better to perform two dot products on the smaller matrices, rather than + a single dot on the larger matrix. + """ + x, y = node.inputs + op = node.op + + def check_for_block_diag(x): + return x.owner and ( + isinstance(x.owner.op, BlockDiagonal) + or isinstance(x.owner.op, Blockwise) + and isinstance(x.owner.op.core_op, BlockDiagonal) + ) + + if not (check_for_block_diag(x) or check_for_block_diag(y)): + return None + + # Case 1: Only one input is BlockDiagonal. In this case, multiply all components of the block-diagonal with the + # non-block diagonal, and return a new block diagonal + if check_for_block_diag(x) and not check_for_block_diag(y): + components = x.owner.inputs + y_splits = split( + y, + splits_size=[component.shape[-1] for component in components], + n_splits=len(components), + ) + new_components = [ + op(component, y_split) for component, y_split in zip(components, y_splits) + ] + new_output = join(0, *new_components) + elif not check_for_block_diag(x) and check_for_block_diag(y): + components = y.owner.inputs + new_components = [op(x, component) for component in components] + new_output = join(0, *new_components) + + # Case 2: Both inputs are BlockDiagonal. Here we can proceed only if the static shapes are known and identical. In + # that case, blockdiag(a,b) @ blockdiag(c, d) = blockdiag(a @ c, b @ d), but this is not true in the general case + elif any(shape is None for shape in (*x.type.shape, *y.type.shape)): + return None + elif x.ndim == y.ndim and all( + x_shape == y_shape for x_shape, y_shape in zip(x.type.shape, y.type.shape) + ): + x_components = x.owner.inputs + y_components = y.owner.inputs + + if len(x_components) != len(y_components): + return None + + new_output = BlockDiagonal(len(x_components))( + *[op(x_comp, y_comp) for x_comp, y_comp in zip(x_components, y_components)] + ) + else: + return None + + copy_stack_trace(node.outputs[0], new_output) + return [new_output] + + @register_canonicalize @node_rewriter([DimShuffle]) def local_lift_transpose_through_dot(fgraph, node): @@ -2496,7 +2565,6 @@ def add_calculate(num, denum, aslist=False, out_type=None): name="add_canonizer_group", ) - register_canonicalize(local_add_canonizer, "shape_unsafe", name="local_add_canonizer") @@ -3619,7 +3687,6 @@ def logmexpm1_to_log1mexp(fgraph, node): ) register_stabilize(logdiffexp_to_log1mexpdiff, name="logdiffexp_to_log1mexpdiff") - # log(sigmoid(x) / (1 - sigmoid(x))) -> x # i.e logit(sigmoid(x)) -> x local_logit_sigmoid = PatternNodeRewriter( @@ -3633,7 +3700,6 @@ def logmexpm1_to_log1mexp(fgraph, node): register_canonicalize(local_logit_sigmoid) register_specialize(local_logit_sigmoid) - # sigmoid(log(x / (1-x)) -> x # i.e., sigmoid(logit(x)) -> x local_sigmoid_logit = PatternNodeRewriter( @@ -3674,7 +3740,6 @@ def local_useless_conj(fgraph, node): register_specialize(local_polygamma_to_tri_gamma) - local_log_kv = PatternNodeRewriter( # Rewrite log(kv(v, x)) = log(kve(v, x) * exp(-x)) -> log(kve(v, x)) - x # During stabilize -x is converted to -1.0 * x diff --git a/tests/tensor/rewriting/test_math.py b/tests/tensor/rewriting/test_math.py index c4999fcd33..3be12da3e5 100644 --- a/tests/tensor/rewriting/test_math.py +++ b/tests/tensor/rewriting/test_math.py @@ -113,6 +113,7 @@ simplify_mul, ) from pytensor.tensor.shape import Reshape, Shape_i, SpecifyShape, specify_shape +from pytensor.tensor.slinalg import BlockDiagonal from pytensor.tensor.type import ( TensorType, cmatrix, @@ -4654,3 +4655,75 @@ def test_local_dot_to_mul(batched, a_shape, b_shape): out.eval({a: a_test, b: b_test}, mode=test_mode), rewritten_out.eval({a: a_test, b: b_test}, mode=test_mode), ) + + +def test_local_block_diag_dot_to_dot_block_diag(): + """ + Test that dot(block_diag(x, y,), z) is rewritten to concat(dot(x, z[:n]), dot(y, z[n:])) + """ + a = tensor("a", shape=(4, 2)) + b = tensor("b", shape=(2, 4)) + c = tensor("c", shape=(4, 4)) + d = tensor("d", shape=(10,)) + + x = pt.linalg.block_diag(a, b, c) + out = x @ d + + fn = pytensor.function([a, b, c, d], out) + assert not any( + isinstance(node, BlockDiagonal) for node in fn.maker.fgraph.toposort() + ) + + fn_expected = pytensor.function( + [a, b, c, d], + out, + mode=get_default_mode().excluding("local_block_diag_dot_to_dot_block_diag"), + ) + + rng = np.random.default_rng() + a_val = rng.normal(size=a.type.shape).astype(a.type.dtype) + b_val = rng.normal(size=b.type.shape).astype(b.type.dtype) + c_val = rng.normal(size=c.type.shape).astype(c.type.dtype) + d_val = rng.normal(size=d.type.shape).astype(d.type.dtype) + + np.testing.assert_allclose( + fn(a_val, b_val, c_val, d_val), + fn_expected(a_val, b_val, c_val, d_val), + atol=1e-6 if config.floatX == "float32" else 1e-12, + rtol=1e-6 if config.floatX == "float32" else 1e-12, + ) + + +@pytest.mark.parametrize("rewrite", [True, False], ids=["rewrite", "no_rewrite"]) +@pytest.mark.parametrize("size", [10, 100, 1000], ids=["small", "medium", "large"]) +def test_block_diag_dot_to_dot_concat_benchmark(benchmark, size, rewrite): + rng = np.random.default_rng() + a_size = int(rng.uniform(0, size)) + b_size = int(rng.uniform(0, size - a_size)) + c_size = size - a_size - b_size + + a = tensor("a", shape=(a_size, a_size)) + b = tensor("b", shape=(b_size, b_size)) + c = tensor("c", shape=(c_size, c_size)) + d = tensor("d", shape=(size,)) + + x = pt.linalg.block_diag(a, b, c) + out = x @ d + + mode = get_default_mode() + if not rewrite: + mode = mode.excluding("local_block_diag_dot_to_dot_block_diag") + fn = pytensor.function([a, b, c, d], out, mode=mode) + + a_val = rng.normal(size=a.type.shape).astype(a.type.dtype) + b_val = rng.normal(size=b.type.shape).astype(b.type.dtype) + c_val = rng.normal(size=c.type.shape).astype(c.type.dtype) + d_val = rng.normal(size=d.type.shape).astype(d.type.dtype) + + benchmark( + fn, + a_val, + b_val, + c_val, + d_val, + ) From ffb71d30c2e8516f811cdd1e88fb48f01e06bc21 Mon Sep 17 00:00:00 2001 From: jessegrabowski Date: Sat, 21 Jun 2025 22:02:28 +0200 Subject: [PATCH 2/5] Handle right-multiplication case --- pytensor/tensor/rewriting/math.py | 29 +++++++++++------------------ tests/tensor/rewriting/test_math.py | 11 ++++++++--- 2 files changed, 19 insertions(+), 21 deletions(-) diff --git a/pytensor/tensor/rewriting/math.py b/pytensor/tensor/rewriting/math.py index a0d5a9dc7b..a35363a170 100644 --- a/pytensor/tensor/rewriting/math.py +++ b/pytensor/tensor/rewriting/math.py @@ -176,7 +176,7 @@ def local_0_dot_x(fgraph, node): @node_rewriter([Dot]) def local_block_diag_dot_to_dot_block_diag(fgraph, node): r""" - Perform the rewrite ``dot(block_diag(A, B), C) -> block_diag(dot(A, C), dot(B, C))`` + Perform the rewrite ``dot(block_diag(A, B), C) -> concat(dot(A, C), dot(B, C))`` BlockDiag results in the creation of a matrix of shape ``(n1 * n2, m1 * m2)``. Because dot has complexity of approximately O(n^3), it's always better to perform two dot products on the smaller matrices, rather than @@ -210,25 +210,18 @@ def check_for_block_diag(x): new_output = join(0, *new_components) elif not check_for_block_diag(x) and check_for_block_diag(y): components = y.owner.inputs - new_components = [op(x, component) for component in components] - new_output = join(0, *new_components) - - # Case 2: Both inputs are BlockDiagonal. Here we can proceed only if the static shapes are known and identical. In - # that case, blockdiag(a,b) @ blockdiag(c, d) = blockdiag(a @ c, b @ d), but this is not true in the general case - elif any(shape is None for shape in (*x.type.shape, *y.type.shape)): - return None - elif x.ndim == y.ndim and all( - x_shape == y_shape for x_shape, y_shape in zip(x.type.shape, y.type.shape) - ): - x_components = x.owner.inputs - y_components = y.owner.inputs + x_splits = split( + x, + splits_size=[component.shape[0] for component in components], + n_splits=len(components), + axis=1, + ) - if len(x_components) != len(y_components): - return None + new_components = [ + op(x_split, component) for component, x_split in zip(components, x_splits) + ] + new_output = join(1, *new_components) - new_output = BlockDiagonal(len(x_components))( - *[op(x_comp, y_comp) for x_comp, y_comp in zip(x_components, y_components)] - ) else: return None diff --git a/tests/tensor/rewriting/test_math.py b/tests/tensor/rewriting/test_math.py index 3be12da3e5..b1451825ab 100644 --- a/tests/tensor/rewriting/test_math.py +++ b/tests/tensor/rewriting/test_math.py @@ -4657,17 +4657,22 @@ def test_local_dot_to_mul(batched, a_shape, b_shape): ) -def test_local_block_diag_dot_to_dot_block_diag(): +@pytest.mark.parametrize("left_multiply", [True, False], ids=["left", "right"]) +def test_local_block_diag_dot_to_dot_block_diag(left_multiply): """ Test that dot(block_diag(x, y,), z) is rewritten to concat(dot(x, z[:n]), dot(y, z[n:])) """ a = tensor("a", shape=(4, 2)) b = tensor("b", shape=(2, 4)) c = tensor("c", shape=(4, 4)) - d = tensor("d", shape=(10,)) + d = tensor("d", shape=(10, 10)) x = pt.linalg.block_diag(a, b, c) - out = x @ d + + if left_multiply: + out = x @ d + else: + out = d @ x fn = pytensor.function([a, b, c, d], out) assert not any( From c5137d7214aff6d35cd2fbcee7a30e9e2459b1d7 Mon Sep 17 00:00:00 2001 From: jessegrabowski Date: Sat, 21 Jun 2025 22:17:33 +0200 Subject: [PATCH 3/5] The robot was right! --- tests/tensor/rewriting/test_math.py | 2 +- 1 file changed, 1 insertion(+), 1 deletion(-) diff --git a/tests/tensor/rewriting/test_math.py b/tests/tensor/rewriting/test_math.py index b1451825ab..32bcb5c471 100644 --- a/tests/tensor/rewriting/test_math.py +++ b/tests/tensor/rewriting/test_math.py @@ -4676,7 +4676,7 @@ def test_local_block_diag_dot_to_dot_block_diag(left_multiply): fn = pytensor.function([a, b, c, d], out) assert not any( - isinstance(node, BlockDiagonal) for node in fn.maker.fgraph.toposort() + isinstance(node.op, BlockDiagonal) for node in fn.maker.fgraph.toposort() ) fn_expected = pytensor.function( From 3b66eba6cd44675ad900561c6879191ceec66d7c Mon Sep 17 00:00:00 2001 From: jessegrabowski Date: Sun, 22 Jun 2025 12:33:12 +0200 Subject: [PATCH 4/5] Respond to feedback --- pytensor/tensor/rewriting/math.py | 23 +++++++++++++++++------ 1 file changed, 17 insertions(+), 6 deletions(-) diff --git a/pytensor/tensor/rewriting/math.py b/pytensor/tensor/rewriting/math.py index a35363a170..b12d75ae35 100644 --- a/pytensor/tensor/rewriting/math.py +++ b/pytensor/tensor/rewriting/math.py @@ -170,10 +170,8 @@ def local_0_dot_x(fgraph, node): return [constant_zero] -@register_canonicalize -@register_specialize @register_stabilize -@node_rewriter([Dot]) +@node_rewriter([Blockwise]) def local_block_diag_dot_to_dot_block_diag(fgraph, node): r""" Perform the rewrite ``dot(block_diag(A, B), C) -> concat(dot(A, C), dot(B, C))`` @@ -182,8 +180,8 @@ def local_block_diag_dot_to_dot_block_diag(fgraph, node): of approximately O(n^3), it's always better to perform two dot products on the smaller matrices, rather than a single dot on the larger matrix. """ - x, y = node.inputs - op = node.op + if not isinstance(node.op.core_op, BlockDiagonal): + return def check_for_block_diag(x): return x.owner and ( @@ -192,6 +190,15 @@ def check_for_block_diag(x): and isinstance(x.owner.op.core_op, BlockDiagonal) ) + # Check that the BlockDiagonal is an input to a Dot node: + clients = list(get_clients_at_depth(fgraph, node, depth=1)) + if not clients or len(clients) > 1 or not isinstance(clients[0].op, Dot): + return + + [dot_node] = clients + op = dot_node.op + x, y = dot_node.inputs + if not (check_for_block_diag(x) or check_for_block_diag(y)): return None @@ -208,6 +215,7 @@ def check_for_block_diag(x): op(component, y_split) for component, y_split in zip(components, y_splits) ] new_output = join(0, *new_components) + elif not check_for_block_diag(x) and check_for_block_diag(y): components = y.owner.inputs x_splits = split( @@ -222,11 +230,14 @@ def check_for_block_diag(x): ] new_output = join(1, *new_components) + # Case 2: Both inputs are BlockDiagonal. Do nothing else: + # TODO: If shapes are statically known and all components have equal shapes, we could rewrite + # this case to block_diag(*[dot(comp_1, comp_2) for comp_1, comp_2 in zip(x.owner.inputs, y.owner.inputs)]) return None copy_stack_trace(node.outputs[0], new_output) - return [new_output] + return {dot_node.outputs[0]: new_output} @register_canonicalize From 09bddf1112246841c302d6ae1d821d1e79914431 Mon Sep 17 00:00:00 2001 From: jessegrabowski Date: Sun, 22 Jun 2025 13:00:24 +0200 Subject: [PATCH 5/5] Use `rewrite_mode` defined in `test_math.py` for testing --- tests/tensor/rewriting/test_math.py | 4 ++-- 1 file changed, 2 insertions(+), 2 deletions(-) diff --git a/tests/tensor/rewriting/test_math.py b/tests/tensor/rewriting/test_math.py index 32bcb5c471..137b91fb34 100644 --- a/tests/tensor/rewriting/test_math.py +++ b/tests/tensor/rewriting/test_math.py @@ -4674,7 +4674,7 @@ def test_local_block_diag_dot_to_dot_block_diag(left_multiply): else: out = d @ x - fn = pytensor.function([a, b, c, d], out) + fn = pytensor.function([a, b, c, d], out, mode=rewrite_mode) assert not any( isinstance(node.op, BlockDiagonal) for node in fn.maker.fgraph.toposort() ) @@ -4682,7 +4682,7 @@ def test_local_block_diag_dot_to_dot_block_diag(left_multiply): fn_expected = pytensor.function( [a, b, c, d], out, - mode=get_default_mode().excluding("local_block_diag_dot_to_dot_block_diag"), + mode=rewrite_mode.excluding("local_block_diag_dot_to_dot_block_diag"), ) rng = np.random.default_rng()