From 4c6bc990acc2c1dd55525cd5263ab5466f4bbca2 Mon Sep 17 00:00:00 2001 From: zxy Date: Thu, 19 Mar 2026 15:13:16 +0800 Subject: [PATCH 01/17] support qwen3-omni --- lmdeploy/archs.py | 6 +- lmdeploy/model.py | 13 +- lmdeploy/pytorch/configurations/qwen3_omni.py | 18 + lmdeploy/pytorch/models/module_map.py | 7 + .../pytorch/models/qwen3_omni_moe_thinker.py | 1083 +++++++++++++++++ lmdeploy/pytorch/models/qwen3_vl.py | 4 +- lmdeploy/pytorch/multimodal/data_type.py | 1 + lmdeploy/serve/processors/multimodal.py | 7 +- lmdeploy/utils.py | 4 + lmdeploy/vl/media/audio.py | 60 + lmdeploy/vl/model/base.py | 44 +- lmdeploy/vl/model/builder.py | 1 + lmdeploy/vl/model/qwen3.py | 4 +- lmdeploy/vl/model/qwen3_omni.py | 220 ++++ 14 files changed, 1437 insertions(+), 35 deletions(-) create mode 100644 lmdeploy/pytorch/configurations/qwen3_omni.py create mode 100644 lmdeploy/pytorch/models/qwen3_omni_moe_thinker.py create mode 100644 lmdeploy/vl/media/audio.py create mode 100644 lmdeploy/vl/model/qwen3_omni.py diff --git a/lmdeploy/archs.py b/lmdeploy/archs.py index a4fe0d2333..9db3a73e07 100644 --- a/lmdeploy/archs.py +++ b/lmdeploy/archs.py @@ -110,9 +110,9 @@ def check_vl_llm(config: dict) -> bool: 'InternVLChatModel', 'MiniCPMV', 'LlavaForConditionalGeneration', 'LlavaNextForConditionalGeneration', 'Phi3VForCausalLM', 'Qwen2VLForConditionalGeneration', 'Qwen2_5_VLForConditionalGeneration', 'Qwen3VLForConditionalGeneration', 'Qwen3VLMoeForConditionalGeneration', 'Qwen3_5ForConditionalGeneration', - 'Qwen3_5MoeForConditionalGeneration', 'MllamaForConditionalGeneration', 'MolmoForCausalLM', - 'Gemma3ForConditionalGeneration', 'Llama4ForConditionalGeneration', 'InternVLForConditionalGeneration', - 'InternS1ForConditionalGeneration', 'InternS1ProForConditionalGeneration', + 'Qwen3_5MoeForConditionalGeneration', 'Qwen3OmniMoeForConditionalGeneration', 'MllamaForConditionalGeneration', + 'MolmoForCausalLM', 'Gemma3ForConditionalGeneration', 'Llama4ForConditionalGeneration', + 'InternVLForConditionalGeneration', 'InternS1ForConditionalGeneration', 'InternS1ProForConditionalGeneration', 'InternS1_1_ForConditionalGeneration', 'Glm4vForConditionalGeneration' ]) if arch == 'QWenLMHeadModel' and 'visual' in config: diff --git a/lmdeploy/model.py b/lmdeploy/model.py index 981e4b80b1..86506ee619 100644 --- a/lmdeploy/model.py +++ b/lmdeploy/model.py @@ -685,8 +685,19 @@ class HFChatTemplate(BaseChatTemplate): def __init__(self, model_path: str = '', **kwargs): self.model_path = model_path try: - from transformers import AutoTokenizer + from transformers import AutoProcessor, AutoTokenizer self.tokenizer = AutoTokenizer.from_pretrained(model_path, trust_remote_code=True) + + # Some tokenizers do not have chat_template, in this case try to get chat_template from processor + # If this still does not work, fallback to BaseChatTemplate. + if getattr(self.tokenizer, 'chat_template', None) is None: + try: + processor = AutoProcessor.from_pretrained(model_path, trust_remote_code=True) + self.tokenizer.chat_template = getattr(processor, 'chat_template', None) + except Exception as e: + logger.warning(f'Failed to load processor from {model_path} for chat template. ' + f'Fallback to tokenizer only. Error: {e}') + # Verify if the model can perform apply_chat_template with different roles. self.user_start, self.user_end, _, _ = self._user_instruction() self.assistant_start, self.assistant_end, _, _ = self._assistant_instruction() diff --git a/lmdeploy/pytorch/configurations/qwen3_omni.py b/lmdeploy/pytorch/configurations/qwen3_omni.py new file mode 100644 index 0000000000..6001f374fc --- /dev/null +++ b/lmdeploy/pytorch/configurations/qwen3_omni.py @@ -0,0 +1,18 @@ +# Copyright (c) OpenMMLab. All rights reserved. +from .builder import AutoModelConfigBuilder +from .default import DefaultModelConfigBuilder + + +class Qwen3OmniModelConfigBuilder(AutoModelConfigBuilder): + + @classmethod + def condition(cls, hf_config): + """config.""" + return hf_config.model_type == 'qwen3_omni_moe' + + @classmethod + def build(cls, hf_config, model_path: str = None, **kwargs): + """build.""" + cfg = DefaultModelConfigBuilder.build(hf_config.thinker_config.text_config, model_path, **kwargs) + cfg.hf_config = hf_config + return cfg diff --git a/lmdeploy/pytorch/models/module_map.py b/lmdeploy/pytorch/models/module_map.py index 514b4cf842..9b10e4500d 100644 --- a/lmdeploy/pytorch/models/module_map.py +++ b/lmdeploy/pytorch/models/module_map.py @@ -186,6 +186,13 @@ f'{LMDEPLOY_PYTORCH_MODEL_PATH}.qwen3_5_moe.Qwen3_5MoeForConditionalGeneration', }) +# qwen3 omni moe thinker +# only support thinker module, so map to Qwen3OmniMoeThinkerForConditionalGeneration +MODULE_MAP.update({ + 'Qwen3OmniMoeForConditionalGeneration': + f'{LMDEPLOY_PYTORCH_MODEL_PATH}.qwen3_omni_moe_thinker.Qwen3OmniMoeThinkerForConditionalGeneration', +}) + # starcoder2 MODULE_MAP.update({ 'Starcoder2ForCausalLM': f'{LMDEPLOY_PYTORCH_MODEL_PATH}.starcoder2.Starcoder2ForCausalLM', diff --git a/lmdeploy/pytorch/models/qwen3_omni_moe_thinker.py b/lmdeploy/pytorch/models/qwen3_omni_moe_thinker.py new file mode 100644 index 0000000000..09eb0782bc --- /dev/null +++ b/lmdeploy/pytorch/models/qwen3_omni_moe_thinker.py @@ -0,0 +1,1083 @@ +# Copyright (c) OpenMMLab. All rights reserved. + +import math +from functools import lru_cache +from typing import Any, Dict, Iterable, List, Tuple + +import numpy as np +import torch +from torch import nn +from torch.nn import functional as F +from transformers.activations import ACT2FN +from transformers.configuration_utils import PretrainedConfig + +from lmdeploy.pytorch.engine.input_process import BaseModelInputProcessor, PreprocessInputResult +from lmdeploy.pytorch.model_inputs import StepContext, StepContextManager +from lmdeploy.pytorch.multimodal.data_type import MultiModalData +from lmdeploy.pytorch.nn import ApplyRotaryEmb, FlashAttention, LayerNorm +from lmdeploy.pytorch.nn.linear import build_colwise_linear, build_qkv_proj, build_rowwise_linear +from lmdeploy.pytorch.weight_loader.model_weight_loader import load_weight +from lmdeploy.vl.constants import Modality + +from .qwen3_vl import Qwen3VLVisionBlock, Qwen3VLVisionPatchEmbed, Qwen3VLVisionRotaryEmbedding +from .qwen3_vl_moe import Qwen3VLMoeTextModel +from .utils.cudagraph import CudaGraphMeta, CudaGraphMixin +from .utils.model import DeployModelMixin, vlm_model + + +def _get_feat_extract_output_lengths(input_lengths): + """Computes the output length of the convolutional layers and the output + length of the audio encoder.""" + + input_lengths_leave = input_lengths % 100 + feat_lengths = (input_lengths_leave - 1) // 2 + 1 + output_lengths = ((feat_lengths - 1) // 2 + 1 - 1) // 2 + 1 + (input_lengths // 100) * 13 + return output_lengths + + +class Qwen3OmniMoeAudioAttention(nn.Module): + """Vision attention.""" + + def __init__(self, config: PretrainedConfig, dtype: torch.dtype = None, device: torch.device = None): + super().__init__() + quantization_config = getattr(config, 'quantization_config', None) + dim = config.d_model + num_heads = config.encoder_attention_heads + head_dim = dim // num_heads + self.head_dim = head_dim + + # packed qkv + self.qkv_proj = build_qkv_proj( + dim, + num_q_heads=num_heads, + num_kv_heads=num_heads, + head_size=head_dim, + bias=True, + quant_config=quantization_config, + dtype=dtype, + device=device, + ) + + # rotary embedding + self.apply_rotary_pos_emb = ApplyRotaryEmb() + + # attention + self.attention = FlashAttention( + num_heads, + head_dim, + causal=False, + ) + + # o_proj + self.out_proj = build_rowwise_linear(dim, + dim, + bias=True, + quant_config=quantization_config, + dtype=dtype, + device=device, + is_tp=True) + + def forward( + self, + hidden_states: torch.Tensor, + cu_seqlens: torch.Tensor, + ) -> torch.Tensor: + seq_length = hidden_states.shape[0] + + # qkv proj + qkv_states = self.qkv_proj(hidden_states) + # (-1, heads, head_dim) + qkv_states = qkv_states.flatten(0, -2) + q, k, v = self.qkv_proj.split_qkv(qkv_states) + + attn_output = self.attention( + q, + k, + v, + q_start_loc=cu_seqlens[:-1], + q_seqlens=cu_seqlens[1:] - cu_seqlens[:-1], + ) + + attn_output = attn_output.reshape(seq_length, -1) + + # o proj + attn_output = self.out_proj(attn_output) + return attn_output + + +class Qwen3OmniMoeAudioEncoderLayer(nn.Module): + """Qwen3OmniMoeAudioEncoderLayer.""" + + def __init__(self, config, dtype: torch.dtype = None, device: torch.device = None) -> None: + super().__init__() + self.embed_dim = config.d_model + self.self_attn = Qwen3OmniMoeAudioAttention(config, dtype=dtype, device=device) + self.self_attn_layer_norm = LayerNorm(self.embed_dim, eps=1e-5, dtype=dtype, device=device) + + self.activation_fn = ACT2FN[config.activation_function] + self.fc1 = build_colwise_linear( + self.embed_dim, + config.encoder_ffn_dim, + bias=True, + dtype=dtype, + device=device, + ) + self.fc2 = build_rowwise_linear( + config.encoder_ffn_dim, + self.embed_dim, + bias=True, + dtype=dtype, + device=device, + ) + self.final_layer_norm = LayerNorm(self.embed_dim, eps=1e-5, dtype=dtype, device=device) + + def forward( + self, + hidden_states: torch.Tensor, + cu_seqlens: torch.Tensor, + ): + residual = hidden_states + hidden_states = self.self_attn_layer_norm(hidden_states) + hidden_states = self.self_attn( + hidden_states=hidden_states, + cu_seqlens=cu_seqlens, + ) + hidden_states = residual + hidden_states + residual = hidden_states + hidden_states = self.final_layer_norm(hidden_states) + hidden_states = self.fc1(hidden_states) + hidden_states = self.activation_fn(hidden_states) + hidden_states = self.fc2(hidden_states) + hidden_states = residual + hidden_states + + if hidden_states.dtype == torch.float16: + clamp_value = torch.finfo(hidden_states.dtype).max - 1000 + hidden_states = torch.clamp(hidden_states, min=-clamp_value, max=clamp_value) + + outputs = (hidden_states, ) + + return outputs + + +class SinusoidsPositionEmbedding(nn.Module): + + def __init__(self, length, channels, max_timescale=10000): + super().__init__() + if channels % 2 != 0: + raise ValueError('SinusoidsPositionEmbedding needs even channels input') + log_timescale_increment = np.log(max_timescale) / (channels // 2 - 1) + inv_timescales = torch.exp(-log_timescale_increment * torch.arange(channels // 2).float()) + scaled_time = torch.arange(length)[:, np.newaxis] * inv_timescales[np.newaxis, :] + self.register_buffer( + 'positional_embedding', + torch.cat([torch.sin(scaled_time), torch.cos(scaled_time)], dim=1), + persistent=False, + ) + + def forward(self, seqlen: int): + return self.positional_embedding[:seqlen, :] + + +class Qwen3OmniMoeAudioEncoder(nn.Module): + """Qwen3OmniMoeAudioEncoder.""" + + def __init__(self, config, dtype: torch.dtype = None, device: torch.device = None) -> None: + super().__init__() + + embed_dim = config.d_model + self.num_mel_bins = config.num_mel_bins + self.max_source_positions = config.max_source_positions + self.embed_scale = math.sqrt(embed_dim) if config.scale_embedding else 1.0 + self.n_window = config.n_window + self.positional_embedding = SinusoidsPositionEmbedding(self.max_source_positions, embed_dim) + self.layers = nn.ModuleList( + [Qwen3OmniMoeAudioEncoderLayer(config, dtype=dtype, device=device) for _ in range(config.encoder_layers)]) + self.ln_post = LayerNorm(config.d_model, eps=1e-5, dtype=dtype, device=device) + self.conv2d1 = nn.Conv2d(1, config.downsample_hidden_size, 3, 2, padding=1, dtype=dtype, device=device) + self.conv2d2 = nn.Conv2d(config.downsample_hidden_size, + config.downsample_hidden_size, + 3, + 2, + padding=1, + dtype=dtype, + device=device) + self.conv2d3 = nn.Conv2d(config.downsample_hidden_size, + config.downsample_hidden_size, + 3, + 2, + padding=1, + dtype=dtype, + device=device) + conv_out_dim = config.downsample_hidden_size * ((((config.num_mel_bins + 1) // 2 + 1) // 2 + 1) // 2) + self.conv_out = nn.Linear( + conv_out_dim, + config.d_model, + bias=False, + dtype=dtype, + device=device, + ) + self.proj1 = nn.Linear(config.d_model, config.d_model, dtype=dtype, device=device) + self.act = ACT2FN[config.activation_function] + self.proj2 = nn.Linear(config.d_model, config.output_dim, dtype=dtype, device=device) + self.n_window_infer = config.n_window_infer + self.conv_chunksize = config.conv_chunksize + + def forward( + self, + input_features: torch.Tensor, + feature_lens: torch.Tensor, + aftercnn_lens=None, + ): + r"""feature_lens (`torch.LongTensor` of shape `(batch_size,)`): + + mel length + aftercnn_lens (`torch.LongTensor` of shape `(batch_size,)`): + mel length after cnn + """ + aftercnn_lens = _get_feat_extract_output_lengths(feature_lens) + chunk_num = torch.ceil(feature_lens / (self.n_window * 2)).long() + + chunk_lengths = torch.tensor( + [self.n_window * 2] * chunk_num.sum(), + dtype=torch.long, + device=feature_lens.device, + ) + tail_chunk_index = F.pad(chunk_num, (1, 0), value=-1).cumsum(0)[1:] + chunk_lengths[tail_chunk_index] = feature_lens % (self.n_window * 2) + chunk_lengths[chunk_lengths == 0] = self.n_window * 2 + + chunk_list = input_features.T.split(chunk_lengths.tolist(), dim=0) + padded_feature = nn.utils.rnn.pad_sequence(chunk_list, batch_first=True).transpose(1, 2) + feature_lens_after_cnn = _get_feat_extract_output_lengths(chunk_lengths) + padded_mask_after_cnn = nn.utils.rnn.pad_sequence( + [torch.ones(length, dtype=torch.bool, device=padded_feature.device) for length in feature_lens_after_cnn], + batch_first=True, + ) + padded_feature = padded_feature.unsqueeze(1) + # Split to chunk to avoid OOM during convolution + padded_embeds = [] + for chunk in padded_feature.split(self.conv_chunksize, dim=0): + padded_embed = F.gelu(self.conv2d1(chunk)) + padded_embed = F.gelu(self.conv2d2(padded_embed)) + padded_embed = F.gelu(self.conv2d3(padded_embed)) + padded_embeds.append(padded_embed) + padded_embed = torch.cat(padded_embeds, dim=0) + b, c, f, t = padded_embed.size() + padded_embed = self.conv_out(padded_embed.permute(0, 3, 1, 2).contiguous().view(b, t, c * f)) + + positional_embedding = ( + self.positional_embedding.positional_embedding[:padded_embed.shape[1], :].unsqueeze(0).to( + padded_embed.dtype)) + padded_embed = padded_embed + positional_embedding + hidden_states = padded_embed[padded_mask_after_cnn] + cu_chunk_lens = [0] + window_aftercnn = padded_mask_after_cnn.shape[-1] * (self.n_window_infer // (self.n_window * 2)) + for cnn_len in aftercnn_lens: + cu_chunk_lens += [window_aftercnn] * (cnn_len // window_aftercnn) + remainder = cnn_len % window_aftercnn + if remainder != 0: + cu_chunk_lens += [remainder] + cu_seqlens = torch.tensor(cu_chunk_lens, device=aftercnn_lens.device).cumsum(-1, dtype=torch.int32) + + for encoder_layer in self.layers: + layer_outputs = encoder_layer( + hidden_states, + cu_seqlens, + ) + + hidden_states = layer_outputs[0] + + hidden_states = self.ln_post(hidden_states) + hidden_states = self.proj1(hidden_states) + hidden_states = self.act(hidden_states) + hidden_states = self.proj2(hidden_states) + return hidden_states + + +class Qwen3OmniMoeVisionPatchMerger(nn.Module): + """Vision patch merger. + + Different namings with qwen3vl, but actual calculations are the same. + """ + + def __init__(self, + config: PretrainedConfig, + use_postshuffle_norm=False, + dtype: torch.dtype = None, + device: torch.device = None) -> None: + super().__init__() + self.hidden_size = config.hidden_size * (config.spatial_merge_size**2) + self.use_postshuffle_norm = use_postshuffle_norm + self.ln_q = LayerNorm(self.hidden_size if use_postshuffle_norm else config.hidden_size, + eps=1e-6, + dtype=dtype, + device=device) + self.mlp = nn.ModuleList([ + build_colwise_linear( + self.hidden_size, + self.hidden_size, + bias=True, + dtype=dtype, + device=device, + is_tp=True, + ), + nn.GELU(), + build_rowwise_linear( + self.hidden_size, + config.out_hidden_size, + bias=True, + dtype=dtype, + device=device, + is_tp=True, + ), + ]) + + def forward(self, x: torch.Tensor) -> torch.Tensor: + x = self.ln_q(x.view(-1, self.hidden_size) if self.use_postshuffle_norm else x).view(-1, self.hidden_size) + for layer in self.mlp: + x = layer(x) + return x + + +@vlm_model +class Qwen3OmniMoeVisionEncoder(nn.Module): + """Vision transformer.""" + + def __init__(self, config: PretrainedConfig, dtype: torch.dtype = None, device: torch.device = None): + super().__init__() + self.config = config + self.spatial_merge_size = config.spatial_merge_size + + self.patch_embed = Qwen3VLVisionPatchEmbed(config=config, dtype=dtype, device=device) + + self.pos_embed = nn.Embedding(config.num_position_embeddings, config.hidden_size, dtype=dtype, device=device) + self.num_grid_per_side = int(config.num_position_embeddings**0.5) + + head_dim = config.hidden_size // config.num_heads + self.rotary_pos_emb = Qwen3VLVisionRotaryEmbedding(head_dim // 2, device=device) + + self.blocks = nn.ModuleList( + [Qwen3VLVisionBlock(config, layer_idx, dtype=dtype, device=device) for layer_idx in range(config.depth)]) + self.merger = Qwen3OmniMoeVisionPatchMerger(config=config, + use_postshuffle_norm=False, + dtype=dtype, + device=device) + + self.deepstack_visual_indexes = config.deepstack_visual_indexes + self.merger_list = nn.ModuleList([ + Qwen3OmniMoeVisionPatchMerger(config=config, use_postshuffle_norm=True, dtype=dtype, device=device) + for _ in range(len(config.deepstack_visual_indexes)) + ]) + + @staticmethod + @lru_cache(maxsize=1024) + def rot_pos_ids(h: int, w: int, spatial_merge_size: int) -> torch.Tensor: + h_div = h // spatial_merge_size + w_div = w // spatial_merge_size + + hpos_ids = np.broadcast_to(np.arange(h).reshape(h, 1), (h, w)) + hpos_ids = hpos_ids.reshape( + h_div, + spatial_merge_size, + w_div, + spatial_merge_size, + ) + hpos_ids = hpos_ids.transpose(0, 2, 1, 3) + hpos_ids = hpos_ids.flatten() + + wpos_ids = np.broadcast_to(np.arange(w).reshape(1, w), (h, w)) + wpos_ids = wpos_ids.reshape( + h_div, + spatial_merge_size, + w_div, + spatial_merge_size, + ) + wpos_ids = wpos_ids.transpose(0, 2, 1, 3) + wpos_ids = wpos_ids.flatten() + + return torch.from_numpy(np.stack([hpos_ids, wpos_ids], axis=-1)) + + def rot_pos_emb(self, grid_thw: torch.Tensor) -> torch.Tensor: + """Rotary position embedding.""" + pos_ids = [] + + for t, h, w in grid_thw: + base = self.rot_pos_ids(int(h), int(w), self.spatial_merge_size) + pos_ids.append(base if t == 1 else base.repeat(t, 1)) + + pos_ids = torch.cat(pos_ids, dim=0) + max_grid_size = grid_thw[:, 1:].max() + rotary_pos_emb_full = self.rotary_pos_emb(max_grid_size) + rotary_pos_emb = rotary_pos_emb_full[pos_ids].flatten(1) + + return rotary_pos_emb + + # copy from https://github.com/vllm-project/vllm/blob/main/vllm/model_executor/models/qwen3_vl.py#L474 + def fast_pos_embed_interpolate(self, grid_thw: List[List[int]]) -> torch.Tensor: + num_grid_per_side = self.num_grid_per_side + m_size = self.spatial_merge_size + hidden_dim = self.pos_embed.embedding_dim + device = self.pos_embed.weight.device + + outputs = [] + for t, h, w in grid_thw: + h_idxs = torch.linspace(0, num_grid_per_side - 1, h, dtype=torch.float32, device=device) + w_idxs = torch.linspace(0, num_grid_per_side - 1, w, dtype=torch.float32, device=device) + + h_floor = h_idxs.to(torch.long) + w_floor = w_idxs.to(torch.long) + h_ceil = torch.clamp(h_floor + 1, max=num_grid_per_side - 1) + w_ceil = torch.clamp(w_floor + 1, max=num_grid_per_side - 1) + + dh = h_idxs - h_floor + dw = w_idxs - w_floor + + # Create meshgrid view for all h, w vars + dh_grid, dw_grid = torch.meshgrid(dh, dw, indexing='ij') + h_floor_grid, w_floor_grid = torch.meshgrid(h_floor, w_floor, indexing='ij') + h_ceil_grid, w_ceil_grid = torch.meshgrid(h_ceil, w_ceil, indexing='ij') + + # original computation of weights + # w00 = (1 - dh_grid) * (1 - dw_grid) + # w01 = (1 - dh_grid) * dw_grid + # w10 = dh_grid * (1 - dw_grid) + # w11 = dh_grid * dw_grid + # we reuse w11 here to avoid duplicate + # dh_grid * dw_grid computation + w11 = dh_grid * dw_grid + w10 = dh_grid - w11 + w01 = dw_grid - w11 + w00 = 1 - dh_grid - w01 + + h_grid = torch.stack([h_floor_grid, h_floor_grid, h_ceil_grid, h_ceil_grid]) + w_grid = torch.stack([w_floor_grid, w_ceil_grid, w_floor_grid, w_ceil_grid]) + h_grid_idx = h_grid * num_grid_per_side + + indices = (h_grid_idx + w_grid).reshape(4, -1) + weights = torch.stack([w00, w01, w10, w11], dim=0).reshape(4, -1, 1) + weights = weights.to(dtype=self.pos_embed.weight.dtype, device=device) + + embeds = self.pos_embed(indices) + embeds *= weights + combined = embeds.sum(dim=0) + + combined = combined.reshape(h // m_size, m_size, w // m_size, m_size, hidden_dim) + combined = combined.permute(0, 2, 1, 3, 4).reshape(1, -1, hidden_dim) + repeated = combined.expand(t, -1, -1).reshape(-1, hidden_dim) + outputs.append(repeated) + + return torch.cat(outputs, dim=0) + + def forward(self, hidden_states: torch.Tensor, cu_seqlens: torch.Tensor, rotary_pos_emb: torch.Tensor, + pos_embeds: torch.Tensor) -> torch.Tensor: + """forward.""" + hidden_states = self.patch_embed(hidden_states) + hidden_states = hidden_states + pos_embeds + cu_seqlens = torch.nn.functional.pad(cu_seqlens, (1, 0), value=0) + + deepstack_feature_lists = [] + for layer_num, blk in enumerate(self.blocks): + hidden_states = blk(hidden_states, cu_seqlens=cu_seqlens, rotary_pos_emb=rotary_pos_emb) + if layer_num in self.deepstack_visual_indexes: + deepstack_merge_idx = self.deepstack_visual_indexes.index(layer_num) + deepstack_feature = self.merger_list[deepstack_merge_idx](hidden_states) + deepstack_feature_lists.append(deepstack_feature) + + hidden_states = self.merger(hidden_states) + + return hidden_states, deepstack_feature_lists + + +class Qwen3OmniMoeThinkerForConditionalGeneration(nn.Module, DeployModelMixin, CudaGraphMixin): + """ModelForCausalLM.""" + + packed_modules_mapping = { + 'qkv_proj': [ + 'q_proj', + 'k_proj', + 'v_proj', + ], + 'gate_up_proj': [ + 'gate_proj', + 'up_proj', + ], + } + + def __init__(self, + config: PretrainedConfig, + ctx_mgr: StepContextManager, + dtype: torch.dtype = None, + device: torch.device = None): + super().__init__() + self.config = config + self.ctx_mgr = ctx_mgr + thinker_config = config.thinker_config + + # build preprocessor + self.input_processor = Qwen3OmniInputProcessor(self.config) + + # build audio encoder + self.audio_tower = Qwen3OmniMoeAudioEncoder( + thinker_config.audio_config, + dtype=dtype, + device=device, + ) + + # build vision encoder + self.visual = Qwen3OmniMoeVisionEncoder( + thinker_config.vision_config, + dtype=dtype, + device=device, + ) + + # build text model + self.language_model = Qwen3VLMoeTextModel(thinker_config.text_config, dtype=dtype, device=device) + + # build lm_head + self.lm_head = build_rowwise_linear(thinker_config.text_config.hidden_size, + thinker_config.text_config.vocab_size, + bias=False, + dtype=dtype, + device=device) + + def forward( + self, + input_ids: torch.Tensor, + position_ids: torch.Tensor, + past_key_values: List[List[torch.Tensor]], + attn_metadata: Any = None, + inputs_embeds: torch.Tensor = None, + mrope_position_ids: torch.Tensor = None, + pixel_values: torch.Tensor = None, + vis_cu_seqlens: torch.Tensor = None, + vis_pos_emb: torch.Tensor = None, + image_mask: torch.Tensor = None, + pos_embeds: torch.Tensor = None, + grid_thw: torch.Tensor = None, + audio_values: torch.Tensor = None, + audio_mask: torch.Tensor = None, + audio_feature_lengths: torch.Tensor = None, + **kwargs, + ): + """Model forward, return logits.""" + + visual_pos_masks = None + deepstack_visual_embeds = None + if inputs_embeds is None: + inputs_embeds = self.get_input_embeddings()(input_ids) + + if pixel_values is not None: + dtype = inputs_embeds.dtype + pixel_values = pixel_values.to(dtype) + vis_pos_emb = (vis_pos_emb[0].to(dtype), vis_pos_emb[1].to(dtype)) + + # get image embeds and deepstack visual embeds + image_embeds, deepstack_visual_embeds = self.visual(pixel_values, + cu_seqlens=vis_cu_seqlens, + rotary_pos_emb=vis_pos_emb, + pos_embeds=pos_embeds) + + # split image embeds per sample + split_sizes = (grid_thw.prod(-1) // self.visual.spatial_merge_size**2).tolist() + image_embeds = torch.split(image_embeds, split_sizes) + image_embeds = torch.cat(image_embeds, dim=0).to(inputs_embeds.device, dtype) + + # mask and scatter to create final input embeddings + expanded_image_mask = image_mask.unsqueeze(-1).expand_as(inputs_embeds) + inputs_embeds = inputs_embeds.masked_scatter(expanded_image_mask, image_embeds) + + visual_pos_masks = expanded_image_mask + + if audio_values is not None: + dtype = inputs_embeds.dtype + audio_values = audio_values.to(dtype) + audio_embeds = self.audio_tower( + input_features=audio_values, + feature_lens=audio_feature_lengths, + ) + inputs_embeds = inputs_embeds.masked_scatter(audio_mask.unsqueeze(-1), audio_embeds) + + hidden_states = self.language_model( + input_ids=input_ids, + position_ids=position_ids, + past_key_values=past_key_values, + attn_metadata=attn_metadata, + inputs_embeds=inputs_embeds, + mrope_position_ids=mrope_position_ids, + # args for deepstack + visual_pos_masks=visual_pos_masks, + deepstack_visual_embeds=deepstack_visual_embeds, + ) + return hidden_states + + def get_logits(self, hidden_states: torch.Tensor): + """Compute logits of the model output.""" + return self.lm_head(hidden_states) + + def get_input_embeddings(self): + """Get input embeddings.""" + return self.language_model.get_input_embeddings() + + def prepare_inputs_for_generation( + self, + past_key_values: List[List[torch.Tensor]], + inputs_embeds: torch.Tensor | None = None, + context: StepContext = None, + ): + """Prepare input.""" + + # get input_ids, position_ids and attention metadatas + input_ids = context.input_ids + position_ids = context.position_ids + attn_metadata = context.attn_metadata + + pixel_values = None + vis_cu_seqlens = None + vis_pos_emb = None + image_mask = None + grid_thw = None + pos_embeds = None + audio_values = None + audio_mask = None + audio_feature_lengths = None + if context.input_multimodals is not None: + mm_inputs = [input_mm.get('mm_data', []) for input_mm in context.input_multimodals] + # flatten batch + mm_inputs = [item for sublist in mm_inputs for item in sublist] + + if len(mm_inputs) > 0: + modality = mm_inputs[0].modality + + image_token_id = mm_inputs[0].meta.get('image_token_id') + video_token_id = mm_inputs[0].meta.get('video_token_id') + audio_token_id = mm_inputs[0].meta.get('audio_token_id') + + if modality == Modality.AUDIO: + audio_values = torch.cat([inp.data for inp in mm_inputs]) + # FIXME: zhouxinyu, batch ? + audio_values = audio_values.squeeze(0) + audio_mask = (input_ids == audio_token_id) + # FIXME: zhouxinyu, list ? + audio_feature_lengths = mm_inputs[0].meta['audio_feature_lengths'] + elif modality in [Modality.IMAGE, Modality.VIDEO]: + pixel_values = torch.cat([inp.data for inp in mm_inputs]) + + mm_token_id = image_token_id if modality == Modality.IMAGE else video_token_id + image_mask = (input_ids == mm_token_id) + + grid_thw = torch.cat([data.meta['grid_thw'] for data in mm_inputs]).cpu() + vis_pos_emb = self.visual.rot_pos_emb(grid_thw) + pos_embeds = self.visual.fast_pos_embed_interpolate(grid_thw) + vis_cu_seqlens = torch.repeat_interleave(grid_thw[:, 1] * grid_thw[:, 2], + grid_thw[:, 0]).to(pixel_values.device) + vis_cu_seqlens = vis_cu_seqlens.cumsum(dim=0, dtype=torch.int32) + vis_pos_emb = vis_pos_emb.repeat(1, 2) + vis_pos_emb = (vis_pos_emb.cos(), vis_pos_emb.sin()) + + mrope_position_ids = getattr(context, 'mrope_position_ids', None) + + # process vision embeddings + vision_embeddings = context.input_embeddings + vision_embedding_indexing = context.input_embedding_indexing + if vision_embeddings is not None and len(vision_embeddings) > 0: + if inputs_embeds is None: + inputs_embeds = self.get_input_embeddings()(input_ids) + inputs_embeds[:, vision_embedding_indexing, :] = vision_embeddings.to(inputs_embeds) + + # inputs of forward + return dict( + input_ids=input_ids, + position_ids=position_ids, + past_key_values=past_key_values, + attn_metadata=attn_metadata, + inputs_embeds=inputs_embeds, + mrope_position_ids=mrope_position_ids, + pixel_values=pixel_values, + vis_cu_seqlens=vis_cu_seqlens, + vis_pos_emb=vis_pos_emb, + image_mask=image_mask, + grid_thw=grid_thw, + pos_embeds=pos_embeds, + audio_values=audio_values, + audio_mask=audio_mask, + audio_feature_lengths=audio_feature_lengths, + ) + + def rename_weight(self, name: str) -> str: + """Rename weight.""" + if name.startswith('thinker.model.'): + return 'language_model.' + name[len('thinker.model.'):] + elif name.startswith('thinker.visual.'): + return 'visual.' + name[len('thinker.visual.'):] + elif name.startswith('thinker.audio_tower.'): + return 'audio_tower.' + name[len('thinker.audio_tower.'):] + # thinker_config.text_config tie_word_embeddings = False + elif name.startswith('thinker.lm_head.'): + return 'lm_head.' + name[len('thinker.lm_head.'):] + return name + + def _load_weight_experts(self, name: str, loaded_weight: torch.Tensor, params_dict: Dict[str, nn.Parameter], + expert_params_mapping: List): + """Load weight experts.""" + + for (param_name, weight_name, expert_id, shard_id) in expert_params_mapping: + if weight_name not in name: + continue + name = name.replace(weight_name, param_name) + param = params_dict[name] + load_weight(param, loaded_weight, expert_id=expert_id, shard_id=shard_id) + break + else: + param = params_dict[name] + load_weight(param, loaded_weight) + + # modify from vllm qwen3vlmoe fused expert loading + def _load_weight_fused_experts(self, name: str, loaded_weight: torch.Tensor, params_dict: Dict[str, nn.Parameter], + fused_expert_params_mapping: List): + """Load weight of fused expert weights.""" + num_experts = self.config.text_config.num_experts + + for (param_name, weight_name) in fused_expert_params_mapping: + if weight_name not in name: + continue + name = name.replace(weight_name, param_name) + param = params_dict[name] + + loaded_weight = loaded_weight.transpose(-1, -2) # no bias + if 'gate_up' in name: + loaded_weight = loaded_weight.chunk(2, dim=-2) + w1 = loaded_weight[0] + w3 = loaded_weight[1] + for expert_id in range(num_experts): + load_weight(param, w1[expert_id], expert_id=expert_id, shard_id='gate') + load_weight(param, w3[expert_id], expert_id=expert_id, shard_id='up') + elif 'down' in name: + w2 = loaded_weight + for expert_id in range(num_experts): + load_weight(param, w2[expert_id], expert_id=expert_id, shard_id='down') + + def load_weights(self, weights: Iterable[Tuple[str, torch.Tensor]]): + """Load weights.""" + # modify from vllm + stacked_params_mapping = [ + # (param_name, shard_name, shard_id) + ('.qkv_proj', '.q_proj', 'q'), + ('.qkv_proj', '.k_proj', 'k'), + ('.qkv_proj', '.v_proj', 'v'), + ('.gate_up_proj', '.gate_proj', 0), + ('.gate_up_proj', '.up_proj', 1), + ] + + # expert mapping + num_experts = self.config.thinker_config.text_config.num_experts + expert_params_mapping = [] + for exp_id in range(num_experts): + # (param_name, weight_name, expert_id, shard_id) + gate_param = ('.experts.gate_up', f'.experts.{exp_id}.gate_proj', exp_id, 'gate') + up_param = ('.experts.gate_up', f'.experts.{exp_id}.up_proj', exp_id, 'up') + down_param = ('.experts.down', f'.experts.{exp_id}.down_proj', exp_id, 'down') + expert_params_mapping += [gate_param, up_param, down_param] + + # fused expert mapping + fused_expert_params_mapping = [ + # (param_name, weight_name) + ('.experts.gate_up.weight', '.experts.gate_up_proj'), + ('.experts.down.weight', '.experts.down_proj'), + ] + + params_dict = dict(self.named_parameters()) + for name, loaded_weight in weights: + if 'rotary_emb.inv_freq' in name: + continue + if ('rotary_emb.cos_cached' in name or 'rotary_emb.sin_cached' in name): + continue + # skip talker and code2wav weights + if ('talker.' in name or 'code2wav.' in name): + continue + + name = name.replace('.block_sparse_moe.', '.mlp.') + if '.experts' in name: + is_fused_expert = ('experts.gate_up_proj' in name or 'experts.down_proj' in name) + if is_fused_expert: + self._load_weight_fused_experts(name, + loaded_weight, + params_dict, + fused_expert_params_mapping=fused_expert_params_mapping) + else: + self._load_weight_experts(name, + loaded_weight, + params_dict, + expert_params_mapping=expert_params_mapping) + else: + for (param_name, weight_name, shard_id) in stacked_params_mapping: + if weight_name not in name: + continue + name = name.replace(weight_name, param_name) + param = params_dict[name] + load_weight(param, loaded_weight, shard_id=shard_id) + break + else: + if '.qkv.' in name: + param = params_dict[name] + q, k, v = param.weight_spliter(loaded_weight) + load_weight(param, q, shard_id='q') + load_weight(param, k, shard_id='k') + load_weight(param, v, shard_id='v') + else: + param = params_dict[name] + load_weight(param, loaded_weight) + + def make_buffers_cudagraph(self, graph_meta: CudaGraphMeta, **kwargs): + """Make cudagraph buffers from forward inputs.""" + max_tokens = graph_meta.max_tokens + + input_buffers = super().make_buffers_cudagraph(graph_meta=graph_meta, **kwargs) + mrope_position_ids = kwargs.get('mrope_position_ids', None) + if mrope_position_ids is not None: + input_buffers['mrope_position_ids'] = mrope_position_ids.new_zeros(3, max_tokens) + + return input_buffers + + def fill_buffers_cudagraph(self, graph_meta: CudaGraphMeta, **kwargs): + """Fill cudagraph buffers from forward inputs.""" + + new_inputs = super().fill_buffers_cudagraph(graph_meta=graph_meta, **kwargs) + + input_ids = kwargs.get('input_ids') + num_tokens = input_ids.size(-1) + new_batch_size = graph_meta.max_batchs + + is_decoding = graph_meta.is_decoding + input_buffers = graph_meta.input_buffers + mrope_position_ids = kwargs.get('mrope_position_ids', None) + if mrope_position_ids is not None: + input_buffers['mrope_position_ids'][:, :num_tokens] = mrope_position_ids + if is_decoding: + new_inputs['mrope_position_ids'] = input_buffers['mrope_position_ids'][:, :new_batch_size] + else: + new_inputs['mrope_position_ids'] = input_buffers['mrope_position_ids'] + + return new_inputs + + def _get_model_metas(self, context: StepContext): + """Get model metas.""" + model_metas = context.model_metas + if model_metas is None: + batch_size = context.q_seqlens.numel() + return [dict(mrope_delta=0)] * batch_size + return [dict(mrope_delta=0) if meta is None else meta for meta in model_metas] + + def _update_model_meta_decoding(self, context: StepContext): + """Update model meta for decoding.""" + model_metas = self._get_model_metas(context) + position_ids = context.position_ids + + mrope_deltas = [meta['mrope_delta'] for meta in model_metas] + mrope_deltas = position_ids.new_tensor(mrope_deltas) + mrope_position_ids = position_ids + mrope_deltas[None] + mrope_position_ids = mrope_position_ids.expand(3, -1) + + context.mrope_position_ids = mrope_position_ids + return model_metas + + def _get_multimodal_pos_ids(self, grid_thw: list, device: torch.device): + """Get mrope ids.""" + t, h, w = grid_thw + h //= 2 + w //= 2 + stride = torch.tensor([h * w, w, 1], device=device)[:, None] + size = torch.tensor([t, h, w], device=device)[:, None] + pos_ids = torch.arange(t * h * w, device=device)[None].expand(3, -1) + pos_ids = pos_ids // stride % size + return pos_ids + + def _update_model_meta_prefilling(self, context: StepContext): + """Update model meta for prefilling.""" + model_metas = self._get_model_metas(context) + input_multimodals = context.input_multimodals + if input_multimodals is None: + input_multimodals = [None] * len(model_metas) + position_ids = context.position_ids + batched_pos_ids = position_ids[0].split(context.q_seqlens.tolist()) + mrope_position_ids = [] + new_model_metas = [] + for pos_ids, model_meta, input_mm in zip(batched_pos_ids, model_metas, input_multimodals): + mm_data_list = [] + if input_mm is not None: + mm_data_list.extend(input_mm.get('mm_data', [])) + + if model_meta is None or 'mrope_delta' not in model_meta: + mrope_delta = 0 + else: + mrope_delta = model_meta['mrope_delta'] + + pos_start = pos_ids[0].item() + mrope_pos_ids = pos_ids + mrope_delta + mrope_pos_ids = mrope_pos_ids[None].expand(3, -1).clone() + + for mm_data in mm_data_list: + if mm_data.modality == Modality.IMAGE: + grid_thw = mm_data.meta['grid_thw'][0].tolist() + _, h, w = grid_thw + h //= 2 + w //= 2 + num_pad = mm_data.end - mm_data.start - max(h, w) + mrope_delta -= num_pad + fill_start = mm_data.start - pos_start + fill_end = mm_data.end - pos_start + img_pos_ids = self._get_multimodal_pos_ids(grid_thw, pos_ids.device) + img_pos_ids += mrope_pos_ids[:, fill_start:fill_start + 1] + mrope_pos_ids[:, fill_end:] -= num_pad + mrope_pos_ids[:, fill_start:fill_end] = img_pos_ids + elif mm_data.modality == Modality.VIDEO: + second_per_grid = mm_data.meta.get('second_per_grid', 2.0) + position_id_per_seconds = self.config.thinker_config.position_id_per_seconds + + grid_thw = mm_data.meta['grid_thw'][0].tolist() + t, h, w = grid_thw + llm_h = h // 2 # spatial_merge_size = 2 + llm_w = w // 2 + + device = pos_ids.device + # Temporal indices as real timestamps (float, e.g. 0, 1.083, 2.167 for fps=24) + t_index = torch.arange(t, device=device).float() * (second_per_grid * position_id_per_seconds) + h_index = torch.arange(llm_h, device=device).float() + w_index = torch.arange(llm_w, device=device).float() + + # Build [3, T*llm_h*llm_w] pos ids + t_expanded = t_index.view(-1, 1).expand(-1, llm_h * llm_w).flatten() + h_expanded = h_index.view(1, -1, 1).expand(t, -1, llm_w).flatten() + w_expanded = w_index.view(1, 1, -1).expand(t, llm_h, -1).flatten() + video_pos_ids = torch.stack([t_expanded, h_expanded, w_expanded]) # [3, T*llm_h*llm_w] + + max_video_pos = max( + float((t - 1) * second_per_grid * position_id_per_seconds) if t > 1 else 0.0, + float(llm_h - 1), + float(llm_w - 1), + ) + video_num_tokens = t * llm_h * llm_w + num_pad = video_num_tokens - max_video_pos - 1 + mrope_delta -= num_pad + + fill_start = mm_data.start - pos_start + fill_end = mm_data.end - pos_start + + # Convert to float to hold non-integer temporal positions + mrope_pos_ids = mrope_pos_ids.float() + offset = mrope_pos_ids[0, fill_start].item() + mrope_pos_ids[:, fill_start:fill_end] = video_pos_ids + offset + mrope_pos_ids[:, fill_end:] -= num_pad + + mrope_position_ids.append(mrope_pos_ids) + new_model_metas.append(dict(mrope_delta=mrope_delta)) + + mrope_position_ids = torch.cat(mrope_position_ids, dim=1) + context.mrope_position_ids = mrope_position_ids + + return new_model_metas + + def update_model_metas(self, + past_key_values: List[List[torch.Tensor]], + inputs_embeds: torch.Tensor | None = None, + context: StepContext = None): + """Update model meta.""" + if context.is_decoding: + return self._update_model_meta_decoding(context) + else: + return self._update_model_meta_prefilling(context) + + def get_input_processor(self) -> BaseModelInputProcessor: + """Get input processor.""" + return self.input_processor + + +class Qwen3OmniInputProcessor(BaseModelInputProcessor): + """Qwen3 Omni input processor.""" + + def __init__(self, config: PretrainedConfig) -> None: + self.config = config + + def _make_image_mm_data(self, input_mm: Dict[str, Any]) -> MultiModalData: + """Make image MultiModalData.""" + pixel_values = input_mm['pixel_values'] + image_grid_thw = input_mm['image_grid_thw'] + offset = input_mm['offset'] + start = offset + image_token_id = input_mm['image_token_id'] + num_pad = input_mm['mm_token_num'] + if isinstance(num_pad, torch.Tensor): + num_pad = num_pad.item() + + mm_data = MultiModalData(modality=Modality.IMAGE, + data=pixel_values, + start=start, + end=start + num_pad, + meta=dict(grid_thw=image_grid_thw, image_token_id=image_token_id)) + return mm_data + + def _make_video_mm_data(self, input_mm: Dict[str, Any]) -> MultiModalData: + """Make video MultiModalData.""" + pixel_values_videos = input_mm['pixel_values_videos'] + video_grid_thw = input_mm['video_grid_thw'] + offset = input_mm['offset'] + start = offset + video_token_id = input_mm['video_token_id'] + num_pad = input_mm['mm_token_num'] + if isinstance(num_pad, torch.Tensor): + num_pad = num_pad.item() + + mm_data = MultiModalData(modality=Modality.VIDEO, + data=pixel_values_videos, + start=start, + end=start + num_pad, + meta=dict( + grid_thw=video_grid_thw, + video_token_id=video_token_id, + second_per_grid=input_mm.get('second_per_grid'), + )) + return mm_data + + def _make_audio_mm_data(self, input_mm: Dict[str, Any]) -> MultiModalData: + """Make audio MultiModalData.""" + input_features = input_mm['input_features'] + offset = input_mm['offset'] + start = offset + audio_token_id = input_mm['audio_token_id'] + num_pad = input_mm['mm_token_num'] + if isinstance(num_pad, torch.Tensor): + num_pad = num_pad.item() + + mm_data = MultiModalData(modality=Modality.AUDIO, + data=input_features, + start=start, + end=start + num_pad, + meta=dict( + audio_token_id=audio_token_id, + audio_feature_lengths=input_mm.get('audio_feature_lengths'), + )) + return mm_data + + def preprocess_input(self, + input_ids: List[int], + input_multimodals: List[Dict[str, Any]] = None, + **kwargs) -> PreprocessInputResult: + """Prepare multimodal input.""" + if input_multimodals is None or len(input_multimodals) == 0: + return input_ids, input_multimodals + + input_mm_data = [] + for input_mm in input_multimodals: + modality = input_mm.get('modality') + + if modality == Modality.IMAGE: + mm_data = self._make_image_mm_data(input_mm) + elif modality == Modality.VIDEO: + mm_data = self._make_video_mm_data(input_mm) + elif modality == Modality.AUDIO: + mm_data = self._make_audio_mm_data(input_mm) + + input_mm_data.append(mm_data) + + result = PreprocessInputResult(input_ids=input_ids, input_multimodals=dict(mm_data=input_mm_data)) + + return result diff --git a/lmdeploy/pytorch/models/qwen3_vl.py b/lmdeploy/pytorch/models/qwen3_vl.py index 0d0434a58f..6d98613b75 100644 --- a/lmdeploy/pytorch/models/qwen3_vl.py +++ b/lmdeploy/pytorch/models/qwen3_vl.py @@ -899,7 +899,7 @@ def _make_image_mm_data(self, input_mm: Dict[str, Any]) -> MultiModalData: offset = input_mm['offset'] start = offset image_token_id = input_mm['image_token_id'] - num_pad = input_mm['image_tokens'] + num_pad = input_mm['mm_token_num'] if isinstance(num_pad, torch.Tensor): num_pad = num_pad.item() @@ -917,7 +917,7 @@ def _make_video_mm_data(self, input_mm: Dict[str, Any]) -> MultiModalData: offset = input_mm['offset'] start = offset video_token_id = input_mm['video_token_id'] - num_pad = input_mm['video_tokens'] + num_pad = input_mm['mm_token_num'] if isinstance(num_pad, torch.Tensor): num_pad = num_pad.item() diff --git a/lmdeploy/pytorch/multimodal/data_type.py b/lmdeploy/pytorch/multimodal/data_type.py index 4ff71fdcbc..e5c555a717 100644 --- a/lmdeploy/pytorch/multimodal/data_type.py +++ b/lmdeploy/pytorch/multimodal/data_type.py @@ -15,6 +15,7 @@ class MultiModalData: start: int end: int = None meta: Dict[str, Any] = None + modality: Modality = Modality.IMAGE modality: Modality = Modality.IMAGE diff --git a/lmdeploy/serve/processors/multimodal.py b/lmdeploy/serve/processors/multimodal.py index 8632747936..eb2048f019 100644 --- a/lmdeploy/serve/processors/multimodal.py +++ b/lmdeploy/serve/processors/multimodal.py @@ -8,6 +8,7 @@ from lmdeploy.tokenizer import Tokenizer from lmdeploy.utils import get_logger from lmdeploy.vl.constants import Modality +from lmdeploy.vl.media.audio import AudioMediaIO from lmdeploy.vl.media.connection import load_from_url from lmdeploy.vl.media.image import ImageMediaIO from lmdeploy.vl.media.time_series import TimeSeriesMediaIO @@ -124,6 +125,10 @@ def _parse_multimodal_item(i: int, in_messages: List[Dict], out_messages: List[D vid_io = VideoMediaIO(image_io=ImageMediaIO(), **media_io_kwargs.get('video', {})) data, metadata = load_from_url(data_src, vid_io) item_params['video_metadata'] = metadata + elif item_type == 'audio_url': + modality = Modality.AUDIO + audio_io = AudioMediaIO(**media_io_kwargs.get('audio', {})) + data = load_from_url(data_src, audio_io) elif item_type == 'time_series_url': modality = Modality.TIME_SERIES ts_io = TimeSeriesMediaIO(**media_io_kwargs.get('time_series', {})) @@ -304,7 +309,7 @@ def _re_format_prompt_images_pair(prompt: Tuple) -> Dict: def _has_multimodal_input(self, messages: List[Dict]) -> bool: """Check if messages contain multimodal input (images).""" - multimodal_types = ['image_url', 'image_data', 'video_url', 'time_series_url'] + multimodal_types = ['image_url', 'image_data', 'video_url', 'audio_url', 'time_series_url'] return any( isinstance(message.get('content'), list) and any( item.get('type') in multimodal_types for item in message['content']) for message in messages) diff --git a/lmdeploy/utils.py b/lmdeploy/utils.py index 5e06ab5ae9..de93483206 100644 --- a/lmdeploy/utils.py +++ b/lmdeploy/utils.py @@ -305,6 +305,10 @@ def _get_and_verify_max_len( for key in llm_keys: hf_config = getattr(hf_config, key, hf_config) + # for qwen3-omni thinker + if hasattr(hf_config, 'thinker_config'): + hf_config = hf_config.thinker_config.text_config + logger = get_logger('lmdeploy') derived_max_model_len = float('inf') possible_keys = [ diff --git a/lmdeploy/vl/media/audio.py b/lmdeploy/vl/media/audio.py new file mode 100644 index 0000000000..19d39bfff9 --- /dev/null +++ b/lmdeploy/vl/media/audio.py @@ -0,0 +1,60 @@ +# Copyright (c) OpenMMLab. All rights reserved. +# adapted from https://github.com/vllm-project/vllm/blob/main/vllm/multimodal/media/audio.py + +import base64 +from io import BytesIO +from pathlib import Path + +import numpy.typing as npt + +from .base import MediaIO + + +class AudioMediaIO(MediaIO[tuple[npt.NDArray, float]]): + + def __init__(self, **kwargs) -> None: + super().__init__() + + # lazy import to avoid dependency issues for users who don't use audio features + try: + import librosa + self._librosa = librosa + except ImportError: + raise ImportError('Please install librosa via `pip install librosa`.') + + try: + import soundfile + self._soundfile = soundfile + except ImportError: + raise ImportError('Please install soundfile via `pip install soundfile`.') + + # for potential custom arguments from --media-io-kwargs + self.kwargs = kwargs + + def load_bytes(self, data: bytes) -> tuple[npt.NDArray, float]: + # sr = None, preserves the original sampling rate of the audio file + return self._librosa.load(BytesIO(data), sr=None) + + def load_base64( + self, + media_type: str, + data: str, + ) -> tuple[npt.NDArray, float]: + return self.load_bytes(base64.b64decode(data)) + + def load_file(self, filepath: Path) -> tuple[npt.NDArray, float]: + return self._librosa.load(filepath, sr=None) + + def encode_base64( + self, + media: tuple[npt.NDArray, int], + *, + audio_format: str = 'WAV', + ) -> str: + audio, sr = media + + with BytesIO() as buffer: + self._soundfile.write(buffer, audio, sr, format=audio_format) + data = buffer.getvalue() + + return base64.b64encode(data).decode('utf-8') diff --git a/lmdeploy/vl/model/base.py b/lmdeploy/vl/model/base.py index f282e44e67..d1bf97b608 100644 --- a/lmdeploy/vl/model/base.py +++ b/lmdeploy/vl/model/base.py @@ -252,40 +252,32 @@ def to_pytorch_with_input_ids(self, messages): return dict(prompt=None, input_ids=input_ids, multimodal=preps) - def to_pytorch_aux(self, messages, prompt, IMAGE_TOKEN, tokenizer, sequence_start): + def to_pytorch_aux(self, messages, prompt, mm_placeholder, tokenizer, sequence_start): """Auxiliary function to pack the preprocessing results in a format - compatible with what is required by pytorch engine. - - Args: - messages(List[Dict]): the output of `preprocess` - prompt(str): the prompt after applying chat template - IMAGE_TOKEN(str): a placeholder where image tokens will be - inserted - tokenzer: the tokenizer model - sequence_start: starting flag of a sequence - """ - # collect all preprocessing result from messages - preps = [x['content'] for x in messages if x['role'] == 'preprocess'] - assert len(preps) == 1 - preps = preps[0] + compatible with what is required by pytorch engine.""" + # collect all multi-modal preprocessing result from messages, keyed by 'preprocess' + mm_items = [x['content'] for x in messages if x['role'] == 'preprocess'] + assert len(mm_items) == 1 + mm_items = mm_items[0] # split prompt into segments and validate data - segs = prompt.split(IMAGE_TOKEN) - assert len(segs) == len(preps) + 1, (f'the number of {IMAGE_TOKEN} is not equal ' - f'to input images, {len(segs) - 1} vs {len(preps)}') + prompt_segments = prompt.split(mm_placeholder) + assert len(prompt_segments) == len(mm_items) + 1, ( + f'the number of {mm_placeholder} is not equal ' + f'to input multi modal items, {len(mm_items) - 1} vs {len(prompt_segments)}') - # calculate the image token offset for each image + # calculate the token offset for each multi modal item input_ids = [] - for i, seg in enumerate(segs): - if i > 0 and i <= len(preps): - preps[i - 1].update(offset=len(input_ids)) - image_tokens = preps[i - 1]['image_tokens'] - assert self.image_token_id == preps[i - 1]['image_token_id'] - input_ids.extend([self.image_token_id] * image_tokens) + mm_placeholder_id = tokenizer.encode(mm_placeholder, add_special_tokens=False)[-1] + for i, seg in enumerate(prompt_segments): + if i > 0 and i <= len(mm_items): + mm_items[i - 1].update(offset=len(input_ids)) + mm_token_num = mm_items[i - 1]['mm_token_num'] + input_ids.extend([mm_placeholder_id] * mm_token_num) token_ids = tokenizer.encode(seg, add_bos=((i == 0) and sequence_start)) input_ids.extend(token_ids) - return dict(prompt=prompt, input_ids=input_ids, multimodal=preps) + return dict(prompt=prompt, input_ids=input_ids, multimodal=mm_items) def to_turbomind_aux(self, messages, prompt, IMAGE_TOKEN, tokenizer, sequence_start): """Auxiliary function to pack the forwarding results in a format diff --git a/lmdeploy/vl/model/builder.py b/lmdeploy/vl/model/builder.py index 11262483e5..2a14e5e241 100644 --- a/lmdeploy/vl/model/builder.py +++ b/lmdeploy/vl/model/builder.py @@ -31,6 +31,7 @@ from .qwen2 import Qwen2VLModel # noqa F401 from .qwen3 import Qwen3VLModel # noqa F401 from .qwen3_5 import Qwen3_5Model # noqa F401 +from .qwen3_omni import Qwen3OmniModel # noqa F401 from .xcomposer2 import Xcomposer2VisionModel # noqa F401 from .yi import YiVisionModel # noqa F401 diff --git a/lmdeploy/vl/model/qwen3.py b/lmdeploy/vl/model/qwen3.py index 404f84d781..4531627dc2 100644 --- a/lmdeploy/vl/model/qwen3.py +++ b/lmdeploy/vl/model/qwen3.py @@ -93,7 +93,7 @@ def _preprocess_image(self, return_tensors='pt') merge_length = self.processor.image_processor.merge_size**2 image_tokens = result['image_grid_thw'].prod(dim=1) // merge_length - result.update(dict(image_size=image.size, image_tokens=image_tokens, image_token_id=self.image_token_id)) + result.update(dict(image_size=image.size, mm_token_num=image_tokens, image_token_id=self.image_token_id)) return result def _preprocess_video(self, @@ -206,7 +206,7 @@ def to_pytorch_aux_video(self, messages, prompt, VIDEO_TOKEN, tokenizer, sequenc video_token_ids = tokenizer.encode(video_placeholder) input_ids.extend(video_token_ids) - preps[i - 1].update(video_tokens=len(video_token_ids)) + preps[i - 1].update(mm_token_num=len(video_token_ids)) token_ids = tokenizer.encode(seg, add_bos=((i == 0) and sequence_start)) input_ids.extend(token_ids) diff --git a/lmdeploy/vl/model/qwen3_omni.py b/lmdeploy/vl/model/qwen3_omni.py new file mode 100644 index 0000000000..78bd49790f --- /dev/null +++ b/lmdeploy/vl/model/qwen3_omni.py @@ -0,0 +1,220 @@ +# Copyright (c) OpenMMLab. All rights reserved. +from typing import Any, Dict, List, Optional + +import torch +from transformers import AutoProcessor + +from lmdeploy.utils import get_logger +from lmdeploy.vl.constants import Modality +from lmdeploy.vl.model.base import VISION_MODELS, VisionModel + +logger = get_logger('lmdeploy') + + +def check_transformers(): + try: + from transformers import Qwen3OmniMoeForConditionalGeneration # noqa: F401 + except ImportError: + raise ImportError('please install latest transformers by ' + 'pip install git+https://github.com/huggingface/transformers.git') + + +@VISION_MODELS.register_module() +class Qwen3OmniModel(VisionModel): + """Qwen3Omni model.""" + + _arch = ['Qwen3OmniMoeForConditionalGeneration'] + + def build_preprocessor(self): + check_transformers() + self.processor = AutoProcessor.from_pretrained(self.model_path) + tokenizer = self.processor.tokenizer + + # image tokens + self.image_token = self.processor.image_token + self.image_token_id = tokenizer.encode(self.image_token)[-1] + + # video tokens + self.video_token = self.processor.video_token + self.video_token_id = tokenizer.encode(self.video_token)[-1] + + # audio tokens + self.audio_token = self.processor.audio_token + self.audio_token_id = tokenizer.encode(self.audio_token)[-1] + + def get_processor_args(self, mm_processor_kwargs: Optional[Dict[str, Any]] = None): + min_pixels = self.processor.image_processor.size['shortest_edge'] + max_pixels = self.processor.image_processor.size['longest_edge'] + + if mm_processor_kwargs is None: + return min_pixels, max_pixels + + input_min_pixels = mm_processor_kwargs.get('min_pixels', None) + input_max_pixels = mm_processor_kwargs.get('max_pixels', None) + + # boundary check for min_pixels and max_pixels + if input_min_pixels is None: + if input_max_pixels is not None: + # only max_pixels is given in the input + if input_max_pixels < min_pixels: + logger.warning( + f'input max_pixels {input_max_pixels} < default min_pixels {min_pixels}, fall back to default.') + return min_pixels, max_pixels + max_pixels = input_max_pixels + else: + if input_max_pixels is None: + # only min_pixels is given in the input + if input_min_pixels > max_pixels: + logger.warning( + f'input min_pixels {input_min_pixels} > default max_pixels {max_pixels}, fall back to default.') + return min_pixels, max_pixels + else: + if input_min_pixels > input_max_pixels: + logger.warning( + f'input min_pixels {input_min_pixels} > max_pixels {input_max_pixels}, fall back to default.') + return min_pixels, max_pixels + max_pixels = input_max_pixels + min_pixels = input_min_pixels + + return min_pixels, max_pixels + + def _preprocess_image(self, + data: List[Any], + params: Dict[str, Any], + mm_processor_kwargs: Dict[str, Any] | None = None) -> List[Dict]: + + image = data.convert('RGB') + min_pixels, max_pixels = self.get_processor_args(mm_processor_kwargs) + + result = self.processor.image_processor(images=image, + size={ + 'shortest_edge': min_pixels, + 'longest_edge': max_pixels + }, + return_tensors='pt') + merge_length = self.processor.image_processor.merge_size**2 + image_tokens = result['image_grid_thw'].prod(dim=1) // merge_length + result.update(dict(image_size=image.size, mm_token_num=image_tokens, image_token_id=self.image_token_id)) + return result + + def _preprocess_video(self, + data: List[Any], + params: Dict[str, Any], + mm_processor_kwargs: Dict[str, Any] | None = None) -> List[Dict]: + + # TODO: zhouxinyu, apply transformers smart_resize using per-request kwargs + metadata = params['video_metadata'] + video_kwargs = dict(return_metadata=True, + do_resize=True, + do_sample_frames=False, + video_metadata=metadata, + return_tensors='pt') + + # TODO: update from mm_processor_kwargs when needed + video_kwargs.update(size={ + 'shortest_edge': 128 * 32 * 32, + 'longest_edge': 768 * 32 * 32, + }) + result = self.processor.video_processor(videos=data, **video_kwargs) + video_grid_thw = result['video_grid_thw'] + + merge_length = self.processor.video_processor.merge_size**2 + if metadata.get('fps') is None: + logger.warning_once('Qwen3VL: fps not found, defaulting to 24.') + metadata['fps'] = metadata['fps'] or 24 + + # TODO: update fps from video kwargs, refer to transformers/models/qwen3_omni_moe/processing_qwen3_omni_moe.py + second_per_grid = self.processor.video_processor.temporal_patch_size / video_kwargs.get('fps', 1.0) + + frame_seqlen = video_grid_thw[0][1:].prod() // merge_length + video_tokens = video_grid_thw[0].prod() // merge_length # T*H*W / merge^2 + result.update(frame_seqlen=frame_seqlen, + mm_token_num=video_tokens, + second_per_grid=second_per_grid, + video_token_id=self.video_token_id) + return result + + def _get_feat_extract_output_lengths(self, input_lengths): + """Computes the output length of the convolutional layers and the + output length of the audio encoder.""" + + input_lengths_leave = input_lengths % 100 + feat_lengths = (input_lengths_leave - 1) // 2 + 1 + output_lengths = ((feat_lengths - 1) // 2 + 1 - 1) // 2 + 1 + (input_lengths // 100) * 13 + return output_lengths + + def _preprocess_audio(self, + data: List[Any], + params: Dict[str, Any], + mm_processor_kwargs: Dict[str, Any] | None = None) -> List[Dict]: + audio, original_sr = data + # NOTE: WhisperFeatureExtractor was trained using a fixed sampling rate of 16000 + # TODO: zhouxinyu, get truncation from mm_processor_kwargs when needed + sr = self.processor.feature_extractor.sampling_rate + audio_kwargs = { + 'sampling_rate': sr, + 'padding': True, + 'truncation': False, + 'return_attention_mask': True, + 'return_tensors': 'pt' + } + result = self.processor.feature_extractor(audio, **audio_kwargs) + feature_attention_mask = result.get('attention_mask') + audio_feature_lengths = torch.sum(feature_attention_mask, dim=1) + audio_output_length = self._get_feat_extract_output_lengths(audio_feature_lengths) + audio_tokens = audio_output_length + + result.update( + dict(mm_token_num=audio_tokens, + audio_feature_lengths=audio_feature_lengths, + audio_token_id=self.audio_token_id)) + return result + + def preprocess(self, messages: List[Dict], mm_processor_kwargs: Dict[str, Any] | None = None) -> List[Dict]: + """Refer to `super().preprocess()` for spec.""" + outputs = [] + self.contains_video_input = False + self.contains_audio_input = False + + mm_items = self.collect_multimodal_items(messages) + for modality, data, params in mm_items: + result = {} + if modality == Modality.IMAGE: + result = self._preprocess_image(data, params, mm_processor_kwargs) + elif modality == Modality.VIDEO: + self.contains_video_input = True + result = self._preprocess_video(data, params, mm_processor_kwargs) + elif modality == Modality.AUDIO: + self.contains_audio_input = True + result = self._preprocess_audio(data, params, mm_processor_kwargs) + + result.update(modality=modality) + outputs.append(result) + + messages.append(dict(role='preprocess', content=outputs)) + return messages + + def proc_messages(self, messages, chat_template, sequence_start, chat_template_kwargs=None): + """Apply chat template to get the prompt.""" + chat_template_kwargs = chat_template_kwargs or {} + messages = [x for x in messages if x['role'] not in ['preprocess', 'forward']] + prompt = chat_template.messages2prompt(messages, sequence_start, **chat_template_kwargs) + + mm_placeholder = self.image_token + if self.contains_video_input: + mm_placeholder = self.video_token + elif self.contains_audio_input: + mm_placeholder = self.audio_token + + return prompt, mm_placeholder + + def to_pytorch(self, + messages, + chat_template, + tokenizer, + sequence_start, + chat_template_kwargs: Dict | None = None, + **kwargs): + """Return to the information needed by pytorch engine.""" + prompt, mm_placeholder = self.proc_messages(messages, chat_template, sequence_start, chat_template_kwargs) + return self.to_pytorch_aux(messages, prompt, mm_placeholder, tokenizer, sequence_start) From eb33e33ff102b795479f886357411acd7a5983e7 Mon Sep 17 00:00:00 2001 From: zxy Date: Thu, 2 Apr 2026 18:44:48 +0800 Subject: [PATCH 02/17] minor --- lmdeploy/vl/model/qwen3_omni.py | 146 +++++++++++++------------------- 1 file changed, 57 insertions(+), 89 deletions(-) diff --git a/lmdeploy/vl/model/qwen3_omni.py b/lmdeploy/vl/model/qwen3_omni.py index 78bd49790f..b62405a94f 100644 --- a/lmdeploy/vl/model/qwen3_omni.py +++ b/lmdeploy/vl/model/qwen3_omni.py @@ -1,5 +1,5 @@ # Copyright (c) OpenMMLab. All rights reserved. -from typing import Any, Dict, List, Optional +from typing import Any import torch from transformers import AutoProcessor @@ -42,90 +42,58 @@ def build_preprocessor(self): self.audio_token = self.processor.audio_token self.audio_token_id = tokenizer.encode(self.audio_token)[-1] - def get_processor_args(self, mm_processor_kwargs: Optional[Dict[str, Any]] = None): - min_pixels = self.processor.image_processor.size['shortest_edge'] - max_pixels = self.processor.image_processor.size['longest_edge'] - - if mm_processor_kwargs is None: - return min_pixels, max_pixels - - input_min_pixels = mm_processor_kwargs.get('min_pixels', None) - input_max_pixels = mm_processor_kwargs.get('max_pixels', None) - - # boundary check for min_pixels and max_pixels - if input_min_pixels is None: - if input_max_pixels is not None: - # only max_pixels is given in the input - if input_max_pixels < min_pixels: - logger.warning( - f'input max_pixels {input_max_pixels} < default min_pixels {min_pixels}, fall back to default.') - return min_pixels, max_pixels - max_pixels = input_max_pixels - else: - if input_max_pixels is None: - # only min_pixels is given in the input - if input_min_pixels > max_pixels: - logger.warning( - f'input min_pixels {input_min_pixels} > default max_pixels {max_pixels}, fall back to default.') - return min_pixels, max_pixels - else: - if input_min_pixels > input_max_pixels: - logger.warning( - f'input min_pixels {input_min_pixels} > max_pixels {input_max_pixels}, fall back to default.') - return min_pixels, max_pixels - max_pixels = input_max_pixels - min_pixels = input_min_pixels - - return min_pixels, max_pixels + def resolve_size_params(self, processor, mm_processor_kwargs: dict[str, Any] | None = None): + default_min = processor.size['shortest_edge'] + default_max = processor.size['longest_edge'] + + if not mm_processor_kwargs: + return {'shortest_edge': default_min, 'longest_edge': default_max} + + min_pixels = mm_processor_kwargs.get('min_pixels', default_min) + max_pixels = mm_processor_kwargs.get('max_pixels', default_max) + + if min_pixels > max_pixels: + logger.warning(f'min_pixels {min_pixels} > max_pixels {max_pixels}, falling back to defaults.') + return {'shortest_edge': default_min, 'longest_edge': default_max} + + return {'shortest_edge': min_pixels, 'longest_edge': max_pixels} def _preprocess_image(self, - data: List[Any], - params: Dict[str, Any], - mm_processor_kwargs: Dict[str, Any] | None = None) -> List[Dict]: - - image = data.convert('RGB') - min_pixels, max_pixels = self.get_processor_args(mm_processor_kwargs) - - result = self.processor.image_processor(images=image, - size={ - 'shortest_edge': min_pixels, - 'longest_edge': max_pixels - }, - return_tensors='pt') + data: list[Any], + params: dict[str, Any], + mm_processor_kwargs: dict[str, Any] | None = None) -> list[dict]: + + size = self.resolve_size_params(self.processor.image_processor, mm_processor_kwargs) + result = self.processor.image_processor(images=data, size=size, return_tensors='pt') merge_length = self.processor.image_processor.merge_size**2 image_tokens = result['image_grid_thw'].prod(dim=1) // merge_length - result.update(dict(image_size=image.size, mm_token_num=image_tokens, image_token_id=self.image_token_id)) + result.update(dict(image_size=data.size, mm_token_num=image_tokens, image_token_id=self.image_token_id)) return result def _preprocess_video(self, - data: List[Any], - params: Dict[str, Any], - mm_processor_kwargs: Dict[str, Any] | None = None) -> List[Dict]: + data: list[Any], + params: dict[str, Any], + mm_processor_kwargs: dict[str, Any] | None = None) -> list[dict]: - # TODO: zhouxinyu, apply transformers smart_resize using per-request kwargs metadata = params['video_metadata'] - video_kwargs = dict(return_metadata=True, - do_resize=True, - do_sample_frames=False, - video_metadata=metadata, - return_tensors='pt') - - # TODO: update from mm_processor_kwargs when needed - video_kwargs.update(size={ - 'shortest_edge': 128 * 32 * 32, - 'longest_edge': 768 * 32 * 32, - }) - result = self.processor.video_processor(videos=data, **video_kwargs) - video_grid_thw = result['video_grid_thw'] + if metadata.get('fps') is None or metadata['fps'] <= 0: + logger.warning('Qwen3Omni: fps not found or invalid, fallback to 24.') + metadata['fps'] = 24 + size = self.resolve_size_params(self.processor.video_processor, mm_processor_kwargs) + + # do_resize = True, we leave resize to hf processor + # do_sample_frames = False, we already sample frames in video loader, avoid duplicates in hf processor + result = self.processor.video_processor(videos=data, + size=size, + return_metadata=True, + do_resize=True, + do_sample_frames=False, + video_metadata=metadata, + return_tensors='pt') merge_length = self.processor.video_processor.merge_size**2 - if metadata.get('fps') is None: - logger.warning_once('Qwen3VL: fps not found, defaulting to 24.') - metadata['fps'] = metadata['fps'] or 24 - - # TODO: update fps from video kwargs, refer to transformers/models/qwen3_omni_moe/processing_qwen3_omni_moe.py - second_per_grid = self.processor.video_processor.temporal_patch_size / video_kwargs.get('fps', 1.0) - + video_grid_thw = result['video_grid_thw'] + second_per_grid = self.processor.video_processor.temporal_patch_size / metadata.get('fps', 1.0) frame_seqlen = video_grid_thw[0][1:].prod() // merge_length video_tokens = video_grid_thw[0].prod() // merge_length # T*H*W / merge^2 result.update(frame_seqlen=frame_seqlen, @@ -144,21 +112,21 @@ def _get_feat_extract_output_lengths(self, input_lengths): return output_lengths def _preprocess_audio(self, - data: List[Any], - params: Dict[str, Any], - mm_processor_kwargs: Dict[str, Any] | None = None) -> List[Dict]: + data: list[Any], + params: dict[str, Any], + mm_processor_kwargs: dict[str, Any] | None = None) -> list[dict]: audio, original_sr = data - # NOTE: WhisperFeatureExtractor was trained using a fixed sampling rate of 16000 - # TODO: zhouxinyu, get truncation from mm_processor_kwargs when needed + # WhisperFeatureExtractor was trained using a fixed sampling rate of 16000 sr = self.processor.feature_extractor.sampling_rate - audio_kwargs = { - 'sampling_rate': sr, - 'padding': True, - 'truncation': False, - 'return_attention_mask': True, - 'return_tensors': 'pt' - } - result = self.processor.feature_extractor(audio, **audio_kwargs) + truncation = mm_processor_kwargs.get('truncation', False) if mm_processor_kwargs else False + + result = self.processor.feature_extractor(audio, + sampling_rate=sr, + padding=True, + truncation=truncation, + return_attention_mask=True, + return_tensors='pt') + feature_attention_mask = result.get('attention_mask') audio_feature_lengths = torch.sum(feature_attention_mask, dim=1) audio_output_length = self._get_feat_extract_output_lengths(audio_feature_lengths) @@ -170,7 +138,7 @@ def _preprocess_audio(self, audio_token_id=self.audio_token_id)) return result - def preprocess(self, messages: List[Dict], mm_processor_kwargs: Dict[str, Any] | None = None) -> List[Dict]: + def preprocess(self, messages: list[dict], mm_processor_kwargs: dict[str, Any] | None = None) -> list[dict]: """Refer to `super().preprocess()` for spec.""" outputs = [] self.contains_video_input = False @@ -213,7 +181,7 @@ def to_pytorch(self, chat_template, tokenizer, sequence_start, - chat_template_kwargs: Dict | None = None, + chat_template_kwargs: dict | None = None, **kwargs): """Return to the information needed by pytorch engine.""" prompt, mm_placeholder = self.proc_messages(messages, chat_template, sequence_start, chat_template_kwargs) From 4f6c57db32747304133700b16752973c069803d3 Mon Sep 17 00:00:00 2001 From: zxy Date: Thu, 2 Apr 2026 20:19:56 +0800 Subject: [PATCH 03/17] use builtin mrope --- .../pytorch/models/qwen3_omni_moe_thinker.py | 256 +++--------------- 1 file changed, 41 insertions(+), 215 deletions(-) diff --git a/lmdeploy/pytorch/models/qwen3_omni_moe_thinker.py b/lmdeploy/pytorch/models/qwen3_omni_moe_thinker.py index 09eb0782bc..4d5d34091a 100644 --- a/lmdeploy/pytorch/models/qwen3_omni_moe_thinker.py +++ b/lmdeploy/pytorch/models/qwen3_omni_moe_thinker.py @@ -1,8 +1,9 @@ # Copyright (c) OpenMMLab. All rights reserved. import math +from collections.abc import Iterable, Sequence from functools import lru_cache -from typing import Any, Dict, Iterable, List, Tuple +from typing import Any import numpy as np import torch @@ -21,7 +22,7 @@ from .qwen3_vl import Qwen3VLVisionBlock, Qwen3VLVisionPatchEmbed, Qwen3VLVisionRotaryEmbedding from .qwen3_vl_moe import Qwen3VLMoeTextModel -from .utils.cudagraph import CudaGraphMeta, CudaGraphMixin +from .utils.cudagraph import CudaGraphMixin from .utils.model import DeployModelMixin, vlm_model @@ -413,7 +414,7 @@ def rot_pos_emb(self, grid_thw: torch.Tensor) -> torch.Tensor: return rotary_pos_emb # copy from https://github.com/vllm-project/vllm/blob/main/vllm/model_executor/models/qwen3_vl.py#L474 - def fast_pos_embed_interpolate(self, grid_thw: List[List[int]]) -> torch.Tensor: + def fast_pos_embed_interpolate(self, grid_thw: list[list[int]]) -> torch.Tensor: num_grid_per_side = self.num_grid_per_side m_size = self.spatial_merge_size hidden_dim = self.pos_embed.embedding_dim @@ -544,7 +545,7 @@ def forward( self, input_ids: torch.Tensor, position_ids: torch.Tensor, - past_key_values: List[List[torch.Tensor]], + past_key_values: list[list[torch.Tensor]], attn_metadata: Any = None, inputs_embeds: torch.Tensor = None, mrope_position_ids: torch.Tensor = None, @@ -620,7 +621,7 @@ def get_input_embeddings(self): def prepare_inputs_for_generation( self, - past_key_values: List[List[torch.Tensor]], + past_key_values: list[list[torch.Tensor]], inputs_embeds: torch.Tensor | None = None, context: StepContext = None, ): @@ -716,8 +717,8 @@ def rename_weight(self, name: str) -> str: return 'lm_head.' + name[len('thinker.lm_head.'):] return name - def _load_weight_experts(self, name: str, loaded_weight: torch.Tensor, params_dict: Dict[str, nn.Parameter], - expert_params_mapping: List): + def _load_weight_experts(self, name: str, loaded_weight: torch.Tensor, params_dict: dict[str, nn.Parameter], + expert_params_mapping: list): """Load weight experts.""" for (param_name, weight_name, expert_id, shard_id) in expert_params_mapping: @@ -731,32 +732,7 @@ def _load_weight_experts(self, name: str, loaded_weight: torch.Tensor, params_di param = params_dict[name] load_weight(param, loaded_weight) - # modify from vllm qwen3vlmoe fused expert loading - def _load_weight_fused_experts(self, name: str, loaded_weight: torch.Tensor, params_dict: Dict[str, nn.Parameter], - fused_expert_params_mapping: List): - """Load weight of fused expert weights.""" - num_experts = self.config.text_config.num_experts - - for (param_name, weight_name) in fused_expert_params_mapping: - if weight_name not in name: - continue - name = name.replace(weight_name, param_name) - param = params_dict[name] - - loaded_weight = loaded_weight.transpose(-1, -2) # no bias - if 'gate_up' in name: - loaded_weight = loaded_weight.chunk(2, dim=-2) - w1 = loaded_weight[0] - w3 = loaded_weight[1] - for expert_id in range(num_experts): - load_weight(param, w1[expert_id], expert_id=expert_id, shard_id='gate') - load_weight(param, w3[expert_id], expert_id=expert_id, shard_id='up') - elif 'down' in name: - w2 = loaded_weight - for expert_id in range(num_experts): - load_weight(param, w2[expert_id], expert_id=expert_id, shard_id='down') - - def load_weights(self, weights: Iterable[Tuple[str, torch.Tensor]]): + def load_weights(self, weights: Iterable[tuple[str, torch.Tensor]]): """Load weights.""" # modify from vllm stacked_params_mapping = [ @@ -778,13 +754,6 @@ def load_weights(self, weights: Iterable[Tuple[str, torch.Tensor]]): down_param = ('.experts.down', f'.experts.{exp_id}.down_proj', exp_id, 'down') expert_params_mapping += [gate_param, up_param, down_param] - # fused expert mapping - fused_expert_params_mapping = [ - # (param_name, weight_name) - ('.experts.gate_up.weight', '.experts.gate_up_proj'), - ('.experts.down.weight', '.experts.down_proj'), - ] - params_dict = dict(self.named_parameters()) for name, loaded_weight in weights: if 'rotary_emb.inv_freq' in name: @@ -797,17 +766,10 @@ def load_weights(self, weights: Iterable[Tuple[str, torch.Tensor]]): name = name.replace('.block_sparse_moe.', '.mlp.') if '.experts' in name: - is_fused_expert = ('experts.gate_up_proj' in name or 'experts.down_proj' in name) - if is_fused_expert: - self._load_weight_fused_experts(name, - loaded_weight, - params_dict, - fused_expert_params_mapping=fused_expert_params_mapping) - else: - self._load_weight_experts(name, - loaded_weight, - params_dict, - expert_params_mapping=expert_params_mapping) + self._load_weight_experts(name, + loaded_weight, + params_dict, + expert_params_mapping=expert_params_mapping) else: for (param_name, weight_name, shard_id) in stacked_params_mapping: if weight_name not in name: @@ -827,165 +789,6 @@ def load_weights(self, weights: Iterable[Tuple[str, torch.Tensor]]): param = params_dict[name] load_weight(param, loaded_weight) - def make_buffers_cudagraph(self, graph_meta: CudaGraphMeta, **kwargs): - """Make cudagraph buffers from forward inputs.""" - max_tokens = graph_meta.max_tokens - - input_buffers = super().make_buffers_cudagraph(graph_meta=graph_meta, **kwargs) - mrope_position_ids = kwargs.get('mrope_position_ids', None) - if mrope_position_ids is not None: - input_buffers['mrope_position_ids'] = mrope_position_ids.new_zeros(3, max_tokens) - - return input_buffers - - def fill_buffers_cudagraph(self, graph_meta: CudaGraphMeta, **kwargs): - """Fill cudagraph buffers from forward inputs.""" - - new_inputs = super().fill_buffers_cudagraph(graph_meta=graph_meta, **kwargs) - - input_ids = kwargs.get('input_ids') - num_tokens = input_ids.size(-1) - new_batch_size = graph_meta.max_batchs - - is_decoding = graph_meta.is_decoding - input_buffers = graph_meta.input_buffers - mrope_position_ids = kwargs.get('mrope_position_ids', None) - if mrope_position_ids is not None: - input_buffers['mrope_position_ids'][:, :num_tokens] = mrope_position_ids - if is_decoding: - new_inputs['mrope_position_ids'] = input_buffers['mrope_position_ids'][:, :new_batch_size] - else: - new_inputs['mrope_position_ids'] = input_buffers['mrope_position_ids'] - - return new_inputs - - def _get_model_metas(self, context: StepContext): - """Get model metas.""" - model_metas = context.model_metas - if model_metas is None: - batch_size = context.q_seqlens.numel() - return [dict(mrope_delta=0)] * batch_size - return [dict(mrope_delta=0) if meta is None else meta for meta in model_metas] - - def _update_model_meta_decoding(self, context: StepContext): - """Update model meta for decoding.""" - model_metas = self._get_model_metas(context) - position_ids = context.position_ids - - mrope_deltas = [meta['mrope_delta'] for meta in model_metas] - mrope_deltas = position_ids.new_tensor(mrope_deltas) - mrope_position_ids = position_ids + mrope_deltas[None] - mrope_position_ids = mrope_position_ids.expand(3, -1) - - context.mrope_position_ids = mrope_position_ids - return model_metas - - def _get_multimodal_pos_ids(self, grid_thw: list, device: torch.device): - """Get mrope ids.""" - t, h, w = grid_thw - h //= 2 - w //= 2 - stride = torch.tensor([h * w, w, 1], device=device)[:, None] - size = torch.tensor([t, h, w], device=device)[:, None] - pos_ids = torch.arange(t * h * w, device=device)[None].expand(3, -1) - pos_ids = pos_ids // stride % size - return pos_ids - - def _update_model_meta_prefilling(self, context: StepContext): - """Update model meta for prefilling.""" - model_metas = self._get_model_metas(context) - input_multimodals = context.input_multimodals - if input_multimodals is None: - input_multimodals = [None] * len(model_metas) - position_ids = context.position_ids - batched_pos_ids = position_ids[0].split(context.q_seqlens.tolist()) - mrope_position_ids = [] - new_model_metas = [] - for pos_ids, model_meta, input_mm in zip(batched_pos_ids, model_metas, input_multimodals): - mm_data_list = [] - if input_mm is not None: - mm_data_list.extend(input_mm.get('mm_data', [])) - - if model_meta is None or 'mrope_delta' not in model_meta: - mrope_delta = 0 - else: - mrope_delta = model_meta['mrope_delta'] - - pos_start = pos_ids[0].item() - mrope_pos_ids = pos_ids + mrope_delta - mrope_pos_ids = mrope_pos_ids[None].expand(3, -1).clone() - - for mm_data in mm_data_list: - if mm_data.modality == Modality.IMAGE: - grid_thw = mm_data.meta['grid_thw'][0].tolist() - _, h, w = grid_thw - h //= 2 - w //= 2 - num_pad = mm_data.end - mm_data.start - max(h, w) - mrope_delta -= num_pad - fill_start = mm_data.start - pos_start - fill_end = mm_data.end - pos_start - img_pos_ids = self._get_multimodal_pos_ids(grid_thw, pos_ids.device) - img_pos_ids += mrope_pos_ids[:, fill_start:fill_start + 1] - mrope_pos_ids[:, fill_end:] -= num_pad - mrope_pos_ids[:, fill_start:fill_end] = img_pos_ids - elif mm_data.modality == Modality.VIDEO: - second_per_grid = mm_data.meta.get('second_per_grid', 2.0) - position_id_per_seconds = self.config.thinker_config.position_id_per_seconds - - grid_thw = mm_data.meta['grid_thw'][0].tolist() - t, h, w = grid_thw - llm_h = h // 2 # spatial_merge_size = 2 - llm_w = w // 2 - - device = pos_ids.device - # Temporal indices as real timestamps (float, e.g. 0, 1.083, 2.167 for fps=24) - t_index = torch.arange(t, device=device).float() * (second_per_grid * position_id_per_seconds) - h_index = torch.arange(llm_h, device=device).float() - w_index = torch.arange(llm_w, device=device).float() - - # Build [3, T*llm_h*llm_w] pos ids - t_expanded = t_index.view(-1, 1).expand(-1, llm_h * llm_w).flatten() - h_expanded = h_index.view(1, -1, 1).expand(t, -1, llm_w).flatten() - w_expanded = w_index.view(1, 1, -1).expand(t, llm_h, -1).flatten() - video_pos_ids = torch.stack([t_expanded, h_expanded, w_expanded]) # [3, T*llm_h*llm_w] - - max_video_pos = max( - float((t - 1) * second_per_grid * position_id_per_seconds) if t > 1 else 0.0, - float(llm_h - 1), - float(llm_w - 1), - ) - video_num_tokens = t * llm_h * llm_w - num_pad = video_num_tokens - max_video_pos - 1 - mrope_delta -= num_pad - - fill_start = mm_data.start - pos_start - fill_end = mm_data.end - pos_start - - # Convert to float to hold non-integer temporal positions - mrope_pos_ids = mrope_pos_ids.float() - offset = mrope_pos_ids[0, fill_start].item() - mrope_pos_ids[:, fill_start:fill_end] = video_pos_ids + offset - mrope_pos_ids[:, fill_end:] -= num_pad - - mrope_position_ids.append(mrope_pos_ids) - new_model_metas.append(dict(mrope_delta=mrope_delta)) - - mrope_position_ids = torch.cat(mrope_position_ids, dim=1) - context.mrope_position_ids = mrope_position_ids - - return new_model_metas - - def update_model_metas(self, - past_key_values: List[List[torch.Tensor]], - inputs_embeds: torch.Tensor | None = None, - context: StepContext = None): - """Update model meta.""" - if context.is_decoding: - return self._update_model_meta_decoding(context) - else: - return self._update_model_meta_prefilling(context) - def get_input_processor(self) -> BaseModelInputProcessor: """Get input processor.""" return self.input_processor @@ -997,7 +800,24 @@ class Qwen3OmniInputProcessor(BaseModelInputProcessor): def __init__(self, config: PretrainedConfig) -> None: self.config = config - def _make_image_mm_data(self, input_mm: Dict[str, Any]) -> MultiModalData: + @classmethod + def _get_multimodal_pos_ids(cls, grid_thw: Sequence[int]) -> np.ndarray: + """Get mrope ids.""" + t, h, w = grid_thw + h = h // 2 + w = w // 2 + stride = np.array([h * w, w, 1])[None] + size = np.array([t, h, w])[None] + pos_ids = np.arange(t * h * w)[:, None].repeat(3, axis=1) + pos_ids = pos_ids // stride % size + return pos_ids + + @classmethod + def make_mrope(cls, grid_thw: torch.Tensor): + img_pos_ids = cls._get_multimodal_pos_ids(grid_thw[0].tolist()) + return img_pos_ids + + def _make_image_mm_data(self, input_mm: dict[str, Any]) -> MultiModalData: """Make image MultiModalData.""" pixel_values = input_mm['pixel_values'] image_grid_thw = input_mm['image_grid_thw'] @@ -1008,14 +828,17 @@ def _make_image_mm_data(self, input_mm: Dict[str, Any]) -> MultiModalData: if isinstance(num_pad, torch.Tensor): num_pad = num_pad.item() + mrope_pos_ids = self.make_mrope(image_grid_thw) + mm_data = MultiModalData(modality=Modality.IMAGE, data=pixel_values, start=start, end=start + num_pad, + mrope_pos_ids=mrope_pos_ids, meta=dict(grid_thw=image_grid_thw, image_token_id=image_token_id)) return mm_data - def _make_video_mm_data(self, input_mm: Dict[str, Any]) -> MultiModalData: + def _make_video_mm_data(self, input_mm: dict[str, Any]) -> MultiModalData: """Make video MultiModalData.""" pixel_values_videos = input_mm['pixel_values_videos'] video_grid_thw = input_mm['video_grid_thw'] @@ -1026,10 +849,13 @@ def _make_video_mm_data(self, input_mm: Dict[str, Any]) -> MultiModalData: if isinstance(num_pad, torch.Tensor): num_pad = num_pad.item() + mrope_pos_ids = self.make_mrope(video_grid_thw) + mm_data = MultiModalData(modality=Modality.VIDEO, data=pixel_values_videos, start=start, end=start + num_pad, + mrope_pos_ids=mrope_pos_ids, meta=dict( grid_thw=video_grid_thw, video_token_id=video_token_id, @@ -1037,7 +863,7 @@ def _make_video_mm_data(self, input_mm: Dict[str, Any]) -> MultiModalData: )) return mm_data - def _make_audio_mm_data(self, input_mm: Dict[str, Any]) -> MultiModalData: + def _make_audio_mm_data(self, input_mm: dict[str, Any]) -> MultiModalData: """Make audio MultiModalData.""" input_features = input_mm['input_features'] offset = input_mm['offset'] @@ -1058,8 +884,8 @@ def _make_audio_mm_data(self, input_mm: Dict[str, Any]) -> MultiModalData: return mm_data def preprocess_input(self, - input_ids: List[int], - input_multimodals: List[Dict[str, Any]] = None, + input_ids: list[int], + input_multimodals: list[dict[str, Any]] = None, **kwargs) -> PreprocessInputResult: """Prepare multimodal input.""" if input_multimodals is None or len(input_multimodals) == 0: From b81d76b6da3e7e49283d8811dd38f1dafa0ed451 Mon Sep 17 00:00:00 2001 From: zxy Date: Tue, 7 Apr 2026 17:21:18 +0800 Subject: [PATCH 04/17] minor fix --- lmdeploy/pytorch/configurations/qwen3_omni.py | 1 + lmdeploy/vl/model/qwen3_omni.py | 49 +++++++++++-------- 2 files changed, 30 insertions(+), 20 deletions(-) diff --git a/lmdeploy/pytorch/configurations/qwen3_omni.py b/lmdeploy/pytorch/configurations/qwen3_omni.py index 6001f374fc..9d1e3ec61c 100644 --- a/lmdeploy/pytorch/configurations/qwen3_omni.py +++ b/lmdeploy/pytorch/configurations/qwen3_omni.py @@ -15,4 +15,5 @@ def build(cls, hf_config, model_path: str = None, **kwargs): """build.""" cfg = DefaultModelConfigBuilder.build(hf_config.thinker_config.text_config, model_path, **kwargs) cfg.hf_config = hf_config + cfg.use_mrope = True return cfg diff --git a/lmdeploy/vl/model/qwen3_omni.py b/lmdeploy/vl/model/qwen3_omni.py index b62405a94f..1b5a4d1774 100644 --- a/lmdeploy/vl/model/qwen3_omni.py +++ b/lmdeploy/vl/model/qwen3_omni.py @@ -3,6 +3,8 @@ import torch from transformers import AutoProcessor +from transformers.models.qwen3_omni_moe.processing_qwen3_omni_moe import Qwen3OmniMoeProcessorKwargs +from transformers.models.whisper import WhisperFeatureExtractor from lmdeploy.utils import get_logger from lmdeploy.vl.constants import Modality @@ -42,9 +44,12 @@ def build_preprocessor(self): self.audio_token = self.processor.audio_token self.audio_token_id = tokenizer.encode(self.audio_token)[-1] - def resolve_size_params(self, processor, mm_processor_kwargs: dict[str, Any] | None = None): - default_min = processor.size['shortest_edge'] - default_max = processor.size['longest_edge'] + # default kwargs for hf processor + self.default_mm_processor_kwargs = Qwen3OmniMoeProcessorKwargs._defaults + + def resolve_size_params(self, default_size, mm_processor_kwargs: dict[str, Any] | None = None): + default_min = default_size['shortest_edge'] + default_max = default_size['longest_edge'] if not mm_processor_kwargs: return {'shortest_edge': default_min, 'longest_edge': default_max} @@ -63,7 +68,8 @@ def _preprocess_image(self, params: dict[str, Any], mm_processor_kwargs: dict[str, Any] | None = None) -> list[dict]: - size = self.resolve_size_params(self.processor.image_processor, mm_processor_kwargs) + default_image_size = self.processor.image_processor.size + size = self.resolve_size_params(default_image_size, mm_processor_kwargs) result = self.processor.image_processor(images=data, size=size, return_tensors='pt') merge_length = self.processor.image_processor.merge_size**2 image_tokens = result['image_grid_thw'].prod(dim=1) // merge_length @@ -79,7 +85,9 @@ def _preprocess_video(self, if metadata.get('fps') is None or metadata['fps'] <= 0: logger.warning('Qwen3Omni: fps not found or invalid, fallback to 24.') metadata['fps'] = 24 - size = self.resolve_size_params(self.processor.video_processor, mm_processor_kwargs) + + defualt_video_kwargs = self.default_mm_processor_kwargs['videos_kwargs'] + size = self.resolve_size_params(defualt_video_kwargs['size'], mm_processor_kwargs) # do_resize = True, we leave resize to hf processor # do_sample_frames = False, we already sample frames in video loader, avoid duplicates in hf processor @@ -93,11 +101,11 @@ def _preprocess_video(self, merge_length = self.processor.video_processor.merge_size**2 video_grid_thw = result['video_grid_thw'] - second_per_grid = self.processor.video_processor.temporal_patch_size / metadata.get('fps', 1.0) - frame_seqlen = video_grid_thw[0][1:].prod() // merge_length - video_tokens = video_grid_thw[0].prod() // merge_length # T*H*W / merge^2 - result.update(frame_seqlen=frame_seqlen, - mm_token_num=video_tokens, + # TODO: custom fps + second_per_grid = self.processor.video_processor.temporal_patch_size / defualt_video_kwargs.get('fps', 1.0) + video_tokens = video_grid_thw[0].prod() // merge_length # T * H * W / merge_size^2 + + result.update(mm_token_num=video_tokens, second_per_grid=second_per_grid, video_token_id=self.video_token_id) return result @@ -116,16 +124,17 @@ def _preprocess_audio(self, params: dict[str, Any], mm_processor_kwargs: dict[str, Any] | None = None) -> list[dict]: audio, original_sr = data - # WhisperFeatureExtractor was trained using a fixed sampling rate of 16000 - sr = self.processor.feature_extractor.sampling_rate - truncation = mm_processor_kwargs.get('truncation', False) if mm_processor_kwargs else False - - result = self.processor.feature_extractor(audio, - sampling_rate=sr, - padding=True, - truncation=truncation, - return_attention_mask=True, - return_tensors='pt') + defualt_audio_kwargs = self.default_mm_processor_kwargs['audios_kwargs'] + feature_extractor = self.processor.feature_extractor + assert isinstance(feature_extractor, WhisperFeatureExtractor), \ + 'Qwen3Omni audio processor only support WhisperFeatureExtractor' + + # truncation is explicitly set to False to avoid different hf processor behavior + # https://github.com/huggingface/transformers/pull/41473 + result = feature_extractor(audio, + truncation=False, + return_tensors='pt', + **defualt_audio_kwargs) feature_attention_mask = result.get('attention_mask') audio_feature_lengths = torch.sum(feature_attention_mask, dim=1) From 10257fdd7aad16fdebef47b526f1e3e0e29f2ad6 Mon Sep 17 00:00:00 2001 From: zxy Date: Thu, 7 May 2026 21:01:07 +0800 Subject: [PATCH 05/17] Add Qwen3 Omni new preprocess support --- .../pytorch/models/qwen3_omni_moe_thinker.py | 84 +++--- lmdeploy/vl/model/base.py | 20 +- lmdeploy/vl/model/preprocess_utils.py | 69 ++++- lmdeploy/vl/model/qwen3_omni.py | 170 +----------- .../test_vl/test_qwen3_omni_processor.py | 258 ++++++++++++++++++ 5 files changed, 393 insertions(+), 208 deletions(-) create mode 100644 tests/test_lmdeploy/test_vl/test_qwen3_omni_processor.py diff --git a/lmdeploy/pytorch/models/qwen3_omni_moe_thinker.py b/lmdeploy/pytorch/models/qwen3_omni_moe_thinker.py index 4d5d34091a..06358566e5 100644 --- a/lmdeploy/pytorch/models/qwen3_omni_moe_thinker.py +++ b/lmdeploy/pytorch/models/qwen3_omni_moe_thinker.py @@ -23,7 +23,7 @@ from .qwen3_vl import Qwen3VLVisionBlock, Qwen3VLVisionPatchEmbed, Qwen3VLVisionRotaryEmbedding from .qwen3_vl_moe import Qwen3VLMoeTextModel from .utils.cudagraph import CudaGraphMixin -from .utils.model import DeployModelMixin, vlm_model +from .utils.model import DeployModelMixinV1, vlm_model def _get_feat_extract_output_lengths(input_lengths): @@ -489,7 +489,7 @@ def forward(self, hidden_states: torch.Tensor, cu_seqlens: torch.Tensor, rotary_ return hidden_states, deepstack_feature_lists -class Qwen3OmniMoeThinkerForConditionalGeneration(nn.Module, DeployModelMixin, CudaGraphMixin): +class Qwen3OmniMoeThinkerForConditionalGeneration(nn.Module, DeployModelMixinV1, CudaGraphMixin): """ModelForCausalLM.""" packed_modules_mapping = { @@ -647,26 +647,30 @@ def prepare_inputs_for_generation( mm_inputs = [item for sublist in mm_inputs for item in sublist] if len(mm_inputs) > 0: - modality = mm_inputs[0].modality - - image_token_id = mm_inputs[0].meta.get('image_token_id') - video_token_id = mm_inputs[0].meta.get('video_token_id') - audio_token_id = mm_inputs[0].meta.get('audio_token_id') - - if modality == Modality.AUDIO: - audio_values = torch.cat([inp.data for inp in mm_inputs]) - # FIXME: zhouxinyu, batch ? - audio_values = audio_values.squeeze(0) + audio_inputs = [inp for inp in mm_inputs if inp.modality == Modality.AUDIO] + visual_inputs = [inp for inp in mm_inputs if inp.modality in [Modality.IMAGE, Modality.VIDEO]] + + if audio_inputs: + audio_chunks = [] + for inp in audio_inputs: + audio_data = inp.data + if audio_data.dim() == 3 and audio_data.shape[0] == 1: + audio_data = audio_data.squeeze(0) + audio_chunks.append(audio_data) + audio_values = torch.cat(audio_chunks, dim=-1) + audio_token_id = audio_inputs[0].meta.get('audio_token_id') audio_mask = (input_ids == audio_token_id) - # FIXME: zhouxinyu, list ? - audio_feature_lengths = mm_inputs[0].meta['audio_feature_lengths'] - elif modality in [Modality.IMAGE, Modality.VIDEO]: - pixel_values = torch.cat([inp.data for inp in mm_inputs]) + audio_feature_lengths = torch.cat([ + inp.meta['audio_feature_lengths'] + if isinstance(inp.meta['audio_feature_lengths'], torch.Tensor) else + torch.tensor([inp.meta['audio_feature_lengths']], dtype=torch.long) for inp in audio_inputs + ]) - mm_token_id = image_token_id if modality == Modality.IMAGE else video_token_id - image_mask = (input_ids == mm_token_id) + if visual_inputs: + pixel_values = torch.cat([inp.data for inp in visual_inputs]) + image_mask = self.get_multimodal_mask(input_ids, visual_inputs) - grid_thw = torch.cat([data.meta['grid_thw'] for data in mm_inputs]).cpu() + grid_thw = torch.stack([data.meta['grid_thw'] for data in visual_inputs]).cpu() vis_pos_emb = self.visual.rot_pos_emb(grid_thw) pos_embeds = self.visual.fast_pos_embed_interpolate(grid_thw) vis_cu_seqlens = torch.repeat_interleave(grid_thw[:, 1] * grid_thw[:, 2], @@ -814,7 +818,8 @@ def _get_multimodal_pos_ids(cls, grid_thw: Sequence[int]) -> np.ndarray: @classmethod def make_mrope(cls, grid_thw: torch.Tensor): - img_pos_ids = cls._get_multimodal_pos_ids(grid_thw[0].tolist()) + grid_thw = grid_thw.tolist() if grid_thw.dim() == 1 else grid_thw[0].tolist() + img_pos_ids = cls._get_multimodal_pos_ids(grid_thw) return img_pos_ids def _make_image_mm_data(self, input_mm: dict[str, Any]) -> MultiModalData: @@ -822,18 +827,14 @@ def _make_image_mm_data(self, input_mm: dict[str, Any]) -> MultiModalData: pixel_values = input_mm['pixel_values'] image_grid_thw = input_mm['image_grid_thw'] offset = input_mm['offset'] - start = offset image_token_id = input_mm['image_token_id'] - num_pad = input_mm['mm_token_num'] - if isinstance(num_pad, torch.Tensor): - num_pad = num_pad.item() mrope_pos_ids = self.make_mrope(image_grid_thw) mm_data = MultiModalData(modality=Modality.IMAGE, data=pixel_values, - start=start, - end=start + num_pad, + start=offset[0], + end=offset[1], mrope_pos_ids=mrope_pos_ids, meta=dict(grid_thw=image_grid_thw, image_token_id=image_token_id)) return mm_data @@ -843,18 +844,14 @@ def _make_video_mm_data(self, input_mm: dict[str, Any]) -> MultiModalData: pixel_values_videos = input_mm['pixel_values_videos'] video_grid_thw = input_mm['video_grid_thw'] offset = input_mm['offset'] - start = offset video_token_id = input_mm['video_token_id'] - num_pad = input_mm['mm_token_num'] - if isinstance(num_pad, torch.Tensor): - num_pad = num_pad.item() mrope_pos_ids = self.make_mrope(video_grid_thw) mm_data = MultiModalData(modality=Modality.VIDEO, data=pixel_values_videos, - start=start, - end=start + num_pad, + start=offset[0], + end=offset[1], mrope_pos_ids=mrope_pos_ids, meta=dict( grid_thw=video_grid_thw, @@ -867,19 +864,28 @@ def _make_audio_mm_data(self, input_mm: dict[str, Any]) -> MultiModalData: """Make audio MultiModalData.""" input_features = input_mm['input_features'] offset = input_mm['offset'] - start = offset audio_token_id = input_mm['audio_token_id'] - num_pad = input_mm['mm_token_num'] - if isinstance(num_pad, torch.Tensor): - num_pad = num_pad.item() + feature_attention_mask = input_mm.get('feature_attention_mask') + if feature_attention_mask is not None: + audio_feature_lengths = torch.sum(feature_attention_mask, dim=1) + input_features = input_features.permute(0, 2, 1)[feature_attention_mask.bool()].permute(1, 0) + else: + audio_feature_lengths = input_mm.get('audio_feature_lengths') + if audio_feature_lengths is None: + audio_feature_lengths = torch.full((input_features.shape[0], ), + input_features.shape[-1], + dtype=torch.long) + audio_len = offset[1] - offset[0] + mrope_pos_ids = np.arange(audio_len, dtype=np.int64)[:, None].repeat(3, axis=1) mm_data = MultiModalData(modality=Modality.AUDIO, data=input_features, - start=start, - end=start + num_pad, + start=offset[0], + end=offset[1], + mrope_pos_ids=mrope_pos_ids, meta=dict( audio_token_id=audio_token_id, - audio_feature_lengths=input_mm.get('audio_feature_lengths'), + audio_feature_lengths=audio_feature_lengths, )) return mm_data diff --git a/lmdeploy/vl/model/base.py b/lmdeploy/vl/model/base.py index beca33f01d..111ac1dac8 100644 --- a/lmdeploy/vl/model/base.py +++ b/lmdeploy/vl/model/base.py @@ -32,6 +32,10 @@ class VisionModel(ABC): # video-related attributes 'pixel_values_videos': Modality.VIDEO, 'video_grid_thw': Modality.VIDEO, + 'video_second_per_grid': Modality.VIDEO, + # audio-related attributes + 'input_features': Modality.AUDIO, + 'feature_attention_mask': Modality.AUDIO, # time series-related attributes 'ts_values': Modality.TIME_SERIES, 'ts_sr': Modality.TIME_SERIES, @@ -42,6 +46,7 @@ class VisionModel(ABC): FEATURE_NAMES = [ 'pixel_values', 'pixel_values_videos', + 'input_features', 'ts_values', ] @@ -103,6 +108,7 @@ def preprocess(self, mm_items = self.collect_multimodal_items(messages) raw_images, raw_videos, video_metadatas = [], [], [] + raw_audios = [] raw_time_series, sampling_rates = [], [] for modality, data, params in mm_items: if modality == Modality.IMAGE: @@ -110,6 +116,8 @@ def preprocess(self, elif modality == Modality.VIDEO: raw_videos.append(data) video_metadatas.append(params.get('video_metadata', None)) + elif modality == Modality.AUDIO: + raw_audios.append(data[0] if isinstance(data, tuple) else data) elif modality == Modality.TIME_SERIES: raw_time_series.append(data) sampling_rates.append(params.get('sampling_rate', None)) @@ -120,6 +128,7 @@ def preprocess(self, kwargs = {} images_kwargs = {} videos_kwargs = {} + audio_kwargs = {} mm_processor_kwargs = mm_processor_kwargs or {} if raw_images: kwargs['images'] = raw_images @@ -139,15 +148,20 @@ def preprocess(self, modality='video') if video_size is not None: videos_kwargs['size'] = video_size + if raw_audios: + kwargs['audio'] = raw_audios + audio_kwargs = mm_processor_kwargs.get('audio') or {} if images_kwargs: kwargs['images_kwargs'] = images_kwargs if videos_kwargs: kwargs['videos_kwargs'] = videos_kwargs + if audio_kwargs: + kwargs['audio_kwargs'] = audio_kwargs if raw_time_series: assert hasattr(self, 'time_series_processor'), \ 'time series processor is not defined for time series input' - assert not raw_images and not raw_videos, \ - 'time series is not compatible with image/video input' + assert not raw_images and not raw_videos and not raw_audios, \ + 'time series is not compatible with image/video/audio input' self.tokenizer = self.processor.tokenizer time_series_processor = self.time_series_processor kwargs['time_series'] = raw_time_series @@ -188,6 +202,8 @@ def preprocess(self, # expand bundled hf processor outputs into per-image/video entry for lmdeploy to consume expanded_mm_items = get_expanded_mm_items(collected_mm_items, self.mm_tokens) + # HF processors return features grouped by modality; offsets restore prompt order for mixed inputs. + expanded_mm_items.sort(key=lambda item: item['offset'][0]) return dict(input_ids=input_ids.tolist(), multimodal=expanded_mm_items) diff --git a/lmdeploy/vl/model/preprocess_utils.py b/lmdeploy/vl/model/preprocess_utils.py index bb591116d9..f9d149f837 100644 --- a/lmdeploy/vl/model/preprocess_utils.py +++ b/lmdeploy/vl/model/preprocess_utils.py @@ -94,6 +94,30 @@ def get_expanded_mm_items(collected_mm_items, mm_tokens: 'MultimodalSpecialToken offset=item['offset'][0], image_token_id=token_id, )) + elif modality == Modality.VIDEO: + second_per_grid = item.get('video_second_per_grid') + if second_per_grid is not None: + second_per_grid = second_per_grid[0] + if isinstance(second_per_grid, torch.Tensor) and second_per_grid.numel() == 1: + second_per_grid = second_per_grid.item() + expanded_mm_items.append( + dict( + modality=modality, + pixel_values_videos=item['feature'], + video_grid_thw=item['video_grid_thw'][0], + offset=item['offset'][0], + second_per_grid=second_per_grid, + video_token_id=token_id, + )) + elif modality == Modality.AUDIO: + expanded_mm_items.append( + dict( + modality=modality, + input_features=item['feature'], + feature_attention_mask=item.get('feature_attention_mask'), + offset=item['offset'][0], + audio_token_id=token_id, + )) elif modality == Modality.TIME_SERIES: expanded_mm_items.append( dict( @@ -143,10 +167,6 @@ def get_expanded_mm_items(collected_mm_items, mm_tokens: 'MultimodalSpecialToken frames_per_video.append(T) total_frames += T - if num_items != total_frames: - expanded_mm_items.append(item) - continue - patches_per_video = [] for i in range(num_videos): grid = video_grid_thw[i] @@ -157,6 +177,25 @@ def get_expanded_mm_items(collected_mm_items, mm_tokens: 'MultimodalSpecialToken cumulative = torch.cumsum(torch.tensor(patches_per_video, dtype=torch.long), dim=0) slice_indices = [0] + cumulative.tolist() + if num_items != total_frames: + for video_idx in range(num_videos): + start, end = slice_indices[video_idx], slice_indices[video_idx + 1] + second_per_grid = item.get('video_second_per_grid') + if second_per_grid is not None: + second_per_grid = second_per_grid[video_idx] + if isinstance(second_per_grid, torch.Tensor) and second_per_grid.numel() == 1: + second_per_grid = second_per_grid.item() + expanded_mm_items.append( + dict( + modality=modality, + pixel_values_videos=item['feature'][start:end], + video_grid_thw=video_grid_thw[video_idx], + offset=item['offset'][video_idx], + second_per_grid=second_per_grid, + video_token_id=token_id, + )) + continue + frame_start_indices = [0] for i in range(num_videos): frame_start_indices.append(frame_start_indices[-1] + frames_per_video[i]) @@ -164,19 +203,39 @@ def get_expanded_mm_items(collected_mm_items, mm_tokens: 'MultimodalSpecialToken for video_idx in range(num_videos): start, end = slice_indices[video_idx], slice_indices[video_idx + 1] frame_start, frame_end = frame_start_indices[video_idx], frame_start_indices[video_idx + 1] + second_per_grid = item.get('video_second_per_grid') + if second_per_grid is not None: + second_per_grid = second_per_grid[video_idx] + if isinstance(second_per_grid, torch.Tensor) and second_per_grid.numel() == 1: + second_per_grid = second_per_grid.item() # TODO: zhouxinyu, not sure per-frame split is good or not # TODO: zhouxinyu, grid_thw [1, h, w] is only for qwen3vl t, h, w = video_grid_thw[video_idx].tolist() for frame_idx in range(t): video_feature = item['feature'][start:end] + offset = item['offset'][frame_start:frame_end][frame_idx] expanded_mm_items.append( dict( modality=modality, pixel_values_videos=video_feature[frame_idx * h * w:(frame_idx + 1) * h * w], video_grid_thw=torch.tensor([1, h, w]), - offset=item['offset'][frame_start:frame_end][frame_idx], + offset=offset, + second_per_grid=second_per_grid, video_token_id=token_id, )) + elif modality == Modality.AUDIO: + for i in range(num_items): + feature_attention_mask = item.get('feature_attention_mask') + if feature_attention_mask is not None: + feature_attention_mask = feature_attention_mask[i:i + 1] + expanded_mm_items.append( + dict( + modality=modality, + input_features=item['feature'][i:i + 1], + feature_attention_mask=feature_attention_mask, + offset=item['offset'][i], + audio_token_id=token_id, + )) return expanded_mm_items diff --git a/lmdeploy/vl/model/qwen3_omni.py b/lmdeploy/vl/model/qwen3_omni.py index 1b5a4d1774..a3dbd75f55 100644 --- a/lmdeploy/vl/model/qwen3_omni.py +++ b/lmdeploy/vl/model/qwen3_omni.py @@ -1,16 +1,7 @@ # Copyright (c) OpenMMLab. All rights reserved. -from typing import Any - -import torch from transformers import AutoProcessor -from transformers.models.qwen3_omni_moe.processing_qwen3_omni_moe import Qwen3OmniMoeProcessorKwargs -from transformers.models.whisper import WhisperFeatureExtractor - -from lmdeploy.utils import get_logger -from lmdeploy.vl.constants import Modality -from lmdeploy.vl.model.base import VISION_MODELS, VisionModel -logger = get_logger('lmdeploy') +from lmdeploy.vl.model.base import VISION_MODELS, MultimodalSpecialTokens, VisionModel def check_transformers(): @@ -29,7 +20,7 @@ class Qwen3OmniModel(VisionModel): def build_preprocessor(self): check_transformers() - self.processor = AutoProcessor.from_pretrained(self.model_path) + self.processor = AutoProcessor.from_pretrained(self.model_path, trust_remote_code=True) tokenizer = self.processor.tokenizer # image tokens @@ -44,154 +35,9 @@ def build_preprocessor(self): self.audio_token = self.processor.audio_token self.audio_token_id = tokenizer.encode(self.audio_token)[-1] - # default kwargs for hf processor - self.default_mm_processor_kwargs = Qwen3OmniMoeProcessorKwargs._defaults - - def resolve_size_params(self, default_size, mm_processor_kwargs: dict[str, Any] | None = None): - default_min = default_size['shortest_edge'] - default_max = default_size['longest_edge'] - - if not mm_processor_kwargs: - return {'shortest_edge': default_min, 'longest_edge': default_max} - - min_pixels = mm_processor_kwargs.get('min_pixels', default_min) - max_pixels = mm_processor_kwargs.get('max_pixels', default_max) - - if min_pixels > max_pixels: - logger.warning(f'min_pixels {min_pixels} > max_pixels {max_pixels}, falling back to defaults.') - return {'shortest_edge': default_min, 'longest_edge': default_max} - - return {'shortest_edge': min_pixels, 'longest_edge': max_pixels} - - def _preprocess_image(self, - data: list[Any], - params: dict[str, Any], - mm_processor_kwargs: dict[str, Any] | None = None) -> list[dict]: - - default_image_size = self.processor.image_processor.size - size = self.resolve_size_params(default_image_size, mm_processor_kwargs) - result = self.processor.image_processor(images=data, size=size, return_tensors='pt') - merge_length = self.processor.image_processor.merge_size**2 - image_tokens = result['image_grid_thw'].prod(dim=1) // merge_length - result.update(dict(image_size=data.size, mm_token_num=image_tokens, image_token_id=self.image_token_id)) - return result - - def _preprocess_video(self, - data: list[Any], - params: dict[str, Any], - mm_processor_kwargs: dict[str, Any] | None = None) -> list[dict]: - - metadata = params['video_metadata'] - if metadata.get('fps') is None or metadata['fps'] <= 0: - logger.warning('Qwen3Omni: fps not found or invalid, fallback to 24.') - metadata['fps'] = 24 - - defualt_video_kwargs = self.default_mm_processor_kwargs['videos_kwargs'] - size = self.resolve_size_params(defualt_video_kwargs['size'], mm_processor_kwargs) - - # do_resize = True, we leave resize to hf processor - # do_sample_frames = False, we already sample frames in video loader, avoid duplicates in hf processor - result = self.processor.video_processor(videos=data, - size=size, - return_metadata=True, - do_resize=True, - do_sample_frames=False, - video_metadata=metadata, - return_tensors='pt') - - merge_length = self.processor.video_processor.merge_size**2 - video_grid_thw = result['video_grid_thw'] - # TODO: custom fps - second_per_grid = self.processor.video_processor.temporal_patch_size / defualt_video_kwargs.get('fps', 1.0) - video_tokens = video_grid_thw[0].prod() // merge_length # T * H * W / merge_size^2 - - result.update(mm_token_num=video_tokens, - second_per_grid=second_per_grid, - video_token_id=self.video_token_id) - return result - - def _get_feat_extract_output_lengths(self, input_lengths): - """Computes the output length of the convolutional layers and the - output length of the audio encoder.""" - - input_lengths_leave = input_lengths % 100 - feat_lengths = (input_lengths_leave - 1) // 2 + 1 - output_lengths = ((feat_lengths - 1) // 2 + 1 - 1) // 2 + 1 + (input_lengths // 100) * 13 - return output_lengths - - def _preprocess_audio(self, - data: list[Any], - params: dict[str, Any], - mm_processor_kwargs: dict[str, Any] | None = None) -> list[dict]: - audio, original_sr = data - defualt_audio_kwargs = self.default_mm_processor_kwargs['audios_kwargs'] - feature_extractor = self.processor.feature_extractor - assert isinstance(feature_extractor, WhisperFeatureExtractor), \ - 'Qwen3Omni audio processor only support WhisperFeatureExtractor' - - # truncation is explicitly set to False to avoid different hf processor behavior - # https://github.com/huggingface/transformers/pull/41473 - result = feature_extractor(audio, - truncation=False, - return_tensors='pt', - **defualt_audio_kwargs) - - feature_attention_mask = result.get('attention_mask') - audio_feature_lengths = torch.sum(feature_attention_mask, dim=1) - audio_output_length = self._get_feat_extract_output_lengths(audio_feature_lengths) - audio_tokens = audio_output_length - - result.update( - dict(mm_token_num=audio_tokens, - audio_feature_lengths=audio_feature_lengths, - audio_token_id=self.audio_token_id)) - return result - - def preprocess(self, messages: list[dict], mm_processor_kwargs: dict[str, Any] | None = None) -> list[dict]: - """Refer to `super().preprocess()` for spec.""" - outputs = [] - self.contains_video_input = False - self.contains_audio_input = False - - mm_items = self.collect_multimodal_items(messages) - for modality, data, params in mm_items: - result = {} - if modality == Modality.IMAGE: - result = self._preprocess_image(data, params, mm_processor_kwargs) - elif modality == Modality.VIDEO: - self.contains_video_input = True - result = self._preprocess_video(data, params, mm_processor_kwargs) - elif modality == Modality.AUDIO: - self.contains_audio_input = True - result = self._preprocess_audio(data, params, mm_processor_kwargs) - - result.update(modality=modality) - outputs.append(result) - - messages.append(dict(role='preprocess', content=outputs)) - return messages - - def proc_messages(self, messages, chat_template, sequence_start, chat_template_kwargs=None): - """Apply chat template to get the prompt.""" - chat_template_kwargs = chat_template_kwargs or {} - messages = [x for x in messages if x['role'] not in ['preprocess', 'forward']] - prompt = chat_template.messages2prompt(messages, sequence_start, **chat_template_kwargs) - - mm_placeholder = self.image_token - if self.contains_video_input: - mm_placeholder = self.video_token - elif self.contains_audio_input: - mm_placeholder = self.audio_token - - return prompt, mm_placeholder - - def to_pytorch(self, - messages, - chat_template, - tokenizer, - sequence_start, - chat_template_kwargs: dict | None = None, - **kwargs): - """Return to the information needed by pytorch engine.""" - prompt, mm_placeholder = self.proc_messages(messages, chat_template, sequence_start, chat_template_kwargs) - return self.to_pytorch_aux(messages, prompt, mm_placeholder, tokenizer, sequence_start) + self.mm_tokens = MultimodalSpecialTokens(image_token=self.image_token, + video_token=self.video_token, + audio_token=self.audio_token, + image_token_id=self.image_token_id, + video_token_id=self.video_token_id, + audio_token_id=self.audio_token_id) diff --git a/tests/test_lmdeploy/test_vl/test_qwen3_omni_processor.py b/tests/test_lmdeploy/test_vl/test_qwen3_omni_processor.py new file mode 100644 index 0000000000..dd0071ee4c --- /dev/null +++ b/tests/test_lmdeploy/test_vl/test_qwen3_omni_processor.py @@ -0,0 +1,258 @@ +from types import SimpleNamespace + +import torch + +from lmdeploy.pytorch.models.qwen3_omni_moe_thinker import Qwen3OmniInputProcessor +from lmdeploy.vl.constants import Modality +from lmdeploy.vl.model.base import MultimodalSpecialTokens, VisionModel +from lmdeploy.vl.model.preprocess_utils import get_expanded_mm_items +from lmdeploy.vl.model.qwen3_omni import Qwen3OmniModel + + +class FakeQwen3OmniProcessor: + + image_token = '' + audio_token = '