# Copyright 2025 - Oumi
#
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at
#
# http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.
import importlib
import importlib.util
import torch
if importlib.util.find_spec("flash_attn") is None:
_FLASH_ATTN_V2_INSTALLED = False
else:
try:
from flash_attn.flash_attn_interface import ( # pyright: ignore[reportMissingImports]
_flash_attn_backward,
_flash_attn_forward,
)
_FLASH_ATTN_V2_INSTALLED = True
except ImportError as e:
_FLASH_ATTN_V2_INSTALLED = False
raise ImportError(
"Failed to import Flash Attention `_flash_attn_forward` and "
"`_flash_attn_backward` functions. Consider re-installing Flash Attention: "
"`pip install flash-attn --no-build-isolation`."
) from e
from oumi.models.layers.zigzag_utils import (
RingComm,
get_default_args,
update_out_and_lse,
)
[docs]
def is_zigzag_ring_flash_attn_available() -> bool:
"""Indicates whether zigzag ring attention is available."""
return _FLASH_ATTN_V2_INSTALLED
# Derived from
# jzhang38/EasyContext/easy_context/zigzag_ring_attn/monkey_patch.py
# zhuzilin/ring-flash-attention/ring_flash_attn/zigzag_ring_flash_attn.py
[docs]
def zigzag_ring_flash_attn_forward(
process_group,
q: torch.Tensor,
k: torch.Tensor,
v: torch.Tensor,
softmax_scale,
dropout_p=0,
causal=True,
window_size=(-1, -1),
alibi_slopes=None,
deterministic=False,
):
"""Zigzag ring flash attention forward."""
assert causal, "zigzag ring is meaningless for causal=False"
comm = RingComm(process_group)
world_size: int = int(comm.world_size)
assert world_size > 0, "Empty world!"
block_seq_len = q.shape[1] // 2
q1 = q[:, block_seq_len:]
out = None
lse = None
next_k, next_v = None, None # type: ignore
def forward(q, k, v, causal):
"""Zigzag ring flash attention forward."""
params = get_default_args(_flash_attn_forward).copy()
params.update(
{
"q": q,
"k": k,
"v": v,
"dropout_p": dropout_p,
"softmax_scale": softmax_scale,
"causal": causal,
"window_size": window_size,
"alibi_slopes": alibi_slopes,
"return_softmax": True and dropout_p > 0,
}
)
# block_out, _, _, _, _, block_lse, _, _ = _flash_attn_forward(**params)
block_out, block_lse, _, _ = _flash_attn_forward(**params)
return block_out, block_lse
for step in range(world_size):
if step + 1 != world_size:
next_k: torch.Tensor = comm.send_recv(k)
next_v: torch.Tensor = comm.send_recv(v)
comm.commit()
if step == 0:
block_out, block_lse = forward(q, k, v, causal=True)
out, lse = update_out_and_lse(out, lse, block_out, block_lse)
elif step <= comm.rank:
k0 = k[:, :block_seq_len]
v0 = v[:, :block_seq_len]
block_out, block_lse = forward(q, k0, v0, causal=False)
out, lse = update_out_and_lse(out, lse, block_out, block_lse)
else:
block_out, block_lse = forward(q1, k, v, causal=False)
out, lse = update_out_and_lse(
out,
lse,
block_out,
block_lse,
slice_=(slice(None), slice(block_seq_len, None)),
)
if step + 1 != world_size:
comm.wait()
k = next_k
v = next_v
assert out is not None, f"world_size: {world_size}"
assert lse is not None, f"world_size: {world_size}"
out = out.to(q.dtype)
lse = lse.squeeze(dim=-1).transpose(1, 2)
return out, lse
[docs]
def zigzag_ring_flash_attn_backward(
process_group,
dout: torch.Tensor,
q: torch.Tensor,
k: torch.Tensor,
v: torch.Tensor,
out: torch.Tensor,
softmax_lse: torch.Tensor,
softmax_scale,
dropout_p=0,
causal=True,
window_size=(-1, -1),
alibi_slopes=None,
deterministic=False,
):
"""Zigzag ring flash attention backward."""
assert causal, "zigzag ring is meaningless for causal=False"
kv_comm = RingComm(process_group)
world_size: int = int(kv_comm.world_size)
assert world_size > 0, "Empty world!"
d_kv_comm = RingComm(process_group)
assert world_size == int(d_kv_comm.world_size), "Inconsistent world sizes!"
dq, dk, dv = None, None, None
next_dk, next_dv = None, None
next_k, next_v = None, None
dk_comm_buffer, dv_comm_buffer = None, None
dout1 = dout.chunk(2, dim=1)[1]
q1 = q.chunk(2, dim=1)[1]
out1 = out.chunk(2, dim=1)[1]
softmax_lse1 = softmax_lse.chunk(2, dim=2)[1].contiguous()
block_seq_len = q.shape[1] // 2
# repeatly allocating buffer may be slow...
dq_buffer = torch.empty(q.shape, dtype=q.dtype, device=q.device)
dk_buffer = torch.empty(k.shape, dtype=k.dtype, device=k.device)
dv_buffer = torch.empty(v.shape, dtype=v.dtype, device=v.device)
def backward(dout, q, k, v, out, softmax_lse, causal):
"""Zigzag ring flash attention backward."""
seqlen_q = q.shape[1]
seqlen_kv = k.shape[1]
params = get_default_args(_flash_attn_backward).copy()
params.update(
{
"dout": dout,
"q": q,
"k": k,
"v": v,
"out": out,
"softmax_lse": softmax_lse,
"dq": dq_buffer[:, :seqlen_q],
"dk": dk_buffer[:, :seqlen_kv],
"dv": dv_buffer[:, :seqlen_kv],
"dropout_p": dropout_p,
"softmax_scale": softmax_scale,
"causal": causal,
"window_size": window_size,
"alibi_slopes": alibi_slopes,
"deterministic": deterministic,
}
)
_flash_attn_backward(**params)
for step in range(world_size):
if step + 1 != world_size:
next_k = kv_comm.send_recv(k)
next_v = kv_comm.send_recv(v)
kv_comm.commit()
if step == 0:
backward(dout, q, k, v, out, softmax_lse, causal=True)
dq = dq_buffer.to(torch.float32)
dk = dk_buffer.to(torch.float32)
dv = dv_buffer.to(torch.float32)
else:
assert step > 0, f"step: {step}, world_size: {world_size}"
assert dq is not None, f"step: {step}, world_size: {world_size}"
assert dk is not None, f"step: {step}, world_size: {world_size}"
assert dv is not None, f"step: {step}, world_size: {world_size}"
assert next_dk is not None, f"step: {step}, world_size: {world_size}"
assert next_dv is not None, f"step: {step}, world_size: {world_size}"
if step <= kv_comm.rank:
k0 = k[:, :block_seq_len]
v0 = v[:, :block_seq_len]
backward(dout, q, k0, v0, out, softmax_lse, causal=False)
dq += dq_buffer
else:
backward(dout1, q1, k, v, out1, softmax_lse1, causal=False)
# always use the first half in dq_buffer.
dq[:, block_seq_len:] += dq_buffer[:, :block_seq_len]
d_kv_comm.wait()
dk_comm_buffer, dv_comm_buffer = dk, dv
dk, dv = next_dk, next_dv
if step <= kv_comm.rank:
dk[:, :block_seq_len] += dk_buffer[:, :block_seq_len]
dv[:, :block_seq_len] += dv_buffer[:, :block_seq_len]
else:
dk += dk_buffer
dv += dv_buffer
if step + 1 != world_size:
kv_comm.wait()
assert next_k is not None, f"step: {step}, world_size: {world_size}"
assert next_v is not None, f"step: {step}, world_size: {world_size}"
k = next_k
v = next_v
next_dk = d_kv_comm.send_recv(dk, dk_comm_buffer)
next_dv = d_kv_comm.send_recv(dv, dv_comm_buffer)
d_kv_comm.commit()
d_kv_comm.wait()
assert dq is not None, f"step: {step}, world_size: {world_size}"
assert next_dk is not None, f"step: {step}, world_size: {world_size}"
assert next_dv is not None, f"step: {step}, world_size: {world_size}"
return dq.to(q.dtype), next_dk.to(q.dtype), next_dv.to(q.dtype)
[docs]
class ZigZagRingFlashAttnFunc(torch.autograd.Function):
"""Zigzag ring flash attention."""
[docs]
@staticmethod
def forward(
ctx,
q,
k,
v,
dropout_p,
softmax_scale,
causal,
window_size,
alibi_slopes,
deterministic,
return_softmax,
group,
):
"""Zigzag ring flash attention forward."""
if softmax_scale is None:
softmax_scale = q.shape[-1] ** (-0.5)
assert alibi_slopes is None
k = k.contiguous()
v = v.contiguous()
out, softmax_lse = zigzag_ring_flash_attn_forward(
group,
q,
k,
v,
softmax_scale=softmax_scale,
dropout_p=dropout_p,
causal=causal,
window_size=window_size,
alibi_slopes=alibi_slopes,
deterministic=False,
)
# this should be out_padded
ctx.save_for_backward(q, k, v, out, softmax_lse)
ctx.dropout_p = dropout_p
ctx.softmax_scale = softmax_scale
ctx.causal = causal
ctx.window_size = window_size
ctx.alibi_slopes = alibi_slopes
ctx.deterministic = deterministic
ctx.group = group
return out if not return_softmax else (out, softmax_lse, None)
[docs]
@staticmethod
def backward(ctx, dout, *args):
"""Zigzag ring flash attention backward."""
q, k, v, out, softmax_lse = ctx.saved_tensors
dq, dk, dv = zigzag_ring_flash_attn_backward(
ctx.group,
dout,
q,
k,
v,
out,
softmax_lse,
softmax_scale=ctx.softmax_scale,
dropout_p=ctx.dropout_p,
causal=ctx.causal,
window_size=ctx.window_size,
alibi_slopes=ctx.alibi_slopes,
deterministic=ctx.deterministic,
)
return dq, dk, dv, None, None, None, None, None, None, None, None
[docs]
def zigzag_ring_flash_attn_qkvpacked_func(
qkv,
dropout_p=0.0,
softmax_scale=None,
causal=False,
window_size=(-1, -1),
alibi_slopes=None,
deterministic=False,
return_attn_probs=False,
group=None,
):
"""Zigzag ring flash attention."""
return ZigZagRingFlashAttnFunc.apply(
qkv[:, :, 0],
qkv[:, :, 1],
qkv[:, :, 2],
dropout_p,
softmax_scale,
causal,
window_size,
alibi_slopes,
deterministic,
return_attn_probs,
group,
)
[docs]
def zigzag_ring_flash_attn_kvpacked_func(
q,
kv,
dropout_p=0.0,
softmax_scale=None,
causal=False,
window_size=(-1, -1),
alibi_slopes=None,
deterministic=False,
return_attn_probs=False,
group=None,
):
"""Zigzag ring flash attention."""
return ZigZagRingFlashAttnFunc.apply(
q,
kv[:, :, 0],
kv[:, :, 1],
dropout_p,
softmax_scale,
causal,
window_size,
alibi_slopes,
deterministic,
return_attn_probs,
group,
)
[docs]
def zigzag_ring_flash_attn_func(
q,
k,
v,
dropout_p=0.0,
softmax_scale=None,
causal=False,
window_size=(-1, -1),
alibi_slopes=None,
deterministic=False,
return_attn_probs=False,
group=None,
):
"""Zigzag ring flash attention."""
return ZigZagRingFlashAttnFunc.apply(
q,
k,
v,
dropout_p,
softmax_scale,
causal,
window_size,
alibi_slopes,
deterministic,
return_attn_probs,
group,
)