Skip to content
Open
Show file tree
Hide file tree
Changes from 1 commit
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
39 changes: 35 additions & 4 deletions lmdeploy/serve/processors/multimodal.py
Original file line number Diff line number Diff line change
@@ -1,5 +1,6 @@
# Copyright (c) OpenMMLab. All rights reserved.
import asyncio
from functools import partial
from typing import Any, Literal

import PIL
Expand Down Expand Up @@ -37,6 +38,7 @@ def __init__(self,
self.chat_template = chat_template
self.vl_encoder = vl_encoder
self.backend = backend
self.prompt_lock = asyncio.Lock()

@staticmethod
def merge_message_content(msg: dict) -> dict:
Expand Down Expand Up @@ -343,6 +345,30 @@ async def _get_text_prompt_input(self,
chat_template_kwargs: dict | None = None,
**kwargs):
"""Process text-only prompt and return prompt string and input_ids."""
loop = asyncio.get_event_loop()
async with self.prompt_lock:
return await loop.run_in_executor(
None,
partial(self._get_text_prompt_input_sync,
prompt=prompt,
do_preprocess=do_preprocess,
sequence_start=sequence_start,
adapter_name=adapter_name,
tools=tools,
reasoning_effort=reasoning_effort,
chat_template_kwargs=chat_template_kwargs,
**kwargs))
Comment thread
CUHKSZzxy marked this conversation as resolved.

def _get_text_prompt_input_sync(self,
prompt: str | list[dict],
do_preprocess: bool,
sequence_start: bool,
adapter_name: str,
tools: list[object] | None = None,
reasoning_effort: Literal['low', 'medium', 'high'] | None = None,
chat_template_kwargs: dict | None = None,
**kwargs):
"""Render and tokenize a text prompt."""
# Change multimodal data to openai text messages
if isinstance(prompt, list):
prompt = [self.merge_message_content(msg) for msg in prompt]
Expand Down Expand Up @@ -392,10 +418,15 @@ async def _get_multimodal_prompt_input(self,
chat_template_kwargs=chat_template_kwargs)
elif self.backend == 'pytorch':
if self.vl_encoder._uses_new_preprocess:
input_prompt = self.vl_encoder.model.get_input_prompt(messages=messages,
chat_template=chat_template,
sequence_start=sequence_start,
chat_template_kwargs=chat_template_kwargs)
loop = asyncio.get_event_loop()
async with self.prompt_lock:
input_prompt = await loop.run_in_executor(
None,
partial(self.vl_encoder.model.get_input_prompt,
messages=messages,
chat_template=chat_template,
sequence_start=sequence_start,
chat_template_kwargs=chat_template_kwargs))
Comment thread
CUHKSZzxy marked this conversation as resolved.
Outdated
results = await self.vl_encoder.preprocess(messages, input_prompt, mm_processor_kwargs)
else:
results = await self.vl_encoder.preprocess(messages, mm_processor_kwargs)
Expand Down
67 changes: 40 additions & 27 deletions lmdeploy/vl/engine.py
Original file line number Diff line number Diff line change
Expand Up @@ -3,6 +3,7 @@
import asyncio
import inspect
from concurrent.futures import ThreadPoolExecutor
from functools import partial
from typing import Any

import torch
Expand Down Expand Up @@ -44,6 +45,7 @@ def __init__(
self.vision_config = vision_config
self.max_batch_size = vision_config.max_batch_size
self.executor = ThreadPoolExecutor(max_workers=1)
self.executor_lock = asyncio.Lock()
self._uses_new_preprocess = self._is_new_preprocess_api(self.model)
torch.cuda.empty_cache()

Expand All @@ -61,14 +63,14 @@ async def preprocess(self,
input_prompt: str | list[int] | None = None,
mm_processor_kwargs: dict[str, Any] | None = None) -> list[dict]:
"""Preprocess multimodal data in the messages."""
if self._uses_new_preprocess:
future = asyncio.get_event_loop().run_in_executor(
self.executor, self.model.preprocess, messages, input_prompt, mm_processor_kwargs)
else:
future = asyncio.get_event_loop().run_in_executor(
self.executor, self.model.preprocess, messages)
future.add_done_callback(_raise_exception_on_finish)
outputs = await future
async with self.executor_lock:
if self._uses_new_preprocess:
future = asyncio.get_event_loop().run_in_executor(
self.executor, self.model.preprocess, messages, input_prompt, mm_processor_kwargs)
else:
future = asyncio.get_event_loop().run_in_executor(self.executor, self.model.preprocess, messages)
future.add_done_callback(_raise_exception_on_finish)
outputs = await future
Comment thread
CUHKSZzxy marked this conversation as resolved.
Outdated
return outputs

async def async_infer(self, messages: list[dict]) -> list[dict]:
Expand All @@ -78,10 +80,11 @@ async def async_infer(self, messages: list[dict]) -> list[dict]:
messages (list[dict]): a list of message, which is the output
of `preprocess()`
"""
future = asyncio.get_event_loop().run_in_executor(self.executor, self.model.forward, messages,
self.max_batch_size)
future.add_done_callback(_raise_exception_on_finish)
outputs = await future
async with self.executor_lock:
future = asyncio.get_event_loop().run_in_executor(self.executor, self.model.forward, messages,
self.max_batch_size)
future.add_done_callback(_raise_exception_on_finish)
outputs = await future
return outputs

async def wrap_for_pytorch(
Expand Down Expand Up @@ -112,15 +115,20 @@ async def wrap_for_pytorch(
}
"""
has_input_ids = self.model.has_input_ids(messages)
if not has_input_ids:
result = self.model.to_pytorch(messages,
chat_template,
tokenizer,
sequence_start,
tools=tools,
chat_template_kwargs=chat_template_kwargs)
else:
result = self.model.to_pytorch_with_input_ids(messages)
loop = asyncio.get_event_loop()
async with self.executor_lock:
if not has_input_ids:
result = await loop.run_in_executor(
self.executor,
partial(self.model.to_pytorch,
messages,
chat_template,
tokenizer,
sequence_start,
tools=tools,
chat_template_kwargs=chat_template_kwargs))
else:
result = await loop.run_in_executor(self.executor, self.model.to_pytorch_with_input_ids, messages)
Comment thread
CUHKSZzxy marked this conversation as resolved.
Outdated
# clear data
for i, message in enumerate(messages):
if isinstance(message['content'], list):
Expand Down Expand Up @@ -153,12 +161,17 @@ async def wrap_for_turbomind(
...
}
"""
result = self.model.to_turbomind(messages,
chat_template,
tokenizer,
sequence_start,
tools=tools,
chat_template_kwargs=chat_template_kwargs)
loop = asyncio.get_event_loop()
async with self.executor_lock:
result = await loop.run_in_executor(
self.executor,
partial(self.model.to_turbomind,
messages,
chat_template,
tokenizer,
sequence_start,
tools=tools,
chat_template_kwargs=chat_template_kwargs))
# clear data
for i, message in enumerate(messages):
if isinstance(message['content'], list):
Expand Down
Loading