group-wbl/.venv/lib/python3.13/site-packages/onnxruntime/transformers/fusion_attention.py
2026-01-09 09:12:25 +08:00

1190 lines
49 KiB
Python

# -------------------------------------------------------------------------
# Copyright (c) Microsoft Corporation. All rights reserved.
# Licensed under the MIT License.
# --------------------------------------------------------------------------
from logging import getLogger
import numpy as np
from fusion_base import Fusion
from fusion_options import AttentionMaskFormat
from fusion_utils import FusionUtils, NumpyHelper
from onnx import NodeProto, TensorProto, helper, numpy_helper
from onnx_model import OnnxModel
logger = getLogger(__name__)
class AttentionMask:
"""
Fuse Attention subgraph into one Attention node.
"""
def __init__(self, model: OnnxModel):
self.model = model
# A lookup table with mask input as key, and mask index output as value
self.mask_indice = {}
# A lookup table with mask input as key, and cast (to int32) output as value
self.mask_casted = {}
self.utils = FusionUtils(model)
self.mask_format = AttentionMaskFormat.MaskIndexEnd
self.opset_version = model.get_opset_version()
def set_mask_format(self, mask_format: AttentionMaskFormat):
self.mask_format = mask_format
def set_mask_indice(self, mask, mask_index):
if mask in self.mask_indice:
assert mask_index == self.mask_indice[mask]
self.mask_indice[mask] = mask_index
def get_first_mask(self):
assert len(self.mask_indice) > 0
return next(iter(self.mask_indice))
def process_mask(self, mask_2d: str) -> str | None:
if self.mask_format == AttentionMaskFormat.NoMask:
return None
if mask_2d in self.mask_indice:
return self.mask_indice[mask_2d]
# Add cast to convert int64 to int32
if self.model.find_graph_input(mask_2d):
casted, input_name = self.utils.cast_graph_input_to_int32(mask_2d)
else:
input_name, _cast_node = self.utils.cast_input_to_int32(mask_2d)
casted = True
if casted:
self.mask_casted[mask_2d] = input_name
# Attention supports int32 attention mask (2D) since 1.4.0
if self.mask_format == AttentionMaskFormat.AttentionMask:
self.mask_indice[mask_2d] = input_name
return input_name
# Add a mask processing node to convert attention mask to mask index (1D)
output_name = self.model.create_node_name("mask_index")
if self.opset_version < 13:
mask_index_node = helper.make_node(
"ReduceSum",
inputs=[input_name],
outputs=[output_name],
name=self.model.create_node_name("ReduceSum", "MaskReduceSum"),
)
mask_index_node.attribute.extend([helper.make_attribute("axes", [1]), helper.make_attribute("keepdims", 0)])
else:
# ReduceSum-13: axes is moved from attribute to input
axes_name = "ort_const_1_reduce_sum_axes"
if self.model.get_initializer(axes_name) is None:
self.model.add_initializer(
helper.make_tensor(
name=axes_name,
data_type=TensorProto.INT64,
dims=[1],
vals=[1],
raw=False,
)
)
mask_index_node = helper.make_node(
"ReduceSum",
inputs=[input_name, axes_name],
outputs=[output_name],
name=self.model.create_node_name("ReduceSum", "MaskReduceSum"),
)
mask_index_node.attribute.extend([helper.make_attribute("keepdims", 0)])
self.model.add_node(mask_index_node)
self.mask_indice[mask_2d] = output_name
return output_name
class FusionAttention(Fusion):
"""
Fuse Attention subgraph into one Attention node.
"""
def __init__(
self,
model: OnnxModel,
hidden_size: int,
num_heads: int,
attention_mask: AttentionMask | None = None,
use_multi_head_attention: bool = False,
disable_multi_head_attention_bias: bool = False,
search_op_types: list[str] = ["SkipLayerNormalization", "LayerNormalization"], # noqa: B006
):
attention_op_name = "MultiHeadAttention" if use_multi_head_attention else "Attention"
super().__init__(model, attention_op_name, search_op_types)
self.hidden_size = hidden_size
self.num_heads = num_heads
self.attention_mask = attention_mask if attention_mask else AttentionMask(model)
self.use_multi_head_attention = use_multi_head_attention
self.disable_multi_head_attention_bias = disable_multi_head_attention_bias
self.mask_filter_value = None
# Flags to show warning only once
self.num_heads_warning = True
self.hidden_size_warning = True
self.shape_infer = None
self.shape_infer_done = True
def get_num_heads_and_hidden_size_from_concat(self, concat: NodeProto) -> tuple[int, int]:
"""
Detect num_heads and hidden_size from Concat node in the following subgraph:
SkipLayerNormalization or EmbedLayerNormalization
/ |
MatMul Shape
| |
Add Gather(indices=0)
| |
| Unsqueeze
| |
| Concat (*, -1, 12, 64)
| /
Reshape
|
Transpose
"""
if len(concat.input) == 4:
num_heads = self.model.get_constant_value(concat.input[2])
head_size = self.model.get_constant_value(concat.input[3])
if (
isinstance(num_heads, np.ndarray)
and num_heads.size == 1
and isinstance(head_size, np.ndarray)
and head_size.size == 1
):
return num_heads[0], num_heads[0] * head_size[0]
return self.num_heads, self.hidden_size
def get_num_heads_and_hidden_size(self, reshape_q: NodeProto) -> tuple[int, int]:
"""Detect num_heads and hidden_size from a reshape node.
Args:
reshape_q (NodeProto): reshape node for Q
Returns:
Tuple[int, int]: num_heads and hidden_size
"""
# we assume that reshape fusion has done, so the shape is a tensor like [0, 0, num_heads, head_size]
q_shape_value = self.model.get_constant_value(reshape_q.input[1])
if q_shape_value is None:
concat = self.model.get_parent(reshape_q, 1)
if concat is not None and concat.op_type == "Concat":
return self.get_num_heads_and_hidden_size_from_concat(concat)
logger.debug("%s is not initializer.", reshape_q.input[1])
return self.num_heads, self.hidden_size # Fall back to user specified value
if (
(not isinstance(q_shape_value, np.ndarray))
or len(q_shape_value) != 4
or (q_shape_value[2] <= 0 or q_shape_value[3] <= 0)
):
logger.debug("q_shape_value=%s. Expected value are like [0, 0, num_heads, head_size].", q_shape_value)
return self.num_heads, self.hidden_size # Fall back to user specified value
num_heads = q_shape_value[2]
head_size = q_shape_value[3]
hidden_size = num_heads * head_size
if self.num_heads > 0 and num_heads != self.num_heads:
if self.num_heads_warning:
logger.warning(
"--num_heads is %d. Detected value is %d. Using detected value.", self.num_heads, num_heads
)
self.num_heads_warning = False # Do not show the warning more than once
if self.hidden_size > 0 and hidden_size != self.hidden_size:
if self.hidden_size_warning:
logger.warning(
"--hidden_size is %d. Detected value is %d. Using detected value.", self.hidden_size, hidden_size
)
self.hidden_size_warning = False # Do not show the warning more than once
return num_heads, hidden_size
def get_add_qk_str(self, add_qk: NodeProto):
if not self.shape_infer_done:
self.shape_infer = self.model.infer_runtime_shape(update=True)
self.shape_infer_done = True
if self.shape_infer is None:
return None
input_0_shape = self.shape_infer.get_edge_shape(add_qk.input[0])
input_1_shape = self.shape_infer.get_edge_shape(add_qk.input[1])
if input_0_shape is None or input_1_shape is None:
logger.debug("one of the inputs of %s is None", add_qk)
return None
if input_0_shape != input_1_shape:
logger.debug("the shape of two inputs of %s is not same", add_qk)
return None
return add_qk.input[1]
def reshape_add_qk(self, add_qk: str):
# Convert 4D mask from (B,1,S,T) to (B,N,S,T)
# B = batch size, N = num heads, S = source sequence length, T = target sequence length
mask_output_name = add_qk + "_mask"
# Check if concat node for (B,1,S,T) --> (B,N,S,T) already exists
concat_node = list(filter(lambda node: node.output[0] == mask_output_name, self.nodes_to_add))
if len(concat_node) == 1:
return mask_output_name
assert len(concat_node) == 0
concat_node_name = self.model.create_node_name("Concat")
concat_add_qk_fp32 = helper.make_node(
"Concat",
inputs=[add_qk for _ in range(self.num_heads)],
outputs=[mask_output_name],
name=concat_node_name,
axis=1,
)
# Add new node to graph
self.nodes_to_add.append(concat_add_qk_fp32)
self.node_name_to_graph_name[concat_node_name] = self.this_graph_name
return mask_output_name
def concat_kv(self, past_k: str, past_v: str) -> str:
"""Concatenate past_k and past_v inputs to create past_kv input.
Args:
past_k (str): name of past K value
past_v (str): name of past V value
Returns:
kv_output_name (str): name of past KV value
"""
# Unsqueeze K and V nodes from (B,N,P,H) to (1,B,N,P,H)
# B = batch size, N = num heads, P = past sequence length, H = head size
unsqueeze_k_name = self.model.create_node_name("Unsqueeze")
unsqueeze_v_name = self.model.create_node_name("Unsqueeze")
k_5d_name = (past_k + "_5d").replace(".", "_")
v_5d_name = (past_v + "_5d").replace(".", "_")
k_5d = helper.make_node(
"Unsqueeze",
inputs=[past_k],
outputs=[k_5d_name],
name=unsqueeze_k_name,
axes=[0],
)
v_5d = helper.make_node(
"Unsqueeze",
inputs=[past_v],
outputs=[v_5d_name],
name=unsqueeze_v_name,
axes=[0],
)
# Add unsqueeze nodes to graph
self.nodes_to_add.append(k_5d)
self.nodes_to_add.append(v_5d)
self.node_name_to_graph_name[unsqueeze_k_name] = self.this_graph_name
self.node_name_to_graph_name[unsqueeze_v_name] = self.this_graph_name
# Concat K and V to get one node of size (2,B,N,P,H)
concat_node_name = self.model.create_node_name("Concat")
kv_output_name = past_v.replace(".value", ".kv").replace(".", "_").replace("_value", "_kv")
concat_kv = helper.make_node(
"Concat",
inputs=[k_5d_name, v_5d_name],
outputs=[kv_output_name],
name=concat_node_name,
axis=0,
)
# Add concat node to graph
self.nodes_to_add.append(concat_kv)
self.node_name_to_graph_name[concat_node_name] = self.this_graph_name
return kv_output_name
def split_kv(self, present_k_name: str, present_v_name: str, kv_node: str):
"""Split kv_node containing present KV values into separate present K and present V values.
Args:
present_k_name (str): name of output to store present K value in
present_v_name (str): name of output to store present V value in
kv_node (str): name of present KV values
"""
# Split kv_node into present_k and present_v nodes
# Create initializers for indexing kv_node, whose shape is (2,B,N,P,H)
k_index, v_index = "index_0", "index_1"
k_dim = self.model.get_initializer(k_index)
v_dim = self.model.get_initializer(v_index)
if k_dim is None:
k_dim = numpy_helper.from_array(np.array(0, dtype="int64"), name=k_index)
self.model.add_initializer(k_dim, self.this_graph_name)
if v_dim is None:
v_dim = numpy_helper.from_array(np.array(1, dtype="int64"), name=v_index)
self.model.add_initializer(v_dim, self.this_graph_name)
# Create nodes to index kv_node
gather_k_name = self.model.create_node_name("Gather")
gather_v_name = self.model.create_node_name("Gather")
present_k = helper.make_node(
"Gather",
inputs=[kv_node, k_index],
outputs=[present_k_name],
name=gather_k_name,
axis=0,
)
present_v = helper.make_node(
"Gather",
inputs=[kv_node, v_index],
outputs=[present_v_name],
name=gather_v_name,
axis=0,
)
# Add gather nodes to graph
self.nodes_to_add.append(present_k)
self.nodes_to_add.append(present_v)
self.node_name_to_graph_name[gather_k_name] = self.this_graph_name
self.node_name_to_graph_name[gather_v_name] = self.this_graph_name
def create_combined_qkv_bias(
self,
q_add: NodeProto,
k_add: NodeProto | None,
v_add: NodeProto | None,
name_prefix: str,
) -> NodeProto | None:
q_bias = self.model.get_initializer(q_add.input[1]) or self.model.get_initializer(q_add.input[0])
qb = NumpyHelper.to_array(q_bias)
kb = np.zeros_like(qb)
vb = np.zeros_like(qb)
if k_add is not None:
k_bias = self.model.get_initializer(k_add.input[1]) or self.model.get_initializer(k_add.input[0])
kb = NumpyHelper.to_array(k_bias)
if v_add is not None:
v_bias = self.model.get_initializer(v_add.input[1]) or self.model.get_initializer(v_add.input[0])
vb = NumpyHelper.to_array(v_bias)
qkv_bias = np.stack((qb, kb, vb), axis=0)
qkv_bias_dim = 3 * np.prod(qb.shape)
bias_name = name_prefix + "_qkv_bias"
self.add_initializer(
name=bias_name,
data_type=q_bias.data_type,
dims=[qkv_bias_dim],
vals=qkv_bias,
)
return bias_name
def create_packed_qkv_matmul_node(
self,
q_matmul: NodeProto,
k_matmul: NodeProto,
v_matmul: NodeProto,
q_add: NodeProto,
k_add: NodeProto | None,
v_add: NodeProto | None,
) -> tuple[NodeProto, NodeProto, NodeProto]:
"""Create packed QKV MatMul node before MultiHeadAttention node.
This is for the scenario where an Attention node should be created but cannot be created
because past_key and past_value are separate inputs and not one concatenated input.
Args:
q_matmul (NodeProto): name of MatMul from Q path - (batch_size, sequence_length, hidden_size)
k_matmul (NodeProto): name of MatMul from K path - (batch_size, sequence_length, hidden_size)
v_matmul (NodeProto): name of MatMul from V path - (batch_size, sequence_length, hidden_size)
q_add (NodeProto): name of Add from Q path
k_add (NodeProto): name of Add from K path
v_add (NodeProto): name of Add from V path
Returns:
q_output (NodeProto): Slice node for Q
k_output (NodeProto): Slice node for K
v_output (NodeProto): Slice node for V
"""
matmul_node_name = self.model.create_node_name("MatMul")
# Check that input for Q, K, V is the same
assert q_matmul.input[0] == k_matmul.input[0] and k_matmul.input[0] == v_matmul.input[0]
# Created packed QKV weight
q_weight = self.model.get_initializer(q_matmul.input[1])
k_weight = self.model.get_initializer(k_matmul.input[1])
v_weight = self.model.get_initializer(v_matmul.input[1])
qw = NumpyHelper.to_array(q_weight)
kw = NumpyHelper.to_array(k_weight)
vw = NumpyHelper.to_array(v_weight)
assert qw.shape == kw.shape and kw.shape == vw.shape
d = qw.shape[0]
qkv_weight = np.stack((qw, kw, vw), axis=1).reshape((d, 3 * d))
qkv_weight_name = matmul_node_name + "_qkv_weight"
self.add_initializer(
name=qkv_weight_name,
data_type=q_weight.data_type,
dims=[qkv_weight.shape[0], qkv_weight.shape[1]],
vals=qkv_weight,
)
# Created packed QKV MatMul with output (B, S, 3*D)
# Output is of the form:
#
# [[[Q Q ... Q Q K K ... K K V V ... V V]]]
# [Q Q ... Q Q K K ... K K V V ... V V]
# .
# .
# .
# [[Q Q ... Q Q K K ... K K V V ... V V]
# [Q Q ... Q Q K K ... K K V V ... V V]]]
qkv_matmul_output = matmul_node_name + "_qkv_out"
qkv_matmul = helper.make_node(
"MatMul",
inputs=[q_matmul.input[0], qkv_weight_name],
outputs=[qkv_matmul_output],
name=matmul_node_name,
)
self.node_name_to_graph_name[matmul_node_name] = self.this_graph_name
qkv_nodes = [qkv_matmul]
# Create Slice nodes to access Q, K, V
q_slice_name = matmul_node_name + "_q_start_index"
self.add_initializer(name=q_slice_name, data_type=TensorProto.INT64, dims=[1], vals=[0], raw=False)
k_slice_name = matmul_node_name + "_k_start_index"
self.add_initializer(name=k_slice_name, data_type=TensorProto.INT64, dims=[1], vals=[d], raw=False)
v_slice_name = matmul_node_name + "_v_start_index"
self.add_initializer(name=v_slice_name, data_type=TensorProto.INT64, dims=[1], vals=[2 * d], raw=False)
end_of_qkv_name = matmul_node_name + "_end_of_qkv_index"
self.add_initializer(name=end_of_qkv_name, data_type=TensorProto.INT64, dims=[1], vals=[3 * d], raw=False)
qkv_last_axis_name = matmul_node_name + "_qkv_last_axis"
self.add_initializer(name=qkv_last_axis_name, data_type=TensorProto.INT64, dims=[1], vals=[-1], raw=False)
q_slice_output = matmul_node_name + "_q_out"
q_slice = helper.make_node(
"Slice",
inputs=[qkv_matmul_output, q_slice_name, k_slice_name, qkv_last_axis_name],
outputs=[q_slice_output],
name=self.model.create_node_name("Slice"),
)
self.node_name_to_graph_name[q_slice.name] = self.this_graph_name
k_slice_output = matmul_node_name + "_k_out"
k_slice = helper.make_node(
"Slice",
inputs=[qkv_matmul_output, k_slice_name, v_slice_name, qkv_last_axis_name],
outputs=[k_slice_output],
name=self.model.create_node_name("Slice"),
)
self.node_name_to_graph_name[k_slice.name] = self.this_graph_name
v_slice_output = matmul_node_name + "_v_out"
v_slice = helper.make_node(
"Slice",
inputs=[qkv_matmul_output, v_slice_name, end_of_qkv_name, qkv_last_axis_name],
outputs=[v_slice_output],
name=self.model.create_node_name("Slice"),
)
self.node_name_to_graph_name[v_slice.name] = self.this_graph_name
q_output = q_slice
k_output = k_slice
v_output = v_slice
qkv_nodes.extend([q_slice, k_slice, v_slice])
if self.disable_multi_head_attention_bias:
if q_add is not None:
initializer_input = 1 if self.model.get_initializer(q_add.input[1]) else 0
if np.any(NumpyHelper.to_array(self.model.get_initializer(q_add.input[initializer_input]))):
q_add.input[1 - initializer_input] = q_slice_output
q_output = q_add
qkv_nodes.append(q_add)
self.node_name_to_graph_name[q_add.name] = self.this_graph_name
if k_add is not None:
initializer_input = 1 if self.model.get_initializer(k_add.input[1]) else 0
if np.any(NumpyHelper.to_array(self.model.get_initializer(k_add.input[initializer_input]))):
k_add.input[1 - initializer_input] = k_slice_output
k_output = k_add
qkv_nodes.append(k_add)
self.node_name_to_graph_name[k_add.name] = self.this_graph_name
if v_add is not None:
initializer_input = 1 if self.model.get_initializer(v_add.input[1]) else 0
if np.any(NumpyHelper.to_array(self.model.get_initializer(v_add.input[initializer_input]))):
v_add.input[1 - initializer_input] = v_slice_output
v_output = v_add
qkv_nodes.append(v_add)
self.node_name_to_graph_name[v_add.name] = self.this_graph_name
# Add nodes to graph
self.nodes_to_add.extend(qkv_nodes)
return q_output, k_output, v_output
# This function is used in child classes for bart or conformer model.
def create_multihead_attention_node(
self,
q_matmul: NodeProto,
k_matmul: NodeProto | str | None,
v_matmul: NodeProto | str | None,
q_add: NodeProto,
k_add: NodeProto | None,
v_add: NodeProto | None,
num_heads: int,
hidden_size: int,
output: str,
key_padding_mask: str = "",
add_qk: str = "",
unidirectional: bool = False,
past_k: str = "",
past_v: str = "",
present_k: str = "",
present_v: str = "",
packed_qkv: bool = False,
) -> NodeProto | None:
"""Create a MultiHeadAttention node.
Args:
q_matmul (NodeProto): name of MatMul from Q path - (batch_size, sequence_length, hidden_size)
k_matmul (NodeProto): name of MatMul from K path - (batch_size, sequence_length, hidden_size) or (batch_size, num_heads, past_sequence_length, head_size)
v_matmul (NodeProto): name of MatMul from V path - (batch_size, sequence_length, hidden_size) or (batch_size, num_heads, past_sequence_length, head_size)
q_add (NodeProto): name of Add from Q path
k_add (NodeProto): name of Add from K path
v_add (NodeProto): name of Add from V path
num_heads (int): number of attention heads. If a model is pruned, it is the number of heads after pruning.
hidden_size (int): hidden dimension. If a model is pruned, it is the hidden dimension after pruning.
output (str): output name of MHA
key_padding_mask (str): name of key padding mask
add_qk (str): name of add after Q x K'
unidirectional (bool): whether to apply causal attention mask automatically or not
past_k (str): name of past K value - (batch_size, num_heads, past_sequence_length, head_size)
past_v (str): name of past V value - (batch_size, num_heads, past_sequence_length, head_size)
present_k (str): name of present K value - (batch_size, num_heads, sequence_length, head_size)
present_v (str): name of present V value - (batch_size, num_heads, sequence_length, head_size)
packed_qkv (bool): whether to combine MatMuls from Q, K, V paths
Note: This is for the scenario where an Attention node should be created but cannot be created
because past_key and past_value are separate inputs and not one concatenated input.
Returns:
Union[NodeProto, None]: the node created or None if failed.
"""
# B = batch size, N = num heads, P = past seq len, H = head size
assert num_heads > 0
if hidden_size > 0 and (hidden_size % num_heads) != 0:
logger.debug("input hidden size %d is not a multiple of num of heads %d", hidden_size, num_heads)
return None
graph_input_names = {node.name for node in self.model.graph().input}
mha_node_name = self.model.create_node_name("Attention")
# Add initial Q/K/V inputs for MHA
mha_inputs = []
if packed_qkv:
q_slice, k_slice, v_slice = self.create_packed_qkv_matmul_node(
q_matmul,
k_matmul,
v_matmul,
q_add,
k_add,
v_add,
)
mha_inputs.extend([q_slice.output[0], k_slice.output[0], v_slice.output[0]])
elif isinstance(k_matmul, NodeProto) and isinstance(v_matmul, NodeProto):
if self.disable_multi_head_attention_bias:
mha_inputs.extend([q_add.output[0], k_matmul.output[0], v_add.output[0]])
else:
mha_inputs.extend([q_matmul.output[0], k_matmul.output[0], v_matmul.output[0]])
elif (
isinstance(k_matmul, str)
and isinstance(v_matmul, str)
and k_matmul in graph_input_names
and v_matmul in graph_input_names
):
if self.disable_multi_head_attention_bias:
mha_inputs.extend([q_add.output[0], k_matmul, v_matmul])
else:
mha_inputs.extend([q_matmul.output[0], k_matmul, v_matmul])
else:
return None
# Add bias to inputs for MHA
# Bias for cross attention is not fully supported in DMMHA and cpu MHA kernels since they assume
# bias has been added to key and value when they are in BNSH format, so only bias for query is used.
# Need add checks if we found such assumption is not true.
if not self.disable_multi_head_attention_bias:
bias_name = self.create_combined_qkv_bias(q_add, k_add, v_add, mha_node_name)
mha_inputs.append(bias_name)
else:
mha_inputs.append("")
# Add optional inputs for MHA
if past_k and past_v:
mha_inputs.extend([key_padding_mask, add_qk, past_k, past_v])
elif key_padding_mask or add_qk:
mha_inputs.extend([key_padding_mask, add_qk])
# Add outputs for MHA
mha_outputs = [output]
if present_k and present_v:
mha_outputs.extend([present_k, present_v])
mha_node = helper.make_node(
"MultiHeadAttention",
inputs=mha_inputs,
outputs=mha_outputs,
name=mha_node_name,
)
mha_node.domain = "com.microsoft"
mha_node.attribute.append(helper.make_attribute("num_heads", num_heads))
if unidirectional:
mha_node.attribute.append(helper.make_attribute("unidirectional", int(unidirectional)))
self.increase_counter("MultiHeadAttention")
return mha_node
def create_attention_node(
self,
mask_index: str | None,
q_matmul: NodeProto,
k_matmul: NodeProto,
v_matmul: NodeProto,
q_add: NodeProto,
k_add: NodeProto,
v_add: NodeProto,
num_heads: int,
hidden_size: int,
first_input: str,
output: str,
add_qk_str: str = "",
causal: bool = False,
past_k: str = "",
past_v: str = "",
present_k: str = "",
present_v: str = "",
scale: float | None = None,
) -> NodeProto | None:
"""Create an Attention node.
Args:
mask_index (str | None): mask input
q_matmul (NodeProto): MatMul node in fully connection for Q
k_matmul (NodeProto): MatMul node in fully connection for K
v_matmul (NodeProto): MatMul node in fully connection for V
q_add (NodeProto): Add bias node in fully connection for Q
k_add (NodeProto): Add bias node in fully connection for K
v_add (NodeProto): Add bias node in fully connection for V
num_heads (int): number of attention heads. If a model is pruned, it is the number of heads after pruning.
hidden_size (int): hidden dimension. If a model is pruned, it is the hidden dimension after pruning.
first_input (str): first input name
output (str): output name
add_qk_str (str): name of Add node after Q x K'
causal: whether it is uni-directional mask.
past_k (str): name of input for past K value
past_v (str): name of input for past V value
present_k (str): name of output to store present K value
present_v (str): name of output to store present V value
scale: scale before softmax
Returns:
Union[NodeProto, None]: the node created or None if failed.
"""
assert num_heads > 0
if hidden_size > 0 and (hidden_size % num_heads) != 0:
logger.debug("input hidden size %d is not a multiple of num of heads %d", hidden_size, num_heads)
return None
has_bias = True
if q_add is None and k_add is None and v_add is None:
has_bias = False
q_weight = self.model.get_initializer(q_matmul.input[1])
k_weight = self.model.get_initializer(k_matmul.input[1])
v_weight = self.model.get_initializer(v_matmul.input[1])
q_bias, k_bias, v_bias = None, None, None
if has_bias:
q_bias = self.model.get_initializer(q_add.input[1]) or self.model.get_initializer(q_add.input[0])
k_bias = self.model.get_initializer(k_add.input[1]) or self.model.get_initializer(k_add.input[0])
v_bias = self.model.get_initializer(v_add.input[1]) or self.model.get_initializer(v_add.input[0])
if not (k_weight and v_weight and q_bias and k_bias):
return None
if q_weight is None:
print(
f"{q_matmul.input[1]} is not an initializer. "
"Please set do_constant_folding=True in torch.onnx.export to unblock attention fusion"
)
return None
qw = NumpyHelper.to_array(q_weight)
kw = NumpyHelper.to_array(k_weight)
vw = NumpyHelper.to_array(v_weight)
# assert q and k have same shape as expected
assert qw.shape == kw.shape
qw_in_size = qw.shape[0]
kw_in_size = kw.shape[0]
vw_in_size = vw.shape[0]
assert qw_in_size == kw_in_size == vw_in_size
if hidden_size > 0 and hidden_size != qw_in_size:
logger.warning(
"Input hidden size (%d) is not same as weight matrix dimension of q,k,v (%d). "
"Please provide a correct input hidden size or pass in 0",
hidden_size,
qw_in_size,
)
is_qkv_diff_dims = False
if qw.shape != vw.shape:
is_qkv_diff_dims = True
# All the matrices can have the same shape or q, k matrices can have the same shape with v being different
# For 2d weights, the shapes would be [in_size, out_size].
# For 3d weights, shape would be [in_size, a, b] where a*b = out_size
qw_out_size = np.prod(qw.shape[1:])
kw_out_size = np.prod(kw.shape[1:])
vw_out_size = np.prod(vw.shape[1:])
qkv_weight_dim = 0
if is_qkv_diff_dims:
qkv_weight = np.concatenate((qw, kw, vw), axis=1)
qkv_weight_dim = qw_out_size + kw_out_size + vw_out_size
else:
qkv_weight = np.stack((qw, kw, vw), axis=1)
qkv_weight_dim = 3 * qw_out_size
qkv_bias_dim = 0
qkv_bias: np.ndarray | None = None
if has_bias:
qb = NumpyHelper.to_array(q_bias)
kb = NumpyHelper.to_array(k_bias)
vb = NumpyHelper.to_array(v_bias)
q_bias_shape = np.prod(qb.shape)
k_bias_shape = np.prod(kb.shape)
v_bias_shape = np.prod(vb.shape)
assert q_bias_shape == k_bias_shape == qw_out_size
assert v_bias_shape == vw_out_size
if is_qkv_diff_dims:
qkv_bias = np.concatenate((qb, kb, vb), axis=0)
qkv_bias_dim = q_bias_shape + k_bias_shape + v_bias_shape
else:
qkv_bias = np.stack((qb, kb, vb), axis=0)
qkv_bias_dim = 3 * q_bias_shape
attention_node_name = self.model.create_node_name("Attention")
if not self.use_multi_head_attention:
self.add_initializer(
name=attention_node_name + "_qkv_weight",
data_type=q_weight.data_type,
dims=[qw_in_size, int(qkv_weight_dim)],
vals=qkv_weight,
)
if has_bias:
self.add_initializer(
name=attention_node_name + "_qkv_bias",
data_type=q_bias.data_type,
dims=[int(qkv_bias_dim)],
vals=qkv_bias,
)
# For MultiHeadAttention operator, use separated inputs for query, key and value, and no weights.
if self.use_multi_head_attention:
if add_qk_str:
logger.debug("MultiHeadAttention does not support relative_position_bias: cannot fuse the attention.")
return None
attention_inputs = [
q_matmul.output[0],
k_matmul.output[0],
v_matmul.output[0],
attention_node_name + "_qkv_bias",
]
if mask_index is not None:
attention_inputs.append(mask_index)
attention_node = helper.make_node(
"MultiHeadAttention",
inputs=attention_inputs,
outputs=[output],
name=attention_node_name,
)
self.increase_counter("MultiHeadAttention")
else:
attention_inputs = [
first_input,
attention_node_name + "_qkv_weight",
attention_node_name + "_qkv_bias" if has_bias else "",
]
if mask_index is not None:
attention_inputs.append(mask_index)
else:
attention_inputs.append("")
past_exists = past_k and past_v
if past_exists:
past_kv = self.concat_kv(past_k, past_v)
attention_inputs.append(past_kv)
if add_qk_str:
# Add additional add to attention node (input name = attention_bias)
if not past_exists:
attention_inputs.append("")
attention_inputs.append(add_qk_str)
attention_outputs = [output]
if present_k and present_v:
present_kv = present_k.replace(".key", "").replace("_key", "").replace(".", "_")
attention_outputs.append(present_kv)
self.split_kv(present_k, present_v, present_kv)
attention_node = helper.make_node(
"Attention",
inputs=attention_inputs,
outputs=attention_outputs,
name=attention_node_name,
)
self.increase_counter("Attention")
attention_node.domain = "com.microsoft"
attention_node.attribute.extend([helper.make_attribute("num_heads", num_heads)])
if causal:
attention_node.attribute.extend([helper.make_attribute("unidirectional", 1)])
if scale is not None:
attention_node.attribute.extend([helper.make_attribute("scale", scale)])
if is_qkv_diff_dims:
attention_node.attribute.extend(
[helper.make_attribute("qkv_hidden_sizes", [qw_out_size, kw_out_size, vw_out_size])]
)
if self.mask_filter_value is not None:
attention_node.attribute.extend([helper.make_attribute("mask_filter_value", float(self.mask_filter_value))])
return attention_node
def fuse(self, node, input_name_to_nodes, output_name_to_node):
# Sometimes we can not fuse skiplayernormalization since the add before layernorm has an output that used by nodes outside skiplayernorm
# Conceptually we treat add before layernorm as skiplayernorm node since they share the same pattern
normalize_node = node
start_node = normalize_node
if normalize_node.op_type == "LayerNormalization":
add_before_layernorm = self.model.match_parent(normalize_node, "Add", 0)
if add_before_layernorm is not None:
start_node = add_before_layernorm
else:
return
# SkipLayerNormalization has two inputs, and one of them is the root input for attention.
qkv_nodes = self.model.match_parent_path(
start_node,
["Add", "MatMul", "Reshape", "Transpose", "MatMul"],
[None, None, 0, 0, 0],
)
einsum_node = None
if qkv_nodes is not None:
(_, _, reshape_qkv, transpose_qkv, matmul_qkv) = qkv_nodes
else:
# Match Albert
qkv_nodes = self.model.match_parent_path(
start_node, ["Add", "Einsum", "Transpose", "MatMul"], [1, None, 0, 0]
)
if qkv_nodes is not None:
(_, einsum_node, transpose_qkv, matmul_qkv) = qkv_nodes
else:
return
other_inputs = []
for _i, node_input in enumerate(start_node.input):
if node_input not in output_name_to_node:
continue
if node_input == qkv_nodes[0].output[0]:
continue
other_inputs.append(node_input)
if len(other_inputs) != 1:
return
root_input = other_inputs[0]
# Match flaubert Mask
# |
# Mul --> LayerNormalization --> Attention --> MatMul --> Add
# | |
# | |
# +---------------------------------------------------------
mul_before_layernorm = self.model.match_parent(start_node, "Mul", 0)
if mul_before_layernorm is not None:
mul_children = input_name_to_nodes[mul_before_layernorm.output[0]]
if mul_children is not None and len(mul_children) == 2:
layernorm_node = mul_children[1]
if layernorm_node.op_type == "LayerNormalization":
root_input = layernorm_node.output[0]
else:
return
elif mul_children is not None and len(mul_children) == 5:
root_input = mul_before_layernorm.output[0]
else:
return
elif normalize_node.op_type == "LayerNormalization":
children = input_name_to_nodes[root_input]
for child in children:
if child.op_type == "LayerNormalization":
root_input = child.output[0]
# When Add before the LayerNormalization produces an output
# that is consumed by some other nodes other than the LayerNormalization itself,
# fused SkipLayerNormalization will have several outputs.
# In this case we need to pick the one used in Attention
# For example, this is the case for ViT
# SkipLayerNormalization --> Attention --> MatMul --> Add --> SkipLayerNormalization
# | |
# | |
# +---------------------------------------------------------------------+
parent_node = output_name_to_node[root_input]
if parent_node.op_type == "SkipLayerNormalization" and len(parent_node.output) == 4:
root_input = parent_node.output[0]
children = input_name_to_nodes[root_input]
children_types = [child.op_type for child in children]
if children_types.count("MatMul") != 3:
return
v_nodes = self.model.match_parent_path(matmul_qkv, ["Transpose", "Reshape", "Add", "MatMul"], [1, 0, 0, None])
if v_nodes is None:
logger.debug("fuse_attention: failed to match v path")
return
(_, _, add_v, matmul_v) = v_nodes
is_distill = False
is_distill_add = False
is_no_mask_attention = False
is_sdpa = False
qk_paths = {
"path1": (["Softmax", "Add", "Div", "MatMul"], [0, 0, None, 0]),
"path2": (["Softmax", "Add", "Mul", "MatMul"], [0, 0, None, 0]),
"path3": (["Softmax", "Where", "MatMul", "Div"], [0, 0, 2, 0]),
"path4": (["Softmax", "Add", "Where", "MatMul"], [0, 0, 0, 2]),
"path5": (["Softmax", "Div", "MatMul"], [0, 0, 0]),
"sdpa": (["Softmax", "Add", "MatMul", "Mul", "Sqrt"], [0, 0, None, 0, 1]),
}
qk_nodes = None
for k, v in qk_paths.items():
qk_nodes = self.model.match_parent_path(matmul_qkv, v[0], v[1])
if qk_nodes is None:
continue
if k == "path3":
is_distill = True
elif k == "path4":
is_distill_add = True
elif k == "path5":
is_no_mask_attention = True
elif k == "sdpa":
is_sdpa = True
break
if qk_nodes is None:
logger.debug("fuse_attention: failed to match qk path")
return
add_qk = None
matmul_qk = None
where_qk = None
after_q = None
if is_distill:
(_, where_qk, matmul_qk, _) = qk_nodes
elif is_distill_add:
(_, add_qk, where_qk, matmul_qk) = qk_nodes
elif is_no_mask_attention:
(_, _, matmul_qk) = qk_nodes
elif is_sdpa:
(_, add_qk, matmul_qk, after_q, _) = qk_nodes
else:
(_, add_qk, _, matmul_qk) = qk_nodes
after_q = after_q or matmul_qk
q_nodes = self.model.match_parent_path(after_q, ["Transpose", "Reshape", "Add", "MatMul"], [0, 0, 0, None])
if q_nodes is None:
q_nodes = self.model.match_parent_path(
after_q,
["Div", "Transpose", "Reshape", "Add", "MatMul"],
[0, 0, 0, 0, None],
)
if q_nodes is None:
logger.debug("fuse_attention: failed to match q path")
return
reshape_q = q_nodes[-3]
add_q = q_nodes[-2]
matmul_q = q_nodes[-1]
after_k = matmul_qk
if is_sdpa:
mul_k_nodes = self.model.match_parent_path(matmul_qk, ["Mul", "Sqrt"], [1, None])
if mul_k_nodes is None:
logger.debug("fuse_attention: failed to match mul sqrt q path")
return
(after_k, _) = mul_k_nodes
k_nodes = self.model.match_parent_path(
after_k, ["Transpose", "Reshape", "Add", "MatMul"], [0 if is_sdpa else 1, 0, 0, None]
)
if k_nodes is None:
k_nodes = self.model.match_parent_path(
matmul_qk,
["Transpose", "Transpose", "Reshape", "Add", "MatMul"],
[1, 0, 0, 0, None],
)
if k_nodes is None:
logger.debug("fuse_attention: failed to match k path")
return
add_k = k_nodes[-2]
matmul_k = k_nodes[-1]
# Note that Cast might be removed by OnnxRuntime so we match two patterns here.
mask_nodes = None
add_qk_str = ""
if is_distill:
_, mask_nodes, _ = self.model.match_parent_paths(
where_qk,
[
(["Expand", "Reshape", "Equal"], [0, 0, 0]),
(["Equal", "Unsqueeze", "Unsqueeze"], [0, 0, 0]),
(["Cast", "Expand", "Reshape", "Equal"], [0, 0, 0, 0]),
],
output_name_to_node,
)
elif is_distill_add:
_, mask_nodes, _ = self.model.match_parent_paths(
where_qk,
[
(["Cast", "Equal", "Unsqueeze", "Unsqueeze"], [0, 0, 0, 0]),
(["Equal", "Unsqueeze", "Unsqueeze"], [0, 0, 0]),
],
output_name_to_node,
)
if add_qk is not None:
add_qk_str = self.get_add_qk_str(add_qk)
if add_qk_str is None:
logger.debug("fuse_attention: failed to verify shape inference of %s", add_qk)
return
elif is_no_mask_attention:
pass
else:
_, mask_nodes, _ = self.model.match_parent_paths(
add_qk,
[
(["Mul", "Sub", "Cast", "Unsqueeze", "Unsqueeze"], [None, 0, 1, 0, 0]),
(["Mul", "Sub", "Unsqueeze", "Unsqueeze"], [None, 0, 1, 0]),
# The following two patterns are for SDPA.
(["Where", "Cast", "Sub", "Expand", "Unsqueeze", "Unsqueeze"], [None, 0, 0, 1, 0, 0]),
(["Where", "Cast", "Sub", "Cast", "Expand", "Unsqueeze", "Unsqueeze"], [None, 0, 0, 1, 0, 0, 0]),
],
output_name_to_node,
)
if not is_no_mask_attention and mask_nodes is None:
logger.debug("fuse_attention: failed to match mask path")
return
if not is_no_mask_attention and len(mask_nodes) > 1:
_, mul_val = self.model.get_constant_input(mask_nodes[0])
# The mask value shall be a float scalar (usually is the lowest float value).
if (
(mul_val is None)
or not (isinstance(mul_val, np.ndarray) and mul_val.size == 1)
or (float(mul_val) >= 0)
):
return
if float(mul_val) != -10000:
self.mask_filter_value = float(mul_val)
if matmul_v.input[0] == root_input and matmul_q.input[0] == root_input and matmul_k.input[0] == root_input:
mask_index = self.attention_mask.process_mask(mask_nodes[-1].input[0]) if not is_no_mask_attention else None
attention_last_node = reshape_qkv if einsum_node is None else transpose_qkv
q_num_heads, q_hidden_size = self.get_num_heads_and_hidden_size(reshape_q)
if q_num_heads <= 0 or q_hidden_size <= 0:
logger.warning(
"Failed to detect num_heads and hidden_size for Attention fusion. "
"Please specify those parameters in argument."
)
return
# number of heads are same for all the paths, hence to create attention node, we pass the q_num_heads
# the input_hidden_size represents the input hidden size, this is used as needed but hidden sizes for Q, K are extracted appropriately
new_node = self.create_attention_node(
mask_index=mask_index,
q_matmul=matmul_q,
k_matmul=matmul_k,
v_matmul=matmul_v,
q_add=add_q,
k_add=add_k,
v_add=add_v,
num_heads=q_num_heads,
hidden_size=q_hidden_size,
first_input=root_input,
output=attention_last_node.output[0],
add_qk_str=add_qk_str,
)
if new_node is None:
return
self.nodes_to_add.append(new_node)
self.node_name_to_graph_name[new_node.name] = self.this_graph_name
if einsum_node is not None:
unique_index = einsum_node.input[0]
new_edge = "edge_modified_" + unique_index
shape_tensor = self.add_initializer(
name="shape_modified_tensor" + unique_index,
data_type=TensorProto.INT64,
dims=[4],
vals=[0, 0, q_num_heads, int(q_hidden_size / q_num_heads)],
raw=False,
)
self.model.add_node(
helper.make_node(
"Reshape",
[attention_last_node.output[0], shape_tensor.name],
[new_edge],
"reshape_modified_" + unique_index,
),
self.this_graph_name,
)
einsum_node.input[0] = new_edge
self.nodes_to_remove.extend([attention_last_node, transpose_qkv, matmul_qkv])
self.nodes_to_remove.extend(qk_nodes)
# For MultiHeadAttention operator, MatMul nodes for Q/K/V projection shall not be fused.
self.nodes_to_remove.extend(q_nodes if not self.use_multi_head_attention else q_nodes[:-1])
self.nodes_to_remove.extend(k_nodes if not self.use_multi_head_attention else k_nodes[:-1])
self.nodes_to_remove.extend(v_nodes if not self.use_multi_head_attention else v_nodes[:-1])
# Use prune graph to remove mask nodes since they are shared by all attention nodes.
self.prune_graph = True