Skip to content

vllm.v1.worker.kv_connector_model_runner_mixin

Define KV connector functionality mixin for model runners.

logger module-attribute

logger = init_logger(__name__)

KVConnectorModelRunnerMixin

Source code in vllm/v1/worker/kv_connector_model_runner_mixin.py
class KVConnectorModelRunnerMixin:
    @staticmethod
    def maybe_setup_kv_connector(scheduler_output: "SchedulerOutput"):
        # Update KVConnector with the KVConnector metadata forward().
        if has_kv_transfer_group():
            kv_connector = get_kv_transfer_group()
            assert isinstance(kv_connector, KVConnectorBase)
            assert scheduler_output.kv_connector_metadata is not None
            kv_connector.bind_connector_metadata(scheduler_output.kv_connector_metadata)

            # Background KV cache transfers happen here.
            # These transfers are designed to be async and the requests
            # involved may be disjoint from the running requests.
            # Do this here to save a collective_rpc.
            kv_connector.start_load_kv(get_forward_context())

    @staticmethod
    def ensure_kv_transfer_shutdown() -> None:
        # has_kv_transfer_group can be None during interpreter shutdown.
        if has_kv_transfer_group and has_kv_transfer_group():  # type: ignore[truthy-function]
            ensure_kv_transfer_shutdown()

    @staticmethod
    def maybe_wait_for_kv_save() -> None:
        if has_kv_transfer_group():
            get_kv_transfer_group().wait_for_save()

    @staticmethod
    def get_finished_kv_transfers(
        scheduler_output: "SchedulerOutput",
    ) -> tuple[set[str] | None, set[str] | None]:
        if has_kv_transfer_group():
            return get_kv_transfer_group().get_finished(
                scheduler_output.finished_req_ids
            )
        return None, None

    @staticmethod
    def kv_connector_no_forward(
        scheduler_output: "SchedulerOutput", vllm_config: VllmConfig
    ) -> ModelRunnerOutput:
        # KV send/recv even if no work to do.
        with (
            set_forward_context(None, vllm_config),
            KVConnectorModelRunnerMixin._get_kv_connector_output(
                scheduler_output, wait_for_save=False
            ) as kv_connector_output,
        ):
            pass

        if kv_connector_output.is_empty():
            return EMPTY_MODEL_RUNNER_OUTPUT

        output = copy.copy(EMPTY_MODEL_RUNNER_OUTPUT)
        output.kv_connector_output = kv_connector_output
        return output

    @staticmethod
    def maybe_get_kv_connector_output(
        scheduler_output: "SchedulerOutput",
    ) -> AbstractContextManager[KVConnectorOutput | None]:
        return (
            KVConnectorModelRunnerMixin._get_kv_connector_output(scheduler_output)
            if has_kv_transfer_group()
            else nullcontext()
        )

    # This context manager must be used within an active forward context.
    # It encapsulates the entire KV connector lifecycle within execute_model
    @staticmethod
    @contextmanager
    def _get_kv_connector_output(
        scheduler_output: "SchedulerOutput", wait_for_save: bool = True
    ) -> Generator[KVConnectorOutput, None, None]:
        output = KVConnectorOutput()

        # Update KVConnector with the KVConnector metadata forward().
        kv_connector = get_kv_transfer_group()
        assert isinstance(kv_connector, KVConnectorBase)
        assert scheduler_output.kv_connector_metadata is not None
        kv_connector.bind_connector_metadata(scheduler_output.kv_connector_metadata)

        # Background KV cache transfers happen here.
        # These transfers are designed to be async and the requests
        # involved may be disjoint from the running requests.
        # Do this here to save a collective_rpc.
        kv_connector.start_load_kv(get_forward_context())
        try:
            yield output
        finally:
            if wait_for_save:
                kv_connector.wait_for_save()

            output.finished_sending, output.finished_recving = (
                kv_connector.get_finished(scheduler_output.finished_req_ids)
            )
            output.invalid_block_ids = kv_connector.get_block_ids_with_load_errors()

            output.kv_connector_stats = (
                KVConnectorModelRunnerMixin.get_kv_connector_stats()
            )
            kv_connector.clear_connector_metadata()

    @staticmethod
    def get_kv_connector_stats() -> KVConnectorStats | None:
        if has_kv_transfer_group():
            return get_kv_transfer_group().get_kv_connector_stats()
        return None

    @staticmethod
    def use_uniform_kv_cache(
        attn_groups: list[list[AttentionGroup]],
        cache_dtype: CacheDType,
    ) -> bool:
        """
        Determines whether a uniform KV layout should be used.
        A uniform layout means all layers KV caches will share the same
        underlying tensor, where for a given block number, the respective
        KV data for all layers will be contiguous.
        This will allow efficient KV transfer of per-block KV data for all
        layers at once.
        Note this layout will only be applied given 3 conditions:
        1. The KV Cache config contains just a single group where all layers
            have the same page size.
        2. A KV connector is configured, and the KV connector instance prefers
            to use this layout (prefer_cross_layer_blocks() returns True)
        2. The flash attention backend supports this layout
            (get_kv_cache_stride_order(True) includes a placement for a
            num_layers dimension)

        Note that the actual placement of the num_layers dimensions
        in the unified layers tensors will be determined by the attention
        backend.
        Thus, the layers KV data may still not be contiguous per block
        if the attention backend does not support it.

        Args:
            attn_groups: The list of attention groups for this model
            cache_dtype: The KV cache dtype
        Returns:
            True if we should use a uniform KV cache layout.
        """

        if not has_kv_transfer_group():
            return False
        if not get_kv_transfer_group().prefer_cross_layer_blocks:
            return False

        if len(attn_groups) != 1 or len(attn_groups[0]) != 1:
            return False

        attn_group = attn_groups[0][0]
        kv_cache_spec = attn_group.kv_cache_spec
        if not isinstance(kv_cache_spec, AttentionSpec):
            return False

        attn_backend = attn_group.backend
        kv_cache_shape = attn_backend.get_kv_cache_shape(
            1234,
            kv_cache_spec.block_size,
            kv_cache_spec.num_kv_heads,
            kv_cache_spec.head_size,
            cache_dtype_str=cache_dtype,
        )

        try:
            kv_cache_stride_order = attn_backend.get_kv_cache_stride_order(
                include_num_layers_dimension=True
            )
        except (AttributeError, NotImplementedError):
            return False

        # check that attention backend include a layers dimension
        return len(kv_cache_stride_order) == len(kv_cache_shape) + 1

    @staticmethod
    def allocate_uniform_kv_caches(
        kv_cache_config: KVCacheConfig,
        attn_groups: list[list[AttentionGroup]],
        cache_dtype: CacheDType,
        device: torch.device,
        kernel_block_sizes: list[int],
    ) -> tuple[dict[str, torch.Tensor], torch.Tensor, type[AttentionBackend]]:
        """
        Initializes and reshapes KV caches for the simple case where all
        layers have the same layout.

        This function assumes use_uniform_kv_cache() returned True.

        Args:
            kv_cache_config: The KV cache config
            attn_groups: The list of attention groups for this model
            cache_dtype: The KV cache dtype
            device: The torch device to allocate on.
            kernel_block_sizes: The kernel block sizes for each KV cache group.
        Returns:
            A tuple (kv_caches, cross_layers_kv_cache, attn_backend) where:
                kv_caches is a dict mapping between layer names to their
                    corresponding memory buffer for KV cache.
                cross_layers_kv_cache is the cross layers kv cache tensor
                attn_backend is the attention backend matching this tensor
        """
        attn_group = attn_groups[0][0]
        kv_cache_spec = attn_group.kv_cache_spec
        assert isinstance(kv_cache_spec, AttentionSpec)

        tensor_sizes = set(
            kv_cache_tensor.size for kv_cache_tensor in kv_cache_config.kv_cache_tensors
        )
        assert len(tensor_sizes) == 1
        tensor_size = tensor_sizes.pop()

        page_size = kv_cache_spec.page_size_bytes
        assert tensor_size % page_size == 0
        num_blocks = tensor_size // page_size
        num_layers = len(kv_cache_config.kv_cache_tensors)
        total_size = tensor_size * num_layers

        assert len(kernel_block_sizes) == 1
        kernel_block_size = kernel_block_sizes[0]
        num_blocks_per_kv_block = kv_cache_spec.block_size // kernel_block_size
        kernel_num_blocks = num_blocks * num_blocks_per_kv_block

        attn_backend = attn_group.backend
        kv_cache_shape = attn_backend.get_kv_cache_shape(
            kernel_num_blocks,
            kernel_block_size,
            kv_cache_spec.num_kv_heads,
            kv_cache_spec.head_size,
            cache_dtype_str=cache_dtype,
        )

        # prepend a num_layers dimension into the shape
        kv_cache_shape = (num_layers,) + kv_cache_shape

        try:
            kv_cache_stride_order = attn_backend.get_kv_cache_stride_order(
                include_num_layers_dimension=True
            )
            assert len(kv_cache_stride_order) == len(kv_cache_shape)
        except (AttributeError, NotImplementedError):
            kv_cache_stride_order = tuple(range(len(kv_cache_shape)))

        kv_cache_shape = tuple(kv_cache_shape[i] for i in kv_cache_stride_order)

        logger.info("Allocating a cross layer KV cache of shape %s", kv_cache_shape)

        # allocate one contiguous buffer for all layers
        cross_layers_kv_cache = (
            torch.zeros(total_size, dtype=torch.int8, device=device)
            .view(kv_cache_spec.dtype)
            .view(kv_cache_shape)
        )

        # Maintain original KV shape view.
        inv_order = [
            kv_cache_stride_order.index(i) for i in range(len(kv_cache_stride_order))
        ]
        permuted_kv_cache = cross_layers_kv_cache.permute(*inv_order)

        kv_caches = {}
        for i, kv_cache_tensor in enumerate(kv_cache_config.kv_cache_tensors):
            tensor = permuted_kv_cache[i]
            for layer_name in kv_cache_tensor.shared_by:
                kv_caches[layer_name] = tensor

        return kv_caches, cross_layers_kv_cache, attn_backend

