# Copyright 2025 - Oumi## Licensed under the Apache License, Version 2.0 (the "License");# you may not use this file except in compliance with the License.# You may obtain a copy of the License at## http://www.apache.org/licenses/LICENSE-2.0## Unless required by applicable law or agreed to in writing, software# distributed under the License is distributed on an "AS IS" BASIS,# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.# See the License for the specific language governing permissions and# limitations under the License.importimportlibimportimportlib.utilimporttorchifimportlib.util.find_spec("flash_attn")isNone:_FLASH_ATTN_V2_INSTALLED=Falseelse:try:fromflash_attn.flash_attn_interfaceimport(# pyright: ignore[reportMissingImports]_flash_attn_backward,_flash_attn_forward,)_FLASH_ATTN_V2_INSTALLED=TrueexceptImportErrorase:_FLASH_ATTN_V2_INSTALLED=FalseraiseImportError("Failed to import Flash Attention `_flash_attn_forward` and ""`_flash_attn_backward` functions. Consider re-installing Flash Attention: ""`pip install flash-attn --no-build-isolation`.")fromefromoumi.models.layers.zigzag_utilsimport(RingComm,get_default_args,update_out_and_lse,)
[docs]defis_zigzag_ring_flash_attn_available()->bool:"""Indicates whether zigzag ring attention is available."""return_FLASH_ATTN_V2_INSTALLED
[docs]defzigzag_ring_flash_attn_forward(process_group,q:torch.Tensor,k:torch.Tensor,v:torch.Tensor,softmax_scale,dropout_p=0,causal=True,window_size=(-1,-1),alibi_slopes=None,deterministic=False,):"""Zigzag ring flash attention forward."""assertcausal,"zigzag ring is meaningless for causal=False"comm=RingComm(process_group)world_size:int=int(comm.world_size)assertworld_size>0,"Empty world!"block_seq_len=q.shape[1]//2q1=q[:,block_seq_len:]out=Nonelse=Nonenext_k,next_v=None,None# type: ignoredefforward(q,k,v,causal):"""Zigzag ring flash attention forward."""params=get_default_args(_flash_attn_forward).copy()params.update({"q":q,"k":k,"v":v,"dropout_p":dropout_p,"softmax_scale":softmax_scale,"causal":causal,"window_size":window_size,"alibi_slopes":alibi_slopes,"return_softmax":Trueanddropout_p>0,})# block_out, _, _, _, _, block_lse, _, _ = _flash_attn_forward(**params)block_out,block_lse,_,_=_flash_attn_forward(**params)returnblock_out,block_lseforstepinrange(world_size):ifstep+1!=world_size:next_k:torch.Tensor=comm.send_recv(k)next_v:torch.Tensor=comm.send_recv(v)comm.commit()ifstep==0:block_out,block_lse=forward(q,k,v,causal=True)out,lse=update_out_and_lse(out,lse,block_out,block_lse)elifstep<=comm.rank:k0=k[:,:block_seq_len]v0=v[:,:block_seq_len]block_out,block_lse=forward(q,k0,v0,causal=False)out,lse=update_out_and_lse(out,lse,block_out,block_lse)else:block_out,block_lse=forward(q1,k,v,causal=False)out,lse=update_out_and_lse(out,lse,block_out,block_lse,slice_=(slice(None),slice(block_seq_len,None)),)ifstep+1!=world_size:comm.wait()k=next_kv=next_vassertoutisnotNone,f"world_size: {world_size}"assertlseisnotNone,f"world_size: {world_size}"out=out.to(q.dtype)lse=lse.squeeze(dim=-1).transpose(1,2)returnout,lse
[docs]defzigzag_ring_flash_attn_backward(process_group,dout:torch.Tensor,q:torch.Tensor,k:torch.Tensor,v:torch.Tensor,out:torch.Tensor,softmax_lse:torch.Tensor,softmax_scale,dropout_p=0,causal=True,window_size=(-1,-1),alibi_slopes=None,deterministic=False,):"""Zigzag ring flash attention backward."""assertcausal,"zigzag ring is meaningless for causal=False"kv_comm=RingComm(process_group)world_size:int=int(kv_comm.world_size)assertworld_size>0,"Empty world!"d_kv_comm=RingComm(process_group)assertworld_size==int(d_kv_comm.world_size),"Inconsistent world sizes!"dq,dk,dv=None,None,Nonenext_dk,next_dv=None,Nonenext_k,next_v=None,Nonedk_comm_buffer,dv_comm_buffer=None,Nonedout1=dout.chunk(2,dim=1)[1]q1=q.chunk(2,dim=1)[1]out1=out.chunk(2,dim=1)[1]softmax_lse1=softmax_lse.chunk(2,dim=2)[1].contiguous()block_seq_len=q.shape[1]//2# repeatly allocating buffer may be slow...dq_buffer=torch.empty(q.shape,dtype=q.dtype,device=q.device)dk_buffer=torch.empty(k.shape,dtype=k.dtype,device=k.device)dv_buffer=torch.empty(v.shape,dtype=v.dtype,device=v.device)defbackward(dout,q,k,v,out,softmax_lse,causal):"""Zigzag ring flash attention backward."""seqlen_q=q.shape[1]seqlen_kv=k.shape[1]params=get_default_args(_flash_attn_backward).copy()params.update({"dout":dout,"q":q,"k":k,"v":v,"out":out,"softmax_lse":softmax_lse,"dq":dq_buffer[:,:seqlen_q],"dk":dk_buffer[:,:seqlen_kv],"dv":dv_buffer[:,:seqlen_kv],"dropout_p":dropout_p,"softmax_scale":softmax_scale,"causal":causal,"window_size":window_size,"alibi_slopes":alibi_slopes,"deterministic":deterministic,})_flash_attn_backward(**params)forstepinrange(world_size):ifstep+1!=world_size:next_k=kv_comm.send_recv(k)next_v=kv_comm.send_recv(v)kv_comm.commit()ifstep==0:backward(dout,q,k,v,out,softmax_lse,causal=True)dq=dq_buffer.to(torch.float32)dk=dk_buffer.to(torch.float32)dv=dv_buffer.to(torch.float32)else:assertstep>0,f"step: {step}, world_size: {world_size}"assertdqisnotNone,f"step: {step}, world_size: {world_size}"assertdkisnotNone,f"step: {step}, world_size: {world_size}"assertdvisnotNone,f"step: {step}, world_size: {world_size}"assertnext_dkisnotNone,f"step: {step}, world_size: {world_size}"assertnext_dvisnotNone,f"step: {step}, world_size: {world_size}"ifstep<=kv_comm.rank:k0=k[:,:block_seq_len]v0=v[:,:block_seq_len]backward(dout,q,k0,v0,out,softmax_lse,causal=False)dq+=dq_bufferelse:backward(dout1,q1,k,v,out1,softmax_lse1,causal=False)# always use the first half in dq_buffer.dq[:,block_seq_len:]+=dq_buffer[:,:block_seq_len]d_kv_comm.wait()dk_comm_buffer,dv_comm_buffer=dk,dvdk,dv=next_dk,next_dvifstep<=kv_comm.rank:dk[:,:block_seq_len]+=dk_buffer[:,:block_seq_len]dv[:,:block_seq_len]+=dv_buffer[:,:block_seq_len]else:dk+=dk_bufferdv+=dv_bufferifstep+1!=world_size:kv_comm.wait()assertnext_kisnotNone,f"step: {step}, world_size: {world_size}"assertnext_visnotNone,f"step: {step}, world_size: {world_size}"k=next_kv=next_vnext_dk=d_kv_comm.send_recv(dk,dk_comm_buffer)next_dv=d_kv_comm.send_recv(dv,dv_comm_buffer)d_kv_comm.commit()d_kv_comm.wait()assertdqisnotNone,f"step: {step}, world_size: {world_size}"assertnext_dkisnotNone,f"step: {step}, world_size: {world_size}"assertnext_dvisnotNone,f"step: {step}, world_size: {world_size}"returndq.to(q.dtype),next_dk.to(q.dtype),next_dv.to(q.dtype)
[docs]classZigZagRingFlashAttnFunc(torch.autograd.Function):"""Zigzag ring flash attention."""
[docs]@staticmethoddefforward(ctx,q,k,v,dropout_p,softmax_scale,causal,window_size,alibi_slopes,deterministic,return_softmax,group,):"""Zigzag ring flash attention forward."""ifsoftmax_scaleisNone:softmax_scale=q.shape[-1]**(-0.5)assertalibi_slopesisNonek=k.contiguous()v=v.contiguous()out,softmax_lse=zigzag_ring_flash_attn_forward(group,q,k,v,softmax_scale=softmax_scale,dropout_p=dropout_p,causal=causal,window_size=window_size,alibi_slopes=alibi_slopes,deterministic=False,)# this should be out_paddedctx.save_for_backward(q,k,v,out,softmax_lse)ctx.dropout_p=dropout_pctx.softmax_scale=softmax_scalectx.causal=causalctx.window_size=window_sizectx.alibi_slopes=alibi_slopesctx.deterministic=deterministicctx.group=groupreturnoutifnotreturn_softmaxelse(out,softmax_lse,None)
[docs]@staticmethoddefbackward(ctx,dout,*args):"""Zigzag ring flash attention backward."""q,k,v,out,softmax_lse=ctx.saved_tensorsdq,dk,dv=zigzag_ring_flash_attn_backward(ctx.group,dout,q,k,v,out,softmax_lse,softmax_scale=ctx.softmax_scale,dropout_p=ctx.dropout_p,causal=ctx.causal,window_size=ctx.window_size,alibi_slopes=ctx.alibi_slopes,deterministic=ctx.deterministic,)returndq,dk,dv,None,None,None,None,None,None,None,None
[docs]defzigzag_ring_flash_attn_qkvpacked_func(qkv,dropout_p=0.0,softmax_scale=None,causal=False,window_size=(-1,-1),alibi_slopes=None,deterministic=False,return_attn_probs=False,group=None,):"""Zigzag ring flash attention."""returnZigZagRingFlashAttnFunc.apply(qkv[:,:,0],qkv[:,:,1],qkv[:,:,2],dropout_p,softmax_scale,causal,window_size,alibi_slopes,deterministic,return_attn_probs,group,)
[docs]defzigzag_ring_flash_attn_kvpacked_func(q,kv,dropout_p=0.0,softmax_scale=None,causal=False,window_size=(-1,-1),alibi_slopes=None,deterministic=False,return_attn_probs=False,group=None,):"""Zigzag ring flash attention."""returnZigZagRingFlashAttnFunc.apply(q,kv[:,:,0],kv[:,:,1],dropout_p,softmax_scale,causal,window_size,alibi_slopes,deterministic,return_attn_probs,group,)
[docs]defzigzag_ring_flash_attn_func(q,k,v,dropout_p=0.0,softmax_scale=None,causal=False,window_size=(-1,-1),alibi_slopes=None,deterministic=False,return_attn_probs=False,group=None,):"""Zigzag ring flash attention."""returnZigZagRingFlashAttnFunc.apply(q,k,v,dropout_p,softmax_scale,causal,window_size,alibi_slopes,deterministic,return_attn_probs,group,)