Skip to content
Open
Show file tree
Hide file tree
Changes from 1 commit
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
12 changes: 12 additions & 0 deletions lmdeploy/cli/lite.py
Original file line number Diff line number Diff line change
Expand Up @@ -39,6 +39,18 @@ def add_parser_auto_awq():
type=int,
default=128,
help='Group size for weight quantization statistics')
parser.add_argument('--no-calib-ds-req',
dest='calib_ds_req',
action='store_false',
default=True,
help='Require calibration dataset before quantizing weights. '
'Default to True. Set to False to skip calibration and directly quantize weights')
Copy link
Copy Markdown
Collaborator

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

It's unnessary to define another option.
We can use "calib_samples=0" to indicate the data free quantization

parser.add_argument('--mod-skip-quant',
Copy link
Copy Markdown
Collaborator

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

--mod-skip-quant is awkward. dest is redudant

parser.add_argument('--exclude-modules',
                    nargs='+',
                    metavar='PATTERN',
                    default=None,
                    help='One or more module name patterns (glob‑style) to exclude from quantization. '
     'Example: --exclude-modules "*.lm_head" "transformer.layers.*.ffn"'')

dest='mod_skip_quant',
nargs='+',
metavar='PATTERN',
default=None,
help='Module name patterns to skip during quantization')

@staticmethod
def add_parser_auto_gptq():
Expand Down
109 changes: 84 additions & 25 deletions lmdeploy/lite/apis/auto_awq.py
Original file line number Diff line number Diff line change
Expand Up @@ -10,10 +10,46 @@

from lmdeploy.lite.apis.calibrate import LAYER_TYPE_MAP, calibrate
from lmdeploy.lite.quantization.awq import FC_FCS_MAP, NORM_FCS_MAP, awq_layers, quant_weights, smooth_layers
from lmdeploy.lite.utils import collect_target_modules
from lmdeploy.lite.utils import collect_target_modules, convert_moe_parameters
from lmdeploy.utils import try_import_deeplink


def load_model(model: str, dtype: Literal['float16', 'bfloat16', 'auto'] = 'auto', work_dir: str = './work_dir'):
from pathlib import Path

from transformers import AutoTokenizer

from lmdeploy.lite.utils import load_hf_from_pretrained
tokenizer = AutoTokenizer.from_pretrained(model, trust_remote_code=True)
model = load_hf_from_pretrained(model, dtype=dtype, trust_remote_code=True)
vl_model = None
work_dir = Path(work_dir)
work_dir.mkdir(parents=True, exist_ok=True)
return vl_model, model, tokenizer, work_dir


Copy link

Copilot AI Apr 28, 2026

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

The no-calibration path (load_model) always uses load_hf_from_pretrained, which loads via AutoModelForCausalLM and always returns vl_model=None. This bypasses the VLM loading logic used in calibrate() (load_vl_model, language_model/llm extraction, etc.), so --no-calib-ds-req is likely to break for VLM/conditional-generation architectures such as Qwen3_5* (and will also skip save_vl_model). Suggest reusing get_task(...) + the same VLM/LLM loading branch as calibrate() (but skipping dataset calibration), rather than a separate loader.

Suggested change
tokenizer = AutoTokenizer.from_pretrained(model, trust_remote_code=True)
model = load_hf_from_pretrained(model, dtype=dtype, trust_remote_code=True)
vl_model = None
work_dir = Path(work_dir)
work_dir.mkdir(parents=True, exist_ok=True)
return vl_model, model, tokenizer, work_dir
work_dir = Path(work_dir)
work_dir.mkdir(parents=True, exist_ok=True)
vl_model = None
tokenizer = None
llm_model = None
# Reuse the task-aware VLM loading path when available so the no-calib
# flow can handle conditional-generation / multimodal architectures.
try:
from lmdeploy.archs import get_task
from lmdeploy.vl.model.builder import load_vl_model
_, pipeline_class = get_task(model)
is_vl_task = pipeline_class is not None and hasattr(pipeline_class, 'is_vl') and pipeline_class.is_vl
if is_vl_task:
vl_model = load_vl_model(model, backend='huggingface')
if hasattr(vl_model, 'language_model'):
llm_model = vl_model.language_model
elif hasattr(vl_model, 'llm'):
llm_model = vl_model.llm
else:
raise AttributeError('Cannot find language model in loaded VLM.')
if hasattr(vl_model, 'tokenizer') and vl_model.tokenizer is not None:
tokenizer = vl_model.tokenizer
elif hasattr(vl_model, 'processor') and hasattr(vl_model.processor, 'tokenizer'):
tokenizer = vl_model.processor.tokenizer
except Exception:
# Fall back to the original text-only loading path if task-aware VLM
# loading is unavailable or the model is not a VLM.
vl_model = None
llm_model = None
tokenizer = None
if tokenizer is None:
tokenizer = AutoTokenizer.from_pretrained(model, trust_remote_code=True)
if llm_model is None:
llm_model = load_hf_from_pretrained(model, dtype=dtype, trust_remote_code=True)
return vl_model, llm_model, tokenizer, work_dir

Copilot uses AI. Check for mistakes.
def config_contains_keyword(config, keyword: str = 'experts') -> bool:
"""Recursively check whether any config key or string value contains the
given keyword."""

keyword = keyword.lower()

if hasattr(config, 'to_dict'):
config = config.to_dict()

def search(obj) -> bool:
if isinstance(obj, dict):
for key, value in obj.items():
if keyword in str(key).lower():
return True
if search(value):
return True

return False

Copy link

Copilot AI Apr 28, 2026

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

config_contains_keyword claims to recursively search config keys or string values, but the implementation only recurses into dict values and ignores lists/tuples/strings entirely. For HF configs, to_dict() commonly contains nested lists/dicts, so this can incorrectly return False and prevent MoE detection/conversion. Update search() to handle dict, list/tuple, and str (and optionally other primitive types via str(obj)), consistent with the docstring.

Suggested change
return False
return False
if isinstance(obj, (list, tuple)):
for item in obj:
if search(item):
return True
return False
if isinstance(obj, str):
return keyword in obj.lower()
if obj is None:
return False
return keyword in str(obj).lower()

Copilot uses AI. Check for mistakes.
return search(config)


def save_vl_model(vl_model, model_path, dst_path):
vl_model.save_pretrained(dst_path, safe_serialization=True)
candidate = [
Expand Down Expand Up @@ -50,7 +86,9 @@ def auto_awq(model: str,
device: str = 'cuda',
revision: str = None,
dtype: Literal['float16', 'bfloat16', 'auto'] = 'auto',
download_dir: str = None):
download_dir: str = None,
mod_skip_quant: list[str] | None = None,
calib_ds_req: bool = True):
"""Perform weight quantization using AWQ algorithm.

Args:
Expand All @@ -75,6 +113,8 @@ def auto_awq(model: str,
dtype (str): Data type for loading model weights and calib infer.
download_dir (str): Directory to download and load the weights,
default to the default cache directory of huggingface.
mod_skip_quant (list[str] | None): Module name substrings to skip during quantization.
calib_ds_req (bool): Whether the calibration dataset is required. Default to True.
"""
try_import_deeplink(device)
if not osp.exists(model):
Expand All @@ -83,47 +123,66 @@ def auto_awq(model: str,
from lmdeploy.utils import get_model
model = get_model(model, revision=revision, download_dir=download_dir)
model_path = model
vl_model, model, tokenizer, work_dir = calibrate(model,
calib_dataset,
calib_samples,
calib_seqlen,
work_dir,
device,
w_bits=w_bits,
w_group_size=w_group_size,
search_scale=search_scale,
dtype=dtype,
batch_size=batch_size)
if calib_ds_req:
vl_model, model, tokenizer, work_dir = calibrate(model,
calib_dataset,
calib_samples,
calib_seqlen,
work_dir,
device,
w_bits=w_bits,
w_group_size=w_group_size,
search_scale=search_scale,
dtype=dtype,
batch_size=batch_size)
input_stats = torch.load(osp.join(work_dir, 'inputs_stats.pth'), weights_only=True)
Copy link

Copilot AI Apr 28, 2026

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

input_stats is loaded here but never used (it gets reloaded inside the later if calib_ds_req: block). This is unnecessary IO and can noticeably slow down quantization for large stats files; consider removing this load or using the already-loaded input_stats later.

Suggested change
input_stats = torch.load(osp.join(work_dir, 'inputs_stats.pth'), weights_only=True)

Copilot uses AI. Check for mistakes.
else:
vl_model, model, tokenizer, work_dir = load_model(model, dtype, work_dir)

layer_type = LAYER_TYPE_MAP[type(model).__name__]
Copy link

Copilot AI Apr 28, 2026

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

When calib_ds_req=False, this code bypasses calibrate()'s supported-model validation and then does layer_type = LAYER_TYPE_MAP[type(model).__name__], which will raise a raw KeyError for unsupported/renamed model classes. Consider adding the same explicit check and user-facing RuntimeError message that calibrate() uses (or reusing calibrate()'s model-type validation) so failures are actionable.

Suggested change
layer_type = LAYER_TYPE_MAP[type(model).__name__]
model_type = type(model).__name__
if model_type not in LAYER_TYPE_MAP:
supported_model_types = ', '.join(sorted(LAYER_TYPE_MAP.keys()))
raise RuntimeError(
f'Unsupported model type: {model_type}. '
f'Supported model types are: {supported_model_types}.')
layer_type = LAYER_TYPE_MAP[model_type]

Copilot uses AI. Check for mistakes.
fc2fcs = FC_FCS_MAP[layer_type]
norm2fcs = NORM_FCS_MAP[layer_type]
input_stats = torch.load(osp.join(work_dir, 'inputs_stats.pth'), weights_only=True)
layers = collect_target_modules(model, layer_type)
fcs = {}
is_moe = (
'moe' in model.config.model_type.lower() or
config_contains_keyword(model.config, 'experts')
)
for l_name, layer in layers.items():
if is_moe:
convert_moe_parameters(model_path, layer)
name2fc = collect_target_modules(layer, nn.Linear, prefix=l_name)
Copy link

Copilot AI Apr 28, 2026

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

convert_moe_parameters(model_path, layer) is called once per decoder layer, but convert_moe_parameters recomputes the registered model name from model_path each time (which calls get_model_arch(...) and reads config). This becomes O(num_layers) config parsing overhead. Consider computing model_name once in auto_awq (or once in convert_moe_parameters via caching) and passing it down so conversion stays cheap.

Copilot uses AI. Check for mistakes.
fcs.update(name2fc)

if search_scale:
awq_ratios = input_stats['ratios']
act_scales = input_stats['absmean']
awq_layers(layers, fc2fcs, norm2fcs, act_scales, awq_ratios, w_group_size, device)
else:
act_scales = input_stats['absmax']
smooth_layers(layers, fc2fcs, norm2fcs, act_scales, w_group_size, device)
quant_weights(model, fcs, w_bits, w_sym, w_group_size, device)
if calib_ds_req:
fc2fcs = FC_FCS_MAP[layer_type]
norm2fcs = NORM_FCS_MAP[layer_type]
input_stats = torch.load(osp.join(work_dir, 'inputs_stats.pth'), weights_only=True)
if search_scale:
awq_ratios = input_stats['ratios']
act_scales = input_stats['absmean']
awq_layers(layers, fc2fcs, norm2fcs, act_scales, awq_ratios, w_group_size, device)
else:
act_scales = input_stats['absmax']
smooth_layers(layers, fc2fcs, norm2fcs, act_scales, w_group_size, device)

quant_weights(model, fcs, w_bits, w_sym, w_group_size, device, mod_skip_quant=mod_skip_quant)
quantization_config = dict(quant_method='awq',
version='gemm',
bits=w_bits,
group_size=w_group_size,
zero_point=not w_sym)
if mod_skip_quant:
quantization_config['modules_to_not_convert'] = list(mod_skip_quant)
model.config.update(dict(quantization_config=quantization_config))

if vl_model:
save_vl_model(vl_model, model_path, work_dir)
else:
model.save_pretrained(work_dir, safe_serialization=True)
# model.save_pretrained(work_dir, safe_serialization=True)
Copy link
Copy Markdown
Collaborator

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

May remove the unused code

model.save_pretrained(
work_dir,
safe_serialization=True,
max_shard_size='4GB',
)
tokenizer.save_pretrained(work_dir)


Expand Down
9 changes: 9 additions & 0 deletions lmdeploy/lite/apis/calibrate.py
Original file line number Diff line number Diff line change
Expand Up @@ -18,6 +18,9 @@
'QWenLMHeadModel': 'QWenBlock',
'Qwen2ForCausalLM': 'Qwen2DecoderLayer',
'Qwen3ForCausalLM': 'Qwen3DecoderLayer',
'Qwen3MoeForCausalLM': 'Qwen3MoeDecoderLayer',
'Qwen3_5ForCausalLM': 'Qwen3_5DecoderLayer',
'Qwen3_5MoeForCausalLM': 'Qwen3_5MoeDecoderLayer',
Copy link

Copilot AI Apr 28, 2026

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

These new entries use Qwen3_5ForCausalLM / Qwen3_5MoeForCausalLM, but elsewhere in the repo Qwen3.5 is represented as Qwen3_5ForConditionalGeneration / Qwen3_5MoeForConditionalGeneration (e.g. lmdeploy/vl/model/qwen3_5.py, turbomind/supported_models.py). With the current keys, type(model).__name__ will likely never match and calibration/quantization will fail. Please align the map keys with the actual architecture class names (and update the corresponding NORM_TYPE_MAP / HEAD_NAME_MAP entries too).

Copilot uses AI. Check for mistakes.
'BaiChuanForCausalLM': 'DecoderLayer', # Baichuan 7B
'BaichuanForCausalLM': 'DecoderLayer', # Baichuan2 7B
'LlamaForCausalLM': 'LlamaDecoderLayer',
Expand All @@ -39,6 +42,9 @@
'QWenLMHeadModel': 'RMSNorm',
'Qwen2ForCausalLM': 'Qwen2RMSNorm',
'Qwen3ForCausalLM': 'Qwen3RMSNorm',
'Qwen3MoeForCausalLM': 'Qwen3MoeRMSNorm',
'Qwen3_5ForCausalLM': 'Qwen3_5RMSNorm',
'Qwen3_5MoeForCausalLM': 'Qwen3_5MoeRMSNorm',
'BaiChuanForCausalLM': 'RMSNorm', # Baichuan 7B
'BaichuanForCausalLM': 'RMSNorm', # Baichuan2 7B
'LlamaForCausalLM': 'LlamaRMSNorm',
Expand All @@ -60,6 +66,9 @@
'QWenLMHeadModel': 'lm_head',
'Qwen2ForCausalLM': 'lm_head',
'Qwen3ForCausalLM': 'lm_head',
'Qwen3MoeForCausalLM': 'lm_head',
'Qwen3_5ForCausalLM': 'lm_head',
'Qwen3_5MoeForCausalLM': 'lm_head',
'BaiChuanForCausalLM': 'lm_head', # Baichuan 7B
'BaichuanForCausalLM': 'lm_head', # Baichuan2 7B
'LlamaForCausalLM': 'lm_head',
Expand Down
3 changes: 3 additions & 0 deletions lmdeploy/lite/mlp_moe_modules/__init__.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,3 @@
# Copyright (c) OpenMMLab. All rights reserved.
from .mixtral import MixtralMoeMLP # noqa: F401
Copy link
Copy Markdown
Collaborator

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

"mlp_moe_modules" -> "moe_mlp_modules"

from .qwen import QwenMoeMLP # noqa: F401
5 changes: 5 additions & 0 deletions lmdeploy/lite/mlp_moe_modules/base.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,5 @@
# Copyright (c) OpenMMLab. All rights reserved.

from mmengine import Registry

CONVERT_MOE_MODELS = Registry('mlp moe module', locations=['lmdeploy.lite.mlp_moe_modules.base'])
Copy link

Copilot AI Apr 28, 2026

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

CONVERT_MOE_MODELS is initialized with locations=['lmdeploy.lite.mlp_moe_modules.base'], but the actual registrations are in lmdeploy/lite/mlp_moe_modules/qwen.py and mixtral.py. Since base.py doesn't import those modules, the registry will stay empty unless callers import lmdeploy.lite.mlp_moe_modules elsewhere, causing CONVERT_MOE_MODELS.get(...) to return None and MoE conversion to silently never run. Consider changing locations to ['lmdeploy.lite.mlp_moe_modules'] (package) or importing the concrete modules in base.py/package init so the registrations are guaranteed to execute.

Suggested change
CONVERT_MOE_MODELS = Registry('mlp moe module', locations=['lmdeploy.lite.mlp_moe_modules.base'])
CONVERT_MOE_MODELS = Registry('mlp moe module', locations=['lmdeploy.lite.mlp_moe_modules'])

Copilot uses AI. Check for mistakes.
23 changes: 23 additions & 0 deletions lmdeploy/lite/mlp_moe_modules/mixtral.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,23 @@
# Copyright (c) OpenMMLab. All rights reserved.

import torch
import torch.nn as nn

from .base import CONVERT_MOE_MODELS


@CONVERT_MOE_MODELS.register_module(name='mixtral')
class MixtralMoeMLP(nn.Module):
"""Use unfused MoE expert MLP after splitting fused expert weights."""

def __init__(self, hidden_size, intermediate_size, dtype=None, device=None):
super().__init__()
self.w1 = nn.Linear(hidden_size, intermediate_size, bias=False, dtype=dtype, device=device)
self.w3 = nn.Linear(hidden_size, intermediate_size, bias=False, dtype=dtype, device=device)
self.w2 = nn.Linear(intermediate_size, hidden_size, bias=False, dtype=dtype, device=device)

def load_weight(self, w1_weight: torch.Tensor, w2_weight: torch.Tensor, w3_weight: torch.Tensor):
"""Load weights for the MoE expert MLP."""
self.w1.weight = nn.Parameter(w1_weight.detach(), requires_grad=False)
self.w2.weight = nn.Parameter(w2_weight.detach(), requires_grad=False)
self.w3.weight = nn.Parameter(w3_weight.detach(), requires_grad=False)
24 changes: 24 additions & 0 deletions lmdeploy/lite/mlp_moe_modules/qwen.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,24 @@
# Copyright (c) OpenMMLab. All rights reserved.

import torch
import torch.nn as nn

from .base import CONVERT_MOE_MODELS


@CONVERT_MOE_MODELS.register_module(name='qwen3-moe')
@CONVERT_MOE_MODELS.register_module(name='qwen3_5-moe')
class QwenMoeMLP(nn.Module):
"""Use unfused MoE expert MLP after splitting fused expert weights."""

def __init__(self, hidden_size, intermediate_size, dtype=None, device=None):
super().__init__()
self.gate_proj = nn.Linear(hidden_size, intermediate_size, bias=False, dtype=dtype, device=device)
self.up_proj = nn.Linear(hidden_size, intermediate_size, bias=False, dtype=dtype, device=device)
self.down_proj = nn.Linear(intermediate_size, hidden_size, bias=False, dtype=dtype, device=device)

def load_weight(self, gate_proj_weight: torch.Tensor, down_proj_weight: torch.Tensor, up_proj_weight: torch.Tensor):
"""Load weights for the MoE expert MLP."""
self.gate_proj.weight = nn.Parameter(gate_proj_weight.detach(), requires_grad=False)
self.up_proj.weight = nn.Parameter(up_proj_weight.detach(), requires_grad=False)
self.down_proj.weight = nn.Parameter(down_proj_weight.detach(), requires_grad=False)
32 changes: 24 additions & 8 deletions lmdeploy/lite/quantization/awq.py
Original file line number Diff line number Diff line change
Expand Up @@ -32,6 +32,10 @@
'input_layernorm': ['self_attn.k_proj', 'self_attn.q_proj', 'self_attn.v_proj'],
'post_attention_layernorm': ['mlp.gate_proj', 'mlp.up_proj']
},
'Qwen3MoeDecoderLayer': {
'input_layernorm': ['self_attn.k_proj', 'self_attn.q_proj', 'self_attn.v_proj'],
'post_attention_layernorm': ['mlp.gate_proj', 'mlp.up_proj']
},
Comment on lines +37 to +40
Copy link

