oumi.models.layers#

Submodules#

oumi.models.layers.ring_attention module#

oumi.models.layers.ring_attention.apply_zigzag_ring_attn_monkey_patch_llama()[source]#

Apply the zigzag ring attention monkey patch to llama.

oumi.models.layers.ring_attention.extract_local(value, rank, world_size, device, dim=1)[source]#

Extract the local value from the global value.

oumi.models.layers.ring_attention.new_decoder_forward(self, hidden_states: Tensor, attention_mask: Tensor | None = None, position_ids: LongTensor | None = None, past_key_value: Cache | None = None, output_attentions: bool | None = False, use_cache: bool | None = False, cache_position: LongTensor | None = None, position_embeddings: tuple[Tensor, Tensor] | None = None, **kwargs) tuple[FloatTensor, tuple[FloatTensor, FloatTensor] | None][source]#

New decoder forward.

oumi.models.layers.ring_attention.new_flash_attn_forward(self, query_states, key_states, value_states, attention_mask, query_length, dropout=0.0, softmax_scale=None, use_sliding_windows=False)[source]#

New flash attention forward.

oumi.models.layers.ring_attention.prepare_zigzag_ring_attn_inputs(input_ids, position_ids, target_ids, rank, world_size, device)[source]#

Prepare the inputs for zigzag ring attention.

oumi.models.layers.zigzag module#

class oumi.models.layers.zigzag.ZigZagRingFlashAttnFunc(*args, **kwargs)[source]#

Bases: Function

Zigzag ring flash attention.

static backward(ctx, dout, *args)[source]#

Zigzag ring flash attention backward.

static forward(ctx, q, k, v, dropout_p, softmax_scale, causal, window_size, alibi_slopes, deterministic, return_softmax, group)[source]#

Zigzag ring flash attention forward.

oumi.models.layers.zigzag.is_zigzag_ring_flash_attn_available() bool[source]#

Indicates whether zigzag ring attention is available.

oumi.models.layers.zigzag.zigzag_ring_flash_attn_backward(process_group, dout: Tensor, q: Tensor, k: Tensor, v: Tensor, out: Tensor, softmax_lse: Tensor, softmax_scale, dropout_p=0, causal=True, window_size=(-1, -1), alibi_slopes=None, deterministic=False)[source]#

Zigzag ring flash attention backward.

oumi.models.layers.zigzag.zigzag_ring_flash_attn_forward(process_group, q: Tensor, k: Tensor, v: Tensor, softmax_scale, dropout_p=0, causal=True, window_size=(-1, -1), alibi_slopes=None, deterministic=False)[source]#

Zigzag ring flash attention forward.

oumi.models.layers.zigzag.zigzag_ring_flash_attn_func(q, k, v, dropout_p=0.0, softmax_scale=None, causal=False, window_size=(-1, -1), alibi_slopes=None, deterministic=False, return_attn_probs=False, group=None)[source]#

Zigzag ring flash attention.

oumi.models.layers.zigzag.zigzag_ring_flash_attn_kvpacked_func(q, kv, dropout_p=0.0, softmax_scale=None, causal=False, window_size=(-1, -1), alibi_slopes=None, deterministic=False, return_attn_probs=False, group=None)[source]#

Zigzag ring flash attention.

oumi.models.layers.zigzag.zigzag_ring_flash_attn_qkvpacked_func(qkv, dropout_p=0.0, softmax_scale=None, causal=False, window_size=(-1, -1), alibi_slopes=None, deterministic=False, return_attn_probs=False, group=None)[source]#

Zigzag ring flash attention.

oumi.models.layers.zigzag_utils module#

class oumi.models.layers.zigzag_utils.RingComm(process_group: ProcessGroup)[source]#

Bases: object

Ring communication.

commit()[source]#

Commit the operations.

send_recv(to_send: Tensor, recv_tensor: Tensor | None = None) Tensor[source]#

Send and receive a tensor.

wait()[source]#

Wait for the operations to complete.

oumi.models.layers.zigzag_utils.get_default_args(func) dict[str, Any][source]#

Get the default arguments of a function.

oumi.models.layers.zigzag_utils.update_out_and_lse(out: Tensor | None, lse: Tensor | None, block_out: Tensor, block_lse: Tensor, slice_=None) tuple[Tensor, Tensor][source]#

Update the output and log-sum-exp of the attention.