Skip to content

vllm.attention.ops.rocm_aiter_mla_sparse

logger module-attribute

logger = init_logger(__name__)

fp8_mqa_logits_torch

fp8_mqa_logits_torch(
    q: Tensor,
    kv: tuple[Tensor, Tensor],
    weights: Tensor,
    cu_seqlen_ks: Tensor,
    cu_seqlen_ke: Tensor,
) -> Tensor

Compute FP8 MQA logits for a single sequence without KV paging.

Parameters:

Name Type Description Default
q Tensor

Query tensor of shape [M, H, D]. Casted to torch.float8_e4m3fn by caller.

required
kv tuple[Tensor, Tensor]

Tuple (k_fp8, k_scales) where k_fp8 has shape [N, D] with dtype torch.float8_e4m3fn and k_scales has shape [N] (or [N, 1]) with dtype torch.float32.

required
weights Tensor

weights of shape [M, H], dtype torch.float32.

required
cu_seqlen_ks Tensor

Start indices (inclusive) for valid K per query position, shape [M], dtype int32.

required
cu_seqlen_ke Tensor

End indices (exclusive) for valid K per query position, shape [M], dtype int32.

required

Returns:

Type Description
Tensor

Logits tensor of shape [M, N], dtype torch.float32.

Source code in vllm/attention/ops/rocm_aiter_mla_sparse.py
def fp8_mqa_logits_torch(
    q: torch.Tensor,
    kv: tuple[torch.Tensor, torch.Tensor],
    weights: torch.Tensor,
    cu_seqlen_ks: torch.Tensor,
    cu_seqlen_ke: torch.Tensor,
) -> torch.Tensor:
    """Compute FP8 MQA logits for a single sequence without KV paging.

    Args:
        q: Query tensor of shape [M, H, D]. Casted to
            `torch.float8_e4m3fn` by caller.
        kv: Tuple `(k_fp8, k_scales)` where `k_fp8` has shape [N, D] with
            dtype `torch.float8_e4m3fn` and `k_scales` has shape [N] (or
            [N, 1]) with dtype `torch.float32`.
        weights: weights of shape [M, H], dtype `torch.float32`.
        cu_seqlen_ks: Start indices (inclusive) for valid K per query position,
            shape [M], dtype int32.
        cu_seqlen_ke: End indices (exclusive) for valid K per query position,
            shape [M], dtype int32.

    Returns:
        Logits tensor of shape [M, N], dtype `torch.float32`.
    """
    kv, scale = kv
    seq_len_kv = kv.shape[0]
    k = kv.to(torch.bfloat16)
    q = q.to(torch.bfloat16)

    mask_lo = (
        torch.arange(0, seq_len_kv, device="cuda")[None, :] >= cu_seqlen_ks[:, None]
    )
    mask_hi = (
        torch.arange(0, seq_len_kv, device="cuda")[None, :] < cu_seqlen_ke[:, None]
    )
    mask = mask_lo & mask_hi

    score = torch.einsum("mhd,nd->hmn", q, k).float() * scale
    logits = (score.relu() * weights.unsqueeze(-1).transpose(0, 1)).sum(dim=0)
    logits = logits.masked_fill(~mask, float("-inf"))

    return logits

fp8_paged_mqa_logits_torch