Copilot AI Apr 28, 2026

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

NORM_FCS_MAP adds support for Qwen3MoeDecoderLayer, but FC_FCS_MAP has no corresponding Qwen3MoeDecoderLayer entry. In auto_awq with calib_ds_req=True (default), it does fc2fcs = FC_FCS_MAP[layer_type], so Qwen3-MoE will currently hit a KeyError during smoothing/AWQ. Similarly, the newly added Qwen3.5 layer types in calibrate.py need entries in both maps to work with calibration-based AWQ.

Copilot uses AI. Check for mistakes.
Comment on lines +37 to +40
'DecoderLayer': {
'input_layernorm': ['self_attn.W_pack'],
'post_attention_layernorm': ['mlp.gate_proj', 'mlp.up_proj']
Expand Down Expand Up @@ -121,14 +125,26 @@
}
}

SKIPPED_MODULE = ['lora', 'block_sparse_moe.gate']
SKIPPED_MODULE = ['lora', 'block_sparse_moe.gate', 'mlp.gate']

def match_builtin_skkiped_pattern(name: str, pattern: str):
if pattern == 'lora':
return pattern in name
return name == pattern or name.endswith(f'.{pattern}') or f'.{pattern}.' in name
Copy link

Copilot AI Apr 28, 2026

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Typo in helper name: match_builtin_skkiped_pattern has an extra 'k' ("skkiped") and also contains a double space in if pattern == 'lora':. Renaming to match_builtin_skipped_pattern (and updating call sites) would improve readability and avoid propagating the typo into future usages.

