-
Notifications
You must be signed in to change notification settings - Fork 696
[Feature] Implement /v1/embeddings endpoint for OpenAI-compatible API
#4550
New issue
Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.
By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.
Already on GitHub? Sign in to your account
base: main
Are you sure you want to change the base?
Changes from all commits
File filter
Filter by extension
Conversations
Jump to
Diff view
Diff view
There are no files selected for viewing
| Original file line number | Diff line number | Diff line change | ||||||
|---|---|---|---|---|---|---|---|---|
|
|
@@ -3,9 +3,11 @@ | |||||||
| from __future__ import annotations | ||||||||
|
|
||||||||
| import asyncio | ||||||||
| import base64 | ||||||||
| import json | ||||||||
| import os | ||||||||
| import re | ||||||||
| import struct | ||||||||
| import time | ||||||||
| from collections.abc import AsyncGenerator | ||||||||
| from contextlib import asynccontextmanager | ||||||||
|
|
@@ -64,6 +66,7 @@ | |||||||
| CompletionStreamResponse, | ||||||||
| DeltaMessage, | ||||||||
| EmbeddingsRequest, | ||||||||
| EmbeddingsResponse, | ||||||||
| EncodeRequest, | ||||||||
| EncodeResponse, | ||||||||
| ErrorResponse, | ||||||||
|
|
@@ -987,10 +990,78 @@ async def _inner_call(): | |||||||
| return response | ||||||||
|
|
||||||||
|
|
||||||||
| @router.post('/v1/embeddings', tags=['unsupported']) | ||||||||
| @router.post('/v1/embeddings', dependencies=[Depends(validate_json_request)]) | ||||||||
| async def create_embeddings(request: EmbeddingsRequest, raw_request: Request = None): | ||||||||
| """Creates embeddings for the text.""" | ||||||||
| return create_error_response(HTTPStatus.BAD_REQUEST, 'Unsupported by turbomind.') | ||||||||
| if isinstance(request.input, str): | ||||||||
| inputs = [request.input] | ||||||||
| else: | ||||||||
| inputs = request.input | ||||||||
|
|
||||||||
| if not inputs: | ||||||||
| return create_error_response(HTTPStatus.BAD_REQUEST, 'Input must not be empty.') | ||||||||
|
|
||||||||
| async_engine = VariableInterface.async_engine | ||||||||
| embedding_data = [] | ||||||||
| total_prompt_tokens = 0 | ||||||||
| for idx, text in enumerate(inputs): | ||||||||
| if not text: | ||||||||
| return create_error_response(HTTPStatus.BAD_REQUEST, 'Input text must not be empty.') | ||||||||
|
|
||||||||
| session = VariableInterface.create_session(-1) | ||||||||
| gen_config = GenerationConfig( | ||||||||
| max_new_tokens=1, | ||||||||
| output_last_hidden_state='all', | ||||||||
| ) | ||||||||
| result_generator = async_engine.generate( | ||||||||
| messages=text, | ||||||||
| session_id=session.session_id, | ||||||||
| gen_config=gen_config, | ||||||||
| stream_response=True, | ||||||||
| sequence_start=True, | ||||||||
| sequence_end=True, | ||||||||
|
||||||||
| sequence_end=True, | |
| sequence_end=True, | |
| do_preprocess=False, |
Copilot
AI
Apr 28, 2026
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
The loop ignores res.finish_reason / error frames from AsyncEngine.generate (e.g. when prefix caching is enabled with output_last_hidden_state='all', the generator yields a finish_reason='error' frame with an error message and no hidden states). Currently this falls through to a generic 500. Handle finish_reason=='error' (and possibly client disconnect) and surface the actual error message/status to the caller.
Copilot
AI
Apr 28, 2026
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
prompt_tokens is overwritten per input and the final usage only reflects the last item. Also, when request.model is omitted the response currently returns an empty string, unlike other endpoints (e.g. /pooling) which default to async_engine.model_name. Consider summing prompt tokens across all inputs and defaulting model to the server model name when not provided.
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
Mean pooling is computed per forward pass using
raw_seq_length. For long-context chunking (inputs.is_chunk), intermediate chunks do not emit outputs (EngineLoopskips non-last chunks), so the embedding on the last chunk will only reflect that final chunk rather than the full input sequence. To make embeddings correct for chunked prefill, accumulate a weighted sum/count across chunks (e.g., store partial sums in the sequence state) and finalize the mean on the last chunk.