_get_kv_connector_output staticmethod

_get_kv_connector_output(
    scheduler_output: SchedulerOutput,
    wait_for_save: bool = True,
) -> Generator[KVConnectorOutput, None, None]
Source code in vllm/v1/worker/kv_connector_model_runner_mixin.py
@staticmethod
@contextmanager
def _get_kv_connector_output(
    scheduler_output: "SchedulerOutput", wait_for_save: bool = True
) -> Generator[KVConnectorOutput, None, None]:
    output = KVConnectorOutput()

    # Update KVConnector with the KVConnector metadata forward().
    kv_connector = get_kv_transfer_group()
    assert isinstance(kv_connector, KVConnectorBase)
    assert scheduler_output.kv_connector_metadata is not None
    kv_connector.bind_connector_metadata(scheduler_output.kv_connector_metadata)

    # Background KV cache transfers happen here.
    # These transfers are designed to be async and the requests
    # involved may be disjoint from the running requests.
    # Do this here to save a collective_rpc.
    kv_connector.start_load_kv(get_forward_context())
    try:
        yield output
    finally:
        if wait_for_save:
            kv_connector.wait_for_save()

        output.finished_sending, output.finished_recving = (
            kv_connector.get_finished(scheduler_output.finished_req_ids)
        )
        output.invalid_block_ids = kv_connector.get_block_ids_with_load_errors()

        output.kv_connector_stats = (
            KVConnectorModelRunnerMixin.get_kv_connector_stats()
        )
        kv_connector.clear_connector_metadata()

