# Copyright 2025 - Oumi## Licensed under the Apache License, Version 2.0 (the "License");# you may not use this file except in compliance with the License.# You may obtain a copy of the License at## http://www.apache.org/licenses/LICENSE-2.0## Unless required by applicable law or agreed to in writing, software# distributed under the License is distributed on an "AS IS" BASIS,# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.# See the License for the specific language governing permissions and# limitations under the License.fromtypingimportOptionalimporttorchimporttorch.utils.checkpointimporttransformers.cache_utilsastransformers_cache_utilsimporttransformers.modeling_flash_attention_utilsastransformers_flash_attention_utilsimporttransformers.modelsastransformers_modelsfromoumi.models.layers.zigzagimport(is_zigzag_ring_flash_attn_available,zigzag_ring_flash_attn_func,)# Derived from:# jzhang38/EasyContext/easy_context/zigzag_ring_attn/prepare_inputs.py
[docs]defextract_local(value,rank,world_size,device,dim=1):"""Extract the local value from the global value."""value_chunks=value.chunk(2*world_size,dim=dim)local_value=torch.cat([value_chunks[rank],value_chunks[2*world_size-rank-1]],dim=dim)returnlocal_value.to(device)
[docs]defprepare_zigzag_ring_attn_inputs(input_ids,position_ids,target_ids,rank,world_size,device):"""Prepare the inputs for zigzag ring attention."""local_input_ids=extract_local(input_ids,rank,world_size,device,)local_position_ids=extract_local(position_ids,rank,world_size,device,)iftarget_idsisnotNone:local_target_ids=extract_local(target_ids,rank,world_size,device,)else:local_target_ids=Nonereturn{"local_input_ids":local_input_ids,"local_position_ids":local_position_ids,"local_target_ids":local_target_ids,}
[docs]defnew_flash_attn_forward(self,query_states,key_states,value_states,attention_mask,query_length,dropout=0.0,softmax_scale=None,use_sliding_windows=False,):"""New flash attention forward."""ifnotself._flash_attn_uses_top_left_mask:causal=self.is_causalelse:causal=self.is_causalandquery_length!=1# Contains at least one padding token in the sequenceassertattention_maskisNoneassertcausalisTrueassertuse_sliding_windowsisFalseattn_output=zigzag_ring_flash_attn_func(query_states,key_states,value_states,dropout,softmax_scale,causal=causal,)returnattn_output
[docs]defnew_decoder_forward(self,hidden_states:torch.Tensor,attention_mask:Optional[torch.Tensor]=None,position_ids:Optional[torch.LongTensor]=None,past_key_value:Optional[transformers_cache_utils.Cache]=None,output_attentions:Optional[bool]=False,use_cache:Optional[bool]=False,cache_position:Optional[torch.LongTensor]=None,position_embeddings:Optional[tuple[torch.Tensor,torch.Tensor]]=None,# will become mandatory in v4.46**kwargs,)->tuple[torch.FloatTensor,Optional[tuple[torch.FloatTensor,torch.FloatTensor]]]:"""New decoder forward."""# Was originally LlamaFlashAttention2, but this was deleted in transformers 4.48.0.assertisinstance(self.self_attn,transformers_models.llama.modeling_llama.LlamaAttention,# Ditto but with MistralFlashAttention2.)orisinstance(self.self_attn,transformers_models.mistral.modeling_mistral.MistralAttention,),("Please toggle on the Flash Attention 2 implementation ""when using zigzag ring attention monkey patch.")residual=hidden_stateshidden_states=self.input_layernorm(hidden_states)# Self Attentionhidden_states,self_attn_weights,present_key_value=self.self_attn(hidden_states=hidden_states,attention_mask=attention_mask,position_ids=position_ids,past_key_value=past_key_value,output_attentions=output_attentions,use_cache=use_cache,cache_position=cache_position,**kwargs,)hidden_states=residual+hidden_states# Fully Connectedresidual=hidden_stateshidden_states=self.post_attention_layernorm(hidden_states)hidden_states=self.mlp(hidden_states)hidden_states=residual+hidden_statesassertisinstance(hidden_states,torch.Tensor)outputs=(hidden_states,)ifoutput_attentions:outputs+=(self_attn_weights,)ifuse_cache:outputs+=(present_key_value,)returnoutputs# type: ignore
[docs]defapply_zigzag_ring_attn_monkey_patch_llama():"""Apply the zigzag ring attention monkey patch to llama."""ifnotis_zigzag_ring_flash_attn_available():raiseRuntimeError("Ring attention is not available. ""Install Flash Attention: `pip install flash-attn --no-build-isolation`.")# See also https://github.com/zhuzilin/ring-flash-attention/blob/6f1c6f7a6b40ec87cf1a80c3858c3c74b6c415c0/ring_flash_attn/adapters/hf_adapter.py#L167transformers_flash_attention_utils._flash_attention_forward=new_flash_attn_forward# (transformers_models.llama.modeling_llama.LlamaFlashAttention2.# _flash_attention_forward = new_flash_attn_forward)transformers_models.llama.modeling_llama.LlamaDecoderLayer.forward=(new_decoder_forward)