diff --git a/docs/en/index.rst b/docs/en/index.rst index 7063c84dde..c596173958 100644 --- a/docs/en/index.rst +++ b/docs/en/index.rst @@ -71,6 +71,7 @@ Documentation multi_modal/vl_pipeline.md multi_modal/api_server_vl.md + multi_modal/multimodal_inputs.md multi_modal/index.rst .. _quantization: diff --git a/docs/en/multi_modal/index.rst b/docs/en/multi_modal/index.rst index ac0e649244..a041172edb 100644 --- a/docs/en/multi_modal/index.rst +++ b/docs/en/multi_modal/index.rst @@ -1,12 +1,6 @@ Vision-Language Models ================================= -.. toctree:: - :maxdepth: 2 - :caption: Guides - - multimodal_inputs.md - .. toctree:: :maxdepth: 2 :caption: Examples diff --git a/docs/en/multi_modal/multimodal_inputs.md b/docs/en/multi_modal/multimodal_inputs.md index 0222075dc8..31588b8583 100644 --- a/docs/en/multi_modal/multimodal_inputs.md +++ b/docs/en/multi_modal/multimodal_inputs.md @@ -9,6 +9,7 @@ LMDeploy uses the OpenAI message format for all modalities. Each content item in | Text | `text` | — | | Image | `image_url` | `image_url.url` | | Video | `video_url` | `video_url.url` | +| Audio | `audio_url` | `audio_url.url` | | Time Series | `time_series_url` | `time_series_url.url` | All examples below target the lmdeploy OpenAI-compatible API server. Start it with: @@ -133,7 +134,7 @@ ______________________________________________________________________ ## Single Video -> **Note:** Native video input is currently supported for **Qwen3-VL**, **Qwen3.5**, and **InternS1-Pro** models only. +> **Note:** Native video input is currently supported for **Qwen3-VL**, **Qwen3.5**, **Qwen3-Omni**, and **InternS1-Pro** models only.
Complete example @@ -176,7 +177,7 @@ ______________________________________________________________________ ## Multiple Videos -> **Note:** Native video input is currently supported for **Qwen3-VL**, **Qwen3.5**, and **InternS1-Pro** models only. +> **Note:** Native video input is currently supported for **Qwen3-VL**, **Qwen3.5**, **Qwen3-Omni**, and **InternS1-Pro** models only.
Complete example @@ -222,7 +223,7 @@ ______________________________________________________________________ ## Mixed Image and Video -> **Note:** Native video input is currently supported for **Qwen3-VL**, **Qwen3.5**, and **InternS1-Pro** models only. +> **Note:** Native video input is currently supported for **Qwen3-VL**, **Qwen3.5**, **Qwen3-Omni**, and **InternS1-Pro** models only.
Complete example @@ -266,6 +267,85 @@ print(response.choices[0].message.content) ______________________________________________________________________ +## Single Audio + +> **Note:** Audio input is currently supported for **Qwen3-Omni** models only. + +
+Complete example + +```python +from openai import OpenAI + +client = OpenAI(api_key='EMPTY', base_url='http://localhost:23333/v1') +model_name = client.models.list().data[0].id + +response = client.chat.completions.create( + model=model_name, + messages=[{ + 'role': 'user', + 'content': [ + { + 'type': 'audio_url', + 'audio_url': { + 'url': 'https://qianwen-res.oss-cn-beijing.aliyuncs.com/Qwen3-Omni/demo/cough.wav', + }, + }, + { + 'type': 'text', + 'text': 'Describe this audio.', + }, + ], + }], + temperature=0.8, + top_p=0.8, +) +print(response.choices[0].message.content) +``` + +
+ +______________________________________________________________________ + +## Multiple Audios + +> **Note:** Audio input is currently supported for **Qwen3-Omni** models only. + +
+Complete example + +```python +from openai import OpenAI + +client = OpenAI(api_key='EMPTY', base_url='http://localhost:23333/v1') +model_name = client.models.list().data[0].id + +audio_url_1 = 'https://qianwen-res.oss-cn-beijing.aliyuncs.com/Qwen3-Omni/demo/cough.wav' +audio_url_2 = 'https://qianwen-res.oss-cn-beijing.aliyuncs.com/Qwen3-Omni/demo/cough.wav' + +response = client.chat.completions.create( + model=model_name, + messages=[{ + 'role': 'user', + 'content': [ + {'type': 'audio_url', 'audio_url': {'url': audio_url_1}}, + {'type': 'audio_url', 'audio_url': {'url': audio_url_2}}, + { + 'type': 'text', + 'text': 'Compare these two audios. What are the similarities and differences?', + }, + ], + }], + temperature=0.8, + top_p=0.8, +) +print(response.choices[0].message.content) +``` + +
+ +______________________________________________________________________ + ## Time Series > **Note:** Time series input is currently supported for the **InternS1-Pro** model only. diff --git a/docs/en/multi_modal/vl_pipeline.md b/docs/en/multi_modal/vl_pipeline.md index d016358d39..4972ba91d5 100644 --- a/docs/en/multi_modal/vl_pipeline.md +++ b/docs/en/multi_modal/vl_pipeline.md @@ -10,7 +10,7 @@ Moreover, we will provide practical inference examples tailored to scenarios wit Using the pipeline interface to infer other VLM models is similar, with the main difference being the configuration and installation dependencies of the models. You can read [here](https://lmdeploy.readthedocs.io/en/latest/multi_modal/index.html) for environment installation and configuration methods for different models. -> **See also:** [Multi-Modal Inputs](multimodal_inputs.md) — message format reference for all modalities (image, video, time series) with OpenAI-style examples. +> **See also:** [Multi-Modal Inputs](multimodal_inputs.md) — message format reference for all modalities (image, video, audio, time series) with OpenAI-style examples. ## A 'Hello, world' example diff --git a/docs/zh_cn/index.rst b/docs/zh_cn/index.rst index c5b63b616d..74e02c2761 100644 --- a/docs/zh_cn/index.rst +++ b/docs/zh_cn/index.rst @@ -71,6 +71,7 @@ LMDeploy 工具箱提供以下核心功能: multi_modal/vl_pipeline.md multi_modal/api_server_vl.md + multi_modal/multimodal_inputs.md multi_modal/index.rst diff --git a/docs/zh_cn/multi_modal/index.rst b/docs/zh_cn/multi_modal/index.rst index ed33bba8d3..9a61f6efdb 100644 --- a/docs/zh_cn/multi_modal/index.rst +++ b/docs/zh_cn/multi_modal/index.rst @@ -1,12 +1,6 @@ 视觉语言模型 ================================= -.. toctree:: - :maxdepth: 2 - :caption: 指南 - - multimodal_inputs.md - .. toctree:: :maxdepth: 2 :caption: 示例 diff --git a/docs/zh_cn/multi_modal/multimodal_inputs.md b/docs/zh_cn/multi_modal/multimodal_inputs.md index 218279caae..c658183a48 100644 --- a/docs/zh_cn/multi_modal/multimodal_inputs.md +++ b/docs/zh_cn/multi_modal/multimodal_inputs.md @@ -9,6 +9,7 @@ LMDeploy 使用 OpenAI 消息格式处理所有模态。消息中的每个内容 | 文本 | `text` | — | | 图像 | `image_url` | `image_url.url` | | 视频 | `video_url` | `video_url.url` | +| 音频 | `audio_url` | `audio_url.url` | | 时序数据 | `time_series_url` | `time_series_url.url` | 以下示例均面向 lmdeploy 兼容 OpenAI 的 API 服务。启动服务: @@ -133,7 +134,7 @@ ______________________________________________________________________ ## 单个视频 -> **注意:** 原生视频输入目前仅支持 **Qwen3-VL**、**Qwen3.5** 和 **InternS1-Pro** 模型。 +> **注意:** 原生视频输入目前仅支持 **Qwen3-VL**、**Qwen3.5**、**Qwen3-Omni** 和 **InternS1-Pro** 模型。
完整示例 @@ -176,7 +177,7 @@ ______________________________________________________________________ ## 多个视频 -> **注意:** 原生视频输入目前仅支持 **Qwen3-VL**、**Qwen3.5** 和 **InternS1-Pro** 模型。 +> **注意:** 原生视频输入目前仅支持 **Qwen3-VL**、**Qwen3.5**、**Qwen3-Omni** 和 **InternS1-Pro** 模型。
完整示例 @@ -222,7 +223,7 @@ ______________________________________________________________________ ## 图像与视频混合 -> **注意:** 原生视频输入目前仅支持 **Qwen3-VL**、**Qwen3.5** 和 **InternS1-Pro** 模型。 +> **注意:** 原生视频输入目前仅支持 **Qwen3-VL**、**Qwen3.5**、**Qwen3-Omni** 和 **InternS1-Pro** 模型。
完整示例 @@ -266,6 +267,85 @@ print(response.choices[0].message.content) ______________________________________________________________________ +## 单个音频 + +> **注意:** 音频输入目前仅支持 **Qwen3-Omni** 模型。 + +
+完整示例 + +```python +from openai import OpenAI + +client = OpenAI(api_key='EMPTY', base_url='http://localhost:23333/v1') +model_name = client.models.list().data[0].id + +response = client.chat.completions.create( + model=model_name, + messages=[{ + 'role': 'user', + 'content': [ + { + 'type': 'audio_url', + 'audio_url': { + 'url': 'https://qianwen-res.oss-cn-beijing.aliyuncs.com/Qwen3-Omni/demo/cough.wav', + }, + }, + { + 'type': 'text', + 'text': '描述这段音频。', + }, + ], + }], + temperature=0.8, + top_p=0.8, +) +print(response.choices[0].message.content) +``` + +
+ +______________________________________________________________________ + +## 多个音频 + +> **注意:** 音频输入目前仅支持 **Qwen3-Omni** 模型。 + +
+完整示例 + +```python +from openai import OpenAI + +client = OpenAI(api_key='EMPTY', base_url='http://localhost:23333/v1') +model_name = client.models.list().data[0].id + +audio_url_1 = 'https://qianwen-res.oss-cn-beijing.aliyuncs.com/Qwen3-Omni/demo/cough.wav' +audio_url_2 = 'https://qianwen-res.oss-cn-beijing.aliyuncs.com/Qwen3-Omni/demo/cough.wav' + +response = client.chat.completions.create( + model=model_name, + messages=[{ + 'role': 'user', + 'content': [ + {'type': 'audio_url', 'audio_url': {'url': audio_url_1}}, + {'type': 'audio_url', 'audio_url': {'url': audio_url_2}}, + { + 'type': 'text', + 'text': '比较这两段音频,有哪些相似点和不同点?', + }, + ], + }], + temperature=0.8, + top_p=0.8, +) +print(response.choices[0].message.content) +``` + +
+ +______________________________________________________________________ + ## 时序数据 > **注意:** 时序数据输入目前仅支持 **InternS1-Pro** 模型。 diff --git a/docs/zh_cn/multi_modal/vl_pipeline.md b/docs/zh_cn/multi_modal/vl_pipeline.md index 63f6fda77f..9662bcc569 100644 --- a/docs/zh_cn/multi_modal/vl_pipeline.md +++ b/docs/zh_cn/multi_modal/vl_pipeline.md @@ -10,7 +10,7 @@ LMDeploy 把视觉-语言模型(VLM)复杂的推理过程,抽象为简单 使用 pipeline 接口推理其他 VLM 模型,大同小异,主要区别在于模型依赖的配置和安装。你可以阅读[此处](https://lmdeploy.readthedocs.io/zh-cn/latest/multi_modal/),查看不同模型的环境安装和配置方式 -> **另请参阅:** [多模态输入](multimodal_inputs.md) — 涵盖所有模态(图像、视频、时序数据)的消息格式参考,包含 OpenAI 风格示例。 +> **另请参阅:** [多模态输入](multimodal_inputs.md) — 涵盖所有模态(图像、视频、音频、时序数据)的消息格式参考,包含 OpenAI 风格示例。 ## "Hello, world" 示例 diff --git a/lmdeploy/archs.py b/lmdeploy/archs.py index ea577f6372..c865310bdf 100644 --- a/lmdeploy/archs.py +++ b/lmdeploy/archs.py @@ -111,9 +111,9 @@ def check_vl_llm(backend: str, config: dict) -> bool: 'InternVLChatModel', 'MiniCPMV', 'LlavaForConditionalGeneration', 'LlavaNextForConditionalGeneration', 'Phi3VForCausalLM', 'Qwen2VLForConditionalGeneration', 'Qwen2_5_VLForConditionalGeneration', 'Qwen3VLForConditionalGeneration', 'Qwen3VLMoeForConditionalGeneration', 'Qwen3_5ForConditionalGeneration', - 'Qwen3_5MoeForConditionalGeneration', 'MllamaForConditionalGeneration', 'MolmoForCausalLM', - 'Gemma3ForConditionalGeneration', 'Llama4ForConditionalGeneration', 'InternVLForConditionalGeneration', - 'InternS1ForConditionalGeneration', 'InternS1ProForConditionalGeneration', + 'Qwen3_5MoeForConditionalGeneration', 'Qwen3OmniMoeForConditionalGeneration', 'MllamaForConditionalGeneration', + 'MolmoForCausalLM', 'Gemma3ForConditionalGeneration', 'Llama4ForConditionalGeneration', + 'InternVLForConditionalGeneration', 'InternS1ForConditionalGeneration', 'InternS1ProForConditionalGeneration', 'InternS1_1_ForConditionalGeneration', 'Glm4vForConditionalGeneration' ]) if arch == 'QWenLMHeadModel' and 'visual' in config: diff --git a/lmdeploy/model.py b/lmdeploy/model.py index ac934c7374..d2394fec4c 100644 --- a/lmdeploy/model.py +++ b/lmdeploy/model.py @@ -687,8 +687,19 @@ class HFChatTemplate(BaseChatTemplate): def __init__(self, model_path: str = '', trust_remote_code: bool = False, **kwargs): self.model_path = model_path try: - from transformers import AutoTokenizer + from transformers import AutoProcessor, AutoTokenizer self.tokenizer = AutoTokenizer.from_pretrained(model_path, trust_remote_code=trust_remote_code) + + # Some tokenizers do not have chat_template, in this case try to get chat_template from processor + # If this still does not work, fallback to BaseChatTemplate. + if getattr(self.tokenizer, 'chat_template', None) is None: + try: + processor = AutoProcessor.from_pretrained(model_path, trust_remote_code=trust_remote_code) + self.tokenizer.chat_template = getattr(processor, 'chat_template', None) + except Exception as e: + logger.warning(f'Failed to load processor from {model_path} for chat template. ' + f'Fallback to tokenizer only. Error: {e}') + # Verify if the model can perform apply_chat_template with different roles. self.user_start, self.user_end, _, _ = self._user_instruction() self.assistant_start, self.assistant_end, _, _ = self._assistant_instruction() diff --git a/lmdeploy/pytorch/configurations/qwen3_omni.py b/lmdeploy/pytorch/configurations/qwen3_omni.py new file mode 100644 index 0000000000..9d1e3ec61c --- /dev/null +++ b/lmdeploy/pytorch/configurations/qwen3_omni.py @@ -0,0 +1,19 @@ +# Copyright (c) OpenMMLab. All rights reserved. +from .builder import AutoModelConfigBuilder +from .default import DefaultModelConfigBuilder + + +class Qwen3OmniModelConfigBuilder(AutoModelConfigBuilder): + + @classmethod + def condition(cls, hf_config): + """config.""" + return hf_config.model_type == 'qwen3_omni_moe' + + @classmethod + def build(cls, hf_config, model_path: str = None, **kwargs): + """build.""" + cfg = DefaultModelConfigBuilder.build(hf_config.thinker_config.text_config, model_path, **kwargs) + cfg.hf_config = hf_config + cfg.use_mrope = True + return cfg diff --git a/lmdeploy/pytorch/models/module_map.py b/lmdeploy/pytorch/models/module_map.py index f2f21eb93b..847f74d6cf 100644 --- a/lmdeploy/pytorch/models/module_map.py +++ b/lmdeploy/pytorch/models/module_map.py @@ -190,6 +190,13 @@ 'Qwen3_5MTPModel': f'{LMDEPLOY_PYTORCH_MODEL_PATH}.qwen3_5_mtp.Qwen3_5MTPModel', }) +# qwen3 omni moe thinker +# only support thinker module, so map to Qwen3OmniMoeThinkerForConditionalGeneration +MODULE_MAP.update({ + 'Qwen3OmniMoeForConditionalGeneration': + f'{LMDEPLOY_PYTORCH_MODEL_PATH}.qwen3_omni_moe_thinker.Qwen3OmniMoeThinkerForConditionalGeneration', +}) + # starcoder2 MODULE_MAP.update({ 'Starcoder2ForCausalLM': f'{LMDEPLOY_PYTORCH_MODEL_PATH}.starcoder2.Starcoder2ForCausalLM', diff --git a/lmdeploy/pytorch/models/qwen3_omni_moe_thinker.py b/lmdeploy/pytorch/models/qwen3_omni_moe_thinker.py new file mode 100644 index 0000000000..9fe4a72b7b --- /dev/null +++ b/lmdeploy/pytorch/models/qwen3_omni_moe_thinker.py @@ -0,0 +1,924 @@ +# Copyright (c) OpenMMLab. All rights reserved. + +import math +from collections.abc import Iterable, Sequence +from functools import lru_cache +from typing import Any + +import numpy as np +import torch +from torch import nn +from torch.nn import functional as F +from transformers.activations import ACT2FN +from transformers.configuration_utils import PretrainedConfig + +from lmdeploy.pytorch.engine.input_process import BaseModelInputProcessor, PreprocessInputResult +from lmdeploy.pytorch.model_inputs import StepContext, StepContextManager +from lmdeploy.pytorch.multimodal.data_type import MultiModalData +from lmdeploy.pytorch.nn import ApplyRotaryEmb, FlashAttention, LayerNorm +from lmdeploy.pytorch.nn.linear import build_colwise_linear, build_qkv_proj, build_rowwise_linear +from lmdeploy.pytorch.weight_loader.model_weight_loader import load_weight +from lmdeploy.vl.constants import Modality + +from .qwen3_vl import Qwen3VLVisionBlock, Qwen3VLVisionPatchEmbed, Qwen3VLVisionRotaryEmbedding +from .qwen3_vl_moe import Qwen3VLMoeTextModel +from .utils.cudagraph import CudaGraphMixin +from .utils.model import DeployModelMixinV1, vlm_model + + +def _get_feat_extract_output_lengths(input_lengths): + """Computes the output length of the convolutional layers and the output + length of the audio encoder.""" + + input_lengths_leave = input_lengths % 100 + feat_lengths = (input_lengths_leave - 1) // 2 + 1 + output_lengths = ((feat_lengths - 1) // 2 + 1 - 1) // 2 + 1 + (input_lengths // 100) * 13 + return output_lengths + + +class Qwen3OmniMoeAudioAttention(nn.Module): + """Vision attention.""" + + def __init__(self, config: PretrainedConfig, dtype: torch.dtype = None, device: torch.device = None): + super().__init__() + quantization_config = getattr(config, 'quantization_config', None) + dim = config.d_model + num_heads = config.encoder_attention_heads + head_dim = dim // num_heads + self.head_dim = head_dim + + # packed qkv + self.qkv_proj = build_qkv_proj( + dim, + num_q_heads=num_heads, + num_kv_heads=num_heads, + head_size=head_dim, + bias=True, + quant_config=quantization_config, + dtype=dtype, + device=device, + ) + + # rotary embedding + self.apply_rotary_pos_emb = ApplyRotaryEmb() + + # attention + self.attention = FlashAttention( + num_heads, + head_dim, + causal=False, + ) + + # o_proj + self.out_proj = build_rowwise_linear(dim, + dim, + bias=True, + quant_config=quantization_config, + dtype=dtype, + device=device, + is_tp=True) + + def forward( + self, + hidden_states: torch.Tensor, + cu_seqlens: torch.Tensor, + ) -> torch.Tensor: + seq_length = hidden_states.shape[0] + + # qkv proj + qkv_states = self.qkv_proj(hidden_states) + # (-1, heads, head_dim) + qkv_states = qkv_states.flatten(0, -2) + q, k, v = self.qkv_proj.split_qkv(qkv_states) + + attn_output = self.attention( + q, + k, + v, + q_start_loc=cu_seqlens[:-1], + q_seqlens=cu_seqlens[1:] - cu_seqlens[:-1], + ) + + attn_output = attn_output.reshape(seq_length, -1) + + # o proj + attn_output = self.out_proj(attn_output) + return attn_output + + +class Qwen3OmniMoeAudioEncoderLayer(nn.Module): + """Qwen3OmniMoeAudioEncoderLayer.""" + + def __init__(self, config, dtype: torch.dtype = None, device: torch.device = None) -> None: + super().__init__() + self.embed_dim = config.d_model + self.self_attn = Qwen3OmniMoeAudioAttention(config, dtype=dtype, device=device) + self.self_attn_layer_norm = LayerNorm(self.embed_dim, eps=1e-5, dtype=dtype, device=device) + + self.activation_fn = ACT2FN[config.activation_function] + self.fc1 = build_colwise_linear( + self.embed_dim, + config.encoder_ffn_dim, + bias=True, + dtype=dtype, + device=device, + ) + self.fc2 = build_rowwise_linear( + config.encoder_ffn_dim, + self.embed_dim, + bias=True, + dtype=dtype, + device=device, + ) + self.final_layer_norm = LayerNorm(self.embed_dim, eps=1e-5, dtype=dtype, device=device) + + def forward( + self, + hidden_states: torch.Tensor, + cu_seqlens: torch.Tensor, + ): + residual = hidden_states + hidden_states = self.self_attn_layer_norm(hidden_states) + hidden_states = self.self_attn( + hidden_states=hidden_states, + cu_seqlens=cu_seqlens, + ) + hidden_states = residual + hidden_states + residual = hidden_states + hidden_states = self.final_layer_norm(hidden_states) + hidden_states = self.fc1(hidden_states) + hidden_states = self.activation_fn(hidden_states) + hidden_states = self.fc2(hidden_states) + hidden_states = residual + hidden_states + + if hidden_states.dtype == torch.float16: + clamp_value = torch.finfo(hidden_states.dtype).max - 1000 + hidden_states = torch.clamp(hidden_states, min=-clamp_value, max=clamp_value) + + outputs = (hidden_states, ) + + return outputs + + +class SinusoidsPositionEmbedding(nn.Module): + + def __init__(self, length, channels, max_timescale=10000): + super().__init__() + if channels % 2 != 0: + raise ValueError('SinusoidsPositionEmbedding needs even channels input') + log_timescale_increment = np.log(max_timescale) / (channels // 2 - 1) + inv_timescales = torch.exp(-log_timescale_increment * torch.arange(channels // 2).float()) + scaled_time = torch.arange(length)[:, np.newaxis] * inv_timescales[np.newaxis, :] + self.register_buffer( + 'positional_embedding', + torch.cat([torch.sin(scaled_time), torch.cos(scaled_time)], dim=1), + persistent=False, + ) + + def forward(self, seqlen: int): + return self.positional_embedding[:seqlen, :] + + +class Qwen3OmniMoeAudioEncoder(nn.Module): + """Qwen3OmniMoeAudioEncoder.""" + + def __init__(self, config, dtype: torch.dtype = None, device: torch.device = None) -> None: + super().__init__() + + embed_dim = config.d_model + self.num_mel_bins = config.num_mel_bins + self.max_source_positions = config.max_source_positions + self.embed_scale = math.sqrt(embed_dim) if config.scale_embedding else 1.0 + self.n_window = config.n_window + self.positional_embedding = SinusoidsPositionEmbedding(self.max_source_positions, embed_dim) + self.layers = nn.ModuleList( + [Qwen3OmniMoeAudioEncoderLayer(config, dtype=dtype, device=device) for _ in range(config.encoder_layers)]) + self.ln_post = LayerNorm(config.d_model, eps=1e-5, dtype=dtype, device=device) + self.conv2d1 = nn.Conv2d(1, config.downsample_hidden_size, 3, 2, padding=1, dtype=dtype, device=device) + self.conv2d2 = nn.Conv2d(config.downsample_hidden_size, + config.downsample_hidden_size, + 3, + 2, + padding=1, + dtype=dtype, + device=device) + self.conv2d3 = nn.Conv2d(config.downsample_hidden_size, + config.downsample_hidden_size, + 3, + 2, + padding=1, + dtype=dtype, + device=device) + conv_out_dim = config.downsample_hidden_size * ((((config.num_mel_bins + 1) // 2 + 1) // 2 + 1) // 2) + self.conv_out = nn.Linear( + conv_out_dim, + config.d_model, + bias=False, + dtype=dtype, + device=device, + ) + self.proj1 = nn.Linear(config.d_model, config.d_model, dtype=dtype, device=device) + self.act = ACT2FN[config.activation_function] + self.proj2 = nn.Linear(config.d_model, config.output_dim, dtype=dtype, device=device) + self.n_window_infer = config.n_window_infer + self.conv_chunksize = config.conv_chunksize + + def forward( + self, + input_features: torch.Tensor, + feature_lens: torch.Tensor, + aftercnn_lens=None, + ): + r"""feature_lens (`torch.LongTensor` of shape `(batch_size,)`): + + mel length + aftercnn_lens (`torch.LongTensor` of shape `(batch_size,)`): + mel length after cnn + """ + aftercnn_lens = _get_feat_extract_output_lengths(feature_lens) + chunk_num = torch.ceil(feature_lens / (self.n_window * 2)).long() + + chunk_lengths = torch.tensor( + [self.n_window * 2] * chunk_num.sum(), + dtype=torch.long, + device=feature_lens.device, + ) + tail_chunk_index = F.pad(chunk_num, (1, 0), value=-1).cumsum(0)[1:] + chunk_lengths[tail_chunk_index] = feature_lens % (self.n_window * 2) + chunk_lengths[chunk_lengths == 0] = self.n_window * 2 + + chunk_list = input_features.T.split(chunk_lengths.tolist(), dim=0) + padded_feature = nn.utils.rnn.pad_sequence(chunk_list, batch_first=True).transpose(1, 2) + feature_lens_after_cnn = _get_feat_extract_output_lengths(chunk_lengths) + padded_mask_after_cnn = nn.utils.rnn.pad_sequence( + [torch.ones(length, dtype=torch.bool, device=padded_feature.device) for length in feature_lens_after_cnn], + batch_first=True, + ) + padded_feature = padded_feature.unsqueeze(1) + # Split to chunk to avoid OOM during convolution + padded_embeds = [] + for chunk in padded_feature.split(self.conv_chunksize, dim=0): + padded_embed = F.gelu(self.conv2d1(chunk)) + padded_embed = F.gelu(self.conv2d2(padded_embed)) + padded_embed = F.gelu(self.conv2d3(padded_embed)) + padded_embeds.append(padded_embed) + padded_embed = torch.cat(padded_embeds, dim=0) + b, c, f, t = padded_embed.size() + padded_embed = self.conv_out(padded_embed.permute(0, 3, 1, 2).contiguous().view(b, t, c * f)) + + positional_embedding = ( + self.positional_embedding.positional_embedding[:padded_embed.shape[1], :].unsqueeze(0).to( + padded_embed.dtype)) + padded_embed = padded_embed + positional_embedding + hidden_states = padded_embed[padded_mask_after_cnn] + cu_chunk_lens = [0] + window_aftercnn = padded_mask_after_cnn.shape[-1] * (self.n_window_infer // (self.n_window * 2)) + for cnn_len in aftercnn_lens: + cu_chunk_lens += [window_aftercnn] * (cnn_len // window_aftercnn) + remainder = cnn_len % window_aftercnn + if remainder != 0: + cu_chunk_lens += [remainder] + cu_seqlens = torch.tensor(cu_chunk_lens, device=aftercnn_lens.device).cumsum(-1, dtype=torch.int32) + + for encoder_layer in self.layers: + layer_outputs = encoder_layer( + hidden_states, + cu_seqlens, + ) + + hidden_states = layer_outputs[0] + + hidden_states = self.ln_post(hidden_states) + hidden_states = self.proj1(hidden_states) + hidden_states = self.act(hidden_states) + hidden_states = self.proj2(hidden_states) + return hidden_states + + +class Qwen3OmniMoeVisionPatchMerger(nn.Module): + """Vision patch merger. + + Different namings with qwen3vl, but actual calculations are the same. + """ + + def __init__(self, + config: PretrainedConfig, + use_postshuffle_norm=False, + dtype: torch.dtype = None, + device: torch.device = None) -> None: + super().__init__() + self.hidden_size = config.hidden_size * (config.spatial_merge_size**2) + self.use_postshuffle_norm = use_postshuffle_norm + self.ln_q = LayerNorm(self.hidden_size if use_postshuffle_norm else config.hidden_size, + eps=1e-6, + dtype=dtype, + device=device) + self.mlp = nn.ModuleList([ + build_colwise_linear( + self.hidden_size, + self.hidden_size, + bias=True, + dtype=dtype, + device=device, + is_tp=True, + ), + nn.GELU(), + build_rowwise_linear( + self.hidden_size, + config.out_hidden_size, + bias=True, + dtype=dtype, + device=device, + is_tp=True, + ), + ]) + + def forward(self, x: torch.Tensor) -> torch.Tensor: + x = self.ln_q(x.view(-1, self.hidden_size) if self.use_postshuffle_norm else x).view(-1, self.hidden_size) + for layer in self.mlp: + x = layer(x) + return x + + +@vlm_model +class Qwen3OmniMoeVisionEncoder(nn.Module): + """Vision transformer.""" + + def __init__(self, config: PretrainedConfig, dtype: torch.dtype = None, device: torch.device = None): + super().__init__() + self.config = config + self.spatial_merge_size = config.spatial_merge_size + + self.patch_embed = Qwen3VLVisionPatchEmbed(config=config, dtype=dtype, device=device) + + self.pos_embed = nn.Embedding(config.num_position_embeddings, config.hidden_size, dtype=dtype, device=device) + self.num_grid_per_side = int(config.num_position_embeddings**0.5) + + head_dim = config.hidden_size // config.num_heads + self.rotary_pos_emb = Qwen3VLVisionRotaryEmbedding(head_dim // 2, device=device) + + self.blocks = nn.ModuleList( + [Qwen3VLVisionBlock(config, layer_idx, dtype=dtype, device=device) for layer_idx in range(config.depth)]) + self.merger = Qwen3OmniMoeVisionPatchMerger(config=config, + use_postshuffle_norm=False, + dtype=dtype, + device=device) + + self.deepstack_visual_indexes = config.deepstack_visual_indexes + self.merger_list = nn.ModuleList([ + Qwen3OmniMoeVisionPatchMerger(config=config, use_postshuffle_norm=True, dtype=dtype, device=device) + for _ in range(len(config.deepstack_visual_indexes)) + ]) + + @staticmethod + @lru_cache(maxsize=1024) + def rot_pos_ids(h: int, w: int, spatial_merge_size: int) -> torch.Tensor: + h_div = h // spatial_merge_size + w_div = w // spatial_merge_size + + hpos_ids = np.broadcast_to(np.arange(h).reshape(h, 1), (h, w)) + hpos_ids = hpos_ids.reshape( + h_div, + spatial_merge_size, + w_div, + spatial_merge_size, + ) + hpos_ids = hpos_ids.transpose(0, 2, 1, 3) + hpos_ids = hpos_ids.flatten() + + wpos_ids = np.broadcast_to(np.arange(w).reshape(1, w), (h, w)) + wpos_ids = wpos_ids.reshape( + h_div, + spatial_merge_size, + w_div, + spatial_merge_size, + ) + wpos_ids = wpos_ids.transpose(0, 2, 1, 3) + wpos_ids = wpos_ids.flatten() + + return torch.from_numpy(np.stack([hpos_ids, wpos_ids], axis=-1)) + + def rot_pos_emb(self, grid_thw: torch.Tensor) -> torch.Tensor: + """Rotary position embedding.""" + pos_ids = [] + + for t, h, w in grid_thw: + base = self.rot_pos_ids(int(h), int(w), self.spatial_merge_size) + pos_ids.append(base if t == 1 else base.repeat(t, 1)) + + pos_ids = torch.cat(pos_ids, dim=0) + max_grid_size = grid_thw[:, 1:].max() + rotary_pos_emb_full = self.rotary_pos_emb(max_grid_size) + rotary_pos_emb = rotary_pos_emb_full[pos_ids].flatten(1) + + return rotary_pos_emb + + # copy from https://github.com/vllm-project/vllm/blob/main/vllm/model_executor/models/qwen3_vl.py#L474 + def fast_pos_embed_interpolate(self, grid_thw: list[list[int]]) -> torch.Tensor: + num_grid_per_side = self.num_grid_per_side + m_size = self.spatial_merge_size + hidden_dim = self.pos_embed.embedding_dim + device = self.pos_embed.weight.device + + outputs = [] + for t, h, w in grid_thw: + h_idxs = torch.linspace(0, num_grid_per_side - 1, h, dtype=torch.float32, device=device) + w_idxs = torch.linspace(0, num_grid_per_side - 1, w, dtype=torch.float32, device=device) + + h_floor = h_idxs.to(torch.long) + w_floor = w_idxs.to(torch.long) + h_ceil = torch.clamp(h_floor + 1, max=num_grid_per_side - 1) + w_ceil = torch.clamp(w_floor + 1, max=num_grid_per_side - 1) + + dh = h_idxs - h_floor + dw = w_idxs - w_floor + + # Create meshgrid view for all h, w vars + dh_grid, dw_grid = torch.meshgrid(dh, dw, indexing='ij') + h_floor_grid, w_floor_grid = torch.meshgrid(h_floor, w_floor, indexing='ij') + h_ceil_grid, w_ceil_grid = torch.meshgrid(h_ceil, w_ceil, indexing='ij') + + # original computation of weights + # w00 = (1 - dh_grid) * (1 - dw_grid) + # w01 = (1 - dh_grid) * dw_grid + # w10 = dh_grid * (1 - dw_grid) + # w11 = dh_grid * dw_grid + # we reuse w11 here to avoid duplicate + # dh_grid * dw_grid computation + w11 = dh_grid * dw_grid + w10 = dh_grid - w11 + w01 = dw_grid - w11 + w00 = 1 - dh_grid - w01 + + h_grid = torch.stack([h_floor_grid, h_floor_grid, h_ceil_grid, h_ceil_grid]) + w_grid = torch.stack([w_floor_grid, w_ceil_grid, w_floor_grid, w_ceil_grid]) + h_grid_idx = h_grid * num_grid_per_side + + indices = (h_grid_idx + w_grid).reshape(4, -1) + weights = torch.stack([w00, w01, w10, w11], dim=0).reshape(4, -1, 1) + weights = weights.to(dtype=self.pos_embed.weight.dtype, device=device) + + embeds = self.pos_embed(indices) + embeds *= weights + combined = embeds.sum(dim=0) + + combined = combined.reshape(h // m_size, m_size, w // m_size, m_size, hidden_dim) + combined = combined.permute(0, 2, 1, 3, 4).reshape(1, -1, hidden_dim) + repeated = combined.expand(t, -1, -1).reshape(-1, hidden_dim) + outputs.append(repeated) + + return torch.cat(outputs, dim=0) + + def forward(self, hidden_states: torch.Tensor, cu_seqlens: torch.Tensor, rotary_pos_emb: torch.Tensor, + pos_embeds: torch.Tensor) -> torch.Tensor: + """forward.""" + hidden_states = self.patch_embed(hidden_states) + hidden_states = hidden_states + pos_embeds + cu_seqlens = torch.nn.functional.pad(cu_seqlens, (1, 0), value=0) + + deepstack_feature_lists = [] + for layer_num, blk in enumerate(self.blocks): + hidden_states = blk(hidden_states, cu_seqlens=cu_seqlens, rotary_pos_emb=rotary_pos_emb) + if layer_num in self.deepstack_visual_indexes: + deepstack_merge_idx = self.deepstack_visual_indexes.index(layer_num) + deepstack_feature = self.merger_list[deepstack_merge_idx](hidden_states) + deepstack_feature_lists.append(deepstack_feature) + + hidden_states = self.merger(hidden_states) + + return hidden_states, deepstack_feature_lists + + +class Qwen3OmniMoeThinkerForConditionalGeneration(nn.Module, DeployModelMixinV1, CudaGraphMixin): + """ModelForCausalLM.""" + + packed_modules_mapping = { + 'qkv_proj': [ + 'q_proj', + 'k_proj', + 'v_proj', + ], + 'gate_up_proj': [ + 'gate_proj', + 'up_proj', + ], + } + + def __init__(self, + config: PretrainedConfig, + ctx_mgr: StepContextManager, + dtype: torch.dtype = None, + device: torch.device = None): + super().__init__() + self.config = config + self.ctx_mgr = ctx_mgr + thinker_config = config.thinker_config + + # build preprocessor + self.input_processor = Qwen3OmniInputProcessor(self.config) + + # build audio encoder + self.audio_tower = Qwen3OmniMoeAudioEncoder( + thinker_config.audio_config, + dtype=dtype, + device=device, + ) + + # build vision encoder + self.visual = Qwen3OmniMoeVisionEncoder( + thinker_config.vision_config, + dtype=dtype, + device=device, + ) + + # build text model + self.language_model = Qwen3VLMoeTextModel(thinker_config.text_config, dtype=dtype, device=device) + + # build lm_head + self.lm_head = build_rowwise_linear(thinker_config.text_config.hidden_size, + thinker_config.text_config.vocab_size, + bias=False, + dtype=dtype, + device=device) + + def forward( + self, + input_ids: torch.Tensor, + position_ids: torch.Tensor, + past_key_values: list[list[torch.Tensor]], + attn_metadata: Any = None, + inputs_embeds: torch.Tensor = None, + mrope_position_ids: torch.Tensor = None, + pixel_values: torch.Tensor = None, + vis_cu_seqlens: torch.Tensor = None, + vis_pos_emb: torch.Tensor = None, + image_mask: torch.Tensor = None, + pos_embeds: torch.Tensor = None, + grid_thw: torch.Tensor = None, + audio_values: torch.Tensor = None, + audio_mask: torch.Tensor = None, + audio_feature_lengths: torch.Tensor = None, + **kwargs, + ): + """Model forward, return logits.""" + + visual_pos_masks = None + deepstack_visual_embeds = None + if inputs_embeds is None: + inputs_embeds = self.get_input_embeddings()(input_ids) + + if pixel_values is not None: + dtype = inputs_embeds.dtype + pixel_values = pixel_values.to(dtype) + vis_pos_emb = (vis_pos_emb[0].to(dtype), vis_pos_emb[1].to(dtype)) + + # get image embeds and deepstack visual embeds + image_embeds, deepstack_visual_embeds = self.visual(pixel_values, + cu_seqlens=vis_cu_seqlens, + rotary_pos_emb=vis_pos_emb, + pos_embeds=pos_embeds) + + # split image embeds per sample + split_sizes = (grid_thw.prod(-1) // self.visual.spatial_merge_size**2).tolist() + image_embeds = torch.split(image_embeds, split_sizes) + image_embeds = torch.cat(image_embeds, dim=0).to(inputs_embeds.device, dtype) + + # mask and scatter to create final input embeddings + expanded_image_mask = image_mask.unsqueeze(-1).expand_as(inputs_embeds) + inputs_embeds = inputs_embeds.masked_scatter(expanded_image_mask, image_embeds) + + visual_pos_masks = expanded_image_mask + + if audio_values is not None: + dtype = inputs_embeds.dtype + audio_values = audio_values.to(dtype) + audio_embeds = self.audio_tower( + input_features=audio_values, + feature_lens=audio_feature_lengths, + ) + inputs_embeds = inputs_embeds.masked_scatter(audio_mask.unsqueeze(-1), audio_embeds) + + hidden_states = self.language_model( + input_ids=input_ids, + position_ids=position_ids, + past_key_values=past_key_values, + attn_metadata=attn_metadata, + inputs_embeds=inputs_embeds, + mrope_position_ids=mrope_position_ids, + # args for deepstack + visual_pos_masks=visual_pos_masks, + deepstack_visual_embeds=deepstack_visual_embeds, + ) + return hidden_states + + def get_logits(self, hidden_states: torch.Tensor): + """Compute logits of the model output.""" + return self.lm_head(hidden_states) + + def get_input_embeddings(self): + """Get input embeddings.""" + return self.language_model.get_input_embeddings() + + def prepare_inputs_for_generation( + self, + past_key_values: list[list[torch.Tensor]], + inputs_embeds: torch.Tensor | None = None, + context: StepContext = None, + ): + """Prepare input.""" + + # get input_ids, position_ids and attention metadatas + input_ids = context.input_ids + position_ids = context.position_ids + attn_metadata = context.attn_metadata + + pixel_values = None + vis_cu_seqlens = None + vis_pos_emb = None + image_mask = None + grid_thw = None + pos_embeds = None + audio_values = None + audio_mask = None + audio_feature_lengths = None + if context.input_multimodals is not None: + mm_inputs = [input_mm.get('mm_data', []) for input_mm in context.input_multimodals] + # flatten batch + mm_inputs = [item for sublist in mm_inputs for item in sublist] + + if len(mm_inputs) > 0: + audio_inputs = [inp for inp in mm_inputs if inp.modality == Modality.AUDIO] + visual_inputs = [inp for inp in mm_inputs if inp.modality in [Modality.IMAGE, Modality.VIDEO]] + + if audio_inputs: + audio_chunks = [] + for inp in audio_inputs: + audio_data = inp.data + if audio_data.dim() == 3 and audio_data.shape[0] == 1: + audio_data = audio_data.squeeze(0) + audio_chunks.append(audio_data) + audio_values = torch.cat(audio_chunks, dim=-1) + audio_mask = self.get_multimodal_mask(input_ids, audio_inputs) + audio_feature_lengths = torch.cat([inp.meta['audio_feature_lengths'] for inp in audio_inputs]) + + if visual_inputs: + pixel_values = torch.cat([inp.data for inp in visual_inputs]) + image_mask = self.get_multimodal_mask(input_ids, visual_inputs) + + grid_thw = torch.stack([data.meta['grid_thw'] for data in visual_inputs]).cpu() + vis_pos_emb = self.visual.rot_pos_emb(grid_thw) + pos_embeds = self.visual.fast_pos_embed_interpolate(grid_thw) + vis_cu_seqlens = torch.repeat_interleave(grid_thw[:, 1] * grid_thw[:, 2], + grid_thw[:, 0]).to(pixel_values.device) + vis_cu_seqlens = vis_cu_seqlens.cumsum(dim=0, dtype=torch.int32) + vis_pos_emb = vis_pos_emb.repeat(1, 2) + vis_pos_emb = (vis_pos_emb.cos(), vis_pos_emb.sin()) + + mrope_position_ids = getattr(context, 'mrope_position_ids', None) + + # process vision embeddings + vision_embeddings = context.input_embeddings + vision_embedding_indexing = context.input_embedding_indexing + if vision_embeddings is not None and len(vision_embeddings) > 0: + if inputs_embeds is None: + inputs_embeds = self.get_input_embeddings()(input_ids) + inputs_embeds[:, vision_embedding_indexing, :] = vision_embeddings.to(inputs_embeds) + + # inputs of forward + return dict( + input_ids=input_ids, + position_ids=position_ids, + past_key_values=past_key_values, + attn_metadata=attn_metadata, + inputs_embeds=inputs_embeds, + mrope_position_ids=mrope_position_ids, + pixel_values=pixel_values, + vis_cu_seqlens=vis_cu_seqlens, + vis_pos_emb=vis_pos_emb, + image_mask=image_mask, + grid_thw=grid_thw, + pos_embeds=pos_embeds, + audio_values=audio_values, + audio_mask=audio_mask, + audio_feature_lengths=audio_feature_lengths, + ) + + def rename_weight(self, name: str) -> str: + """Rename weight.""" + if name.startswith('thinker.model.'): + return 'language_model.' + name[len('thinker.model.'):] + elif name.startswith('thinker.visual.'): + return 'visual.' + name[len('thinker.visual.'):] + elif name.startswith('thinker.audio_tower.'): + return 'audio_tower.' + name[len('thinker.audio_tower.'):] + # thinker_config.text_config tie_word_embeddings = False + elif name.startswith('thinker.lm_head.'): + return 'lm_head.' + name[len('thinker.lm_head.'):] + return name + + def _load_weight_experts(self, name: str, loaded_weight: torch.Tensor, params_dict: dict[str, nn.Parameter], + expert_params_mapping: list): + """Load weight experts.""" + + for (param_name, weight_name, expert_id, shard_id) in expert_params_mapping: + if weight_name not in name: + continue + name = name.replace(weight_name, param_name) + param = params_dict[name] + load_weight(param, loaded_weight, expert_id=expert_id, shard_id=shard_id) + break + else: + param = params_dict[name] + load_weight(param, loaded_weight) + + def load_weights(self, weights: Iterable[tuple[str, torch.Tensor]]): + """Load weights.""" + # modify from vllm + stacked_params_mapping = [ + # (param_name, shard_name, shard_id) + ('.qkv_proj', '.q_proj', 'q'), + ('.qkv_proj', '.k_proj', 'k'), + ('.qkv_proj', '.v_proj', 'v'), + ('.gate_up_proj', '.gate_proj', 0), + ('.gate_up_proj', '.up_proj', 1), + ] + + # expert mapping + num_experts = self.config.thinker_config.text_config.num_experts + expert_params_mapping = [] + for exp_id in range(num_experts): + # (param_name, weight_name, expert_id, shard_id) + gate_param = ('.experts.gate_up', f'.experts.{exp_id}.gate_proj', exp_id, 'gate') + up_param = ('.experts.gate_up', f'.experts.{exp_id}.up_proj', exp_id, 'up') + down_param = ('.experts.down', f'.experts.{exp_id}.down_proj', exp_id, 'down') + expert_params_mapping += [gate_param, up_param, down_param] + + params_dict = dict(self.named_parameters()) + for name, loaded_weight in weights: + if 'rotary_emb.inv_freq' in name: + continue + if ('rotary_emb.cos_cached' in name or 'rotary_emb.sin_cached' in name): + continue + # skip talker and code2wav weights + if ('talker.' in name or 'code2wav.' in name): + continue + + name = name.replace('.block_sparse_moe.', '.mlp.') + if '.experts' in name: + self._load_weight_experts(name, + loaded_weight, + params_dict, + expert_params_mapping=expert_params_mapping) + else: + for (param_name, weight_name, shard_id) in stacked_params_mapping: + if weight_name not in name: + continue + name = name.replace(weight_name, param_name) + param = params_dict[name] + load_weight(param, loaded_weight, shard_id=shard_id) + break + else: + if '.qkv.' in name: + param = params_dict[name] + q, k, v = param.weight_spliter(loaded_weight) + load_weight(param, q, shard_id='q') + load_weight(param, k, shard_id='k') + load_weight(param, v, shard_id='v') + else: + param = params_dict[name] + load_weight(param, loaded_weight) + + def get_input_processor(self) -> BaseModelInputProcessor: + """Get input processor.""" + return self.input_processor + + +class Qwen3OmniInputProcessor(BaseModelInputProcessor): + """Qwen3 Omni input processor.""" + + def __init__(self, config: PretrainedConfig) -> None: + self.config = config + thinker_config = config.thinker_config + self.spatial_merge_size = thinker_config.vision_config.spatial_merge_size + self.position_id_per_seconds = thinker_config.position_id_per_seconds + + @staticmethod + def _get_multimodal_pos_ids(grid_thw: Sequence[int], + spatial_merge_size: int = 2, + second_per_grid: float | None = None, + position_id_per_seconds: int = 25) -> np.ndarray: + """Get mrope ids.""" + t, h, w = grid_thw + h = h // spatial_merge_size + w = w // spatial_merge_size + stride = np.array([h * w, w, 1])[None] + size = np.array([t, h, w])[None] + pos_ids = np.arange(t * h * w)[:, None].repeat(3, axis=1) + pos_ids = pos_ids // stride % size + if second_per_grid is not None: + # HF Qwen3-Omni spaces video temporal ids by elapsed seconds. + scale = second_per_grid * position_id_per_seconds + pos_ids[:, 0] = np.rint(pos_ids[:, 0] * scale) + return pos_ids + + def make_mrope(self, grid_thw: torch.Tensor, second_per_grid: float | None = None): + grid_thw = grid_thw.tolist() if grid_thw.dim() == 1 else grid_thw[0].tolist() + img_pos_ids = self._get_multimodal_pos_ids( + grid_thw, + spatial_merge_size=self.spatial_merge_size, + second_per_grid=second_per_grid, + position_id_per_seconds=self.position_id_per_seconds, + ) + return img_pos_ids + + def _make_image_mm_data(self, input_mm: dict[str, Any]) -> MultiModalData: + """Make image MultiModalData.""" + pixel_values = input_mm['pixel_values'] + image_grid_thw = input_mm['image_grid_thw'] + offset = input_mm['offset'] + image_token_id = input_mm['image_token_id'] + + mrope_pos_ids = self.make_mrope(image_grid_thw) + + mm_data = MultiModalData(modality=Modality.IMAGE, + data=pixel_values, + start=offset[0], + end=offset[1], + mrope_pos_ids=mrope_pos_ids, + meta=dict(grid_thw=image_grid_thw, image_token_id=image_token_id)) + return mm_data + + def _make_video_mm_data(self, input_mm: dict[str, Any]) -> MultiModalData: + """Make video MultiModalData.""" + pixel_values_videos = input_mm['pixel_values_videos'] + video_grid_thw = input_mm['video_grid_thw'] + offset = input_mm['offset'] + video_token_id = input_mm['video_token_id'] + + mrope_pos_ids = self.make_mrope(video_grid_thw, second_per_grid=input_mm.get('second_per_grid')) + + mm_data = MultiModalData(modality=Modality.VIDEO, + data=pixel_values_videos, + start=offset[0], + end=offset[1], + mrope_pos_ids=mrope_pos_ids, + meta=dict( + grid_thw=video_grid_thw, + video_token_id=video_token_id, + second_per_grid=input_mm.get('second_per_grid'), + )) + return mm_data + + def _make_audio_mm_data(self, input_mm: dict[str, Any]) -> MultiModalData: + """Make audio MultiModalData.""" + input_features = input_mm['input_features'] + offset = input_mm['offset'] + audio_token_id = input_mm['audio_token_id'] + feature_attention_mask = input_mm.get('feature_attention_mask') + if feature_attention_mask is not None: + audio_feature_lengths = torch.sum(feature_attention_mask, dim=1) + input_features = input_features.permute(0, 2, 1)[feature_attention_mask.bool()].permute(1, 0) + else: + audio_feature_lengths = input_mm.get('audio_feature_lengths') + if audio_feature_lengths is None: + audio_feature_lengths = torch.full((input_features.shape[0], ), + input_features.shape[-1], + dtype=torch.long) + audio_len = offset[1] - offset[0] + mrope_pos_ids = np.arange(audio_len, dtype=np.int64)[:, None].repeat(3, axis=1) + + mm_data = MultiModalData(modality=Modality.AUDIO, + data=input_features, + start=offset[0], + end=offset[1], + mrope_pos_ids=mrope_pos_ids, + meta=dict( + audio_token_id=audio_token_id, + audio_feature_lengths=audio_feature_lengths, + )) + return mm_data + + def preprocess_input(self, + input_ids: list[int], + input_multimodals: list[dict[str, Any]] = None, + **kwargs) -> PreprocessInputResult: + """Prepare multimodal input.""" + if input_multimodals is None or len(input_multimodals) == 0: + return input_ids, input_multimodals + + input_mm_data = [] + for input_mm in input_multimodals: + modality = input_mm.get('modality') + + if modality == Modality.IMAGE: + mm_data = self._make_image_mm_data(input_mm) + elif modality == Modality.VIDEO: + mm_data = self._make_video_mm_data(input_mm) + elif modality == Modality.AUDIO: + mm_data = self._make_audio_mm_data(input_mm) + + input_mm_data.append(mm_data) + + result = PreprocessInputResult(input_ids=input_ids, input_multimodals=dict(mm_data=input_mm_data)) + + return result diff --git a/lmdeploy/pytorch/models/utils/model.py b/lmdeploy/pytorch/models/utils/model.py index 4d6fc5e635..4b140fa692 100644 --- a/lmdeploy/pytorch/models/utils/model.py +++ b/lmdeploy/pytorch/models/utils/model.py @@ -153,20 +153,23 @@ def build_lm_head(self, return lm_head def get_multimodal_mask(self, input_ids: torch.Tensor, mm_inputs: list[MultiModalData]) -> torch.Tensor: - """Get position masks for vision tokens.""" + """Get position masks for multimodal tokens.""" image_token_id = next((m.meta.get('image_token_id') for m in mm_inputs if m.modality == Modality.IMAGE), None) video_token_id = next((m.meta.get('video_token_id') for m in mm_inputs if m.modality == Modality.VIDEO), None) + audio_token_id = next((m.meta.get('audio_token_id') for m in mm_inputs if m.modality == Modality.AUDIO), None) ts_token_id = next((m.meta.get('ts_token_id') for m in mm_inputs if m.modality == Modality.TIME_SERIES), None) - image_mask, video_mask, ts_mask = None, None, None + image_mask, video_mask, audio_mask, ts_mask = None, None, None, None if image_token_id is not None: image_mask = (input_ids == image_token_id) if video_token_id is not None: video_mask = (input_ids == video_token_id) + if audio_token_id is not None: + audio_mask = (input_ids == audio_token_id) if ts_token_id is not None: ts_mask = (input_ids == ts_token_id) - masks = [m for m in (image_mask, video_mask, ts_mask) if m is not None] + masks = [m for m in (image_mask, video_mask, audio_mask, ts_mask) if m is not None] multimodal_mask = None if not masks else torch.stack(masks, dim=0).any(dim=0) return multimodal_mask diff --git a/lmdeploy/pytorch/multimodal/data_type.py b/lmdeploy/pytorch/multimodal/data_type.py index 6e4d847bf9..e5292cc5c5 100644 --- a/lmdeploy/pytorch/multimodal/data_type.py +++ b/lmdeploy/pytorch/multimodal/data_type.py @@ -16,12 +16,11 @@ class MultiModalData: start: int end: int | None = None meta: dict[str, Any] | None = None + modality: Modality = Modality.IMAGE # for qwen-vl mrope_pos_ids: np.ndarray | None = None - modality: Modality = Modality.IMAGE - def __post_init__(self): if self.end is None: self.end = self.start diff --git a/lmdeploy/serve/processors/multimodal.py b/lmdeploy/serve/processors/multimodal.py index bd9ff183df..4c0d5a32d9 100644 --- a/lmdeploy/serve/processors/multimodal.py +++ b/lmdeploy/serve/processors/multimodal.py @@ -8,6 +8,7 @@ from lmdeploy.tokenizer import Tokenizer from lmdeploy.utils import get_logger from lmdeploy.vl.constants import Modality +from lmdeploy.vl.media.audio import AudioMediaIO from lmdeploy.vl.media.connection import load_from_url from lmdeploy.vl.media.image import ImageMediaIO from lmdeploy.vl.media.time_series import TimeSeriesMediaIO @@ -148,6 +149,9 @@ def _require_data_src(): data, metadata = load_from_url( _require_data_src(), VideoMediaIO(image_io=ImageMediaIO(), **media_io_kwargs.get('video', {}))) item_params['video_metadata'] = metadata + elif item_type in ('audio_url', 'audio'): + modality = Modality.AUDIO + data = load_from_url(_require_data_src(), AudioMediaIO(**media_io_kwargs.get('audio', {}))) elif item_type in ('time_series_url', 'time_series'): modality = Modality.TIME_SERIES data = load_from_url(_require_data_src(), TimeSeriesMediaIO(**media_io_kwargs.get('time_series', {}))) @@ -327,8 +331,11 @@ def _re_format_prompt_images_pair(prompt: tuple) -> dict: def _has_multimodal_input(self, messages: list[dict]) -> bool: """Check if messages contain multimodal input such as images, videos, - or time series.""" - multimodal_types = ['image_url', 'image_data', 'image', 'video_url', 'video', 'time_series_url', 'time_series'] + audios, or time series.""" + multimodal_types = [ + 'image_url', 'image_data', 'image', 'video_url', 'video', 'audio_url', 'audio', 'time_series_url', + 'time_series' + ] return any( isinstance(message.get('content'), list) and any( item.get('type') in multimodal_types for item in message['content']) for message in messages) diff --git a/lmdeploy/utils.py b/lmdeploy/utils.py index 7f9c891e14..fb0e6f069a 100644 --- a/lmdeploy/utils.py +++ b/lmdeploy/utils.py @@ -324,6 +324,10 @@ def _get_and_verify_max_len( for key in llm_keys: hf_config = getattr(hf_config, key, hf_config) + # for qwen3-omni thinker + if hasattr(hf_config, 'thinker_config'): + hf_config = hf_config.thinker_config.text_config + logger = get_logger('lmdeploy') derived_max_model_len = float('inf') possible_keys = [ diff --git a/lmdeploy/vl/media/audio.py b/lmdeploy/vl/media/audio.py new file mode 100644 index 0000000000..ba67f8753c --- /dev/null +++ b/lmdeploy/vl/media/audio.py @@ -0,0 +1,63 @@ +# Copyright (c) OpenMMLab. All rights reserved. +# adapted from https://github.com/vllm-project/vllm/blob/main/vllm/multimodal/media/audio.py + +import base64 +from io import BytesIO +from pathlib import Path + +import numpy.typing as npt + +from .base import MediaIO + + +class AudioMediaIO(MediaIO[tuple[npt.NDArray, float]]): + + def __init__(self, **kwargs) -> None: + super().__init__() + + # lazy import to avoid dependency issues for users who don't use audio features + try: + import librosa + self._librosa = librosa + except ImportError: + raise ImportError('Please install librosa via `pip install librosa`.') + + try: + import soundfile + self._soundfile = soundfile + except ImportError: + raise ImportError('Please install soundfile via `pip install soundfile`.') + + # Qwen3-Omni's feature extractor expects 16 kHz audio; allow explicit + # media-io overrides but resample to that rate by default. + self.sampling_rate = kwargs.get('sampling_rate', kwargs.get('sample_rate', 16000)) + + # for potential custom arguments from --media-io-kwargs + self.kwargs = kwargs + + def load_bytes(self, data: bytes) -> tuple[npt.NDArray, float]: + return self._librosa.load(BytesIO(data), sr=self.sampling_rate) + + def load_base64( + self, + media_type: str, + data: str, + ) -> tuple[npt.NDArray, float]: + return self.load_bytes(base64.b64decode(data)) + + def load_file(self, filepath: Path) -> tuple[npt.NDArray, float]: + return self._librosa.load(filepath, sr=self.sampling_rate) + + def encode_base64( + self, + media: tuple[npt.NDArray, int], + *, + audio_format: str = 'WAV', + ) -> str: + audio, sr = media + + with BytesIO() as buffer: + self._soundfile.write(buffer, audio, sr, format=audio_format) + data = buffer.getvalue() + + return base64.b64encode(data).decode('utf-8') diff --git a/lmdeploy/vl/model/base.py b/lmdeploy/vl/model/base.py index 02d29ac4c1..41f61cfa9a 100644 --- a/lmdeploy/vl/model/base.py +++ b/lmdeploy/vl/model/base.py @@ -32,6 +32,10 @@ class VisionModel(ABC): # video-related attributes 'pixel_values_videos': Modality.VIDEO, 'video_grid_thw': Modality.VIDEO, + 'video_second_per_grid': Modality.VIDEO, + # audio-related attributes + 'input_features': Modality.AUDIO, + 'feature_attention_mask': Modality.AUDIO, # time series-related attributes 'ts_values': Modality.TIME_SERIES, 'ts_sr': Modality.TIME_SERIES, @@ -42,6 +46,7 @@ class VisionModel(ABC): FEATURE_NAMES = [ 'pixel_values', 'pixel_values_videos', + 'input_features', 'ts_values', ] @@ -104,6 +109,7 @@ def preprocess(self, mm_items = self.collect_multimodal_items(messages) raw_images, raw_videos, video_metadatas = [], [], [] + raw_audios = [] raw_time_series, sampling_rates = [], [] for modality, data, params in mm_items: if modality == Modality.IMAGE: @@ -111,6 +117,8 @@ def preprocess(self, elif modality == Modality.VIDEO: raw_videos.append(data) video_metadatas.append(params.get('video_metadata', None)) + elif modality == Modality.AUDIO: + raw_audios.append(data[0] if isinstance(data, tuple) else data) elif modality == Modality.TIME_SERIES: raw_time_series.append(data) sampling_rates.append(params.get('sampling_rate', None)) @@ -121,6 +129,7 @@ def preprocess(self, kwargs = {} images_kwargs = {} videos_kwargs = {} + audio_kwargs = {} mm_processor_kwargs = mm_processor_kwargs or {} if raw_images: kwargs['images'] = raw_images @@ -140,15 +149,24 @@ def preprocess(self, modality='video') if video_size is not None: videos_kwargs['size'] = video_size + if raw_audios: + kwargs['audio'] = raw_audios + audio_kwargs = dict(mm_processor_kwargs.get('audio') or {}) + feature_extractor = getattr(self.processor, 'feature_extractor', None) + sampling_rate = getattr(feature_extractor, 'sampling_rate', None) + if sampling_rate is not None: + audio_kwargs.setdefault('sampling_rate', sampling_rate) if images_kwargs: kwargs['images_kwargs'] = images_kwargs if videos_kwargs: kwargs['videos_kwargs'] = videos_kwargs + if audio_kwargs: + kwargs['audio_kwargs'] = audio_kwargs if raw_time_series: assert hasattr(self, 'time_series_processor'), \ 'time series processor is not defined for time series input' - assert not raw_images and not raw_videos, \ - 'time series is not compatible with image/video input' + assert not raw_images and not raw_videos and not raw_audios, \ + 'time series is not compatible with image/video/audio input' self.tokenizer = self.processor.tokenizer time_series_processor = self.time_series_processor kwargs['time_series'] = raw_time_series diff --git a/lmdeploy/vl/model/builder.py b/lmdeploy/vl/model/builder.py index 7890d013c0..5b91fd911f 100644 --- a/lmdeploy/vl/model/builder.py +++ b/lmdeploy/vl/model/builder.py @@ -30,6 +30,7 @@ from .qwen2 import Qwen2VLModel # noqa F401 from .qwen3 import Qwen3VLModel # noqa F401 from .qwen3_5 import Qwen3_5Model # noqa F401 +from .qwen3_omni import Qwen3OmniModel # noqa F401 from .xcomposer2 import Xcomposer2VisionModel # noqa F401 from .yi import YiVisionModel # noqa F401 diff --git a/lmdeploy/vl/model/preprocess_utils.py b/lmdeploy/vl/model/preprocess_utils.py index bb591116d9..4869a67ec8 100644 --- a/lmdeploy/vl/model/preprocess_utils.py +++ b/lmdeploy/vl/model/preprocess_utils.py @@ -94,6 +94,30 @@ def get_expanded_mm_items(collected_mm_items, mm_tokens: 'MultimodalSpecialToken offset=item['offset'][0], image_token_id=token_id, )) + elif modality == Modality.VIDEO: + second_per_grid = item.get('video_second_per_grid') + if second_per_grid is not None: + second_per_grid = second_per_grid[0] + if isinstance(second_per_grid, torch.Tensor) and second_per_grid.numel() == 1: + second_per_grid = second_per_grid.item() + expanded_mm_items.append( + dict( + modality=modality, + pixel_values_videos=item['feature'], + video_grid_thw=item['video_grid_thw'][0], + offset=item['offset'][0], + second_per_grid=second_per_grid, + video_token_id=token_id, + )) + elif modality == Modality.AUDIO: + expanded_mm_items.append( + dict( + modality=modality, + input_features=item['feature'], + feature_attention_mask=item.get('feature_attention_mask'), + offset=item['offset'][0], + audio_token_id=token_id, + )) elif modality == Modality.TIME_SERIES: expanded_mm_items.append( dict( @@ -143,10 +167,6 @@ def get_expanded_mm_items(collected_mm_items, mm_tokens: 'MultimodalSpecialToken frames_per_video.append(T) total_frames += T - if num_items != total_frames: - expanded_mm_items.append(item) - continue - patches_per_video = [] for i in range(num_videos): grid = video_grid_thw[i] @@ -157,6 +177,25 @@ def get_expanded_mm_items(collected_mm_items, mm_tokens: 'MultimodalSpecialToken cumulative = torch.cumsum(torch.tensor(patches_per_video, dtype=torch.long), dim=0) slice_indices = [0] + cumulative.tolist() + if num_items != total_frames: + for video_idx in range(num_videos): + start, end = slice_indices[video_idx], slice_indices[video_idx + 1] + second_per_grid = item.get('video_second_per_grid') + if second_per_grid is not None: + second_per_grid = second_per_grid[video_idx] + if isinstance(second_per_grid, torch.Tensor) and second_per_grid.numel() == 1: + second_per_grid = second_per_grid.item() + expanded_mm_items.append( + dict( + modality=modality, + pixel_values_videos=item['feature'][start:end], + video_grid_thw=video_grid_thw[video_idx], + offset=item['offset'][video_idx], + second_per_grid=second_per_grid, + video_token_id=token_id, + )) + continue + frame_start_indices = [0] for i in range(num_videos): frame_start_indices.append(frame_start_indices[-1] + frames_per_video[i]) @@ -164,19 +203,41 @@ def get_expanded_mm_items(collected_mm_items, mm_tokens: 'MultimodalSpecialToken for video_idx in range(num_videos): start, end = slice_indices[video_idx], slice_indices[video_idx + 1] frame_start, frame_end = frame_start_indices[video_idx], frame_start_indices[video_idx + 1] + second_per_grid = item.get('video_second_per_grid') + if second_per_grid is not None: + second_per_grid = second_per_grid[video_idx] + if isinstance(second_per_grid, torch.Tensor) and second_per_grid.numel() == 1: + second_per_grid = second_per_grid.item() # TODO: zhouxinyu, not sure per-frame split is good or not # TODO: zhouxinyu, grid_thw [1, h, w] is only for qwen3vl t, h, w = video_grid_thw[video_idx].tolist() for frame_idx in range(t): video_feature = item['feature'][start:end] + offset = item['offset'][frame_start:frame_end][frame_idx] expanded_mm_items.append( dict( modality=modality, pixel_values_videos=video_feature[frame_idx * h * w:(frame_idx + 1) * h * w], video_grid_thw=torch.tensor([1, h, w]), - offset=item['offset'][frame_start:frame_end][frame_idx], + offset=offset, + second_per_grid=second_per_grid, video_token_id=token_id, )) + elif modality == Modality.AUDIO: + for i in range(num_items): + feature_attention_mask = item.get('feature_attention_mask') + if feature_attention_mask is not None: + feature_attention_mask = feature_attention_mask[i:i + 1] + expanded_mm_items.append( + dict( + modality=modality, + input_features=item['feature'][i:i + 1], + feature_attention_mask=feature_attention_mask, + offset=item['offset'][i], + audio_token_id=token_id, + )) + # HF processors return features grouped by modality; offsets restore prompt order for mixed inputs. + expanded_mm_items.sort(key=lambda item: item['offset'][0]) return expanded_mm_items diff --git a/lmdeploy/vl/model/qwen3_omni.py b/lmdeploy/vl/model/qwen3_omni.py new file mode 100644 index 0000000000..c654b375d0 --- /dev/null +++ b/lmdeploy/vl/model/qwen3_omni.py @@ -0,0 +1,46 @@ +# Copyright (c) OpenMMLab. All rights reserved. +from transformers import AutoProcessor + +from lmdeploy.vl.model.base import VISION_MODELS, MultimodalSpecialTokens, VisionModel + + +def check_transformers(): + try: + from transformers import Qwen3OmniMoeForConditionalGeneration # noqa: F401 + except ImportError: + raise ImportError('please install latest transformers by ' + 'pip install git+https://github.com/huggingface/transformers.git') + + +@VISION_MODELS.register_module() +class Qwen3OmniModel(VisionModel): + """Qwen3Omni model.""" + + _arch = ['Qwen3OmniMoeForConditionalGeneration'] + + def build_preprocessor(self, trust_remote_code: bool = False): + check_transformers() + self.processor = AutoProcessor.from_pretrained(self.model_path, trust_remote_code=trust_remote_code) + tokenizer = self.processor.tokenizer + + # image tokens + self.image_token = self.processor.image_token + self.image_token_id = tokenizer.encode(self.image_token)[-1] + + # video tokens + self.video_token = self.processor.video_token + self.video_token_id = tokenizer.encode(self.video_token)[-1] + + # audio tokens + self.audio_token = self.processor.audio_token + self.audio_token_id = tokenizer.encode(self.audio_token)[-1] + + # special tokens + self.mm_tokens = MultimodalSpecialTokens( + image_token=self.image_token, + video_token=self.video_token, + audio_token=self.audio_token, + image_token_id=self.image_token_id, + video_token_id=self.video_token_id, + audio_token_id=self.audio_token_id, + ) diff --git a/tests/test_lmdeploy/test_content_merge.py b/tests/test_lmdeploy/test_content_merge.py index e49de5c39b..d494a8fcbd 100644 --- a/tests/test_lmdeploy/test_content_merge.py +++ b/tests/test_lmdeploy/test_content_merge.py @@ -237,6 +237,11 @@ def __init__(self, image_io=None, **kwargs): self.image_io = image_io self.kwargs = kwargs + class FakeAudioMediaIO: + + def __init__(self, **kwargs): + self.kwargs = kwargs + def fake_load_from_url(data_src, media_io): load_calls.append((data_src, type(media_io).__name__)) if isinstance(media_io, FakeVideoMediaIO): @@ -244,6 +249,7 @@ def fake_load_from_url(data_src, media_io): return f'loaded:{data_src}' monkeypatch.setattr(multimodal_module, 'VideoMediaIO', FakeVideoMediaIO) + monkeypatch.setattr(multimodal_module, 'AudioMediaIO', FakeAudioMediaIO) monkeypatch.setattr(multimodal_module, 'load_from_url', fake_load_from_url) messages = [{ @@ -283,6 +289,11 @@ def fake_load_from_url(data_src, media_io): 'video': 'file:///tmp/a.mp4', 'fps': 1 }, + { + 'type': 'audio', + 'audio': 'file:///tmp/a.wav', + 'sample_rate': 16000 + }, { 'type': 'time_series', 'time_series': { @@ -311,11 +322,13 @@ def fake_load_from_url(data_src, media_io): 'duration': 2 } } - assert content[6] == {'type': Modality.TIME_SERIES, 'data': 'loaded:file:///tmp/a.npy', 'sr': 16000} + assert content[6] == {'type': Modality.AUDIO, 'data': 'loaded:file:///tmp/a.wav', 'sample_rate': 16000} + assert content[7] == {'type': Modality.TIME_SERIES, 'data': 'loaded:file:///tmp/a.npy', 'sr': 16000} assert load_calls == [ ('http://example.com/a.png', 'ImageMediaIO'), ('/tmp/b.png', 'ImageMediaIO'), ('file:///tmp/a.mp4', 'FakeVideoMediaIO'), + ('file:///tmp/a.wav', 'FakeAudioMediaIO'), ('file:///tmp/a.npy', 'TimeSeriesMediaIO'), ] @@ -332,9 +345,9 @@ def test_async_parse_multimodal_item_rejects_missing_payload(item): def test_async_parse_multimodal_item_rejects_unknown_type(): """Test unknown multimodal item types still fail explicitly.""" - messages = [{'role': 'user', 'content': [{'type': 'audio', 'audio': 'file:///tmp/a.wav'}]}] + messages = [{'role': 'user', 'content': [{'type': 'unknown_media', 'unknown_media': 'file:///tmp/a.bin'}]}] - with pytest.raises(NotImplementedError, match='unknown type: audio'): + with pytest.raises(NotImplementedError, match='unknown type: unknown_media'): asyncio.run(MultimodalProcessor.async_parse_multimodal_item(messages)) @@ -342,7 +355,10 @@ def test_has_multimodal_input_detects_all_supported_types(): """Test multimodal detection includes every supported item type.""" processor = MultimodalProcessor(tokenizer=None, chat_template=None) - for item_type in ['image_url', 'image_data', 'image', 'video_url', 'video', 'time_series_url', 'time_series']: + for item_type in [ + 'image_url', 'image_data', 'image', 'video_url', 'video', 'audio_url', 'audio', 'time_series_url', + 'time_series' + ]: assert processor._has_multimodal_input([{'role': 'user', 'content': [{'type': item_type}]}]) assert not processor._has_multimodal_input([{'role': 'user', 'content': [{'type': 'text', 'text': 'hello'}]}]) diff --git a/tests/test_lmdeploy/test_vl/test_qwen3_omni_processor.py b/tests/test_lmdeploy/test_vl/test_qwen3_omni_processor.py new file mode 100644 index 0000000000..03453c5820 --- /dev/null +++ b/tests/test_lmdeploy/test_vl/test_qwen3_omni_processor.py @@ -0,0 +1,282 @@ +from types import SimpleNamespace + +import torch + +from lmdeploy.pytorch.models.qwen3_omni_moe_thinker import Qwen3OmniInputProcessor +from lmdeploy.pytorch.models.utils.model import DeployModelMixinV1 +from lmdeploy.pytorch.multimodal.data_type import MultiModalData +from lmdeploy.vl.constants import Modality +from lmdeploy.vl.model.base import MultimodalSpecialTokens +from lmdeploy.vl.model.preprocess_utils import get_expanded_mm_items +from lmdeploy.vl.model.qwen3_omni import Qwen3OmniModel + + +class FakeQwen3OmniProcessor: + + image_token = '' + audio_token = '