# Copyright 2025 - Oumi
#
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at
#
# http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.
from typing import Optional
import torch
import torch.utils.checkpoint
import transformers.cache_utils as transformers_cache_utils
import transformers.modeling_flash_attention_utils as transformers_flash_attention_utils
import transformers.models as transformers_models
from oumi.models.layers.zigzag import (
is_zigzag_ring_flash_attn_available,
zigzag_ring_flash_attn_func,
)
# Derived from:
# jzhang38/EasyContext/easy_context/zigzag_ring_attn/prepare_inputs.py
[docs]
def new_flash_attn_forward(
self,
query_states,
key_states,
value_states,
attention_mask,
query_length,
dropout=0.0,
softmax_scale=None,
use_sliding_windows=False,
):
"""New flash attention forward."""
if not self._flash_attn_uses_top_left_mask:
causal = self.is_causal
else:
causal = self.is_causal and query_length != 1
# Contains at least one padding token in the sequence
assert attention_mask is None
assert causal is True
assert use_sliding_windows is False
attn_output = zigzag_ring_flash_attn_func(
query_states,
key_states,
value_states,
dropout,
softmax_scale,
causal=causal,
)
return attn_output
[docs]
def new_decoder_forward(
self,
hidden_states: torch.Tensor,
attention_mask: Optional[torch.Tensor] = None,
position_ids: Optional[torch.LongTensor] = None,
past_key_value: Optional[transformers_cache_utils.Cache] = None,
output_attentions: Optional[bool] = False,
use_cache: Optional[bool] = False,
cache_position: Optional[torch.LongTensor] = None,
position_embeddings: Optional[
tuple[torch.Tensor, torch.Tensor]
] = None, # will become mandatory in v4.46
**kwargs,
) -> tuple[torch.FloatTensor, Optional[tuple[torch.FloatTensor, torch.FloatTensor]]]:
"""New decoder forward."""
assert isinstance(
self.self_attn,
transformers_models.llama.modeling_llama.LlamaFlashAttention2,
) or isinstance(
self.self_attn,
transformers_models.mistral.modeling_mistral.MistralFlashAttention2,
), (
"Please toggle on the Flash Attention 2 implementation "
"when using zigzag ring attention monkey patch."
)
residual = hidden_states
hidden_states = self.input_layernorm(hidden_states)
# Self Attention
hidden_states, self_attn_weights, present_key_value = self.self_attn(
hidden_states=hidden_states,
attention_mask=attention_mask,
position_ids=position_ids,
past_key_value=past_key_value,
output_attentions=output_attentions,
use_cache=use_cache,
cache_position=cache_position,
**kwargs,
)
hidden_states = residual + hidden_states
# Fully Connected
residual = hidden_states
hidden_states = self.post_attention_layernorm(hidden_states)
hidden_states = self.mlp(hidden_states)
hidden_states = residual + hidden_states
assert isinstance(hidden_states, torch.Tensor)
outputs = (hidden_states,)
if output_attentions:
outputs += (self_attn_weights,)
if use_cache:
outputs += (present_key_value,)
return outputs # type: ignore
[docs]
def apply_zigzag_ring_attn_monkey_patch_llama():
"""Apply the zigzag ring attention monkey patch to llama."""
if not is_zigzag_ring_flash_attn_available():
raise RuntimeError(
"Ring attention is not available. "
"Install Flash Attention: `pip install flash-attn --no-build-isolation`."
)
# See also https://github.com/zhuzilin/ring-flash-attention/blob/6f1c6f7a6b40ec87cf1a80c3858c3c74b6c415c0/ring_flash_attn/adapters/hf_adapter.py#L167
transformers_flash_attention_utils._flash_attention_forward = new_flash_attn_forward
# (transformers_models.llama.modeling_llama.LlamaFlashAttention2.
# _flash_attention_forward = new_flash_attn_forward)
transformers_models.llama.modeling_llama.LlamaDecoderLayer.forward = (
new_decoder_forward
)