allocate_uniform_kv_caches staticmethod

allocate_uniform_kv_caches(
    kv_cache_config: KVCacheConfig,
    attn_groups: list[list[AttentionGroup]],
    cache_dtype: CacheDType,
    device: device,
    kernel_block_sizes: list[int],
) -> tuple[
    dict[str, Tensor], Tensor, type[AttentionBackend]
]

Initializes and reshapes KV caches for the simple case where all layers have the same layout.

This function assumes use_uniform_kv_cache() returned True.

Parameters:

Name Type Description Default
kv_cache_config KVCacheConfig

The KV cache config

required
attn_groups list[list[AttentionGroup]]

The list of attention groups for this model

required
cache_dtype CacheDType

The KV cache dtype

required
device device

The torch device to allocate on.

required
kernel_block_sizes list[int]

The kernel block sizes for each KV cache group.

required

Returns: A tuple (kv_caches, cross_layers_kv_cache, attn_backend) where: kv_caches is a dict mapping between layer names to their corresponding memory buffer for KV cache. cross_layers_kv_cache is the cross layers kv cache tensor attn_backend is the attention backend matching this tensor

Source code in vllm/v1/worker/kv_connector_model_runner_mixin.py
@staticmethod
def allocate_uniform_kv_caches(
    kv_cache_config: KVCacheConfig,
    attn_groups: list[list[AttentionGroup]],
    cache_dtype: CacheDType,
    device: torch.device,
    kernel_block_sizes: list[int],
) -> tuple[dict[str, torch.Tensor], torch.Tensor, type[AttentionBackend]]:
    """
    Initializes and reshapes KV caches for the simple case where all
    layers have the same layout.

    This function assumes use_uniform_kv_cache() returned True.

    Args:
        kv_cache_config: The KV cache config
        attn_groups: The list of attention groups for this model
        cache_dtype: The KV cache dtype
        device: The torch device to allocate on.
        kernel_block_sizes: The kernel block sizes for each KV cache group.
    Returns:
        A tuple (kv_caches, cross_layers_kv_cache, attn_backend) where:
            kv_caches is a dict mapping between layer names to their
                corresponding memory buffer for KV cache.
            cross_layers_kv_cache is the cross layers kv cache tensor
            attn_backend is the attention backend matching this tensor
    """
    attn_group = attn_groups[0][0]
    kv_cache_spec = attn_group.kv_cache_spec
    assert isinstance(kv_cache_spec, AttentionSpec)

    tensor_sizes = set(
        kv_cache_tensor.size for kv_cache_tensor in kv_cache_config.kv_cache_tensors
    )
    assert len(tensor_sizes) == 1
    tensor_size = tensor_sizes.pop()

    page_size = kv_cache_spec.page_size_bytes
    assert tensor_size % page_size == 0
    num_blocks = tensor_size // page_size
    num_layers = len(kv_cache_config.kv_cache_tensors)
    total_size = tensor_size * num_layers

    assert len(kernel_block_sizes) == 1
    kernel_block_size = kernel_block_sizes[0]
    num_blocks_per_kv_block = kv_cache_spec.block_size // kernel_block_size
    kernel_num_blocks = num_blocks * num_blocks_per_kv_block

    attn_backend = attn_group.backend
    kv_cache_shape = attn_backend.get_kv_cache_shape(
        kernel_num_blocks,
        kernel_block_size,
        kv_cache_spec.num_kv_heads,
        kv_cache_spec.head_size,
        cache_dtype_str=cache_dtype,
    )

    # prepend a num_layers dimension into the shape
    kv_cache_shape = (num_layers,) + kv_cache_shape

    try:
        kv_cache_stride_order = attn_backend.get_kv_cache_stride_order(
            include_num_layers_dimension=True
        )
        assert len(kv_cache_stride_order) == len(kv_cache_shape)
    except (AttributeError, NotImplementedError):
        kv_cache_stride_order = tuple(range(len(kv_cache_shape)))

    kv_cache_shape = tuple(kv_cache_shape[i] for i in kv_cache_stride_order)

    logger.info("Allocating a cross layer KV cache of shape %s", kv_cache_shape)

    # allocate one contiguous buffer for all layers
    cross_layers_kv_cache = (
        torch.zeros(total_size, dtype=torch.int8, device=device)
        .view(kv_cache_spec.dtype)
        .view(kv_cache_shape)
    )

    # Maintain original KV shape view.
    inv_order = [
        kv_cache_stride_order.index(i) for i in range(len(kv_cache_stride_order))
    ]
    permuted_kv_cache = cross_layers_kv_cache.permute(*inv_order)

    kv_caches = {}
    for i, kv_cache_tensor in enumerate(kv_cache_config.kv_cache_tensors):
        tensor = permuted_kv_cache[i]
        for layer_name in kv_cache_tensor.shared_by:
            kv_caches[layer_name] = tensor

    return kv_caches, cross_layers_kv_cache, attn_backend

