Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Support SLM, e.g., Phi, Qwen2, Gemma2, Internlm2.5, MiniCPM, SmolLM, … #33

Merged
merged 1 commit into from
Aug 20, 2024
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
3 changes: 2 additions & 1 deletion llmc/compression/quantization/base_blockwise_quantization.py
Original file line number Diff line number Diff line change
Expand Up @@ -748,7 +748,8 @@ def fuse_ln_fcs(self, ln, fcs):
fc.bias = torch.nn.Parameter(
torch.zeros(fc.out_features, dtype=torch.float64)
)
fc.bias.data = fc.bias.data.double() + torch.matmul(W, ln.bias.double())
fc.bias.data = fc.bias.data.double().to(device=W.device) \
+ torch.matmul(W, ln.bias.double())
fc.bias.data = fc.bias.data.to(fc_dtype)

def remove_mean_from_embed(self):
Expand Down
39 changes: 37 additions & 2 deletions llmc/compression/quantization/module_utils.py
Original file line number Diff line number Diff line change
Expand Up @@ -9,6 +9,17 @@
from transformers.models.mistral.modeling_mistral import MistralRMSNorm
from transformers.models.mixtral.modeling_mixtral import MixtralRMSNorm
from transformers.models.qwen2.modeling_qwen2 import Qwen2RMSNorm

try:
from transformers.models.gemma2.modeling_gemma2 import Gemma2RMSNorm
except Exception:
logger.info(
'Gemma2RMSNorm not installed. '
'If you need it, please update your transformers lib.'
)

class Gemma2RMSNorm(nn.Module):
pass
from transformers.pytorch_utils import ALL_LAYERNORM_LAYERS

try:
Expand All @@ -17,7 +28,7 @@
from .hadamard_utils import matmul_hadU_cuda
except Exception:
logger.info(
'fast_hadamard_transform not installed.'
'fast_hadamard_transform not installed. '
'If you need it, please install it firstly.'
)

Expand Down Expand Up @@ -122,7 +133,10 @@ def forward(self, hidden_states):
@classmethod
@torch.no_grad()
def new(cls, module):
eps = module.variance_epsilon
if hasattr(module, 'eps'):
eps = module.eps
else:
eps = module.variance_epsilon
weight = module.weight
new_module = cls(weight, eps)
return new_module
Expand Down Expand Up @@ -163,6 +177,22 @@ def __repr__(self):
return 'LlmcInternLM2RMSNorm()'


class LlmcGemma2RMSNorm(LlmcLlamaRMSNorm):
def __init__(self, weight, eps=1e-6):
super().__init__(weight, eps)

def __repr__(self):
return 'LlmcGemma2RMSNorm()'


class LlmcMiniCPMRMSNorm(LlmcLlamaRMSNorm):
def __init__(self, weight, eps=1e-6):
super().__init__(weight, eps)

def __repr__(self):
return 'LlmcMiniCPMRMSNorm()'


class OriginFloatLinear(nn.Module):
def __init__(self, weight, bias, ori_module):
super().__init__()
Expand Down Expand Up @@ -616,6 +646,7 @@ def __repr__(self):
MixtralRMSNorm,
Qwen2RMSNorm,
LlamaRMSNorm,
Gemma2RMSNorm,
nn.LayerNorm,
]
_TRANSFORMERS_LINEAR_TYPES_ = [nn.Linear]
Expand All @@ -627,6 +658,8 @@ def __repr__(self):
'Mixtral': LlmcMixtralRMSNorm,
'Interlm2': LlmcInternLM2RMSNorm,
'Qwen2': LlmcQwen2RMSNorm,
'Gemma2': LlmcGemma2RMSNorm,
'MiniCPM': LlmcMiniCPMRMSNorm,
'Starcoder': LlmcLayerNorm,
'Opt': LlmcLayerNorm,
'Bloom': LlmcLayerNorm,
Expand All @@ -641,6 +674,8 @@ def __repr__(self):
LlmcMistralRMSNorm,
LlmcMixtralRMSNorm,
LlmcInternLM2RMSNorm,
LlmcGemma2RMSNorm,
LlmcMiniCPMRMSNorm,
]


Expand Down
19 changes: 17 additions & 2 deletions llmc/compression/quantization/quarot.py
Original file line number Diff line number Diff line change
Expand Up @@ -21,7 +21,9 @@ def __init__(self, model, quant_config, input, config):
self.preprocess()

