diff --git a/amrlib/models/generate_t5/trainer.py b/amrlib/models/generate_t5/trainer.py index 85eeaf6..09e9103 100644 --- a/amrlib/models/generate_t5/trainer.py +++ b/amrlib/models/generate_t5/trainer.py @@ -31,6 +31,8 @@ def __getitem__(self, idx): # prepares lm_labels from target_ids, returns examples with keys as expected by the forward method # this is necessacry because the trainer directly passes this dict as arguments to the model # so make sure the keys match the parameter names of the forward method +# Note*1: The original code (with transformers v3.4.0) returned dict with "lm_labels". +# Support for this was removed in transformers v4.0.0 and replaced it with "labels" class T2TDataCollator: def __call__(self, batch): input_ids = torch.stack([example['input_ids'] for example in batch]) @@ -39,7 +41,7 @@ def __call__(self, batch): attention_mask = torch.stack([example['attention_mask'] for example in batch]) decoder_attention_mask = torch.stack([example['target_attention_mask'] for example in batch]) return {'input_ids': input_ids, 'attention_mask': attention_mask, - 'lm_labels': lm_labels, 'decoder_attention_mask': decoder_attention_mask } + 'labels': lm_labels, 'decoder_attention_mask': decoder_attention_mask } # Note*1 # Note that for save_steps, steps means gradient updates (not batch) so if @@ -80,8 +82,11 @@ def train(self): len(valid_dataset), len(valid_dataset.bad_indexes))) # Train the model print('Training') + # trainer = T5Trainer(model=self.model, args=self.training_args, train_dataset=train_dataset, + # eval_dataset=valid_dataset, data_collator=T2TDataCollator(), prediction_loss_only=True) + # prediction_loss_only=True moved to training_args for compatibility with transformers v4.0.0 trainer = T5Trainer(model=self.model, args=self.training_args, train_dataset=train_dataset, - eval_dataset=valid_dataset, data_collator=T2TDataCollator(), prediction_loss_only=True) + eval_dataset=valid_dataset, data_collator=T2TDataCollator()) trainer.train() # Save the results print('Saving model') diff --git a/amrlib/models/parse_gsii/modules/transformer.py b/amrlib/models/parse_gsii/modules/transformer.py index f6fc1c4..cdf3140 100644 --- a/amrlib/models/parse_gsii/modules/transformer.py +++ b/amrlib/models/parse_gsii/modules/transformer.py @@ -200,6 +200,10 @@ def in_proj_qkv(self, query): # See release notes for v1.7 (torch.chunk) for an explanation. A temporary fix is to use unsafe_chunk instead. # See https://discuss.pytorch.org/t/runtimeerror-for-chunk-inplace-operation-new-with-torch-1-7/105334 return self._in_proj(query).unsafe_chunk(3, dim=-1) + # Possible solution... + # proj = self._in_proj(query) + # sz = proj.size()[2] // 3 + # return proj[:,:,:sz], proj[:,:,sz:2*sz], proj[:,:,2*sz:] def in_proj_kv(self, key): return self._in_proj(key, start=self.embed_dim).chunk(2, dim=-1) diff --git a/configs/model_generate_t5.json b/configs/model_generate_t5.json index ecfd2a1..4cd0053 100644 --- a/configs/model_generate_t5.json +++ b/configs/model_generate_t5.json @@ -13,8 +13,8 @@ "output_dir" : "amrlib/data/model_generate_t5", "do_train" : true, "do_eval" : false, - "evaluate_during_training" : false, "overwrite_output_dir" : false, + "prediction_loss_only" : true, "num_train_epochs" : 8, "save_steps" : 1000, "save_total_limit" : 2,