From 339a9d6bb032389aa1192d37c10a3a102ab4b18e Mon Sep 17 00:00:00 2001 From: huangyushi Date: Fri, 23 Aug 2024 06:15:58 +0800 Subject: [PATCH] remove redundant code --- llmc/compression/quantization/quarot.py | 12 +----------- 1 file changed, 1 insertion(+), 11 deletions(-) diff --git a/llmc/compression/quantization/quarot.py b/llmc/compression/quantization/quarot.py index 99e780f3..458e239a 100644 --- a/llmc/compression/quantization/quarot.py +++ b/llmc/compression/quantization/quarot.py @@ -92,16 +92,6 @@ def block_transform(self, block): logger.info(f'block:{block}') logger.info(f'End transform the {self.block_idx+1}-th block') - def bake_mean_into_linear(self, linear): - linear_dtype = linear.weight.dtype - W_ = linear.weight.data.double() - linear.weight.data = W_ - W_.mean(dim=-2, keepdim=True) - linear.weight.data = linear.weight.data.to(linear_dtype) - if linear.bias is not None: - b_ = linear.bias.data.double() - linear.bias.data = b_ - b_.mean() - linear.bias.data = linear.bias.data.to(linear_dtype) - @torch.no_grad() def subset_transform(self, block, subset): prev_op = subset['prev_op'] @@ -117,7 +107,7 @@ def subset_transform(self, block, subset): self.rotate_pre_layers(layers, self.Q) else: if self.config['model']['type'] in ['Opt', 'StableLm']: - self.bake_mean_into_linear(layers[0]) + self.bake_mean_into_fc(layers[0]) if 'is_mlp' in subset and subset['is_mlp']: self.rotate_post_layers(