group-wbl/.venv/lib/python3.13/site-packages/onnxruntime/transformers/past_helper.py

150 lines
6.6 KiB
Python
Raw Normal View History

2026-01-09 09:48:03 +08:00
# -------------------------------------------------------------------------
# Copyright (c) Microsoft Corporation. All rights reserved.
# Licensed under the MIT License. See License.txt in the project root for
# license information.
# --------------------------------------------------------------------------
import logging
import torch
logger = logging.getLogger(__name__)
class PastKeyValuesHelper:
"""Helper functions to process past key values for encoder-decoder model"""
@staticmethod
def get_past_names(num_layers, present: bool = False):
past_self_names = []
past_cross_names = []
for i in range(num_layers):
past_self_names.extend(
[f"present_key_self_{i}", f"present_value_self_{i}"]
if present
else [f"past_key_self_{i}", f"past_value_self_{i}"]
)
past_cross_names.extend(
[f"present_key_cross_{i}", f"present_value_cross_{i}"]
if present
else [f"past_key_cross_{i}", f"past_value_cross_{i}"]
)
return past_self_names + past_cross_names
@staticmethod
def group_by_self_or_cross(present_key_values):
"""Split present state from grouped by layer to grouped by self/cross attention.
Before: (past_key_self_0, past_value_self_0, past_key_cross_0, past_value_cross_0), (past_key_self_1, past_value_self_1, past_key_cross_1, past_value_cross_1), ...
After: (past_key_self_0, past_value_self_0, past_key_self_1, past_value_self_1, ...), (past_key_cross_0, past_value_cross_0, past_key_cross_1, past_value_cross_1, ...)
"""
present_self = []
present_cross = []
for _i, present_layer_i in enumerate(present_key_values):
assert len(present_layer_i) == 4, f"Expected to have four items. Got {len(present_layer_i)}"
(
present_key_self,
present_value_self,
present_key_cross,
present_value_cross,
) = present_layer_i
present_self.extend([present_key_self, present_value_self])
present_cross.extend([present_key_cross, present_value_cross])
return present_self, present_cross
@staticmethod
def group_by_layer(past, num_layers):
"""Reorder past state from grouped by self/cross attention to grouped by layer.
Before: past_key_self_0, past_value_self_0, past_key_self_1, past_value_self_1, ..., past_key_cross_0, past_value_cross_0, past_key_cross_1, past_value_cross_1, ...
After: (past_key_self_0, past_value_self_0, past_key_cross_0, past_value_cross_0), (past_key_self_1, past_value_self_1, past_key_cross_1, past_value_cross_1),
"""
assert len(past) == 4 * num_layers
return tuple(
[
past[2 * i],
past[2 * i + 1],
past[2 * num_layers + 2 * i],
past[2 * num_layers + 2 * i + 1],
]
for i in range(num_layers)
)
@staticmethod
def back_group_by_layer(past_key_values: tuple[tuple[torch.Tensor]]):
"""Categorize present_key_values from self and cross attention to layer by layer.
Reorder past state from grouped by self/cross attention to grouped by layer.
Before: past_key_self_0, past_value_self_0, past_key_self_1, past_value_self_1, ...,
past_key_cross_0, past_value_cross_0, past_key_cross_1, past_value_cross_1, ...
After: (past_key_self_0, past_value_self_0, past_key_cross_0, past_value_cross_0),
(past_key_self_1, past_value_self_1, past_key_cross_1, past_value_cross_1),
Args:
present_key_values: From past_key_values of a model (group by self and cross attention)
Returns:
past_tuples: present key and values grouped by layer.
"""
past_tuples = ()
half_idx = len(past_key_values) // 2
for i in range(len(past_key_values) // 4):
idx = 2 * i
past_tuples += (
(
past_key_values[idx],
past_key_values[idx + 1],
past_key_values[half_idx + idx],
past_key_values[half_idx + idx + 1],
),
)
return past_tuples
@staticmethod
def group_by_self_and_cross(present_key_values: tuple[torch.Tensor], concat: bool = False):
"""Categorize present_key_values into self and cross attention.
Split present state from grouped by layer to grouped by self/cross attention.
Before: (past_key_self_0, past_value_self_0, past_key_cross_0, past_value_cross_0),
(past_key_self_1, past_value_self_1, past_key_cross_1, past_value_cross_1), ...
After: (past_key_self_0, past_value_self_0, past_key_self_1, past_value_self_1, ...),
(past_key_cross_0, past_value_cross_0, past_key_cross_1, past_value_cross_1, ...)
Args:
present_key_values: From past_key_values of a model (group by layer)
concat: If concat self attention with cross attention key/value to return
Returns:
present_self (Tuple[torch.Tensor]): present key and values from self attention
present_cross (Tuple[torch.Tensor]): present key and values from cross attention
"""
present_self: list[torch.Tensor] = []
present_cross: list[torch.Tensor] = []
for _, present_layer_i in enumerate(present_key_values):
assert len(present_layer_i) == 4, f"Expected to have four items. Got {len(present_layer_i)}"
present_key_self, present_value_self, present_key_cross, present_value_cross = present_layer_i
present_self.extend([present_key_self, present_value_self])
present_cross.extend([present_key_cross, present_value_cross])
if concat:
return present_self + present_cross
else:
return present_self, present_cross
@staticmethod
def get_input_names(past_key_values: tuple[tuple[torch.Tensor]], encoder=True):
"""Process input names of model wrapper.
Args:
past_key_values: Consider `self` and `cross` past_key_values
Returns:
names (List[string]): input names
"""
names = []
num_layers = len(past_key_values) // 4 if encoder else len(past_key_values)
prefix = "past_" if not encoder else "present_"
for i in range(num_layers):
names.extend([prefix + s for s in [f"key_self_{i}", f"value_self_{i}"]])
for i in range(num_layers):
names.extend([prefix + s for s in [f"key_cross_{i}", f"value_cross_{i}"]])
return names