Skip to content

vllm.model_executor.layers.attention.sparse_mla_attention

Common implementation for sparse MLA attention prefill (forward_mha).

Parallel to MLACommonImpl (in mla_attention.py) which provides forward_mha for non-sparse MLA backends, this module provides SparseMLACommonImpl which gives all sparse MLA backends a shared forward_mha implementation.

Decode (forward_mqa) is left abstract for each sparse backend to implement with its own sparse decode kernel.

SparseMLACommonImpl

Bases: SparseMLAAttentionImpl[T], Generic[T]

Common sparse MLA implementation providing forward_mha for prefill.

Subclasses must implement forward_mqa() for decode.

Source code in vllm/model_executor/layers/attention/sparse_mla_attention.py
class SparseMLACommonImpl(SparseMLAAttentionImpl[T], Generic[T]):
    """Common sparse MLA implementation providing forward_mha for prefill.

    Subclasses must implement forward_mqa() for decode.
    """

    def __init__(
        self,
        num_heads: int,
        head_size: int,
        scale: float,
        num_kv_heads: int,
        alibi_slopes: list[float] | None,
        sliding_window: int | None,
        kv_cache_dtype: str,
        logits_soft_cap: float | None,
        attn_type: str,
        kv_sharing_target_layer_name: str | None,
        # MLA Specific Arguments
        q_lora_rank: int | None,
        kv_lora_rank: int,
        qk_nope_head_dim: int,
        qk_rope_head_dim: int,
        qk_head_dim: int,
        v_head_dim: int,
        kv_b_proj: "ColumnParallelLinear",
        indexer: object | None = None,
        q_pad_num_heads: int | None = None,
    ) -> None:
        self.num_heads = num_heads
        self.head_size = head_size
        self.scale = float(scale)
        self.num_kv_heads = num_kv_heads
        self.kv_cache_dtype = kv_cache_dtype

        self.kv_lora_rank = kv_lora_rank
        self.qk_nope_head_dim = qk_nope_head_dim
        self.qk_rope_head_dim = qk_rope_head_dim
        self.qk_head_dim = qk_head_dim
        self.v_head_dim = v_head_dim
        self.kv_b_proj = kv_b_proj

        assert indexer is not None
        self.topk_indices_buffer: torch.Tensor | None = getattr(
            indexer, "topk_indices_buffer", None
        )

        # FA4 is required for mask_mod in forward_mha.
        fa_version = get_flash_attn_version(head_size=qk_head_dim)
        self._fa4_available = fa_version is not None and fa_version >= 4

        # DCP (context parallelism) — lazily initialized by the caller.
        self.dcp_world_size: int = -1
        self.cp_kv_cache_interleave_size: int = (
            get_current_vllm_config().parallel_config.cp_kv_cache_interleave_size
        )

    def _concat_k_nope_k_pe(
        self, k_nope: torch.Tensor, k_pe: torch.Tensor
    ) -> torch.Tensor:
        """Concatenate k_nope and k_pe along the last dimension."""
        k = torch.empty(
            (*k_nope.shape[:-1], k_nope.shape[-1] + k_pe.shape[-1]),
            dtype=k_nope.dtype,
            device=k_nope.device,
        )
        k[..., : k_nope.shape[-1]] = k_nope
        k[..., k_nope.shape[-1] :] = k_pe
        return k

    # ------------------------------------------------------------------
    # Context gathering helpers
    # ------------------------------------------------------------------

    def _gather_and_decompress_all(
        self,
        kv_c_and_k_pe_cache: torch.Tensor,
        block_table: torch.Tensor,
        seq_lens_t: torch.Tensor,
        cu_seq_lens: torch.Tensor,
        total_tokens: int,
        k_scale: torch.Tensor,
    ) -> tuple[torch.Tensor, torch.Tensor]:
        """Gather all positions from paged cache and decompress.

        Uses the same fused CUDA gather kernel as standard MLA prefill.

        Returns:
            k: (total_tokens, num_heads, qk_head_dim)
            v: (total_tokens, num_heads, v_head_dim)
        """
        device = kv_c_and_k_pe_cache.device
        head_size = kv_c_and_k_pe_cache.shape[-1]
        num_reqs = seq_lens_t.shape[0]

        if total_tokens == 0:
            empty_k = torch.empty(
                0,
                self.num_heads,
                self.qk_head_dim,
                dtype=kv_c_and_k_pe_cache.dtype,
                device=device,
            )
            empty_v = torch.empty(
                0,
                self.num_heads,
                self.v_head_dim,
                dtype=kv_c_and_k_pe_cache.dtype,
                device=device,
            )
            return empty_k, empty_v

        workspace = torch.empty(
            total_tokens,
            head_size,
            dtype=kv_c_and_k_pe_cache.dtype,
            device=device,
        )
        token_to_seq = torch.repeat_interleave(
            torch.arange(num_reqs, dtype=torch.int32, device=device),
            seq_lens_t,
        )

        ops.gather_and_maybe_dequant_cache(
            src_cache=kv_c_and_k_pe_cache,
            dst=workspace,
            block_table=block_table,
            cu_seq_lens=cu_seq_lens,
            token_to_seq=token_to_seq,
            num_tokens=total_tokens,
            kv_cache_dtype=self.kv_cache_dtype,
            scale=k_scale,
        )

        kv_c = workspace[..., : self.kv_lora_rank]
        k_pe = workspace[..., self.kv_lora_rank :]

        kv_nope = self.kv_b_proj(kv_c)[0].view(
            -1, self.num_heads, self.qk_nope_head_dim + self.v_head_dim
        )
        k_nope, v = kv_nope.split([self.qk_nope_head_dim, self.v_head_dim], dim=-1)

        k_pe = k_pe.unsqueeze(1).expand(-1, self.num_heads, -1)
        k = self._concat_k_nope_k_pe(k_nope, k_pe)

        return k, v

    # ------------------------------------------------------------------
    # forward_mha — FA4 masked MHA over topk positions
    # ------------------------------------------------------------------

    def forward_mha(
        self,
        q: torch.Tensor,
        kv_c_normed: torch.Tensor,
        k_pe: torch.Tensor,
        kv_c_and_k_pe_cache: torch.Tensor,
        attn_metadata: T,
        k_scale: torch.Tensor,
        output: torch.Tensor,
    ) -> None:
        if is_quantized_kv_cache(self.kv_cache_dtype):
            raise NotImplementedError(
                "Sparse MLA forward_mha with FP8 KV cache not yet supported"
            )
        if not self._fa4_available:
            raise NotImplementedError(
                "Sparse MLA forward_mha requires FA4 (SM100+). "
                "On SM90, all tokens are routed through forward_mqa."
            )

        device = q.device
        prefill_qsl = getattr(attn_metadata, "prefill_query_start_loc", None)
        assert prefill_qsl is not None, (
            "Metadata must provide prefill_query_start_loc for forward_mha"
        )

        num_decodes: int = getattr(attn_metadata, "num_decodes", 0)
        num_prefills: int = getattr(attn_metadata, "num_prefills", 0)
        num_decode_tokens: int = getattr(attn_metadata, "num_decode_tokens", 0)

        all_seq_lens = getattr(attn_metadata, "seq_lens", None)
        assert all_seq_lens is not None
        prefill_seq_lens = all_seq_lens[num_decodes : num_decodes + num_prefills]

        # Per-request query and sequence lengths
        prefill_qsl_cpu = prefill_qsl.cpu()
        q_lens = [
            (prefill_qsl_cpu[i + 1] - prefill_qsl_cpu[i]).item()
            for i in range(num_prefills)
        ]
        max_q_len = max(q_lens) if q_lens else 0

        # Build cu_seqlens_k from seq_lens
        seq_lens_t = prefill_seq_lens.to(torch.int32)
        cu_seqlens_k = torch.zeros(num_prefills + 1, dtype=torch.int32, device=device)
        torch.cumsum(seq_lens_t, dim=0, out=cu_seqlens_k[1:])
        total_kv_tokens = cu_seqlens_k[-1].item()
        max_kv_len = seq_lens_t.max().item() if num_prefills > 0 else 0

        # Get topk indices for prefill tokens
        assert self.topk_indices_buffer is not None
        num_prefill_tokens = q.shape[0]
        topk_all = self.topk_indices_buffer[
            num_decode_tokens : num_decode_tokens + num_prefill_tokens
        ]
        topk_per_req: list[torch.Tensor] = []
        ti_offset = 0
        for i in range(num_prefills):
            ql = q_lens[i]
            topk_per_req.append(topk_all[ti_offset : ti_offset + ql])
            ti_offset += ql

        # Gather all positions from cache and decompress
        block_table = getattr(attn_metadata, "block_table", None)
        assert block_table is not None
        prefill_block_table = block_table[num_decodes : num_decodes + num_prefills]

        k, v = self._gather_and_decompress_all(
            kv_c_and_k_pe_cache,
            prefill_block_table,
            seq_lens_t,
            cu_seqlens_k,
            total_kv_tokens,
            k_scale,
        )

        # Build topk mask
        dense_mask = _build_topk_mask(
            topk_per_req,
            q_lens,
            max_q_len,
            max_kv_len,
            device,
        )

        attn_out, _ = flash_attn_varlen_func(
            q=q,
            k=k,
            v=v,
            cu_seqlens_q=prefill_qsl,
            cu_seqlens_k=cu_seqlens_k,
            max_seqlen_q=max_q_len,
            max_seqlen_k=max_kv_len,
            softmax_scale=self.scale,
            causal=False,
            return_softmax_lse=True,
            fa_version=4,
            dense_mask=dense_mask,
        )

        attn_out = attn_out[..., : self.v_head_dim]
        output.copy_(attn_out.flatten(start_dim=-2))