ensure_kv_transfer_shutdown staticmethod

ensure_kv_transfer_shutdown() -> None
Source code in vllm/v1/worker/kv_connector_model_runner_mixin.py
@staticmethod
def ensure_kv_transfer_shutdown() -> None:
    # has_kv_transfer_group can be None during interpreter shutdown.
    if has_kv_transfer_group and has_kv_transfer_group():  # type: ignore[truthy-function]
        ensure_kv_transfer_shutdown()

get_finished_kv_transfers staticmethod

get_finished_kv_transfers(
    scheduler_output: SchedulerOutput,
) -> tuple[set[str] | None, set[str] | None]
Source code in vllm/v1/worker/kv_connector_model_runner_mixin.py
@staticmethod
def get_finished_kv_transfers(
    scheduler_output: "SchedulerOutput",
) -> tuple[set[str] | None, set[str] | None]:
    if has_kv_transfer_group():
        return get_kv_transfer_group().get_finished(
            scheduler_output.finished_req_ids
        )
    return None, None

get_kv_connector_stats staticmethod

get_kv_connector_stats() -> KVConnectorStats | None
Source code in vllm/v1/worker/kv_connector_model_runner_mixin.py
@staticmethod
def get_kv_connector_stats() -> KVConnectorStats | None:
    if has_kv_transfer_group():
        return get_kv_transfer_group().get_kv_connector_stats()
    return None

