Skip to content

Commit

Permalink
train generate model updates for transformers 4.0.0
Browse files Browse the repository at this point in the history
  • Loading branch information
bjascob committed Dec 6, 2020
1 parent 95d0852 commit eb2f13c
Show file tree
Hide file tree
Showing 3 changed files with 12 additions and 3 deletions.
9 changes: 7 additions & 2 deletions amrlib/models/generate_t5/trainer.py
Original file line number Diff line number Diff line change
Expand Up @@ -31,6 +31,8 @@ def __getitem__(self, idx):
# prepares lm_labels from target_ids, returns examples with keys as expected by the forward method
# this is necessacry because the trainer directly passes this dict as arguments to the model
# so make sure the keys match the parameter names of the forward method
# Note*1: The original code (with transformers v3.4.0) returned dict with "lm_labels".
# Support for this was removed in transformers v4.0.0 and replaced it with "labels"
class T2TDataCollator:
def __call__(self, batch):
input_ids = torch.stack([example['input_ids'] for example in batch])
Expand All @@ -39,7 +41,7 @@ def __call__(self, batch):
attention_mask = torch.stack([example['attention_mask'] for example in batch])
decoder_attention_mask = torch.stack([example['target_attention_mask'] for example in batch])
return {'input_ids': input_ids, 'attention_mask': attention_mask,
'lm_labels': lm_labels, 'decoder_attention_mask': decoder_attention_mask }
'labels': lm_labels, 'decoder_attention_mask': decoder_attention_mask } # Note*1


# Note that for save_steps, steps means gradient updates (not batch) so if
Expand Down Expand Up @@ -80,8 +82,11 @@ def train(self):
len(valid_dataset), len(valid_dataset.bad_indexes)))
# Train the model
print('Training')
# trainer = T5Trainer(model=self.model, args=self.training_args, train_dataset=train_dataset,
# eval_dataset=valid_dataset, data_collator=T2TDataCollator(), prediction_loss_only=True)
# prediction_loss_only=True moved to training_args for compatibility with transformers v4.0.0
trainer = T5Trainer(model=self.model, args=self.training_args, train_dataset=train_dataset,
eval_dataset=valid_dataset, data_collator=T2TDataCollator(), prediction_loss_only=True)
eval_dataset=valid_dataset, data_collator=T2TDataCollator())
trainer.train()
# Save the results
print('Saving model')
Expand Down
4 changes: 4 additions & 0 deletions amrlib/models/parse_gsii/modules/transformer.py
Original file line number Diff line number Diff line change
Expand Up @@ -200,6 +200,10 @@ def in_proj_qkv(self, query):
# See release notes for v1.7 (torch.chunk) for an explanation. A temporary fix is to use unsafe_chunk instead.
# See https://discuss.pytorch.org/t/runtimeerror-for-chunk-inplace-operation-new-with-torch-1-7/105334
return self._in_proj(query).unsafe_chunk(3, dim=-1)
# Possible solution...
# proj = self._in_proj(query)
# sz = proj.size()[2] // 3
# return proj[:,:,:sz], proj[:,:,sz:2*sz], proj[:,:,2*sz:]

def in_proj_kv(self, key):
return self._in_proj(key, start=self.embed_dim).chunk(2, dim=-1)
Expand Down
2 changes: 1 addition & 1 deletion configs/model_generate_t5.json
Original file line number Diff line number Diff line change
Expand Up @@ -13,8 +13,8 @@
"output_dir" : "amrlib/data/model_generate_t5",
"do_train" : true,
"do_eval" : false,
"evaluate_during_training" : false,
"overwrite_output_dir" : false,
"prediction_loss_only" : true,
"num_train_epochs" : 8,
"save_steps" : 1000,
"save_total_limit" : 2,
Expand Down

0 comments on commit eb2f13c

Please sign in to comment.