Skip to content
Open
Show file tree
Hide file tree
Changes from 2 commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
3 changes: 2 additions & 1 deletion lmdeploy/cli/utils.py
Original file line number Diff line number Diff line change
Expand Up @@ -407,7 +407,8 @@ def calib_samples(parser):
return parser.add_argument('--calib-samples',
type=int,
default=128,
help='The number of samples for calibration')
help='The number of samples for calibration. '
'Define 0 to indicate the data free quantization.')

@staticmethod
def calib_seqlen(parser):
Expand Down
66 changes: 43 additions & 23 deletions lmdeploy/lite/apis/auto_awq.py
Original file line number Diff line number Diff line change
Expand Up @@ -8,9 +8,11 @@
import torch
from torch import nn

from lmdeploy.lite.apis.calibrate import LAYER_TYPE_MAP, calibrate
from lmdeploy.lite.apis.calibrate import LAYER_TYPE_MAP, MOE_MODEL_LIST, calibrate, load_model_tokenizer
from lmdeploy.lite.moe_mlp_modules import CONVERT_MOE_MODELS
from lmdeploy.lite.quantization.awq import FC_FCS_MAP, NORM_FCS_MAP, awq_layers, quant_weights, smooth_layers
from lmdeploy.lite.utils import collect_target_modules
from lmdeploy.turbomind.deploy.converter import get_input_model_registered_name
from lmdeploy.utils import try_import_deeplink


