From fada39c9258faa3071cd8ccaf7afd0bf281f0afa Mon Sep 17 00:00:00 2001 From: gushiqiao Date: Thu, 17 Oct 2024 18:55:22 +0800 Subject: [PATCH] Fix vi calib data bug --- llmc/data/dataset/base_dataset.py | 11 +++++------ 1 file changed, 5 insertions(+), 6 deletions(-) diff --git a/llmc/data/dataset/base_dataset.py b/llmc/data/dataset/base_dataset.py index 02a67bc9..c020bcee 100644 --- a/llmc/data/dataset/base_dataset.py +++ b/llmc/data/dataset/base_dataset.py @@ -110,8 +110,10 @@ def build_calib_dataset(self): for name in files: if name.endswith(('.jpg', '.png', '.JPEG')): img_path = os.path.join(root, name) - raw_image = Image.open(img_path) + raw_image = Image.open(img_path).convert('RGB') self.calib_dataset.append(raw_image) + if len(self.calib_dataset) == self.n_samples: + return else: raise ValueError(f'Unsupported data type: {self.calib_dataset_type}') @@ -262,6 +264,7 @@ def img_group_samples_wo_mask(self, samples): # without mask batch = {'pixel_values': torch.cat([sample['pixel_values'] for sample in batch], dim=0)} calib_samples.append(batch) + return calib_samples def get_calib_dataset(self): samples = self.get_calib_samples() @@ -292,11 +295,7 @@ def get_calib_dataset(self): elif self.calib_dataset_type == 'img_txt': calib_samples = self.img_txt_group_samples_wo_mask(samples) logger.info(f'len(calib_samples) : {len(calib_samples)}') - if self.padding: - padding_mask = [calib_sample['attention_mask'] for calib_sample in calib_samples] # noqa - else: - padding_mask = None - return calib_samples, padding_mask + return calib_samples def general_preproc(self, calib_dataset, tokenizer, n_samples, seq_len): dataset = calib_dataset.shuffle(seed=self.seed)