Skip to content

Commit 13111c5

Browse files
kunal-vaishnaviguschmue
authored andcommitted
Fix attention fusion in conformer encoder (#23711)
### Description This PR updates the attention fusion for conformer-encoder models. It is a follow-up to [this PR](#23528). ### Motivation and Context Subsequent modeling code updates have changed (and will continue to change) the graph fusions. However, the three ending attention mask nodes (`Cast --> Unsqueeze --> Equal`) will remain. Thus, the attention fusion should work regardless of any future modeling code changes when handling the attention mask.
1 parent d0a3a0c commit 13111c5

File tree

1 file changed

+3
-3
lines changed

1 file changed

+3
-3
lines changed

onnxruntime/python/tools/transformers/fusion_conformer_attention.py

+3-3
Original file line numberDiff line numberDiff line change
@@ -79,11 +79,11 @@ def fuse(self, normalize_node, input_name_to_nodes, output_name_to_node):
7979
where_qk = qk_nodes[2]
8080
mask_nodes = self.model.match_parent_path(
8181
where_qk,
82-
["Equal", "Unsqueeze", "Cast", "Expand"],
83-
[0, 0, 0, 0],
82+
["Equal", "Unsqueeze", "Cast"],
83+
[0, 0, 0],
8484
)
8585
if mask_nodes is not None:
86-
attn_mask = mask_nodes[-2].output[0]
86+
attn_mask = mask_nodes[-1].output[0]
8787

8888
add_qk, matmul_qk = qk_nodes[-2], qk_nodes[-1]
8989

0 commit comments

Comments
 (0)