vllm.attention.ops.rocm_aiter_mla_sparse ¶
fp8_mqa_logits_torch ¶
fp8_mqa_logits_torch(
q: Tensor,
kv: tuple[Tensor, Tensor],
weights: Tensor,
cu_seqlen_ks: Tensor,
cu_seqlen_ke: Tensor,
) -> Tensor
Compute FP8 MQA logits for a single sequence without KV paging.
Parameters:
| Name | Type | Description | Default |
|---|---|---|---|
q | Tensor | Query tensor of shape [M, H, D]. Casted to | required |
kv | tuple[Tensor, Tensor] | Tuple | required |
weights | Tensor | weights of shape [M, H], dtype | required |
cu_seqlen_ks | Tensor | Start indices (inclusive) for valid K per query position, shape [M], dtype int32. | required |
cu_seqlen_ke | Tensor | End indices (exclusive) for valid K per query position, shape [M], dtype int32. | required |
Returns:
| Type | Description |
|---|---|
Tensor | Logits tensor of shape [M, N], dtype |
Source code in vllm/attention/ops/rocm_aiter_mla_sparse.py
fp8_paged_mqa_logits_torch ¶
fp8_paged_mqa_logits_torch(
q: Tensor,
kv_cache: Tensor,
weights: Tensor,
context_lens: Tensor,
block_tables: Tensor,
max_model_len: int,
)
Source code in vllm/attention/ops/rocm_aiter_mla_sparse.py
rocm_fp8_mqa_logits ¶
rocm_fp8_mqa_logits(
q: Tensor,
kv: tuple[Tensor, Tensor],
weights: Tensor,
cu_seqlen_ks: Tensor,
cu_seqlen_ke: Tensor,
) -> Tensor
Compute FP8 MQA logits for a single sequence without KV paging.
Parameters:
| Name | Type | Description | Default |
|---|---|---|---|
q | Tensor | Query tensor of shape [M, H, D]. Casted to | required |
kv | tuple[Tensor, Tensor] | Tuple | required |
weights | Tensor | weights of shape [M, H], dtype | required |
cu_seqlen_ks | Tensor | Start indices (inclusive) for valid K per query position, shape [M], dtype int32. | required |
cu_seqlen_ke | Tensor | End indices (exclusive) for valid K per query position, shape [M], dtype int32. | required |
Returns:
| Type | Description |
|---|---|
Tensor | Logits tensor of shape [M, N], dtype |
Source code in vllm/attention/ops/rocm_aiter_mla_sparse.py
rocm_fp8_paged_mqa_logits ¶
rocm_fp8_paged_mqa_logits(
q_fp8: Tensor,
kv_cache_fp8: Tensor,
weights: Tensor,
context_lens: Tensor,
block_tables: Tensor,
schedule_metadata: Tensor,
max_model_len: int,
) -> Tensor
Compute FP8 MQA logits using paged KV-cache.
Parameters:
| Name | Type | Description | Default |
|---|---|---|---|
q_fp8 | Tensor | Query tensor of shape [B, next_n, H, D]. Casted to | required |
kv_cache_fp8 | Tensor | Paged KV-cache in packed FP8+scale layout with shape [num_blocks, block_size, 1, D+4], dtype | required |
weights | Tensor | Tensor of shape [B * next_n, H], dtype | required |
context_lens | Tensor | Tensor of shape [B], dtype int32; effective context length for each batch element. | required |
block_tables | Tensor | Tensor of shape [B, max_blocks], dtype int32; maps logical block indices to physical blocks in the paged cache. | required |
schedule_metadata | Tensor | Returned by | required |
max_model_len | int | Maximum sequence length used to size the logits output. | required |
Returns:
| Type | Description |
|---|---|
Tensor | Logits tensor of shape [B * next_n, max_model_len], dtype |
Tensor |
|