def preprocess(self):
assert self.config['model']['type'] in ['Opt', 'Llama', 'Qwen2']
assert self.config['model']['type'] in [
'Opt', 'Llama', 'Qwen2', 'InternLM2',
'MiniCPM', 'StableLm', 'SmolLM']
# if self.config["model"]["type"] in ["Opt"]:
if torch.equal(
self.model.get_head_layers()[0].weight,
Expand Down Expand Up @@ -83,6 +85,16 @@ def block_transform(self, block):
logger.info(f'block:{block}')
logger.info(f'End transform the {self.block_idx+1}-th block')

def bake_mean_into_linear(self, linear):
linear_dtype = linear.weight.dtype
W_ = linear.weight.data.double()
linear.weight.data = W_ - W_.mean(dim=-2, keepdim=True)
linear.weight.data = linear.weight.data.to(linear_dtype)
if linear.bias is not None:
b_ = linear.bias.data.double()
linear.bias.data = b_ - b_.mean()
linear.bias.data = linear.bias.data.to(linear_dtype)

@torch.no_grad()
def subset_transform(self, block, subset):
prev_op = subset['prev_op']
Expand All @@ -97,14 +109,17 @@ def subset_transform(self, block, subset):
self.fuse_ln_fcs(prev_op[0], layers)
self.rotate_pre_layers(layers, self.Q)
else:
if self.config['model']['type'] in ['Opt']:
if self.config['model']['type'] in ['Opt', 'StableLm']:
self.bake_mean_into_linear(layers[0])

if 'is_mlp' in subset and subset['is_mlp']:
self.rotate_post_layers(
layers, self.Q, exact_had=True if self.online_rotate else False
)
else:
for n, m in layers_dict.items():
logger.info(f'layer: {n} {m.weight.shape}')
logger.info(f'{self.Q.shape}')
self.rotate_post_layers(layers, self.Q, exact_had=False)
if self.online_rotate:
apply_exact_had_to_linear(
Expand Down
4 changes: 4 additions & 0 deletions llmc/models/__init__.py
Original file line number Diff line number Diff line change
Expand Up @@ -4,8 +4,12 @@
from .internlm2 import InternLM2
from .llama import Llama
from .llava import Llava
from .minicpm import MiniCPM
from .mistral import Mistral
from .mixtral import Mixtral
from .opt import Opt
from .phi import Phi
from .qwen2 import Qwen2
from .smollm import SmolLM
from .stablelm import StableLm
from .starcoder import Starcoder
30 changes: 30 additions & 0 deletions llmc/models/gemma2.py
Original file line number Diff line number Diff line change
@@ -1,12 +1,34 @@
from loguru import logger

from llmc.utils.registry_factory import MODEL_REGISTRY

try:
from transformers.models.gemma2.modeling_gemma2 import Gemma2RMSNorm
except Exception:
logger.warning('Gemma2 not found')
from types import MethodType

import torch.nn as nn

from .base_model import BaseModel


def gemma2_rms_norm_forward(self, x):
output = self._norm(x.float())
output = output * self.weight.float()
return output.type_as(x)


@MODEL_REGISTRY
class Gemma2(BaseModel):
def __init__(self, model_path, torch_dtype):
super().__init__(model_path, torch_dtype)
for m in self.model.modules():
if isinstance(m, Gemma2RMSNorm):
w = m.weight.data
del m.weight
m.weight = nn.Parameter(w + 1.0)
m.forward = MethodType(gemma2_rms_norm_forward, m)

def find_blocks(self):
self.blocks = self.model.model.layers
Expand All @@ -21,6 +43,12 @@ def find_block_name(self):
def get_embed_layers(self):
return [self.embed_tokens]

def get_head_layers(self):
return [self.model.lm_head]

def get_pre_head_layernorm_layers(self):
return [self.model.model.norm]

def get_layers_except_blocks(self):
return [self.embed_tokens, self.model.model.norm, self.model.lm_head]

Expand Down Expand Up @@ -62,12 +90,14 @@ def get_subsets_in_block(self, block):
'input': ['mlp.gate_proj'],
'inspect': block.mlp,
'has_kwargs': False,
'is_mlp': True,
},
{
'layers': {'mlp.down_proj': block.mlp.down_proj},
'prev_op': [block.mlp.up_proj],
'input': ['mlp.down_proj'],
'inspect': block.mlp.down_proj,
'has_kwargs': False,
'is_mlp': True,
},
]
11 changes: 11 additions & 0 deletions llmc/models/internlm2.py
Original file line number Diff line number Diff line change
@@ -1,3 +1,4 @@
from llmc.compression.quantization.module_utils import _TRANSFORMERS_LN_TYPES_
from llmc.utils.registry_factory import MODEL_REGISTRY

