Skip to content
Open
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
44 changes: 39 additions & 5 deletions lmdeploy/serve/processors/multimodal.py
Original file line number Diff line number Diff line change
@@ -1,12 +1,13 @@
# Copyright (c) OpenMMLab. All rights reserved.
import asyncio
from functools import partial
from typing import Any, Literal

import PIL

from lmdeploy.model import MODELS, BaseChatTemplate
from lmdeploy.tokenizer import Tokenizer
from lmdeploy.utils import get_logger
from lmdeploy.utils import await_executor_future, get_logger
from lmdeploy.vl.constants import Modality
from lmdeploy.vl.media.connection import load_from_url
from lmdeploy.vl.media.image import ImageMediaIO
Expand Down Expand Up @@ -37,6 +38,8 @@ def __init__(self,
self.chat_template = chat_template
self.vl_encoder = vl_encoder
self.backend = backend
# Gate CPU-heavy prompt prep so waiters yield to the server loop.
self.prompt_lock = asyncio.Lock()

@staticmethod
def merge_message_content(msg: dict) -> dict:
Expand Down Expand Up @@ -343,6 +346,31 @@ async def _get_text_prompt_input(self,
chat_template_kwargs: dict | None = None,
**kwargs):
"""Process text-only prompt and return prompt string and input_ids."""
loop = asyncio.get_event_loop()
async with self.prompt_lock:
future = loop.run_in_executor(
None,
partial(self._get_text_prompt_input_sync,
prompt=prompt,
do_preprocess=do_preprocess,
sequence_start=sequence_start,
adapter_name=adapter_name,
tools=tools,
reasoning_effort=reasoning_effort,
chat_template_kwargs=chat_template_kwargs,
**kwargs))
return await await_executor_future(future)

def _get_text_prompt_input_sync(self,
prompt: str | list[dict],
do_preprocess: bool,
sequence_start: bool,
adapter_name: str,
tools: list[object] | None = None,
reasoning_effort: Literal['low', 'medium', 'high'] | None = None,
chat_template_kwargs: dict | None = None,
**kwargs):
"""Render and tokenize a text prompt."""
# Change multimodal data to openai text messages
if isinstance(prompt, list):
prompt = [self.merge_message_content(msg) for msg in prompt]
Expand Down Expand Up @@ -392,10 +420,16 @@ async def _get_multimodal_prompt_input(self,
chat_template_kwargs=chat_template_kwargs)
elif self.backend == 'pytorch':
if self.vl_encoder._uses_new_preprocess:
input_prompt = self.vl_encoder.model.get_input_prompt(messages=messages,
chat_template=chat_template,
sequence_start=sequence_start,
chat_template_kwargs=chat_template_kwargs)
loop = asyncio.get_event_loop()
async with self.prompt_lock:
future = loop.run_in_executor(
None,
partial(self.vl_encoder.model.get_input_prompt,
messages=messages,
chat_template=chat_template,
sequence_start=sequence_start,
chat_template_kwargs=chat_template_kwargs))
input_prompt = await await_executor_future(future)
results = await self.vl_encoder.preprocess(messages, input_prompt, mm_processor_kwargs)
else:
results = await self.vl_encoder.preprocess(messages, mm_processor_kwargs)
Expand Down
12 changes: 12 additions & 0 deletions lmdeploy/utils.py
Original file line number Diff line number Diff line change
Expand Up @@ -15,6 +15,18 @@
logger_initialized = {}


async def await_executor_future(future: asyncio.Future):
"""Await executor work without releasing a lock before it finishes."""
try:
return await asyncio.shield(future)
except asyncio.CancelledError:
try:
await future
except BaseException:
pass
raise


class _ASNI_COLOR:
BRIGHT_RED = '\033[91m'
RED = '\033[31m'
Expand Down
72 changes: 44 additions & 28 deletions lmdeploy/vl/engine.py
Original file line number Diff line number Diff line change
Expand Up @@ -3,12 +3,13 @@
import asyncio
import inspect
from concurrent.futures import ThreadPoolExecutor
from functools import partial
from typing import Any

import torch

from lmdeploy.messages import PytorchEngineConfig, TurbomindEngineConfig, VisionConfig
from lmdeploy.utils import get_logger
from lmdeploy.utils import await_executor_future, get_logger
from lmdeploy.vl.model.builder import load_vl_model

logger = get_logger('lmdeploy')
Expand Down Expand Up @@ -44,6 +45,8 @@ def __init__(
self.vision_config = vision_config
self.max_batch_size = vision_config.max_batch_size
self.executor = ThreadPoolExecutor(max_workers=1)
# Gate VL executor submissions so waiters yield instead of queueing.
self.executor_lock = asyncio.Lock()
self._uses_new_preprocess = self._is_new_preprocess_api(self.model)
torch.cuda.empty_cache()

Expand All @@ -61,14 +64,14 @@ async def preprocess(self,
input_prompt: str | list[int] | None = None,
mm_processor_kwargs: dict[str, Any] | None = None) -> list[dict]:
"""Preprocess multimodal data in the messages."""
if self._uses_new_preprocess:
future = asyncio.get_event_loop().run_in_executor(
self.executor, self.model.preprocess, messages, input_prompt, mm_processor_kwargs)
else:
future = asyncio.get_event_loop().run_in_executor(
self.executor, self.model.preprocess, messages)
future.add_done_callback(_raise_exception_on_finish)
outputs = await future
async with self.executor_lock:
if self._uses_new_preprocess:
future = asyncio.get_event_loop().run_in_executor(
self.executor, self.model.preprocess, messages, input_prompt, mm_processor_kwargs)
else:
future = asyncio.get_event_loop().run_in_executor(self.executor, self.model.preprocess, messages)
future.add_done_callback(_raise_exception_on_finish)
outputs = await await_executor_future(future)
return outputs

async def async_infer(self, messages: list[dict]) -> list[dict]:
Expand All @@ -78,10 +81,11 @@ async def async_infer(self, messages: list[dict]) -> list[dict]:
messages (list[dict]): a list of message, which is the output
of `preprocess()`
"""
future = asyncio.get_event_loop().run_in_executor(self.executor, self.model.forward, messages,
self.max_batch_size)
future.add_done_callback(_raise_exception_on_finish)
outputs = await future
async with self.executor_lock:
future = asyncio.get_event_loop().run_in_executor(self.executor, self.model.forward, messages,
self.max_batch_size)
future.add_done_callback(_raise_exception_on_finish)
outputs = await await_executor_future(future)
return outputs

async def wrap_for_pytorch(
Expand Down Expand Up @@ -112,15 +116,21 @@ async def wrap_for_pytorch(
}
"""
has_input_ids = self.model.has_input_ids(messages)
if not has_input_ids:
result = self.model.to_pytorch(messages,
chat_template,
tokenizer,
sequence_start,
tools=tools,
chat_template_kwargs=chat_template_kwargs)
else:
result = self.model.to_pytorch_with_input_ids(messages)
loop = asyncio.get_event_loop()
async with self.executor_lock:
if not has_input_ids:
future = loop.run_in_executor(
self.executor,
partial(self.model.to_pytorch,
messages,
chat_template,
tokenizer,
sequence_start,
tools=tools,
chat_template_kwargs=chat_template_kwargs))
else:
future = loop.run_in_executor(self.executor, self.model.to_pytorch_with_input_ids, messages)
result = await await_executor_future(future)
# clear data
for i, message in enumerate(messages):
if isinstance(message['content'], list):
Expand Down Expand Up @@ -153,12 +163,18 @@ async def wrap_for_turbomind(
...
}
"""
result = self.model.to_turbomind(messages,
chat_template,
tokenizer,
sequence_start,
tools=tools,
chat_template_kwargs=chat_template_kwargs)
loop = asyncio.get_event_loop()
async with self.executor_lock:
future = loop.run_in_executor(
self.executor,
partial(self.model.to_turbomind,
messages,
chat_template,
tokenizer,
sequence_start,
tools=tools,
chat_template_kwargs=chat_template_kwargs))
result = await await_executor_future(future)
# clear data
for i, message in enumerate(messages):
if isinstance(message['content'], list):
Expand Down
87 changes: 87 additions & 0 deletions tests/test_lmdeploy/test_executor_cancellation.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,87 @@
import asyncio

import pytest

from lmdeploy.serve.processors import MultimodalProcessor
from lmdeploy.vl.engine import ImageEncoder


def test_prompt_lock_waits_for_executor_job_after_cancellation(monkeypatch):
"""Test cancelled prompt prep keeps the lock until executor work ends."""

async def run_case():
loop = asyncio.get_event_loop()
pending = loop.create_future()

class FakeChatTemplate:

def messages2prompt(self, *args, **kwargs):
return 'hello'

class FakeTokenizer:

def encode(self, *args, **kwargs):
return [1, 2, 3]

def fake_run_in_executor(*args, **kwargs):
return pending

monkeypatch.setattr(loop, 'run_in_executor', fake_run_in_executor)
processor = MultimodalProcessor(tokenizer=FakeTokenizer(), chat_template=FakeChatTemplate())

task = asyncio.create_task(
processor._get_text_prompt_input('hello',
do_preprocess=True,
sequence_start=True,
adapter_name=None))
await asyncio.sleep(0)
assert processor.prompt_lock.locked()

task.cancel()
await asyncio.sleep(0)
assert processor.prompt_lock.locked()

pending.set_result({'prompt': 'hello', 'input_ids': [1, 2, 3]})
with pytest.raises(asyncio.CancelledError):
await task
assert not processor.prompt_lock.locked()

asyncio.run(run_case())


def test_image_encoder_lock_waits_for_executor_job_after_cancellation(monkeypatch):
"""Test cancelled VL preprocess keeps the lock until executor work ends."""

async def run_case():
loop = asyncio.get_event_loop()
pending = loop.create_future()

class FakeModel:

def preprocess(self, messages):
return messages

def fake_run_in_executor(*args, **kwargs):
return pending

monkeypatch.setattr(loop, 'run_in_executor', fake_run_in_executor)
encoder = ImageEncoder.__new__(ImageEncoder)
encoder.model = FakeModel()
encoder.executor = None
encoder.executor_lock = asyncio.Lock()
encoder._uses_new_preprocess = False

task = asyncio.create_task(encoder.preprocess([{'content': 'hello'}]))
await asyncio.sleep(0)
assert encoder.executor_lock.locked()

task.cancel()
await asyncio.sleep(0)
assert encoder.executor_lock.locked()

pending.set_result([{'content': 'hello'}])
with pytest.raises(asyncio.CancelledError):
await task
assert not encoder.executor_lock.locked()

asyncio.run(run_case())
Loading