fp8_paged_mqa_logits_torch(
    q: Tensor,
    kv_cache: Tensor,
    weights: Tensor,
    context_lens: Tensor,
    block_tables: Tensor,
    max_model_len: int,
)
Source code in vllm/attention/ops/rocm_aiter_mla_sparse.py
def fp8_paged_mqa_logits_torch(
    q: torch.Tensor,
    kv_cache: torch.Tensor,
    weights: torch.Tensor,
    context_lens: torch.Tensor,
    block_tables: torch.Tensor,
    max_model_len: int,
):
    from vllm.utils.math_utils import cdiv

    fp8_dtype = current_platform.fp8_dtype()
    batch_size, next_n, _, dim = q.size()
    kv_cache, scale = kv_cache[..., :dim], kv_cache[..., dim:]
    scale = scale.contiguous().view(torch.float)
    q = q.float()
    kv_cache = kv_cache.view(fp8_dtype).float() * scale
    num_block, block_size, _, dim = kv_cache.size()
    logits = torch.full(
        [batch_size * next_n, max_model_len],
        float("-inf"),
        device=q.device,
        dtype=torch.float32,
    )
    context_lens = context_lens.tolist()
    for i in range(batch_size):
        context_len = context_lens[i]
        q_offsets = torch.arange(context_len - next_n, context_len, device="cuda")
        weight_slice = (
            weights[i * next_n : (i + 1) * next_n, :].transpose(0, 1).contiguous()
        )
        for block_rk in range(cdiv(context_len, block_size)):
            block_idx = block_tables[i][block_rk]
            qx, kx = q[i], kv_cache[block_idx]
            k_offsets = torch.arange(
                block_rk * block_size, (block_rk + 1) * block_size, device="cuda"
            )
            mask = (k_offsets[None, :] < context_len) & (
                k_offsets[None, :] <= q_offsets[:, None]
            )
            s = torch.where(
                mask[None, :, :],
                (qx.transpose(0, 1) @ kx.transpose(0, 1).transpose(1, 2)).to(
                    logits.dtype
                ),
                float("-inf"),
            )
            s = torch.relu(s) * weight_slice[..., None]
            s = s.sum(dim=0)
            logits[
                i * next_n : (i + 1) * next_n,
                block_rk * block_size : (block_rk + 1) * block_size,
            ] = torch.where(k_offsets[None, :] <= q_offsets[:, None], s, float("-inf"))
    return logits

rocm_fp8_mqa_logits

rocm_fp8_mqa_logits(
    q: Tensor,
    kv: tuple[Tensor, Tensor],
    weights: Tensor,
    cu_seqlen_ks: Tensor,
    cu_seqlen_ke: Tensor,
) -> Tensor

Compute FP8 MQA logits for a single sequence without KV paging.

Parameters:

Name Type Description Default
q Tensor

Query tensor of shape [M, H, D]. Casted to torch.float8_e4m3fn by caller.

required
kv tuple[Tensor, Tensor]

Tuple (k_fp8, k_scales) where k_fp8 has shape [N, D] with dtype torch.float8_e4m3fn and k_scales has shape [N] (or [N, 1]) with dtype torch.float32.

required
weights Tensor

weights of shape [M, H], dtype torch.float32.

required
cu_seqlen_ks Tensor

Start indices (inclusive) for valid K per query position, shape [M], dtype int32.

required
cu_seqlen_ke Tensor

End indices (exclusive) for valid K per query position, shape [M], dtype int32.

required

Returns:

Type Description
Tensor

Logits tensor of shape [M, N], dtype torch.float32.

Source code in vllm/attention/ops/rocm_aiter_mla_sparse.py
def rocm_fp8_mqa_logits(
    q: torch.Tensor,
    kv: tuple[torch.Tensor, torch.Tensor],
    weights: torch.Tensor,
    cu_seqlen_ks: torch.Tensor,
    cu_seqlen_ke: torch.Tensor,
) -> torch.Tensor:
    """Compute FP8 MQA logits for a single sequence without KV paging.

    Args:
        q: Query tensor of shape [M, H, D]. Casted to
            `torch.float8_e4m3fn` by caller.
        kv: Tuple `(k_fp8, k_scales)` where `k_fp8` has shape [N, D] with
            dtype `torch.float8_e4m3fn` and `k_scales` has shape [N] (or
            [N, 1]) with dtype `torch.float32`.
        weights: weights of shape [M, H], dtype `torch.float32`.
        cu_seqlen_ks: Start indices (inclusive) for valid K per query position,
            shape [M], dtype int32.
        cu_seqlen_ke: End indices (exclusive) for valid K per query position,
            shape [M], dtype int32.

    Returns:
        Logits tensor of shape [M, N], dtype `torch.float32`.
    """

    # TODO(ganyi): Temporarily workaround, will remove the module check and reference
    # path after aiter merge this kernel into main
    @lru_cache
    def has_mqa_logits_module():
        return importlib.util.find_spec("aiter.ops.triton.fp8_mqa_logits") is not None

    if rocm_aiter_ops.is_enabled() and has_mqa_logits_module():
        from aiter.ops.triton.fp8_mqa_logits import fp8_mqa_logits

        kv, scale = kv
        return fp8_mqa_logits(q, kv, scale, weights, cu_seqlen_ks, cu_seqlen_ke)
    else:
        return fp8_mqa_logits_torch(q, kv, weights, cu_seqlen_ks, cu_seqlen_ke)

