From 70980775fabd1dee24d45765d415be442e056fbf Mon Sep 17 00:00:00 2001 From: huangyushi Date: Wed, 21 Aug 2024 06:05:39 +0800 Subject: [PATCH] Support SLM, e.g., Phi, Qwen2, Gemma2, Internlm2.5, MiniCPM, SmolLM, StableLm --- .../base_blockwise_quantization.py | 3 +- llmc/compression/quantization/module_utils.py | 39 ++++++++- llmc/compression/quantization/quarot.py | 19 ++++- llmc/models/__init__.py | 4 + llmc/models/gemma2.py | 30 +++++++ llmc/models/internlm2.py | 11 +++ llmc/models/minicpm.py | 84 +++++++++++++++++++ llmc/models/phi.py | 63 ++++++++++++++ llmc/models/smollm.py | 81 ++++++++++++++++++ llmc/models/stablelm.py | 81 ++++++++++++++++++ 10 files changed, 410 insertions(+), 5 deletions(-) create mode 100644 llmc/models/minicpm.py create mode 100644 llmc/models/phi.py create mode 100644 llmc/models/smollm.py create mode 100644 llmc/models/stablelm.py diff --git a/llmc/compression/quantization/base_blockwise_quantization.py b/llmc/compression/quantization/base_blockwise_quantization.py index d5f6b1f1..6c297394 100644 --- a/llmc/compression/quantization/base_blockwise_quantization.py +++ b/llmc/compression/quantization/base_blockwise_quantization.py @@ -748,7 +748,8 @@ def fuse_ln_fcs(self, ln, fcs): fc.bias = torch.nn.Parameter( torch.zeros(fc.out_features, dtype=torch.float64) ) - fc.bias.data = fc.bias.data.double() + torch.matmul(W, ln.bias.double()) + fc.bias.data = fc.bias.data.double().to(device=W.device) \ + + torch.matmul(W, ln.bias.double()) fc.bias.data = fc.bias.data.to(fc_dtype) def remove_mean_from_embed(self): diff --git a/llmc/compression/quantization/module_utils.py b/llmc/compression/quantization/module_utils.py index 6d543568..ea5c01f5 100644 --- a/llmc/compression/quantization/module_utils.py +++ b/llmc/compression/quantization/module_utils.py @@ -9,6 +9,17 @@ from transformers.models.mistral.modeling_mistral import MistralRMSNorm from transformers.models.mixtral.modeling_mixtral import MixtralRMSNorm from transformers.models.qwen2.modeling_qwen2 import Qwen2RMSNorm + +try: + from transformers.models.gemma2.modeling_gemma2 import Gemma2RMSNorm +except Exception: + logger.info( + 'Gemma2RMSNorm not installed. ' + 'If you need it, please update your transformers lib.' + ) + + class Gemma2RMSNorm(nn.Module): + pass from transformers.pytorch_utils import ALL_LAYERNORM_LAYERS try: @@ -17,7 +28,7 @@ from .hadamard_utils import matmul_hadU_cuda except Exception: logger.info( - 'fast_hadamard_transform not installed.' + 'fast_hadamard_transform not installed. ' 'If you need it, please install it firstly.' ) @@ -122,7 +133,10 @@ def forward(self, hidden_states): @classmethod @torch.no_grad() def new(cls, module): - eps = module.variance_epsilon + if hasattr(module, 'eps'): + eps = module.eps + else: + eps = module.variance_epsilon weight = module.weight new_module = cls(weight, eps) return new_module @@ -163,6 +177,22 @@ def __repr__(self): return 'LlmcInternLM2RMSNorm()' +class LlmcGemma2RMSNorm(LlmcLlamaRMSNorm): + def __init__(self, weight, eps=1e-6): + super().__init__(weight, eps) + + def __repr__(self): + return 'LlmcGemma2RMSNorm()' + + +class LlmcMiniCPMRMSNorm(LlmcLlamaRMSNorm): + def __init__(self, weight, eps=1e-6): + super().__init__(weight, eps) + + def __repr__(self): + return 'LlmcMiniCPMRMSNorm()' + + class OriginFloatLinear(nn.Module): def __init__(self, weight, bias, ori_module): super().__init__() @@ -616,6 +646,7 @@ def __repr__(self): MixtralRMSNorm, Qwen2RMSNorm, LlamaRMSNorm, + Gemma2RMSNorm, nn.LayerNorm, ] _TRANSFORMERS_LINEAR_TYPES_ = [nn.Linear] @@ -627,6 +658,8 @@ def __repr__(self): 'Mixtral': LlmcMixtralRMSNorm, 'Interlm2': LlmcInternLM2RMSNorm, 'Qwen2': LlmcQwen2RMSNorm, + 'Gemma2': LlmcGemma2RMSNorm, + 'MiniCPM': LlmcMiniCPMRMSNorm, 'Starcoder': LlmcLayerNorm, 'Opt': LlmcLayerNorm, 'Bloom': LlmcLayerNorm, @@ -641,6 +674,8 @@ def __repr__(self): LlmcMistralRMSNorm, LlmcMixtralRMSNorm, LlmcInternLM2RMSNorm, + LlmcGemma2RMSNorm, + LlmcMiniCPMRMSNorm, ] diff --git a/llmc/compression/quantization/quarot.py b/llmc/compression/quantization/quarot.py index 386bc206..fcd35bf1 100644 --- a/llmc/compression/quantization/quarot.py +++ b/llmc/compression/quantization/quarot.py @@ -21,7 +21,9 @@ def __init__(self, model, quant_config, input, config): self.preprocess() def preprocess(self): - assert self.config['model']['type'] in ['Opt', 'Llama', 'Qwen2'] + assert self.config['model']['type'] in [ + 'Opt', 'Llama', 'Qwen2', 'InternLM2', + 'MiniCPM', 'StableLm', 'SmolLM'] # if self.config["model"]["type"] in ["Opt"]: if torch.equal( self.model.get_head_layers()[0].weight, @@ -83,6 +85,16 @@ def block_transform(self, block): logger.info(f'block:{block}') logger.info(f'End transform the {self.block_idx+1}-th block') + def bake_mean_into_linear(self, linear): + linear_dtype = linear.weight.dtype + W_ = linear.weight.data.double() + linear.weight.data = W_ - W_.mean(dim=-2, keepdim=True) + linear.weight.data = linear.weight.data.to(linear_dtype) + if linear.bias is not None: + b_ = linear.bias.data.double() + linear.bias.data = b_ - b_.mean() + linear.bias.data = linear.bias.data.to(linear_dtype) + @torch.no_grad() def subset_transform(self, block, subset): prev_op = subset['prev_op'] @@ -97,7 +109,7 @@ def subset_transform(self, block, subset): self.fuse_ln_fcs(prev_op[0], layers) self.rotate_pre_layers(layers, self.Q) else: - if self.config['model']['type'] in ['Opt']: + if self.config['model']['type'] in ['Opt', 'StableLm']: self.bake_mean_into_linear(layers[0]) if 'is_mlp' in subset and subset['is_mlp']: @@ -105,6 +117,9 @@ def subset_transform(self, block, subset): layers, self.Q, exact_had=True if self.online_rotate else False ) else: + for n, m in layers_dict.items(): + logger.info(f'layer: {n} {m.weight.shape}') + logger.info(f'{self.Q.shape}') self.rotate_post_layers(layers, self.Q, exact_had=False) if self.online_rotate: apply_exact_had_to_linear( diff --git a/llmc/models/__init__.py b/llmc/models/__init__.py index 0fbf1d53..85aba3b1 100644 --- a/llmc/models/__init__.py +++ b/llmc/models/__init__.py @@ -4,8 +4,12 @@ from .internlm2 import InternLM2 from .llama import Llama from .llava import Llava +from .minicpm import MiniCPM from .mistral import Mistral from .mixtral import Mixtral from .opt import Opt +from .phi import Phi from .qwen2 import Qwen2 +from .smollm import SmolLM +from .stablelm import StableLm from .starcoder import Starcoder diff --git a/llmc/models/gemma2.py b/llmc/models/gemma2.py index b4696f92..775735b8 100644 --- a/llmc/models/gemma2.py +++ b/llmc/models/gemma2.py @@ -1,12 +1,34 @@ +from loguru import logger + from llmc.utils.registry_factory import MODEL_REGISTRY +try: + from transformers.models.gemma2.modeling_gemma2 import Gemma2RMSNorm +except Exception: + logger.warning('Gemma2 not found') +from types import MethodType + +import torch.nn as nn + from .base_model import BaseModel +def gemma2_rms_norm_forward(self, x): + output = self._norm(x.float()) + output = output * self.weight.float() + return output.type_as(x) + + @MODEL_REGISTRY class Gemma2(BaseModel): def __init__(self, model_path, torch_dtype): super().__init__(model_path, torch_dtype) + for m in self.model.modules(): + if isinstance(m, Gemma2RMSNorm): + w = m.weight.data + del m.weight + m.weight = nn.Parameter(w + 1.0) + m.forward = MethodType(gemma2_rms_norm_forward, m) def find_blocks(self): self.blocks = self.model.model.layers @@ -21,6 +43,12 @@ def find_block_name(self): def get_embed_layers(self): return [self.embed_tokens] + def get_head_layers(self): + return [self.model.lm_head] + + def get_pre_head_layernorm_layers(self): + return [self.model.model.norm] + def get_layers_except_blocks(self): return [self.embed_tokens, self.model.model.norm, self.model.lm_head] @@ -62,6 +90,7 @@ def get_subsets_in_block(self, block): 'input': ['mlp.gate_proj'], 'inspect': block.mlp, 'has_kwargs': False, + 'is_mlp': True, }, { 'layers': {'mlp.down_proj': block.mlp.down_proj}, @@ -69,5 +98,6 @@ def get_subsets_in_block(self, block): 'input': ['mlp.down_proj'], 'inspect': block.mlp.down_proj, 'has_kwargs': False, + 'is_mlp': True, }, ] diff --git a/llmc/models/internlm2.py b/llmc/models/internlm2.py index 5e17c0d9..4e3ed49c 100644 --- a/llmc/models/internlm2.py +++ b/llmc/models/internlm2.py @@ -1,3 +1,4 @@ +from llmc.compression.quantization.module_utils import _TRANSFORMERS_LN_TYPES_ from llmc.utils.registry_factory import MODEL_REGISTRY from .base_model import BaseModel @@ -7,6 +8,8 @@ class InternLM2(BaseModel): def __init__(self, model_path, torch_dtype): super().__init__(model_path, torch_dtype) + global _TRANSFORMERS_LN_TYPES_ + _TRANSFORMERS_LN_TYPES_ += [type(self.model.model.norm)] def find_blocks(self): self.blocks = self.model.model.layers @@ -20,6 +23,12 @@ def find_block_name(self): def get_embed_layers(self): return [self.tok_embeddings] + def get_head_layers(self): + return [self.model.output] + + def get_pre_head_layernorm_layers(self): + return [self.model.model.norm] + def get_layers_except_blocks(self): return [self.tok_embeddings, self.model.model.norm, self.model.output] @@ -57,6 +66,7 @@ def get_subsets_in_block(self, block): 'input': ['feed_forward.w1'], 'inspect': block.feed_forward, 'has_kwargs': False, + 'is_mlp': True, }, { 'layers': {'feed_forward.w2': block.feed_forward.w2}, @@ -64,5 +74,6 @@ def get_subsets_in_block(self, block): 'input': ['feed_forward.w2'], 'inspect': block.feed_forward.w2, 'has_kwargs': False, + 'is_mlp': True, }, ] diff --git a/llmc/models/minicpm.py b/llmc/models/minicpm.py new file mode 100644 index 00000000..1f1334a4 --- /dev/null +++ b/llmc/models/minicpm.py @@ -0,0 +1,84 @@ +from llmc.compression.quantization.module_utils import _TRANSFORMERS_LN_TYPES_ +from llmc.utils.registry_factory import MODEL_REGISTRY + +from .base_model import BaseModel + + +@MODEL_REGISTRY +class MiniCPM(BaseModel): + def __init__(self, model_path, torch_dtype): + super().__init__(model_path, torch_dtype) + global _TRANSFORMERS_LN_TYPES_ + _TRANSFORMERS_LN_TYPES_ += [type(self.model.model.norm)] + + def find_blocks(self): + self.blocks = self.model.model.layers + + def find_embed_layers(self): + self.embed_tokens = self.model.model.embed_tokens + + def find_block_name(self): + self.block_name_prefix = 'model.layers' + self.pairs = {'q_proj': 'qkv', 'o_proj': 'out', 'up_proj': 'fc1'} + + def get_embed_layers(self): + return [self.embed_tokens] + + def get_head_layers(self): + return [self.model.lm_head] + + def get_pre_head_layernorm_layers(self): + return [self.model.model.norm] + + def get_layers_except_blocks(self): + return [self.embed_tokens, self.model.model.norm, self.model.lm_head] + + def has_bias(self): + return False + + def get_layernorms_in_block(self, block): + return { + 'input_layernorm': block.input_layernorm, + 'post_attention_layernorm': block.post_attention_layernorm, + } + + def get_subsets_in_block(self, block): + return [ + { + 'layers': { + 'self_attn.q_proj': block.self_attn.q_proj, + 'self_attn.k_proj': block.self_attn.k_proj, + 'self_attn.v_proj': block.self_attn.v_proj, + }, + 'prev_op': [block.input_layernorm], + 'input': ['self_attn.q_proj'], + 'inspect': block.self_attn, + 'has_kwargs': True, + }, + { + 'layers': {'self_attn.o_proj': block.self_attn.o_proj}, + 'prev_op': [block.self_attn.v_proj], + 'input': ['self_attn.o_proj'], + 'inspect': block.self_attn.o_proj, + 'has_kwargs': False, + }, + { + 'layers': { + 'mlp.gate_proj': block.mlp.gate_proj, + 'mlp.up_proj': block.mlp.up_proj, + }, + 'prev_op': [block.post_attention_layernorm], + 'input': ['mlp.gate_proj'], + 'inspect': block.mlp, + 'has_kwargs': False, + 'is_mlp': True, + }, + { + 'layers': {'mlp.down_proj': block.mlp.down_proj}, + 'prev_op': [block.mlp.up_proj], + 'input': ['mlp.down_proj'], + 'inspect': block.mlp.down_proj, + 'has_kwargs': False, + 'is_mlp': True, + }, + ] diff --git a/llmc/models/phi.py b/llmc/models/phi.py new file mode 100644 index 00000000..7a4e6eb3 --- /dev/null +++ b/llmc/models/phi.py @@ -0,0 +1,63 @@ +from llmc.utils.registry_factory import MODEL_REGISTRY + +from .base_model import BaseModel + + +@MODEL_REGISTRY +class Phi(BaseModel): + def __init__(self, model_path, torch_dtype): + super().__init__(model_path, torch_dtype) + + def find_blocks(self): + self.blocks = self.model.model.layers + + def find_embed_layers(self): + self.embed_tokens = self.model.model.embed_tokens + + def find_block_name(self): + self.block_name_prefix = 'model.layers' + self.pairs = {'q_proj': 'qkv', 'o_proj': 'out', 'up_proj': 'fc1'} + + def get_embed_layers(self): + return [self.embed_tokens] + + def get_head_layers(self): + return [self.model.lm_head] + + def get_pre_head_layernorm_layers(self): + return [self.model.model.final_layernorm] + + def get_layers_except_blocks(self): + return [self.embed_tokens, self.model.model.final_layernorm, self.model.lm_head] + + def has_bias(self): + return False + + def get_layernorms_in_block(self, block): + return { + 'input_layernorm': block.input_layernorm, + 'post_attention_layernorm': block.input_layernorm, + } + + def get_subsets_in_block(self, block): + return [ + { + 'layers': { + 'self_attn.q_proj': block.self_attn.q_proj, + 'self_attn.k_proj': block.self_attn.k_proj, + 'self_attn.v_proj': block.self_attn.v_proj, + 'mlp.fc1': block.mlp.fc1, + }, + 'prev_op': [block.input_layernorm], + 'input': ['self_attn.q_proj'], + 'inspect': block, + 'has_kwargs': True, + }, + { + 'layers': {'self_attn.dense': block.self_attn.dense}, + 'prev_op': [block.self_attn.v_proj], + 'input': ['self_attn.dense'], + 'inspect': block.self_attn.dense, + 'has_kwargs': False, + }, + ] diff --git a/llmc/models/smollm.py b/llmc/models/smollm.py new file mode 100644 index 00000000..cee7aaf2 --- /dev/null +++ b/llmc/models/smollm.py @@ -0,0 +1,81 @@ +from llmc.utils.registry_factory import MODEL_REGISTRY + +from .base_model import BaseModel + + +@MODEL_REGISTRY +class SmolLM(BaseModel): + def __init__(self, model_path, torch_dtype): + super().__init__(model_path, torch_dtype) + + def find_blocks(self): + self.blocks = self.model.model.layers + + def find_embed_layers(self): + self.embed_tokens = self.model.model.embed_tokens + + def find_block_name(self): + self.block_name_prefix = 'model.layers' + self.pairs = {'q_proj': 'qkv', 'o_proj': 'out', 'up_proj': 'fc1'} + + def get_embed_layers(self): + return [self.embed_tokens] + + def get_head_layers(self): + return [self.model.lm_head] + + def get_pre_head_layernorm_layers(self): + return [self.model.model.norm] + + def get_layers_except_blocks(self): + return [self.embed_tokens, self.model.model.norm, self.model.lm_head] + + def has_bias(self): + return False + + def get_layernorms_in_block(self, block): + return { + 'input_layernorm': block.input_layernorm, + 'post_attention_layernorm': block.post_attention_layernorm, + } + + def get_subsets_in_block(self, block): + return [ + { + 'layers': { + 'self_attn.q_proj': block.self_attn.q_proj, + 'self_attn.k_proj': block.self_attn.k_proj, + 'self_attn.v_proj': block.self_attn.v_proj, + }, + 'prev_op': [block.input_layernorm], + 'input': ['self_attn.q_proj'], + 'inspect': block.self_attn, + 'has_kwargs': True, + }, + { + 'layers': {'self_attn.o_proj': block.self_attn.o_proj}, + 'prev_op': [block.self_attn.v_proj], + 'input': ['self_attn.o_proj'], + 'inspect': block.self_attn.o_proj, + 'has_kwargs': False, + }, + { + 'layers': { + 'mlp.gate_proj': block.mlp.gate_proj, + 'mlp.up_proj': block.mlp.up_proj, + }, + 'prev_op': [block.post_attention_layernorm], + 'input': ['mlp.gate_proj'], + 'inspect': block.mlp, + 'has_kwargs': False, + 'is_mlp': True, + }, + { + 'layers': {'mlp.down_proj': block.mlp.down_proj}, + 'prev_op': [block.mlp.up_proj], + 'input': ['mlp.down_proj'], + 'inspect': block.mlp.down_proj, + 'has_kwargs': False, + 'is_mlp': True, + }, + ] diff --git a/llmc/models/stablelm.py b/llmc/models/stablelm.py new file mode 100644 index 00000000..640c800c --- /dev/null +++ b/llmc/models/stablelm.py @@ -0,0 +1,81 @@ +from llmc.utils.registry_factory import MODEL_REGISTRY + +from .base_model import BaseModel + + +@MODEL_REGISTRY +class StableLm(BaseModel): + def __init__(self, model_path, torch_dtype): + super().__init__(model_path, torch_dtype) + + def find_blocks(self): + self.blocks = self.model.model.layers + + def find_embed_layers(self): + self.embed_tokens = self.model.model.embed_tokens + + def find_block_name(self): + self.block_name_prefix = 'model.layers' + self.pairs = {'q_proj': 'qkv', 'o_proj': 'out', 'up_proj': 'fc1'} + + def get_embed_layers(self): + return [self.embed_tokens] + + def get_head_layers(self): + return [self.model.lm_head] + + def get_pre_head_layernorm_layers(self): + return [self.model.model.norm] + + def get_layers_except_blocks(self): + return [self.embed_tokens, self.model.model.norm, self.model.lm_head] + + def has_bias(self): + return False + + def get_layernorms_in_block(self, block): + return { + 'input_layernorm': block.input_layernorm, + 'post_attention_layernorm': block.post_attention_layernorm, + } + + def get_subsets_in_block(self, block): + return [ + { + 'layers': { + 'self_attn.q_proj': block.self_attn.q_proj, + 'self_attn.k_proj': block.self_attn.k_proj, + 'self_attn.v_proj': block.self_attn.v_proj, + }, + 'prev_op': [block.input_layernorm], + 'input': ['self_attn.q_proj'], + 'inspect': block.self_attn, + 'has_kwargs': True, + }, + { + 'layers': {'self_attn.o_proj': block.self_attn.o_proj}, + 'prev_op': [block.self_attn.v_proj], + 'input': ['self_attn.o_proj'], + 'inspect': block.self_attn.o_proj, + 'has_kwargs': False, + }, + { + 'layers': { + 'mlp.gate_proj': block.mlp.gate_proj, + 'mlp.up_proj': block.mlp.up_proj, + }, + 'prev_op': [block.post_attention_layernorm], + 'input': ['mlp.gate_proj'], + 'inspect': block.mlp, + 'has_kwargs': False, + 'is_mlp': True, + }, + { + 'layers': {'mlp.down_proj': block.mlp.down_proj}, + 'prev_op': [block.mlp.up_proj], + 'input': ['mlp.down_proj'], + 'inspect': block.mlp.down_proj, + 'has_kwargs': False, + 'is_mlp': True, + }, + ]