Source code for oumi.models.layers.zigzag

# Copyright 2025 - Oumi
#
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at
#
#     http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.

import importlib
import importlib.util

import torch

if importlib.util.find_spec("flash_attn") is None:
    _FLASH_ATTN_V2_INSTALLED = False
else:
    try:
        from flash_attn.flash_attn_interface import (  # pyright: ignore[reportMissingImports]
            _flash_attn_backward,
            _flash_attn_forward,
        )

        _FLASH_ATTN_V2_INSTALLED = True
    except ImportError as e:
        _FLASH_ATTN_V2_INSTALLED = False
        raise ImportError(
            "Failed to import Flash Attention `_flash_attn_forward` and "
            "`_flash_attn_backward` functions. Consider re-installing Flash Attention: "
            "`pip install flash-attn --no-build-isolation`."
        ) from e


from oumi.models.layers.zigzag_utils import (
    RingComm,
    get_default_args,
    update_out_and_lse,
)


[docs] def is_zigzag_ring_flash_attn_available() -> bool: """Indicates whether zigzag ring attention is available.""" return _FLASH_ATTN_V2_INSTALLED
# Derived from # jzhang38/EasyContext/easy_context/zigzag_ring_attn/monkey_patch.py # zhuzilin/ring-flash-attention/ring_flash_attn/zigzag_ring_flash_attn.py
[docs] def zigzag_ring_flash_attn_forward( process_group, q: torch.Tensor, k: torch.Tensor, v: torch.Tensor, softmax_scale, dropout_p=0, causal=True, window_size=(-1, -1), alibi_slopes=None, deterministic=False, ): """Zigzag ring flash attention forward.""" assert causal, "zigzag ring is meaningless for causal=False" comm = RingComm(process_group) world_size: int = int(comm.world_size) assert world_size > 0, "Empty world!" block_seq_len = q.shape[1] // 2 q1 = q[:, block_seq_len:] out = None lse = None next_k, next_v = None, None # type: ignore def forward(q, k, v, causal): """Zigzag ring flash attention forward.""" params = get_default_args(_flash_attn_forward).copy() params.update( { "q": q, "k": k, "v": v, "dropout_p": dropout_p, "softmax_scale": softmax_scale, "causal": causal, "window_size": window_size, "alibi_slopes": alibi_slopes, "return_softmax": True and dropout_p > 0, } ) # block_out, _, _, _, _, block_lse, _, _ = _flash_attn_forward(**params) block_out, block_lse, _, _ = _flash_attn_forward(**params) return block_out, block_lse for step in range(world_size): if step + 1 != world_size: next_k: torch.Tensor = comm.send_recv(k) next_v: torch.Tensor = comm.send_recv(v) comm.commit() if step == 0: block_out, block_lse = forward(q, k, v, causal=True) out, lse = update_out_and_lse(out, lse, block_out, block_lse) elif step <= comm.rank: k0 = k[:, :block_seq_len] v0 = v[:, :block_seq_len] block_out, block_lse = forward(q, k0, v0, causal=False) out, lse = update_out_and_lse(out, lse, block_out, block_lse) else: block_out, block_lse = forward(q1, k, v, causal=False) out, lse = update_out_and_lse( out, lse, block_out, block_lse, slice_=(slice(None), slice(block_seq_len, None)), ) if step + 1 != world_size: comm.wait() k = next_k v = next_v assert out is not None, f"world_size: {world_size}" assert lse is not None, f"world_size: {world_size}" out = out.to(q.dtype) lse = lse.squeeze(dim=-1).transpose(1, 2) return out, lse
[docs] def zigzag_ring_flash_attn_backward( process_group, dout: torch.Tensor, q: torch.Tensor, k: torch.Tensor, v: torch.Tensor, out: torch.Tensor, softmax_lse: torch.Tensor, softmax_scale, dropout_p=0, causal=True, window_size=(-1, -1), alibi_slopes=None, deterministic=False, ): """Zigzag ring flash attention backward.""" assert causal, "zigzag ring is meaningless for causal=False" kv_comm = RingComm(process_group) world_size: int = int(kv_comm.world_size) assert world_size > 0, "Empty world!" d_kv_comm = RingComm(process_group) assert world_size == int(d_kv_comm.world_size), "Inconsistent world sizes!" dq, dk, dv = None, None, None next_dk, next_dv = None, None next_k, next_v = None, None dk_comm_buffer, dv_comm_buffer = None, None dout1 = dout.chunk(2, dim=1)[1] q1 = q.chunk(2, dim=1)[1] out1 = out.chunk(2, dim=1)[1] softmax_lse1 = softmax_lse.chunk(2, dim=2)[1].contiguous() block_seq_len = q.shape[1] // 2 # repeatly allocating buffer may be slow... dq_buffer = torch.empty(q.shape, dtype=q.dtype, device=q.device) dk_buffer = torch.empty(k.shape, dtype=k.dtype, device=k.device) dv_buffer = torch.empty(v.shape, dtype=v.dtype, device=v.device) def backward(dout, q, k, v, out, softmax_lse, causal): """Zigzag ring flash attention backward.""" seqlen_q = q.shape[1] seqlen_kv = k.shape[1] params = get_default_args(_flash_attn_backward).copy() params.update( { "dout": dout, "q": q, "k": k, "v": v, "out": out, "softmax_lse": softmax_lse, "dq": dq_buffer[:, :seqlen_q], "dk": dk_buffer[:, :seqlen_kv], "dv": dv_buffer[:, :seqlen_kv], "dropout_p": dropout_p, "softmax_scale": softmax_scale, "causal": causal, "window_size": window_size, "alibi_slopes": alibi_slopes, "deterministic": deterministic, } ) _flash_attn_backward(**params) for step in range(world_size): if step + 1 != world_size: next_k = kv_comm.send_recv(k) next_v = kv_comm.send_recv(v) kv_comm.commit() if step == 0: backward(dout, q, k, v, out, softmax_lse, causal=True) dq = dq_buffer.to(torch.float32) dk = dk_buffer.to(torch.float32) dv = dv_buffer.to(torch.float32) else: assert step > 0, f"step: {step}, world_size: {world_size}" assert dq is not None, f"step: {step}, world_size: {world_size}" assert dk is not None, f"step: {step}, world_size: {world_size}" assert dv is not None, f"step: {step}, world_size: {world_size}" assert next_dk is not None, f"step: {step}, world_size: {world_size}" assert next_dv is not None, f"step: {step}, world_size: {world_size}" if step <= kv_comm.rank: k0 = k[:, :block_seq_len] v0 = v[:, :block_seq_len] backward(dout, q, k0, v0, out, softmax_lse, causal=False) dq += dq_buffer else: backward(dout1, q1, k, v, out1, softmax_lse1, causal=False) # always use the first half in dq_buffer. dq[:, block_seq_len:] += dq_buffer[:, :block_seq_len] d_kv_comm.wait() dk_comm_buffer, dv_comm_buffer = dk, dv dk, dv = next_dk, next_dv if step <= kv_comm.rank: dk[:, :block_seq_len] += dk_buffer[:, :block_seq_len] dv[:, :block_seq_len] += dv_buffer[:, :block_seq_len] else: dk += dk_buffer dv += dv_buffer if step + 1 != world_size: kv_comm.wait() assert next_k is not None, f"step: {step}, world_size: {world_size}" assert next_v is not None, f"step: {step}, world_size: {world_size}" k = next_k v = next_v next_dk = d_kv_comm.send_recv(dk, dk_comm_buffer) next_dv = d_kv_comm.send_recv(dv, dv_comm_buffer) d_kv_comm.commit() d_kv_comm.wait() assert dq is not None, f"step: {step}, world_size: {world_size}" assert next_dk is not None, f"step: {step}, world_size: {world_size}" assert next_dv is not None, f"step: {step}, world_size: {world_size}" return dq.to(q.dtype), next_dk.to(q.dtype), next_dv.to(q.dtype)
[docs] class ZigZagRingFlashAttnFunc(torch.autograd.Function): """Zigzag ring flash attention."""
[docs] @staticmethod def forward( ctx, q, k, v, dropout_p, softmax_scale, causal, window_size, alibi_slopes, deterministic, return_softmax, group, ): """Zigzag ring flash attention forward.""" if softmax_scale is None: softmax_scale = q.shape[-1] ** (-0.5) assert alibi_slopes is None k = k.contiguous() v = v.contiguous() out, softmax_lse = zigzag_ring_flash_attn_forward( group, q, k, v, softmax_scale=softmax_scale, dropout_p=dropout_p, causal=causal, window_size=window_size, alibi_slopes=alibi_slopes, deterministic=False, ) # this should be out_padded ctx.save_for_backward(q, k, v, out, softmax_lse) ctx.dropout_p = dropout_p ctx.softmax_scale = softmax_scale ctx.causal = causal ctx.window_size = window_size ctx.alibi_slopes = alibi_slopes ctx.deterministic = deterministic ctx.group = group return out if not return_softmax else (out, softmax_lse, None)
[docs] @staticmethod def backward(ctx, dout, *args): """Zigzag ring flash attention backward.""" q, k, v, out, softmax_lse = ctx.saved_tensors dq, dk, dv = zigzag_ring_flash_attn_backward( ctx.group, dout, q, k, v, out, softmax_lse, softmax_scale=ctx.softmax_scale, dropout_p=ctx.dropout_p, causal=ctx.causal, window_size=ctx.window_size, alibi_slopes=ctx.alibi_slopes, deterministic=ctx.deterministic, ) return dq, dk, dv, None, None, None, None, None, None, None, None
[docs] def zigzag_ring_flash_attn_qkvpacked_func( qkv, dropout_p=0.0, softmax_scale=None, causal=False, window_size=(-1, -1), alibi_slopes=None, deterministic=False, return_attn_probs=False, group=None, ): """Zigzag ring flash attention.""" return ZigZagRingFlashAttnFunc.apply( qkv[:, :, 0], qkv[:, :, 1], qkv[:, :, 2], dropout_p, softmax_scale, causal, window_size, alibi_slopes, deterministic, return_attn_probs, group, )
[docs] def zigzag_ring_flash_attn_kvpacked_func( q, kv, dropout_p=0.0, softmax_scale=None, causal=False, window_size=(-1, -1), alibi_slopes=None, deterministic=False, return_attn_probs=False, group=None, ): """Zigzag ring flash attention.""" return ZigZagRingFlashAttnFunc.apply( q, kv[:, :, 0], kv[:, :, 1], dropout_p, softmax_scale, causal, window_size, alibi_slopes, deterministic, return_attn_probs, group, )
[docs] def zigzag_ring_flash_attn_func( q, k, v, dropout_p=0.0, softmax_scale=None, causal=False, window_size=(-1, -1), alibi_slopes=None, deterministic=False, return_attn_probs=False, group=None, ): """Zigzag ring flash attention.""" return ZigZagRingFlashAttnFunc.apply( q, k, v, dropout_p, softmax_scale, causal, window_size, alibi_slopes, deterministic, return_attn_probs, group, )