diff --git a/llmc/compression/quantization/quant.py b/llmc/compression/quantization/quant.py index 28d9e5ce..4f6aac20 100644 --- a/llmc/compression/quantization/quant.py +++ b/llmc/compression/quantization/quant.py @@ -433,7 +433,7 @@ def real_quant_weight_static(self, weight, args): else: dtype = torch.int32 weight = weight.to(dtype) - if zeros != torch.tensor(0.0) and self.round_zp: + if (zeros != torch.tensor(0.0)).all() and self.round_zp: zeros = zeros.to(dtype) else: zeros = None @@ -454,7 +454,7 @@ def real_quant_weight_dynamic(self, weight, args={}): else: dtype = torch.int32 weight = weight.to(dtype) - if zeros != torch.tensor(0.0) and self.round_zp: + if (zeros != torch.tensor(0.0)).all() and self.round_zp: zeros = zeros.to(dtype) else: zeros = None