Skip to content

Commit

Permalink
Fix vi calib data bug
Browse files Browse the repository at this point in the history
  • Loading branch information
gushiqiao committed Oct 17, 2024
1 parent 7e94a7c commit fada39c
Showing 1 changed file with 5 additions and 6 deletions.
11 changes: 5 additions & 6 deletions llmc/data/dataset/base_dataset.py
Original file line number Diff line number Diff line change
Expand Up @@ -110,8 +110,10 @@ def build_calib_dataset(self):
for name in files:
if name.endswith(('.jpg', '.png', '.JPEG')):
img_path = os.path.join(root, name)
raw_image = Image.open(img_path)
raw_image = Image.open(img_path).convert('RGB')
self.calib_dataset.append(raw_image)
if len(self.calib_dataset) == self.n_samples:
return
else:
raise ValueError(f'Unsupported data type: {self.calib_dataset_type}')

Expand Down Expand Up @@ -262,6 +264,7 @@ def img_group_samples_wo_mask(self, samples): # without mask
batch = {'pixel_values': torch.cat([sample['pixel_values']
for sample in batch], dim=0)}
calib_samples.append(batch)
return calib_samples

def get_calib_dataset(self):
samples = self.get_calib_samples()
Expand Down Expand Up @@ -292,11 +295,7 @@ def get_calib_dataset(self):
elif self.calib_dataset_type == 'img_txt':
calib_samples = self.img_txt_group_samples_wo_mask(samples)
logger.info(f'len(calib_samples) : {len(calib_samples)}')
if self.padding:
padding_mask = [calib_sample['attention_mask'] for calib_sample in calib_samples] # noqa
else:
padding_mask = None
return calib_samples, padding_mask
return calib_samples

def general_preproc(self, calib_dataset, tokenizer, n_samples, seq_len):
dataset = calib_dataset.shuffle(seed=self.seed)
Expand Down

0 comments on commit fada39c

Please sign in to comment.