Expand Down Expand Up @@ -59,6 +61,7 @@ def auto_awq(model: str,
calib_dataset (str): The calibration dataset name.
Defaults to 'wikitext2'.
calib_samples (int): The number of samples for calibration.
Define 0 to indicate the data free quantization.
batch_size (int): The batch size for running the calib samples.
Low GPU mem requires small batch_size. Large batch_size
reduces the calibration time while costs more VRAM.
Expand All @@ -83,41 +86,58 @@ def auto_awq(model: str,
from lmdeploy.utils import get_model
model = get_model(model, revision=revision, download_dir=download_dir)
model_path = model
vl_model, model, tokenizer, work_dir = calibrate(model,
calib_dataset,
calib_samples,
calib_seqlen,
work_dir,
device,
w_bits=w_bits,
w_group_size=w_group_size,
search_scale=search_scale,
dtype=dtype,
batch_size=batch_size)
if calib_samples == 0:
vl_model, model, tokenizer, work_dir, _ = load_model_tokenizer(model, dtype=dtype, work_dir=work_dir)
else:
vl_model, model, tokenizer, work_dir = calibrate(model,
calib_dataset,
calib_samples,
calib_seqlen,
work_dir,
device,
w_bits=w_bits,
w_group_size=w_group_size,
search_scale=search_scale,
dtype=dtype,
batch_size=batch_size)

layer_type = LAYER_TYPE_MAP[type(model).__name__]
Copy link

Copilot AI Apr 28, 2026

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

When calib_ds_req=False, this code bypasses calibrate()'s supported-model validation and then does layer_type = LAYER_TYPE_MAP[type(model).__name__], which will raise a raw KeyError for unsupported/renamed model classes. Consider adding the same explicit check and user-facing RuntimeError message that calibrate() uses (or reusing calibrate()'s model-type validation) so failures are actionable.

Suggested change
layer_type = LAYER_TYPE_MAP[type(model).__name__]
model_type = type(model).__name__
if model_type not in LAYER_TYPE_MAP:
supported_model_types = ', '.join(sorted(LAYER_TYPE_MAP.keys()))
raise RuntimeError(
f'Unsupported model type: {model_type}. '
f'Supported model types are: {supported_model_types}.')
layer_type = LAYER_TYPE_MAP[model_type]

Copilot uses AI. Check for mistakes.
fc2fcs = FC_FCS_MAP[layer_type]
norm2fcs = NORM_FCS_MAP[layer_type]
input_stats = torch.load(osp.join(work_dir, 'inputs_stats.pth'), weights_only=True)
layers = collect_target_modules(model, layer_type)
if not getattr(model.config, 'architectures', None):
Copy link
Copy Markdown
Collaborator

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Not a good way to identify if the model is Qwen3.5 or not.

# Qwen3.5 TurboMind quantization works on the text sub-model only, whose
# config may not contain `architectures`. Infer it from the loaded model class
# for downstream AWQ skip/convert logic.
model_arch = type(model).__name__
model.config.architectures = [model_arch]
fcs = {}
for l_name, layer in layers.items():
if model.config.architectures[0] in MOE_MODEL_LIST:
model_name = get_input_model_registered_name(model_path, 'awq')
Copy link
Copy Markdown
Collaborator

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Why check the arch for each layer?

CONVERT_MOE_MODELS.get(model_name)(layer)
name2fc = collect_target_modules(layer, nn.Linear, prefix=l_name)
fcs.update(name2fc)

if search_scale:
awq_ratios = input_stats['ratios']
act_scales = input_stats['absmean']
awq_layers(layers, fc2fcs, norm2fcs, act_scales, awq_ratios, w_group_size, device)
else:
act_scales = input_stats['absmax']
smooth_layers(layers, fc2fcs, norm2fcs, act_scales, w_group_size, device)
quant_weights(model, fcs, w_bits, w_sym, w_group_size, device)
if calib_samples != 0:
fc2fcs = FC_FCS_MAP[layer_type]
norm2fcs = NORM_FCS_MAP[layer_type]
input_stats = torch.load(osp.join(work_dir, 'inputs_stats.pth'), weights_only=True)
if search_scale:
awq_ratios = input_stats['ratios']
act_scales = input_stats['absmean']
awq_layers(layers, fc2fcs, norm2fcs, act_scales, awq_ratios, w_group_size, device)
else:
act_scales = input_stats['absmax']
smooth_layers(layers, fc2fcs, norm2fcs, act_scales, w_group_size, device)

matched_exclude_modules = quant_weights(model, fcs, w_bits, w_sym, w_group_size, device,
arch=model.config.architectures[0])
quantization_config = dict(quant_method='awq',
version='gemm',
bits=w_bits,
group_size=w_group_size,
zero_point=not w_sym)
if matched_exclude_modules:
quantization_config['modules_to_not_convert'] = matched_exclude_modules
model.config.update(dict(quantization_config=quantization_config))

if vl_model:
Expand Down
102 changes: 63 additions & 39 deletions lmdeploy/lite/apis/calibrate.py
Original file line number Diff line number Diff line change
Expand Up @@ -18,6 +18,9 @@
'QWenLMHeadModel': 'QWenBlock',
'Qwen2ForCausalLM': 'Qwen2DecoderLayer',
'Qwen3ForCausalLM': 'Qwen3DecoderLayer',
'Qwen3MoeForCausalLM': 'Qwen3MoeDecoderLayer',
'Qwen3_5ForCausalLM': 'Qwen3_5DecoderLayer',
'Qwen3_5MoeForCausalLM': 'Qwen3_5MoeDecoderLayer',
Copy link

Copilot AI Apr 28, 2026

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

These new entries use Qwen3_5ForCausalLM / Qwen3_5MoeForCausalLM, but elsewhere in the repo Qwen3.5 is represented as Qwen3_5ForConditionalGeneration / Qwen3_5MoeForConditionalGeneration (e.g. lmdeploy/vl/model/qwen3_5.py, turbomind/supported_models.py). With the current keys, type(model).__name__ will likely never match and calibration/quantization will fail. Please align the map keys with the actual architecture class names (and update the corresponding NORM_TYPE_MAP / HEAD_NAME_MAP entries too).

Copilot uses AI. Check for mistakes.
'BaiChuanForCausalLM': 'DecoderLayer', # Baichuan 7B
'BaichuanForCausalLM': 'DecoderLayer', # Baichuan2 7B
'LlamaForCausalLM': 'LlamaDecoderLayer',
Expand All @@ -39,6 +42,9 @@
'QWenLMHeadModel': 'RMSNorm',
'Qwen2ForCausalLM': 'Qwen2RMSNorm',
'Qwen3ForCausalLM': 'Qwen3RMSNorm',
'Qwen3MoeForCausalLM': 'Qwen3MoeRMSNorm',
'Qwen3_5ForCausalLM': 'Qwen3_5RMSNorm',
'Qwen3_5MoeForCausalLM': 'Qwen3_5MoeRMSNorm',
'BaiChuanForCausalLM': 'RMSNorm', # Baichuan 7B
'BaichuanForCausalLM': 'RMSNorm', # Baichuan2 7B
'LlamaForCausalLM': 'LlamaRMSNorm',
Expand All @@ -60,6 +66,9 @@
'QWenLMHeadModel': 'lm_head',
'Qwen2ForCausalLM': 'lm_head',
'Qwen3ForCausalLM': 'lm_head',
'Qwen3MoeForCausalLM': 'lm_head',
'Qwen3_5ForCausalLM': 'lm_head',
'Qwen3_5MoeForCausalLM': 'lm_head',
'BaiChuanForCausalLM': 'lm_head', # Baichuan 7B
'BaichuanForCausalLM': 'lm_head', # Baichuan2 7B
'LlamaForCausalLM': 'lm_head',
Expand All @@ -74,6 +83,12 @@
'MistralForCausalLM': 'lm_head',
}