_concat_k_nope_k_pe

_concat_k_nope_k_pe(k_nope: Tensor, k_pe: Tensor) -> Tensor

Concatenate k_nope and k_pe along the last dimension.

Source code in vllm/model_executor/layers/attention/sparse_mla_attention.py
def _concat_k_nope_k_pe(
    self, k_nope: torch.Tensor, k_pe: torch.Tensor
) -> torch.Tensor:
    """Concatenate k_nope and k_pe along the last dimension."""
    k = torch.empty(
        (*k_nope.shape[:-1], k_nope.shape[-1] + k_pe.shape[-1]),
        dtype=k_nope.dtype,
        device=k_nope.device,
    )
    k[..., : k_nope.shape[-1]] = k_nope
    k[..., k_nope.shape[-1] :] = k_pe
    return k

_gather_and_decompress_all

_gather_and_decompress_all(
    kv_c_and_k_pe_cache: Tensor,
    block_table: Tensor,
    seq_lens_t: Tensor,
    cu_seq_lens: Tensor,
    total_tokens: int,
    k_scale: Tensor,
) -> tuple[Tensor, Tensor]

Gather all positions from paged cache and decompress.

Uses the same fused CUDA gather kernel as standard MLA prefill.

Returns:

Name Type Description
k Tensor

