From 654dde940df3d02d07d5ee7426add48300ad374a Mon Sep 17 00:00:00 2001 From: ZhijunLStudio <335022969@qq.com> Date: Thu, 23 Apr 2026 14:51:38 +0800 Subject: [PATCH] feat: Implement /v1/embeddings endpoint for OpenAI-compatible API Add support for the standard OpenAI embeddings endpoint that extracts last hidden states from the model and applies mean pooling. This enables downstream tools (LangChain, LlamaIndex, RAG pipelines) to use lmdeploy for text embedding generation. Changes: - Replace stub /v1/embeddings with full implementation supporting float and base64 encoding formats - Thread last_hidden_states through the PyTorch engine pipeline (BatchedOutputs -> InferOutput -> EngineOutput) - Capture full-sequence hidden states before postprocessing slices to last token, and mean pool per-sequence in the engine - Pass do_preprocess=False to avoid chat template being applied - Sum prompt_tokens across all inputs instead of overwriting - Default model to async_engine.model_name when not provided - Use little-endian format for base64 encoding (< prefix) - Handle finish_reason='error' frames from engine - Add unit tests for v1/embeddings endpoint --- lmdeploy/pytorch/engine/engine.py | 1 + lmdeploy/pytorch/engine/engine_instance.py | 4 +- lmdeploy/pytorch/engine/engine_loop.py | 5 + lmdeploy/pytorch/engine/inputs_maker.py | 6 + lmdeploy/pytorch/engine/model_agent/agent.py | 27 +++ lmdeploy/pytorch/messages.py | 12 +- lmdeploy/serve/openai/api_server.py | 75 ++++++- lmdeploy/serve/openai/protocol.py | 1 + tests/test_lmdeploy/serve/__init__.py | 0 tests/test_lmdeploy/serve/test_api_server.py | 206 +++++++++++++++++++ 10 files changed, 332 insertions(+), 5 deletions(-) create mode 100644 tests/test_lmdeploy/serve/__init__.py create mode 100644 tests/test_lmdeploy/serve/test_api_server.py diff --git a/lmdeploy/pytorch/engine/engine.py b/lmdeploy/pytorch/engine/engine.py index ed667fe7c2..cfc6d2d6e6 100644 --- a/lmdeploy/pytorch/engine/engine.py +++ b/lmdeploy/pytorch/engine/engine.py @@ -47,6 +47,7 @@ class InferOutput: finish: bool = False logits: torch.Tensor = None logprobs: torch.Tensor = None + last_hidden_states: torch.Tensor = None # send cache blocks back for migration in Disaggregated LLM Serving # when Prefill Engine is Done. diff --git a/lmdeploy/pytorch/engine/engine_instance.py b/lmdeploy/pytorch/engine/engine_instance.py index c08595deaf..817a095a35 100644 --- a/lmdeploy/pytorch/engine/engine_instance.py +++ b/lmdeploy/pytorch/engine/engine_instance.py @@ -240,22 +240,24 @@ async def async_stream_infer(self, resp_data = resp.data token_ids = [] logits = None + last_hidden_states = None if resp_data is not None: # request might be cancelled before any output logits = resp_data.get('logits', None) gen_token_ids = resp_data.get('token_ids', None) if gen_token_ids is not None: token_ids = gen_token_ids[output_offset:].tolist() + last_hidden_states = resp_data.get('last_hidden_states', None) num_ids = len(token_ids) num_all_ids = prompt_ids_len + output_offset + num_ids extra_outputs = self._get_extra_outputs(resp, num_all_ids) routed_experts = extra_outputs.get('routed_experts', None) - logger.debug(f'session[{session_id}] finish: num_out_ids={num_ids}.') yield EngineOutput(resp.type, token_ids, logits=logits, + last_hidden_state=last_hidden_states, cache_block_ids=cache_block_ids, req_metrics=req_metrics, routed_experts=routed_experts, diff --git a/lmdeploy/pytorch/engine/engine_loop.py b/lmdeploy/pytorch/engine/engine_loop.py index 96367969db..fbea8a556b 100644 --- a/lmdeploy/pytorch/engine/engine_loop.py +++ b/lmdeploy/pytorch/engine/engine_loop.py @@ -207,6 +207,7 @@ def _send_resp(self, out: InferOutput): resp_type, data=dict(token_ids=out.token_ids, logits=out.logits, + last_hidden_states=out.last_hidden_states, cache_block_ids=out.cache_block_ids, req_metrics=out.req_metrics, routed_experts=out.routed_experts, @@ -297,6 +298,7 @@ def __get_logprobs(batched_outputs: 'BatchedOutputs'): logits = batched_outputs.logits all_routed_experts = batched_outputs.all_routed_experts + all_hidden_states = batched_outputs.last_hidden_states if model_inputs is not None and (model_inputs.is_chunk and not model_inputs.is_last_chunk): # chunk long context does not need to update seqs and outputs @@ -363,6 +365,9 @@ def __get_logprobs(batched_outputs: 'BatchedOutputs'): if msg.return_logits: logit = __get_logit(msg, logits, seq_length, idx) outputs[session_id].logits = logit + + if msg.return_last_hidden_states and all_hidden_states is not None: + outputs[session_id].last_hidden_states = all_hidden_states[idx] return outputs async def _main_loop_try_send_next_inputs(self): diff --git a/lmdeploy/pytorch/engine/inputs_maker.py b/lmdeploy/pytorch/engine/inputs_maker.py index aa4fccab7e..31ff09bb39 100644 --- a/lmdeploy/pytorch/engine/inputs_maker.py +++ b/lmdeploy/pytorch/engine/inputs_maker.py @@ -629,6 +629,10 @@ def __need_routed_experts(seqs: 'SeqList'): """Need routed experts.""" return any(seq.return_routed_experts for seq in seqs) + def __need_hidden_states(seqs: 'SeqList'): + """Need last hidden states.""" + return any(seq.return_last_hidden_states for seq in seqs) + def __create_model_inputs(seqs): """Createe model inputs.""" inputs = self.create_model_inputs(seqs, True) @@ -728,6 +732,7 @@ def __create_inputs_prefill(): return_logits = __need_logits(running) return_routed_experts = __need_routed_experts(running) + return_last_hidden_states = __need_hidden_states(running) return dict( running=running, @@ -740,6 +745,7 @@ def __create_inputs_prefill(): return_logits=return_logits, extra_inputs=extra_inputs, return_routed_experts=return_routed_experts, + return_last_hidden_states=return_last_hidden_states, ) def do_prefill_pnode(self): diff --git a/lmdeploy/pytorch/engine/model_agent/agent.py b/lmdeploy/pytorch/engine/model_agent/agent.py index ce506e6b21..9e5e4fd152 100644 --- a/lmdeploy/pytorch/engine/model_agent/agent.py +++ b/lmdeploy/pytorch/engine/model_agent/agent.py @@ -82,6 +82,7 @@ class BatchedOutputs: new_token_timestamp: int = 0 extra_outputs: ExtraOutputs | None = None all_routed_experts: torch.Tensor | None = None + last_hidden_states: torch.Tensor | None = None def to_cpu(self): """To cpu.""" @@ -440,11 +441,27 @@ async def _async_model_forward( self, inputs: ModelInputs, return_logits: bool, + return_last_hidden_states: bool = False, ): """Model forward.""" origin_inputs = inputs ret = await self.async_forward(inputs) + # capture full hidden states before postprocessing slices to last token + full_hidden_states = None + if return_last_hidden_states: + raw_hidden = ret['hidden_states'] + raw_seq_length = ret.get('seq_length', inputs.seq_length) + # raw_hidden shape: [1, total_tokens, hidden_dim] or [total_tokens, hidden_dim] + if raw_hidden.dim() == 3: + raw_hidden = raw_hidden[0] # [total_tokens, hidden_dim] + # slice per-sequence and mean pool + if raw_seq_length.numel() == 1: + full_hidden_states = raw_hidden.mean(dim=0, keepdim=True) # [1, hidden_dim] + else: + parts = raw_hidden.split(raw_seq_length.tolist(), dim=0) + full_hidden_states = torch.stack([p.mean(dim=0) for p in parts], dim=0) # [bs, hidden_dim] + if not return_logits: ret = self._postprocess_forward_output(ret, origin_inputs) @@ -452,6 +469,8 @@ async def _async_model_forward( logits = self.get_logits(hidden_states) ret['logits'] = logits + ret['_hidden_states'] = hidden_states + ret['_full_hidden_states'] = full_hidden_states return ret async def async_sampling_logits(self, logits: torch.Tensor, sampling_inputs: SamplingInputs): @@ -603,6 +622,7 @@ async def _step_postprocess_with_output(self, need_broadcast_next: bool, return_logits: bool = False, all_routed_experts: Any = None, + last_hidden_states: torch.Tensor = None, extra_inputs: ExtraInputs = None): """Step postprocess with output.""" rank = self.rank @@ -645,6 +665,7 @@ async def _step_postprocess_with_output(self, model_metas=model_metas, logprobs=logprobs, all_routed_experts=all_routed_experts, + last_hidden_states=last_hidden_states, extra_outputs=extra_outputs)) return inputs, extra_inputs, stopping_criteria, extra_outputs, next_token_ids @@ -679,6 +700,7 @@ async def _async_step( stopping_criteria: StoppingCriteria = None, return_logits: bool = False, return_routed_experts: bool = False, + return_last_hidden_states: bool = False, extra_inputs: ExtraInputs = None, ): """Asyc forward task.""" @@ -739,6 +761,7 @@ async def _async_step( output = await self._async_model_forward( inputs, return_logits=return_logits, + return_last_hidden_states=return_last_hidden_states, ) # recovery is_decoding inputs.is_decoding = is_decoding @@ -747,6 +770,9 @@ async def _async_step( # skip dummy forward output return + # get pre-pooled hidden states for embeddings + last_hidden_states = output.get('_full_hidden_states', None) + logits = output['logits'][0] # [bs, seq, prob] -> [seq, prob] seq_length = output.get('seq_length', inputs.seq_length) last_logits = self._slice_outs(logits, seq_length) # [bs, 1, prob] -> [bs, prob] @@ -778,6 +804,7 @@ async def _async_step( need_broadcast_next, return_logits=return_logits, all_routed_experts=all_routed_experts, + last_hidden_states=last_hidden_states, extra_inputs=extra_inputs, )) else: diff --git a/lmdeploy/pytorch/messages.py b/lmdeploy/pytorch/messages.py index e0c7bf77f5..cdd21fdd7f 100644 --- a/lmdeploy/pytorch/messages.py +++ b/lmdeploy/pytorch/messages.py @@ -60,6 +60,7 @@ class SamplingParam: logits_processors: None | list[LogitsProcessor] = None out_logits: bool = False out_last_hidden_states: bool = False + output_last_hidden_state: str = None num_logprobs: int = -1 return_routed_experts: bool = False @@ -92,8 +93,10 @@ def from_gen_config(cls, gen_config: GenerationConfig): output_logits = None logger.warning('Pytorch Engine only support output_logits="all"' ' with max_new_tokens=0') - if gen_config.output_last_hidden_state is not None: - logger.warning('Pytorch Engine does not support output last hidden states.') + output_last_hidden_state = gen_config.output_last_hidden_state + if output_last_hidden_state and output_last_hidden_state != 'all': + logger.warning('Pytorch Engine only supports output_last_hidden_state="all"') + output_last_hidden_state = None if top_p < 0 or top_p > 1.0: logger.warning('`top_p` has to be a float > 0 and < 1' f' but is {top_p}') @@ -156,6 +159,7 @@ def from_gen_config(cls, gen_config: GenerationConfig): min_new_tokens=min_new_tokens, logits_processors=gen_config.logits_processors, out_logits=(output_logits is not None), + output_last_hidden_state=output_last_hidden_state, num_logprobs=logprobs, return_routed_experts=gen_config.return_routed_experts, repetition_ngram_size=repetition_ngram_size, @@ -790,6 +794,10 @@ def status(self): def return_logits(self): return self.sampling_param.out_logits + @property + def return_last_hidden_states(self): + return self.sampling_param.output_last_hidden_state is not None + @property def logits(self): """Get logits.""" diff --git a/lmdeploy/serve/openai/api_server.py b/lmdeploy/serve/openai/api_server.py index e10f20f44e..0c16e3e5d0 100644 --- a/lmdeploy/serve/openai/api_server.py +++ b/lmdeploy/serve/openai/api_server.py @@ -3,9 +3,11 @@ from __future__ import annotations import asyncio +import base64 import json import os import re +import struct import time from collections.abc import AsyncGenerator from contextlib import asynccontextmanager @@ -64,6 +66,7 @@ CompletionStreamResponse, DeltaMessage, EmbeddingsRequest, + EmbeddingsResponse, EncodeRequest, EncodeResponse, ErrorResponse, @@ -987,10 +990,78 @@ async def _inner_call(): return response -@router.post('/v1/embeddings', tags=['unsupported']) +@router.post('/v1/embeddings', dependencies=[Depends(validate_json_request)]) async def create_embeddings(request: EmbeddingsRequest, raw_request: Request = None): """Creates embeddings for the text.""" - return create_error_response(HTTPStatus.BAD_REQUEST, 'Unsupported by turbomind.') + if isinstance(request.input, str): + inputs = [request.input] + else: + inputs = request.input + + if not inputs: + return create_error_response(HTTPStatus.BAD_REQUEST, 'Input must not be empty.') + + async_engine = VariableInterface.async_engine + embedding_data = [] + total_prompt_tokens = 0 + for idx, text in enumerate(inputs): + if not text: + return create_error_response(HTTPStatus.BAD_REQUEST, 'Input text must not be empty.') + + session = VariableInterface.create_session(-1) + gen_config = GenerationConfig( + max_new_tokens=1, + output_last_hidden_state='all', + ) + result_generator = async_engine.generate( + messages=text, + session_id=session.session_id, + gen_config=gen_config, + stream_response=True, + sequence_start=True, + sequence_end=True, + do_preprocess=False, + ) + + last_hidden_state = None + prompt_tokens = 0 + async for res in result_generator: + if res.finish_reason == 'error': + return create_error_response( + HTTPStatus.INTERNAL_SERVER_ERROR, + getattr(res, 'text', 'Internal error during embedding generation.'), + ) + if res.last_hidden_state is not None: + last_hidden_state = res.last_hidden_state + prompt_tokens = res.input_token_len + + total_prompt_tokens += prompt_tokens + + if last_hidden_state is None: + return create_error_response( + HTTPStatus.INTERNAL_SERVER_ERROR, + 'Model does not support hidden states output for embeddings.', + ) + + # Convert to list (hidden states are already mean-pooled per sequence) + if last_hidden_state.dim() > 1: + # multi-token: mean pool across sequence dimension + emb_list = last_hidden_state.mean(dim=0).tolist() + else: + emb_list = last_hidden_state.tolist() + + if request.encoding_format == 'base64': + packed = struct.pack(f'<{len(emb_list)}f', *emb_list) + encoded = base64.b64encode(packed).decode('utf-8') + embedding_data.append({'object': 'embedding', 'embedding': encoded, 'index': idx}) + else: + embedding_data.append({'object': 'embedding', 'embedding': emb_list, 'index': idx}) + + return EmbeddingsResponse( + data=embedding_data, + model=request.model or async_engine.model_name, + usage=UsageInfo(prompt_tokens=total_prompt_tokens, total_tokens=total_prompt_tokens, completion_tokens=0), + ) @router.post('/v1/encode', dependencies=[Depends(validate_json_request)]) diff --git a/lmdeploy/serve/openai/protocol.py b/lmdeploy/serve/openai/protocol.py index 14edd15239..11889475d0 100644 --- a/lmdeploy/serve/openai/protocol.py +++ b/lmdeploy/serve/openai/protocol.py @@ -370,6 +370,7 @@ class EmbeddingsRequest(BaseModel): """Embedding request.""" model: str = None input: str | list[str] + encoding_format: Literal['float', 'base64'] = 'float' user: str | None = None diff --git a/tests/test_lmdeploy/serve/__init__.py b/tests/test_lmdeploy/serve/__init__.py new file mode 100644 index 0000000000..e69de29bb2 diff --git a/tests/test_lmdeploy/serve/test_api_server.py b/tests/test_lmdeploy/serve/test_api_server.py new file mode 100644 index 0000000000..d6cd0aaa40 --- /dev/null +++ b/tests/test_lmdeploy/serve/test_api_server.py @@ -0,0 +1,206 @@ +# Copyright (c) OpenMMLab. All rights reserved. +import base64 +import struct +from http import HTTPStatus +from unittest.mock import MagicMock, patch + +import pytest +import torch + +from lmdeploy.messages import Response +from lmdeploy.serve.openai.api_server import VariableInterface, create_embeddings +from lmdeploy.serve.openai.protocol import EmbeddingsRequest + + +def _async_gen(items): + """Helper to create an async generator from a list of items.""" + async def gen(): + for item in items: + yield item + return gen() + + +def _mock_response(text='', finish_reason='stop', input_token_len=5, + last_hidden_state=None, token_ids=None): + return Response( + text=text, + generate_token_len=0 if finish_reason != 'length' else 1, + input_token_len=input_token_len, + finish_reason=finish_reason, + token_ids=token_ids or [], + last_hidden_state=last_hidden_state, + index=0, + ) + + +@pytest.mark.asyncio +async def test_embeddings_single_string(): + hidden = torch.tensor([0.1, 0.2, 0.3]) + engine = MagicMock() + engine.model_name = 'test-model' + engine.generate = MagicMock(return_value=_async_gen([ + _mock_response(finish_reason=None, last_hidden_state=None), + _mock_response(finish_reason='stop', last_hidden_state=hidden, input_token_len=5), + ])) + + with patch.object(VariableInterface, 'async_engine', engine), \ + patch.object(VariableInterface, 'create_session', return_value=MagicMock(session_id=0)): + resp = await create_embeddings(EmbeddingsRequest(input='hello world', model='test-model')) + + assert resp.model == 'test-model' + assert resp.usage.prompt_tokens == 5 + assert resp.usage.total_tokens == 5 + assert resp.usage.completion_tokens == 0 + assert len(resp.data) == 1 + assert resp.data[0]['index'] == 0 + assert resp.data[0]['object'] == 'embedding' + assert len(resp.data[0]['embedding']) == 3 + + +@pytest.mark.asyncio +async def test_embeddings_list_input(): + hidden1 = torch.tensor([0.1, 0.2]) + hidden2 = torch.tensor([0.3, 0.4]) + call_count = 0 + + def mock_generate(**kwargs): + nonlocal call_count + call_count += 1 + hidden = hidden1 if call_count == 1 else hidden2 + return _async_gen([ + _mock_response(finish_reason='stop', last_hidden_state=hidden, input_token_len=3), + ]) + + engine = MagicMock() + engine.model_name = 'test-model' + engine.generate = mock_generate + + with patch.object(VariableInterface, 'async_engine', engine), \ + patch.object(VariableInterface, 'create_session', return_value=MagicMock(session_id=0)): + resp = await create_embeddings(EmbeddingsRequest(input=['first text', 'second text'])) + + assert len(resp.data) == 2 + assert resp.data[0]['embedding'] == pytest.approx([0.1, 0.2]) + assert resp.data[1]['embedding'] == pytest.approx([0.3, 0.4]) + assert resp.usage.prompt_tokens == 6 # 3 + 3 + + +@pytest.mark.asyncio +async def test_embeddings_base64_format(): + hidden = torch.tensor([0.1, 0.2, 0.3]) + engine = MagicMock() + engine.model_name = 'test-model' + engine.generate = MagicMock(return_value=_async_gen([ + _mock_response(finish_reason='stop', last_hidden_state=hidden, input_token_len=5), + ])) + + with patch.object(VariableInterface, 'async_engine', engine), \ + patch.object(VariableInterface, 'create_session', return_value=MagicMock(session_id=0)): + resp = await create_embeddings( + EmbeddingsRequest(input='hello', encoding_format='base64')) + + embedding = resp.data[0]['embedding'] + assert isinstance(embedding, str) + decoded = struct.unpack('<3f', base64.b64decode(embedding.encode('utf-8'), validate=True)) + assert decoded == pytest.approx((0.1, 0.2, 0.3)) + + +@pytest.mark.asyncio +async def test_embeddings_empty_input(): + engine = MagicMock() + engine.model_name = 'test-model' + with patch.object(VariableInterface, 'async_engine', engine): + resp = await create_embeddings(EmbeddingsRequest(input='')) + assert resp.status_code == HTTPStatus.BAD_REQUEST + + +@pytest.mark.asyncio +async def test_embeddings_empty_list(): + engine = MagicMock() + engine.model_name = 'test-model' + with patch.object(VariableInterface, 'async_engine', engine): + resp = await create_embeddings(EmbeddingsRequest(input=[])) + assert resp.status_code == HTTPStatus.BAD_REQUEST + + +@pytest.mark.asyncio +async def test_embeddings_empty_text_in_list(): + engine = MagicMock() + engine.model_name = 'test-model' + engine.generate = MagicMock(return_value=_async_gen([ + _mock_response(finish_reason='stop', last_hidden_state=torch.tensor([0.1]), input_token_len=5), + ])) + + with patch.object(VariableInterface, 'async_engine', engine), \ + patch.object(VariableInterface, 'create_session', return_value=MagicMock(session_id=0)): + resp = await create_embeddings(EmbeddingsRequest(input=['valid', ''])) + assert resp.status_code == HTTPStatus.BAD_REQUEST + + +@pytest.mark.asyncio +async def test_embeddings_no_hidden_states(): + engine = MagicMock() + engine.model_name = 'test-model' + engine.generate = MagicMock(return_value=_async_gen([ + _mock_response(finish_reason='stop', last_hidden_state=None, input_token_len=5), + ])) + + with patch.object(VariableInterface, 'async_engine', engine), \ + patch.object(VariableInterface, 'create_session', return_value=MagicMock(session_id=0)): + resp = await create_embeddings(EmbeddingsRequest(input='hello')) + assert resp.status_code == HTTPStatus.INTERNAL_SERVER_ERROR + + +@pytest.mark.asyncio +async def test_embeddings_error_finish_reason(): + engine = MagicMock() + engine.model_name = 'test-model' + engine.generate = MagicMock(return_value=_async_gen([ + _mock_response(text='prefix caching conflict', + finish_reason='error', input_token_len=0), + ])) + + with patch.object(VariableInterface, 'async_engine', engine), \ + patch.object(VariableInterface, 'create_session', return_value=MagicMock(session_id=0)): + resp = await create_embeddings(EmbeddingsRequest(input='hello')) + assert resp.status_code == HTTPStatus.INTERNAL_SERVER_ERROR + + +@pytest.mark.asyncio +async def test_embeddings_model_default_when_none(): + hidden = torch.tensor([0.1, 0.2, 0.3]) + engine = MagicMock() + engine.model_name = 'default-model-name' + engine.generate = MagicMock(return_value=_async_gen([ + _mock_response(finish_reason='stop', last_hidden_state=hidden, input_token_len=5), + ])) + + with patch.object(VariableInterface, 'async_engine', engine), \ + patch.object(VariableInterface, 'create_session', return_value=MagicMock(session_id=0)): + resp = await create_embeddings(EmbeddingsRequest(input='hello')) + assert resp.model == 'default-model-name' + + +@pytest.mark.asyncio +async def test_embeddings_prompt_tokens_summed(): + hidden = torch.tensor([0.1]) + call_count = 0 + + def mock_generate(**kwargs): + nonlocal call_count + call_count += 1 + return _async_gen([ + _mock_response(finish_reason='stop', last_hidden_state=hidden, input_token_len=call_count * 10), + ]) + + engine = MagicMock() + engine.model_name = 'test-model' + engine.generate = mock_generate + + with patch.object(VariableInterface, 'async_engine', engine), \ + patch.object(VariableInterface, 'create_session', return_value=MagicMock(session_id=0)): + resp = await create_embeddings(EmbeddingsRequest(input=['a', 'b', 'c'])) + + # prompt_tokens = 10 + 20 + 30 = 60 + assert resp.usage.prompt_tokens == 60 + assert resp.usage.total_tokens == 60