kv_connector_no_forward staticmethod

kv_connector_no_forward(
    scheduler_output: SchedulerOutput,
    vllm_config: VllmConfig,
) -> ModelRunnerOutput
Source code in vllm/v1/worker/kv_connector_model_runner_mixin.py
@staticmethod
def kv_connector_no_forward(
    scheduler_output: "SchedulerOutput", vllm_config: VllmConfig
) -> ModelRunnerOutput:
    # KV send/recv even if no work to do.
    with (
        set_forward_context(None, vllm_config),
        KVConnectorModelRunnerMixin._get_kv_connector_output(
            scheduler_output, wait_for_save=False
        ) as kv_connector_output,
    ):
        pass

    if kv_connector_output.is_empty():
        return EMPTY_MODEL_RUNNER_OUTPUT

    output = copy.copy(EMPTY_MODEL_RUNNER_OUTPUT)
    output.kv_connector_output = kv_connector_output
    return output

maybe_get_kv_connector_output staticmethod

maybe_get_kv_connector_output(
    scheduler_output: SchedulerOutput,
) -> AbstractContextManager[KVConnectorOutput | None]
Source code in vllm/v1/worker/kv_connector_model_runner_mixin.py
@staticmethod
def maybe_get_kv_connector_output(
    scheduler_output: "SchedulerOutput",
) -> AbstractContextManager[KVConnectorOutput | None]:
    return (
        KVConnectorModelRunnerMixin._get_kv_connector_output(scheduler_output)
        if has_kv_transfer_group()
        else nullcontext()
    )

maybe_setup_kv_connector staticmethod

maybe_setup_kv_connector(scheduler_output: SchedulerOutput)
Source code in vllm/v1/worker/kv_connector_model_runner_mixin.py
@staticmethod
def maybe_setup_kv_connector(scheduler_output: "SchedulerOutput"):
    # Update KVConnector with the KVConnector metadata forward().
    if has_kv_transfer_group():
        kv_connector = get_kv_transfer_group()
        assert isinstance(kv_connector, KVConnectorBase)
        assert scheduler_output.kv_connector_metadata is not None
        kv_connector.bind_connector_metadata(scheduler_output.kv_connector_metadata)

        # Background KV cache transfers happen here.
        # These transfers are designed to be async and the requests
        # involved may be disjoint from the running requests.
        # Do this here to save a collective_rpc.
        kv_connector.start_load_kv(get_forward_context())