(total_tokens, num_heads, qk_head_dim)

v Tensor

(total_tokens, num_heads, v_head_dim)

Source code in vllm/model_executor/layers/attention/sparse_mla_attention.py
def _gather_and_decompress_all(
    self,
    kv_c_and_k_pe_cache: torch.Tensor,
    block_table: torch.Tensor,
    seq_lens_t: torch.Tensor,
    cu_seq_lens: torch.Tensor,
    total_tokens: int,
    k_scale: torch.Tensor,
) -> tuple[torch.Tensor, torch.Tensor]:
    """Gather all positions from paged cache and decompress.

    Uses the same fused CUDA gather kernel as standard MLA prefill.

    Returns:
        k: (total_tokens, num_heads, qk_head_dim)
        v: (total_tokens, num_heads, v_head_dim)
    """
    device = kv_c_and_k_pe_cache.device
    head_size = kv_c_and_k_pe_cache.shape[-1]
    num_reqs = seq_lens_t.shape[0]

    if total_tokens == 0:
        empty_k = torch.empty(
            0,
            self.num_heads,
            self.qk_head_dim,
            dtype=kv_c_and_k_pe_cache.dtype,
            device=device,
        )
        empty_v = torch.empty(
            0,
            self.num_heads,
            self.v_head_dim,
            dtype=kv_c_and_k_pe_cache.dtype,
            device=device,
        )
        return empty_k, empty_v

    workspace = torch.empty(
        total_tokens,
        head_size,
        dtype=kv_c_and_k_pe_cache.dtype,
        device=device,
    )
    token_to_seq = torch.repeat_interleave(
        torch.arange(num_reqs, dtype=torch.int32, device=device),
        seq_lens_t,
    )

    ops.gather_and_maybe_dequant_cache(
        src_cache=kv_c_and_k_pe_cache,
        dst=workspace,
        block_table=block_table,
        cu_seq_lens=cu_seq_lens,
        token_to_seq=token_to_seq,
        num_tokens=total_tokens,
        kv_cache_dtype=self.kv_cache_dtype,
        scale=k_scale,
    )

    kv_c = workspace[..., : self.kv_lora_rank]
    k_pe = workspace[..., self.kv_lora_rank :]

    kv_nope = self.kv_b_proj(kv_c)[0].view(
        -1, self.num_heads, self.qk_nope_head_dim + self.v_head_dim
    )
    k_nope, v = kv_nope.split([self.qk_nope_head_dim, self.v_head_dim], dim=-1)

    k_pe = k_pe.unsqueeze(1).expand(-1, self.num_heads, -1)
    k = self._concat_k_nope_k_pe(k_nope, k_pe)

    return k, v

