From bc80937c4e3cbe9a3c94e067edf09e5dfdc2f1e4 Mon Sep 17 00:00:00 2001 From: gushiqiao Date: Thu, 17 Oct 2024 17:16:17 +0800 Subject: [PATCH 1/3] Update mse quant --- .../quantization/methods/GPTQ/gptq_w_only.yml | 2 + llmc/compression/quantization/quant.py | 95 +++++++++++-------- 2 files changed, 56 insertions(+), 41 deletions(-) diff --git a/configs/quantization/methods/GPTQ/gptq_w_only.yml b/configs/quantization/methods/GPTQ/gptq_w_only.yml index d9c5bcc5..7971a57a 100644 --- a/configs/quantization/methods/GPTQ/gptq_w_only.yml +++ b/configs/quantization/methods/GPTQ/gptq_w_only.yml @@ -29,6 +29,8 @@ quant: symmetric: False granularity: per_group group_size: 128 + # calib_algo: mse + # mse_b_num: 2 special: actorder: True static_groups: False diff --git a/llmc/compression/quantization/quant.py b/llmc/compression/quantization/quant.py index 321287c5..f49f204f 100644 --- a/llmc/compression/quantization/quant.py +++ b/llmc/compression/quantization/quant.py @@ -9,21 +9,22 @@ def __init__(self, bit, symmetric, granularity, **kwargs): self.sym = symmetric self.granularity = granularity self.kwargs = kwargs - if 'calib_algo' in self.kwargs: - self.calib_algo = self.kwargs['calib_algo'] - else: - self.calib_algo = 'minmax' + + self.calib_algo = self.kwargs.get('calib_algo', 'minmax') if self.granularity == 'per_group': self.group_size = self.kwargs['group_size'] elif self.granularity == 'per_head': self.head_num = self.kwargs['head_num'] - if 'ste' in self.kwargs and self.kwargs['ste']: + self.mse_b_num = self.kwargs.get('mse_b_num', 1) + + if self.kwargs.get('ste', False): self.round_func = lambda x: (x.round() - x).detach() + x else: self.round_func = torch.round - self.round_zp = 'round_zp' not in self.kwargs or self.kwargs['round_zp'] + + self.round_zp = self.kwargs.get('round_zp', True) self.sigmoid = torch.nn.Sigmoid() def get_tensor_range(self, tensor, args={}): @@ -34,7 +35,7 @@ def get_tensor_range(self, tensor, args={}): elif self.calib_algo == 'learnable': return self.get_learnable_range(tensor, **args) else: - logger.info('Calibration Algorithm Not Found!') + raise ValueError(f'Unsupported calibration algorithm: {self.calib_algo}') def get_minmax_range(self, tensor): if self.granularity == 'per_tensor': @@ -47,20 +48,16 @@ def get_minmax_range(self, tensor): return (min_val, max_val) def get_mse_range(self, tensor, grid=100, norm=2.4, maxshrink=0.8, bs=256): - if tensor.shape[0] % bs != 0: - logger.warning( - 'Batch size is not a multiple of the tensor size,' - 'set batch size to {}'.format( - tensor.shape[0] - ) - ) - bs = tensor.shape[0] + + assert self.mse_b_num >= 1 and tensor.shape[0] % self.mse_b_num == 0, \ + 'Batch number must be divisible by tensor.shape[0],' + bs = tensor.shape[0] // self.mse_b_num tensor = tensor.float() min_val, max_val = self.get_minmax_range(tensor) dev = tensor.device - for b_num in range(tensor.shape[0] // bs): + for b_num in range(self.mse_b_num): _tensor = tensor[b_num * bs: (b_num + 1) * bs, :] _min_val, _max_val = ( min_val[b_num * bs: (b_num + 1) * bs, :], @@ -113,6 +110,22 @@ def get_mse_range(self, tensor, grid=100, norm=2.4, maxshrink=0.8, bs=256): return (min_val, max_val) + def get_learnable_range(self, tensor, lowbound_factor=None, upbound_factor=None): + min_val, max_val = self.get_minmax_range(tensor) + if self.sym: + if upbound_factor is not None: + abs_max = torch.max(max_val.abs(), min_val.abs()) + abs_max = abs_max.clamp(min=1e-5) + abs_max = self.sigmoid(upbound_factor) * abs_max + min_val = -abs_max + max_val = abs_max + else: + if upbound_factor is not None and lowbound_factor is not None: + min_val = self.sigmoid(lowbound_factor) * min_val + max_val = self.sigmoid(upbound_factor) * max_val + + return (min_val, max_val) + def get_qparams(self, tensor_range, device): min_val, max_val = tensor_range[0], tensor_range[1] qmin = self.qmin @@ -436,33 +449,33 @@ def __init__(self, bit, symmetric, granularity, **kwargs): if self.use_qtorch: try: from qtorch.quant import float_quantize - self.float_quantize = float_quantize + except ImportError: + logger.error('qtorch not found, please install qtorch.') + raise ImportError('Please install qtorch (pip install qtorch).') - if 'float_range' in self.kwargs: - self.qmin, self.qmax = self.kwargs['float_range'] - else: - bit_ranges = { - ('e4m3', 8): torch.float8_e4m3fn, - ('e5m2', 8): torch.float8_e5m2, - ('e3m2', 6): (-28, 28), - ('e4m7', 12): (-510, 510), - ('e2m1', 4): (-6, 6), - } - - key = (self.bit, self.num_bits) - if key in bit_ranges: - if isinstance(bit_ranges[key], tuple): - self.qmin, self.qmax = bit_ranges[key] - else: - finfo = torch.finfo(bit_ranges[key]) - self.qmin, self.qmax = finfo.min, finfo.max - else: - raise NotImplementedError('Only 4, 6, 8, and \ - 12-bit quantization is supported.') + self.float_quantize = float_quantize - except ImportError: - raise ImportError('Please install qtorch \ - (pip install qtorch) to use this function.') + if 'float_range' in self.kwargs: + self.qmin, self.qmax = self.kwargs['float_range'] + else: + bit_ranges = { + ('e4m3', 8): torch.float8_e4m3fn, + ('e5m2', 8): torch.float8_e5m2, + ('e3m2', 6): (-28, 28), + ('e4m7', 12): (-510, 510), + ('e2m1', 4): (-6, 6), + } + + key = (self.bit, self.num_bits) + if key in bit_ranges: + if isinstance(bit_ranges[key], tuple): + self.qmin, self.qmax = bit_ranges[key] + else: + finfo = torch.finfo(bit_ranges[key]) + self.qmin, self.qmax = finfo.min, finfo.max + else: + raise NotImplementedError('Only 4, 6, 8, and \ + 12-bit quantization is supported.') def get_float_qparams(self, tensor, tensor_range, device): min_val, max_val = tensor_range[0], tensor_range[1] From fada39c9258faa3071cd8ccaf7afd0bf281f0afa Mon Sep 17 00:00:00 2001 From: gushiqiao Date: Thu, 17 Oct 2024 18:55:22 +0800 Subject: [PATCH 2/3] Fix vi calib data bug --- llmc/data/dataset/base_dataset.py | 11 +++++------ 1 file changed, 5 insertions(+), 6 deletions(-) diff --git a/llmc/data/dataset/base_dataset.py b/llmc/data/dataset/base_dataset.py index 02a67bc9..c020bcee 100644 --- a/llmc/data/dataset/base_dataset.py +++ b/llmc/data/dataset/base_dataset.py @@ -110,8 +110,10 @@ def build_calib_dataset(self): for name in files: if name.endswith(('.jpg', '.png', '.JPEG')): img_path = os.path.join(root, name) - raw_image = Image.open(img_path) + raw_image = Image.open(img_path).convert('RGB') self.calib_dataset.append(raw_image) + if len(self.calib_dataset) == self.n_samples: + return else: raise ValueError(f'Unsupported data type: {self.calib_dataset_type}') @@ -262,6 +264,7 @@ def img_group_samples_wo_mask(self, samples): # without mask batch = {'pixel_values': torch.cat([sample['pixel_values'] for sample in batch], dim=0)} calib_samples.append(batch) + return calib_samples def get_calib_dataset(self): samples = self.get_calib_samples() @@ -292,11 +295,7 @@ def get_calib_dataset(self): elif self.calib_dataset_type == 'img_txt': calib_samples = self.img_txt_group_samples_wo_mask(samples) logger.info(f'len(calib_samples) : {len(calib_samples)}') - if self.padding: - padding_mask = [calib_sample['attention_mask'] for calib_sample in calib_samples] # noqa - else: - padding_mask = None - return calib_samples, padding_mask + return calib_samples def general_preproc(self, calib_dataset, tokenizer, n_samples, seq_len): dataset = calib_dataset.shuffle(seed=self.seed) From 15535df1f05eafb864ca49c2062c47a5c2480bf6 Mon Sep 17 00:00:00 2001 From: gushiqiao Date: Thu, 17 Oct 2024 19:00:31 +0800 Subject: [PATCH 3/3] Fix vit calib data bug --- llmc/data/dataset/base_dataset.py | 6 +++++- 1 file changed, 5 insertions(+), 1 deletion(-) diff --git a/llmc/data/dataset/base_dataset.py b/llmc/data/dataset/base_dataset.py index c020bcee..1388491b 100644 --- a/llmc/data/dataset/base_dataset.py +++ b/llmc/data/dataset/base_dataset.py @@ -295,7 +295,11 @@ def get_calib_dataset(self): elif self.calib_dataset_type == 'img_txt': calib_samples = self.img_txt_group_samples_wo_mask(samples) logger.info(f'len(calib_samples) : {len(calib_samples)}') - return calib_samples + if self.padding: + padding_mask = [calib_sample['attention_mask'] for calib_sample in calib_samples] # noqa + else: + padding_mask = None + return calib_samples, padding_mask def general_preproc(self, calib_dataset, tokenizer, n_samples, seq_len): dataset = calib_dataset.shuffle(seed=self.seed)