from .base_model import BaseModel
Expand All @@ -7,6 +8,8 @@
class InternLM2(BaseModel):
def __init__(self, model_path, torch_dtype):
super().__init__(model_path, torch_dtype)
global _TRANSFORMERS_LN_TYPES_
_TRANSFORMERS_LN_TYPES_ += [type(self.model.model.norm)]

def find_blocks(self):
self.blocks = self.model.model.layers
Expand All @@ -20,6 +23,12 @@ def find_block_name(self):
def get_embed_layers(self):
return [self.tok_embeddings]

def get_head_layers(self):
return [self.model.output]

def get_pre_head_layernorm_layers(self):
return [self.model.model.norm]

def get_layers_except_blocks(self):
return [self.tok_embeddings, self.model.model.norm, self.model.output]

Expand Down Expand Up @@ -57,12 +66,14 @@ def get_subsets_in_block(self, block):
'input': ['feed_forward.w1'],
'inspect': block.feed_forward,
'has_kwargs': False,
'is_mlp': True,
},
{
'layers': {'feed_forward.w2': block.feed_forward.w2},
'prev_op': [block.feed_forward.w3],
'input': ['feed_forward.w2'],
'inspect': block.feed_forward.w2,
'has_kwargs': False,
'is_mlp': True,
},
]
84 changes: 84 additions & 0 deletions llmc/models/minicpm.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,84 @@
from llmc.compression.quantization.module_utils import _TRANSFORMERS_LN_TYPES_
from llmc.utils.registry_factory import MODEL_REGISTRY

from .base_model import BaseModel


@MODEL_REGISTRY
class MiniCPM(BaseModel):
def __init__(self, model_path, torch_dtype):
super().__init__(model_path, torch_dtype)
global _TRANSFORMERS_LN_TYPES_
_TRANSFORMERS_LN_TYPES_ += [type(self.model.model.norm)]

def find_blocks(self):
self.blocks = self.model.model.layers

def find_embed_layers(self):
self.embed_tokens = self.model.model.embed_tokens

def find_block_name(self):
self.block_name_prefix = 'model.layers'
self.pairs = {'q_proj': 'qkv', 'o_proj': 'out', 'up_proj': 'fc1'}

def get_embed_layers(self):
return [self.embed_tokens]

def get_head_layers(self):
return [self.model.lm_head]

def get_pre_head_layernorm_layers(self):
return [self.model.model.norm]

def get_layers_except_blocks(self):
return [self.embed_tokens, self.model.model.norm, self.model.lm_head]

def has_bias(self):
return False

def get_layernorms_in_block(self, block):
return {
'input_layernorm': block.input_layernorm,
'post_attention_layernorm': block.post_attention_layernorm,
}

def get_subsets_in_block(self, block):
return [
{
'layers': {
'self_attn.q_proj': block.self_attn.q_proj,
'self_attn.k_proj': block.self_attn.k_proj,
'self_attn.v_proj': block.self_attn.v_proj,
},
'prev_op': [block.input_layernorm],
'input': ['self_attn.q_proj'],
'inspect': block.self_attn,
'has_kwargs': True,
},
{
'layers': {'self_attn.o_proj': block.self_attn.o_proj},
'prev_op': [block.self_attn.v_proj],
'input': ['self_attn.o_proj'],
'inspect': block.self_attn.o_proj,
'has_kwargs': False,
},
{
'layers': {
'mlp.gate_proj': block.mlp.gate_proj,
'mlp.up_proj': block.mlp.up_proj,
},
'prev_op': [block.post_attention_layernorm],
'input': ['mlp.gate_proj'],
'inspect': block.mlp,
'has_kwargs': False,
'is_mlp': True,
},
{
'layers': {'mlp.down_proj': block.mlp.down_proj},
'prev_op': [block.mlp.up_proj],
'input': ['mlp.down_proj'],
'inspect': block.mlp.down_proj,
'has_kwargs': False,
'is_mlp': True,
},
]
Loading
Loading