# Copyright 2025 - Oumi## Licensed under the Apache License, Version 2.0 (the "License");# you may not use this file except in compliance with the License.# You may obtain a copy of the License at## http://www.apache.org/licenses/LICENSE-2.0## Unless required by applicable law or agreed to in writing, software# distributed under the License is distributed on an "AS IS" BASIS,# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.# See the License for the specific language governing permissions and# limitations under the License.importinspectfromfunctoolsimportcachefromtypingimportAny,Optionalimporttorchimporttorch.distributedasdistimporttorch.nn.functionalasF__all__=["update_out_and_lse","RingComm","get_default_args"]# Derived from# zhuzilin/ring-flash-attention/ring_flash_attn/utils.py
[docs]@cachedefget_default_args(func)->dict[str,Any]:"""Get the default arguments of a function."""spec=inspect.getfullargspec(func)defaults=spec.defaultsifspec.defaultsisnotNoneelse()padded_defaults=(None,)*(len(spec.args)-len(defaults))+defaultsargs:dict[str,Any]=dict(zip(spec.args,padded_defaults))if"softcap"inargs:args["softcap"]=0.0returnargs
@torch.jit.scriptdef_update_out_and_lse(out:torch.Tensor,lse:torch.Tensor,block_out:torch.Tensor,block_lse:torch.Tensor,)->tuple[torch.Tensor,torch.Tensor]:block_out=block_out.to(torch.float32)block_lse=block_lse.transpose(-2,-1).unsqueeze(dim=-1)# new_lse = lse + torch.log(1 + torch.exp(block_lse - lse))# torch.exp(lse - new_lse) * out + torch.exp(block_lse - new_lse) * block_out# For additional context and discussion, please refer to:# https://github.com/zhuzilin/ring-flash-attention/pull/34#issuecomment-2076126795out=out-F.sigmoid(block_lse-lse)*(out-block_out)lse=lse-F.logsigmoid(lse-block_lse)returnout,lse
[docs]defupdate_out_and_lse(out:Optional[torch.Tensor],lse:Optional[torch.Tensor],block_out:torch.Tensor,block_lse:torch.Tensor,slice_=None,)->tuple[torch.Tensor,torch.Tensor]:"""Update the output and log-sum-exp of the attention."""ifoutisNone:ifslice_isnotNone:raiseRuntimeError("first update_out_and_lse should not pass slice_ args")out=block_out.to(torch.float32)lse=block_lse.transpose(-2,-1).unsqueeze(dim=-1)eliflseisNone:raiseValueError("`lse` can be None only if `out` is None")elifslice_isnotNone:slice_out,slice_lse=out[slice_],lse[slice_]slice_out,slice_lse=_update_out_and_lse(slice_out,slice_lse,block_out,block_lse)out[slice_],lse[slice_]=slice_out,slice_lseelse:out,lse=_update_out_and_lse(out,lse,block_out,block_lse)returnout,lse# type: ignore
@torch.jit.scriptdefflatten_varlen_lse(lse,cu_seqlens):"""Flatten the log-sum-exp of the attention."""new_lse=[]foriinrange(len(cu_seqlens)-1):start,end=cu_seqlens[i],cu_seqlens[i+1]new_lse.append(lse[i,:,:end-start])returntorch.cat(new_lse,dim=1)@torch.jit.scriptdefunflatten_varlen_lse(lse,cu_seqlens,max_seqlen:int):"""Unflatten the log-sum-exp of the attention."""num_seq=len(cu_seqlens)-1num_head=lse.shape[-2]new_lse=torch.empty((num_seq,max_seqlen,num_head,1),dtype=torch.float32,device=lse.device)foriinrange(num_seq):start,end=cu_seqlens[i],cu_seqlens[i+1]new_lse[i,:end-start]=lse[start:end]returnnew_lse.squeeze(dim=-1).transpose(1,2).contiguous()
[docs]classRingComm:"""Ring communication."""def__init__(self,process_group:dist.ProcessGroup):"""Initialize the ring communication."""self._process_group=process_groupself._ops=[]self.rank=dist.get_rank(self._process_group)self.world_size=dist.get_world_size(self._process_group)self._reqs=Noneself.send_rank=(self.rank+1)%self.world_sizeself.recv_rank=(self.rank-1)%self.world_sizeifprocess_groupisnotNone:self.send_rank=dist.get_global_rank(self._process_group,self.send_rank)self.recv_rank=dist.get_global_rank(self._process_group,self.recv_rank)
[docs]defsend_recv(self,to_send:torch.Tensor,recv_tensor:Optional[torch.Tensor]=None)->torch.Tensor:"""Send and receive a tensor."""ifrecv_tensorisNone:res=torch.empty_like(to_send)else:res=recv_tensorsend_op=dist.P2POp(dist.isend,to_send,self.send_rank,group=self._process_group)recv_op=dist.P2POp(dist.irecv,res,self.recv_rank,group=self._process_group)self._ops.append(send_op)self._ops.append(recv_op)returnres
[docs]defcommit(self):"""Commit the operations."""ifself._reqsisnotNone:raiseRuntimeError("commit called twice")self._reqs=dist.batch_isend_irecv(self._ops)
[docs]defwait(self):"""Wait for the operations to complete."""ifself._reqsisNone:raiseRuntimeError("wait called before commit")forreqinself._reqs:req.wait()self._reqs=Noneself._ops=[]