_build_topk_mask

_build_topk_mask(
    topk_indices_per_req: list[Tensor],
    q_lens: list[int],
    max_q_len: int,
    max_seq_len: int,
    device: device,
) -> Tensor

Build a dense (B, max_Q, max_S) int32 mask from topk indices.

Uses a single Triton kernel to scatter 1s at each valid topk position.

Source code in vllm/model_executor/layers/attention/sparse_mla_attention.py
def _build_topk_mask(
    topk_indices_per_req: list[torch.Tensor],
    q_lens: list[int],
    max_q_len: int,
    max_seq_len: int,
    device: torch.device,
) -> torch.Tensor:
    """Build a dense (B, max_Q, max_S) int32 mask from topk indices.

    Uses a single Triton kernel to scatter 1s at each valid topk position.
    """
    B = len(q_lens)
    mask = torch.zeros(B, max_q_len, max_seq_len, dtype=torch.int32, device=device)

    total_q = sum(q_lens)
    if total_q == 0:
        return mask

    topk_packed = torch.cat(topk_indices_per_req, dim=0)  # (total_q, topk)
    topk_k = topk_packed.shape[1]

    q_lens_t = torch.tensor(q_lens, dtype=torch.int32, device=device)
    cu_q_lens = torch.zeros(B + 1, dtype=torch.int32, device=device)
    torch.cumsum(q_lens_t, dim=0, out=cu_q_lens[1:])

    BLOCK_TOPK = triton.next_power_of_2(topk_k)
    _scatter_topk_kernel[(total_q,)](
        mask,
        topk_packed,
        cu_q_lens,
        max_seq_len=max_seq_len,
        topk=topk_k,
        topk_stride=topk_packed.stride(0),
        max_q_len=max_q_len,
        BLOCK_TOPK=BLOCK_TOPK,
        NUM_REQS=B,
    )

    return mask

_scatter_topk_kernel

_scatter_topk_kernel(
    mask_ptr,
    topk_ptr,
    cu_q_lens_ptr,
    max_seq_len: constexpr,
    topk: constexpr,
    topk_stride: constexpr,
    max_q_len: constexpr,
    BLOCK_TOPK: constexpr,
    NUM_REQS: constexpr,
)

Write 1 at each valid topk position in the mask.

Grid: (total_topk_rows,) where total_topk_rows = sum(q_lens). Each program handles one query token's topk entries.

Source code in vllm/model_executor/layers/attention/sparse_mla_attention.py
@triton.jit
def _scatter_topk_kernel(
    mask_ptr,  # (B, max_Q, max_S) int32 output
    topk_ptr,  # (total_topk_rows, topk_stride) int32 packed topk indices
    cu_q_lens_ptr,  # (B+1,) int32 cumulative query lengths
    max_seq_len: tl.constexpr,
    topk: tl.constexpr,
    topk_stride: tl.constexpr,
    max_q_len: tl.constexpr,
    BLOCK_TOPK: tl.constexpr,
    NUM_REQS: tl.constexpr,
):
    """Write 1 at each valid topk position in the mask.

    Grid: (total_topk_rows,) where total_topk_rows = sum(q_lens).
    Each program handles one query token's topk entries.
    """
    row_idx = tl.program_id(0)

    # Find batch index: count how many cu_q_lens boundaries are <= row_idx.
    b: tl.int32 = 0
    for i in tl.static_range(NUM_REQS):
        next_start = tl.load(cu_q_lens_ptr + i + 1)
        b += tl.where(next_start <= row_idx, 1, 0)

    q_start = tl.load(cu_q_lens_ptr + b)
    q_local = row_idx - q_start

    # Load topk indices for this query token
    topk_row_ptr = topk_ptr + row_idx * topk_stride
    offsets = tl.arange(0, BLOCK_TOPK)
    in_range = offsets < topk
    indices = tl.load(topk_row_ptr + offsets, mask=in_range, other=-1)

    # Only write valid indices (-1 marks invalid/padding)
    valid = in_range & (indices >= 0)
    mask_row_ptr = mask_ptr + (b * max_q_len + q_local) * max_seq_len
    tl.store(mask_row_ptr + indices, tl.where(valid, 1, 0), mask=valid)