MOE_MODEL_LIST = [
'Qwen3MoeForCausalLM',
'Qwen3_5MoeForCausalLM',
'MixtralForCausalLM'
]
Comment on lines +86 to +90


def _prepare_for_calibrate(model: nn.Module,
layer_type: str | type,
Expand Down Expand Up @@ -176,7 +191,6 @@ def update_moe_mapping(model, model_type):
if '{i}' in k:
break
num_experts = len(m.get_submodule(k.split('.{i}')[0]))
break

# update FC_FCS_MAP
updated_fc2fcs = dict()
Expand All @@ -195,6 +209,53 @@ def update_moe_mapping(model, model_type):
NORM_FCS_MAP[LAYER_TYPE_MAP[model_type]] = updated_norm2fcs


def load_model_tokenizer(model: str,
dtype: Literal['float16', 'bfloat16', 'auto'] = 'auto',
work_dir: str = './work_dir'):
"""Load model and tokenizer."""
model_type, _ = get_task(backend='turbomind', model_path=model)
make_compatible_internvl_config(model)

# Load tokenizer and configuration
tokenizer = AutoTokenizer.from_pretrained(model, trust_remote_code=True)

if model_type == 'llm':
model = load_hf_from_pretrained(model, dtype=dtype, trust_remote_code=True)
vl_model = None
elif model_type == 'vlm':
_, original_config = get_model_arch(model)
vl_model = load_vl_model(model, backend=None, with_llm=True).vl_model
model = vl_model
if hasattr(vl_model, 'language_model'): # deepseek-vl, ...
model = vl_model.language_model
if hasattr(vl_model, 'llm'): # MiniCPMV, ...
model = vl_model.llm
model.config.use_cache = False
if hasattr(model.config, 'text_config'):
model.config.text_config.use_cache = False
elif hasattr(model.config, 'llm_config'):
model.config.llm_config.use_cache = False
if dtype == 'float16' or (dtype == 'auto' and original_config.torch_dtype == torch.float16):
model.half()
elif dtype == 'bfloat16' or (dtype == 'auto' and original_config.torch_dtype == torch.bfloat16):
assert torch.cuda.is_bf16_supported(
), 'your device does not support bfloat16 please set --dtype float16' # noqa
model.to(torch.bfloat16)
model.eval()

model_type = type(model).__name__
if model_type not in LAYER_TYPE_MAP or model_type not in NORM_TYPE_MAP:
raise RuntimeError(f'Currently, quantification and calibration of {model_type} are '
f'not supported. The supported model types are '
f"{', '.join(LAYER_TYPE_MAP.keys())}.")

# Create work directory if not exists
work_dir = Path(work_dir)
work_dir.mkdir(parents=True, exist_ok=True)

return vl_model, model, tokenizer, work_dir, model_type
Copy link
Copy Markdown
Collaborator

@lvhan028 lvhan028 May 6, 2026

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

It's unnecessary to return "work_dir" since "work_dir" is passed as an argument and its value is not changed in the function



def calibrate(model: str,
calib_dataset: str = 'wikitext2',
calib_samples: int = 128,
Expand Down Expand Up @@ -241,41 +302,7 @@ def calibrate(model: str,
'Support only `wikitext2`, `c4`, `pileval`, `gsm8k`, ' \
'`neuralmagic_calibration`, `open-platypus`, `openwebtext`.'

model_type, _ = get_task(backend='turbomind', model_path=model)
make_compatible_internvl_config(model)

# Load tokenizer and configuration
tokenizer = AutoTokenizer.from_pretrained(model, trust_remote_code=True)

if model_type == 'llm':
model = load_hf_from_pretrained(model, dtype=dtype, trust_remote_code=True)
vl_model = None
elif model_type == 'vlm':
_, original_config = get_model_arch(model)
vl_model = load_vl_model(model, backend=None, with_llm=True).vl_model
model = vl_model
if hasattr(vl_model, 'language_model'): # deepseek-vl, ...
model = vl_model.language_model
if hasattr(vl_model, 'llm'): # MiniCPMV, ...
model = vl_model.llm
model.config.use_cache = False
if hasattr(model.config, 'text_config'):
model.config.text_config.use_cache = False
elif hasattr(model.config, 'llm_config'):
model.config.llm_config.use_cache = False
if dtype == 'float16' or (dtype == 'auto' and original_config.torch_dtype == torch.float16):
model.half()
elif dtype == 'bfloat16' or (dtype == 'auto' and original_config.torch_dtype == torch.bfloat16):
assert torch.cuda.is_bf16_supported(
), 'your device does not support bfloat16 please set --dtype float16' # noqa
model.to(torch.bfloat16)
model.eval()

model_type = type(model).__name__
if model_type not in LAYER_TYPE_MAP or model_type not in NORM_TYPE_MAP:
raise RuntimeError(f'Currently, quantification and calibration of {model_type} are '
f'not supported. The supported model types are '
f"{', '.join(LAYER_TYPE_MAP.keys())}.")
vl_model, model, tokenizer, work_dir, model_type = load_model_tokenizer(model, dtype=dtype, work_dir=work_dir)

if model_type in ['MixtralForCausalLM']:
update_moe_mapping(model, model_type)
Expand Down Expand Up @@ -319,9 +346,6 @@ def calibrate(model: str,
all_data = torch.cat(calib_loader).to(device)
calib_ctx.calibrate(all_data)

# Create work directory if not exists
work_dir = Path(work_dir)
work_dir.mkdir(parents=True, exist_ok=True)
calib_ctx.export(work_dir)

return vl_model, model, tokenizer, work_dir
Expand Down
4 changes: 2 additions & 2 deletions lmdeploy/lite/apis/smooth_quant.py
Original file line number Diff line number Diff line change
Expand Up @@ -96,7 +96,7 @@ def smooth_quant(model: str,
rmsnorms = collect_target_modules(model, norm_type)

for name, linear in fcs.items():
if skipped_module(name):
if skipped_module(name, model.config.architectures[0]):
continue
Comment on lines 98 to 100
linear.to(device)
q_linear = QLinear.from_float(linear, quant_dtype=quant_dtype)
Expand All @@ -108,7 +108,7 @@ def smooth_quant(model: str,
torch.cuda.empty_cache()

for name, norm in rmsnorms.items():
if skipped_module(name):
if skipped_module(name, model.config.architectures[0]):
continue
Comment on lines 110 to 112
norm.to(device)
q_norm = QRMSNorm.from_float(norm, quant_dtype=quant_dtype)
Expand Down
7 changes: 7 additions & 0 deletions lmdeploy/lite/moe_mlp_modules/__init__.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,7 @@
# Copyright (c) OpenMMLab. All rights reserved.

from .base import CONVERT_MOE_MODELS
from .mixtral import MixtralMoeMLP
from .qwen import QwenMoeMLP

__all__ = ['CONVERT_MOE_MODELS', 'MixtralMoeMLP', 'QwenMoeMLP']
5 changes: 5 additions & 0 deletions lmdeploy/lite/moe_mlp_modules/base.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,5 @@
# Copyright (c) OpenMMLab. All rights reserved.

from mmengine import Registry

CONVERT_MOE_MODELS = Registry('moe_mlp_module', locations=['lmdeploy.lite.moe_mlp_modules'])
86 changes: 86 additions & 0 deletions lmdeploy/lite/moe_mlp_modules/mixtral.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,86 @@
# Copyright (c) OpenMMLab. All rights reserved.

import torch
import torch.nn as nn

from .base import CONVERT_MOE_MODELS


class MixtralMoeMLP(nn.Module):
"""Use unfused MoE expert MLP after splitting fused expert weights."""

def __init__(self, hidden_size, intermediate_size, dtype=None, device=None):
super().__init__()
self.w1 = nn.Linear(hidden_size, intermediate_size, bias=False, dtype=dtype, device=device)
self.w3 = nn.Linear(hidden_size, intermediate_size, bias=False, dtype=dtype, device=device)
self.w2 = nn.Linear(intermediate_size, hidden_size, bias=False, dtype=dtype, device=device)

def load_weight(self, w1_weight: torch.Tensor, w2_weight: torch.Tensor, w3_weight: torch.Tensor):
"""Load weights for the MoE expert MLP."""
self.w1.weight = nn.Parameter(w1_weight.detach(), requires_grad=False)
self.w2.weight = nn.Parameter(w2_weight.detach(), requires_grad=False)
self.w3.weight = nn.Parameter(w3_weight.detach(), requires_grad=False)


@CONVERT_MOE_MODELS.register_module(name='mixtral')
class MixtralMoeMLPConverter:

MoEMLP = MixtralMoeMLP

def __init__(self, model: nn.Module):
self.convert_moe_parameters(model)

def convert_gate(self, gate_mod: nn.Module):
num_experts, hidden_size = gate_mod.weight.shape
dtype = gate_mod.weight.dtype
device = gate_mod.weight.device
gate = nn.Linear(hidden_size, num_experts, bias=False, dtype=dtype, device=device)
gate.weight = nn.Parameter(gate_mod.weight.data.detach(), requires_grad=False)
return gate

def convert_experts(self, experts_mod: nn.Module) -> nn.ModuleList:
"""Convert fused MoE expert weights into a ModuleList of MLP experts
without copying."""
num_experts, intermediate_size_2, hidden_size = experts_mod.gate_up_proj.shape
intermediate_size = intermediate_size_2 // 2

dtype = experts_mod.gate_up_proj.dtype

weight_gate_up = experts_mod.gate_up_proj.data
weight_down = experts_mod.down_proj.data

MoeExpert_list = nn.ModuleList()

for e in range(num_experts):
mod_mlp_instance = self.MoEMLP(
hidden_size=hidden_size,
intermediate_size=intermediate_size,
dtype=dtype,
device='meta'
)
mod_mlp_instance.load_weight(
weight_gate_up[e, :intermediate_size],
weight_down[e],
weight_gate_up[e, intermediate_size:]
)
MoeExpert_list.append(mod_mlp_instance)

return MoeExpert_list

def convert_moe_parameters(self, model: nn.Module):
"""Replace fused MoE experts with expert ModuleList if transformers >=
5.0."""
parent_target = next(
(mod for name, mod in model.named_modules() if name == 'mlp'),
None
)
if parent_target is None:
return

target = getattr(parent_target, 'experts', None)
if target is not None and not isinstance(target, nn.ModuleList):
parent_target.experts = self.convert_experts(target)

gate_target = getattr(parent_target, 'gate', None)
if gate_target is not None and not isinstance(gate_target, nn.Linear):
parent_target.gate = self.convert_gate(gate_target)
Comment on lines +70 to +83
Loading
Loading