maybe_wait_for_kv_save staticmethod

maybe_wait_for_kv_save() -> None
Source code in vllm/v1/worker/kv_connector_model_runner_mixin.py
@staticmethod
def maybe_wait_for_kv_save() -> None:
    if has_kv_transfer_group():
        get_kv_transfer_group().wait_for_save()

use_uniform_kv_cache staticmethod

use_uniform_kv_cache(
    attn_groups: list[list[AttentionGroup]],
    cache_dtype: CacheDType,
) -> bool

Determines whether a uniform KV layout should be used. A uniform layout means all layers KV caches will share the same underlying tensor, where for a given block number, the respective KV data for all layers will be contiguous. This will allow efficient KV transfer of per-block KV data for all layers at once. Note this layout will only be applied given 3 conditions: 1. The KV Cache config contains just a single group where all layers have the same page size. 2. A KV connector is configured, and the KV connector instance prefers to use this layout (prefer_cross_layer_blocks() returns True) 2. The flash attention backend supports this layout (get_kv_cache_stride_order(True) includes a placement for a num_layers dimension)

Note that the actual placement of the num_layers dimensions in the unified layers tensors will be determined by the attention backend. Thus, the layers KV data may still not be contiguous per block if the attention backend does not support it.

Parameters:

Name Type Description Default
attn_groups list[list[AttentionGroup]]

The list of attention groups for this model

required
cache_dtype CacheDType

The KV cache dtype

required

Returns: True if we should use a uniform KV cache layout.

Source code in vllm/v1/worker/kv_connector_model_runner_mixin.py
@staticmethod
def use_uniform_kv_cache(
    attn_groups: list[list[AttentionGroup]],
    cache_dtype: CacheDType,
) -> bool:
    """
    Determines whether a uniform KV layout should be used.
    A uniform layout means all layers KV caches will share the same
    underlying tensor, where for a given block number, the respective
    KV data for all layers will be contiguous.
    This will allow efficient KV transfer of per-block KV data for all
    layers at once.
    Note this layout will only be applied given 3 conditions:
    1. The KV Cache config contains just a single group where all layers
        have the same page size.
    2. A KV connector is configured, and the KV connector instance prefers
        to use this layout (prefer_cross_layer_blocks() returns True)
    2. The flash attention backend supports this layout
        (get_kv_cache_stride_order(True) includes a placement for a
        num_layers dimension)

    Note that the actual placement of the num_layers dimensions
    in the unified layers tensors will be determined by the attention
    backend.
    Thus, the layers KV data may still not be contiguous per block
    if the attention backend does not support it.

    Args:
        attn_groups: The list of attention groups for this model
        cache_dtype: The KV cache dtype
    Returns:
        True if we should use a uniform KV cache layout.
    """

    if not has_kv_transfer_group():
        return False
    if not get_kv_transfer_group().prefer_cross_layer_blocks:
        return False

    if len(attn_groups) != 1 or len(attn_groups[0]) != 1:
        return False

    attn_group = attn_groups[0][0]
    kv_cache_spec = attn_group.kv_cache_spec
    if not isinstance(kv_cache_spec, AttentionSpec):
        return False

    attn_backend = attn_group.backend
    kv_cache_shape = attn_backend.get_kv_cache_shape(
        1234,
        kv_cache_spec.block_size,
        kv_cache_spec.num_kv_heads,
        kv_cache_spec.head_size,
        cache_dtype_str=cache_dtype,
    )

    try:
        kv_cache_stride_order = attn_backend.get_kv_cache_stride_order(
            include_num_layers_dimension=True
        )
    except (AttributeError, NotImplementedError):
        return False

    # check that attention backend include a layers dimension
    return len(kv_cache_stride_order) == len(kv_cache_shape) + 1