Copilot uses AI. Check for mistakes.

def skipped_module(name: str):
"""Whether the module should be skipped from quantization."""
for m in SKIPPED_MODULE:
if m in name:
return True
def skipped_module(name: str, extra_patterns=None):
"""Whether the module should be skipped from quantization.

Args:
name: The fully-qualified module name.
extra_patterns: Optional iterable of additional substring patterns
(e.g. user-provided ``mod_skip_quant``). Merged with the
built-in ``SKIPPED_MODULE`` list.
"""
if any(match_builtin_skkiped_pattern(name, pattern) for pattern in SKIPPED_MODULE):
return True
if extra_patterns and any(pattern in name for pattern in extra_patterns):
return True
return False


Expand Down Expand Up @@ -294,7 +310,7 @@ def check_awq_supported(layer_type):
raise NotImplementedError


def quant_weights(model, fcs, bits, symmetry, group_size=-1, device='cuda'):
def quant_weights(model, fcs, bits, symmetry, group_size=-1, device='cuda', mod_skip_quant=None):
"""Quantize the weights of the target model's linear layers."""
from lmdeploy.lite.quantization import WeightQuantizer
from lmdeploy.lite.quantization.modules import WeightOnlyQLinear
Expand All @@ -304,7 +320,7 @@ def quant_weights(model, fcs, bits, symmetry, group_size=-1, device='cuda'):
parent_name, _, child_name = name.rpartition('.')
parent = model.get_submodule(parent_name)
pack_or_skip = 'packed'
if skipped_module(name):
if skipped_module(name, mod_skip_quant):
q_linear = fc
pack_or_skip = 'skipped'
else:
Expand Down
4 changes: 3 additions & 1 deletion lmdeploy/lite/utils/__init__.py
Original file line number Diff line number Diff line change
Expand Up @@ -13,12 +13,14 @@
)
from .calib_dataloader import get_calib_loaders
from .collect import bimap_name_mod, collect_target_modules, collect_target_weights
from .convert_moe_params import convert_moe_parameters
from .global_avail import GlobalAvailMixin
from .load import load_hf_from_pretrained

