oumi.models.layers#
Submodules#
oumi.models.layers.ring_attention module#
- oumi.models.layers.ring_attention.apply_zigzag_ring_attn_monkey_patch_llama()[source]#
Apply the zigzag ring attention monkey patch to llama.
- oumi.models.layers.ring_attention.extract_local(value, rank, world_size, device, dim=1)[source]#
Extract the local value from the global value.
- oumi.models.layers.ring_attention.new_decoder_forward(self, hidden_states: Tensor, attention_mask: Tensor | None = None, position_ids: LongTensor | None = None, past_key_value: Cache | None = None, output_attentions: bool | None = False, use_cache: bool | None = False, cache_position: LongTensor | None = None, position_embeddings: tuple[Tensor, Tensor] | None = None, **kwargs) tuple[FloatTensor, tuple[FloatTensor, FloatTensor] | None] [source]#
New decoder forward.
oumi.models.layers.zigzag module#
- class oumi.models.layers.zigzag.ZigZagRingFlashAttnFunc(*args, **kwargs)[source]#
Bases:
Function
Zigzag ring flash attention.
- oumi.models.layers.zigzag.is_zigzag_ring_flash_attn_available() bool [source]#
Indicates whether zigzag ring attention is available.
- oumi.models.layers.zigzag.zigzag_ring_flash_attn_backward(process_group, dout: Tensor, q: Tensor, k: Tensor, v: Tensor, out: Tensor, softmax_lse: Tensor, softmax_scale, dropout_p=0, causal=True, window_size=(-1, -1), alibi_slopes=None, deterministic=False)[source]#
Zigzag ring flash attention backward.
- oumi.models.layers.zigzag.zigzag_ring_flash_attn_forward(process_group, q: Tensor, k: Tensor, v: Tensor, softmax_scale, dropout_p=0, causal=True, window_size=(-1, -1), alibi_slopes=None, deterministic=False)[source]#
Zigzag ring flash attention forward.
- oumi.models.layers.zigzag.zigzag_ring_flash_attn_func(q, k, v, dropout_p=0.0, softmax_scale=None, causal=False, window_size=(-1, -1), alibi_slopes=None, deterministic=False, return_attn_probs=False, group=None)[source]#
Zigzag ring flash attention.
oumi.models.layers.zigzag_utils module#
- class oumi.models.layers.zigzag_utils.RingComm(process_group: ProcessGroup)[source]#
Bases:
object
Ring communication.