rocm_fp8_paged_mqa_logits

rocm_fp8_paged_mqa_logits(
    q_fp8: Tensor,
    kv_cache_fp8: Tensor,
    weights: Tensor,
    context_lens: Tensor,
    block_tables: Tensor,
    schedule_metadata: Tensor,
    max_model_len: int,
) -> Tensor

Compute FP8 MQA logits using paged KV-cache.

Parameters:

Name Type Description Default
q_fp8 Tensor

Query tensor of shape [B, next_n, H, D]. Casted to torch.float8_e4m3fn by caller.

required
kv_cache_fp8 Tensor

Paged KV-cache in packed FP8+scale layout with shape [num_blocks, block_size, 1, D+4], dtype torch.uint8. The last 4 bytes per (block,pos) store the float dequant scale.

required
weights Tensor

Tensor of shape [B * next_n, H], dtype torch.float32.

required
context_lens Tensor

Tensor of shape [B], dtype int32; effective context length for each batch element.

required
block_tables Tensor

Tensor of shape [B, max_blocks], dtype int32; maps logical block indices to physical blocks in the paged cache.

required
schedule_metadata Tensor

Returned by get_paged_mqa_logits_metadata; used to distribute work across SMs.

required
max_model_len int

Maximum sequence length used to size the logits output.

required

Returns:

Type Description
Tensor

Logits tensor of shape [B * next_n, max_model_len], dtype

Tensor

torch.float32.

Source code in vllm/attention/ops/rocm_aiter_mla_sparse.py
def rocm_fp8_paged_mqa_logits(
    q_fp8: torch.Tensor,
    kv_cache_fp8: torch.Tensor,
    weights: torch.Tensor,
    context_lens: torch.Tensor,
    block_tables: torch.Tensor,
    schedule_metadata: torch.Tensor,
    max_model_len: int,
) -> torch.Tensor:
    """Compute FP8 MQA logits using paged KV-cache.

    Args:
        q_fp8: Query tensor of shape [B, next_n, H, D]. Casted to
            `torch.float8_e4m3fn` by caller.
        kv_cache_fp8: Paged KV-cache in packed FP8+scale layout with shape
            [num_blocks, block_size, 1, D+4], dtype `torch.uint8`. The last
            4 bytes per (block,pos) store the `float` dequant scale.
        weights: Tensor of shape [B * next_n, H], dtype `torch.float32`.
        context_lens: Tensor of shape [B], dtype int32; effective context length
            for each batch element.
        block_tables: Tensor of shape [B, max_blocks], dtype int32; maps logical
            block indices to physical blocks in the paged cache.
        schedule_metadata: Returned by `get_paged_mqa_logits_metadata`;
            used to distribute work across SMs.
        max_model_len: Maximum sequence length used to size the logits output.

    Returns:
        Logits tensor of shape [B * next_n, max_model_len], dtype
        `torch.float32`.
    """

    if rocm_aiter_ops.is_enabled():
        from aiter.ops.triton.pa_mqa_logits import deepgemm_fp8_paged_mqa_logits_stage1

        batch_size, next_n, heads, _ = q_fp8.shape
        out_qk = torch.full(
            (heads, batch_size * next_n, max_model_len),
            float("-inf"),
            device="cuda",
            dtype=torch.float32,
        )
        deepgemm_fp8_paged_mqa_logits_stage1(
            q_fp8,
            kv_cache_fp8,
            weights,
            out_qk,
            context_lens,
            block_tables,
            max_model_len,
        )
        return out_qk.sum(dim=0)
    else:
        return fp8_paged_mqa_logits_torch(
            q_fp8, kv_cache_fp8, weights, context_lens, block_tables, max_model_len
        )