__all__ = [
'cal_qparams_per_channel_absmax', 'cal_qparams_per_channel_minmax', 'cal_qparams_per_group_absmax',
'cal_qparams_per_group_minmax', 'cal_qparams_per_tensor_absmax', 'cal_qparams_per_tensor_minmax', 'QParams',
'get_calib_loaders', 'collect_target_modules', 'precise_round', 'collect_target_weights', 'GlobalAvailMixin',
'split_decoder_layer_inputs', 'bimap_name_mod', 'concat_decoder_layer_outputs', 'load_hf_from_pretrained'
'split_decoder_layer_inputs', 'bimap_name_mod', 'concat_decoder_layer_outputs', 'load_hf_from_pretrained',
'convert_moe_parameters'
]
5 changes: 4 additions & 1 deletion lmdeploy/lite/utils/batch_split.py
Original file line number Diff line number Diff line change
Expand Up @@ -47,7 +47,10 @@ def split_decoder_layer_inputs(batch_size, *args: torch.Tensor | Any,
elif isinstance(val, torch.Tensor) and len(val.shape) > 1 and val.size(1) == bs: # qwen2-vl
new_kwargs[name] = val[:, i:i + batch_size]
elif name == 'position_embeddings' and isinstance(val, tuple) and len(
val[0].shape) > 1 and val[0].size(1) == bs: # qwen2-vl
val[0].shape) < 4 and val[0].size(0) == bs: # qwen3_5
new_kwargs[name] = (val[0][i:i + batch_size], val[1][i:i + batch_size])
elif name == 'position_embeddings' and isinstance(val, tuple) and len(
val[0].shape) >= 4 and val[0].size(1) == bs: # qwen2-vl
new_kwargs[name] = (val[0][:, i:i + batch_size], val[1][:, i:i + batch_size])
else:
new_kwargs[name] = val
Expand Down
Loading
Loading