Skip to content

Commit

Permalink
temp fix for parse model train with torch 1.7
Browse files Browse the repository at this point in the history
  • Loading branch information
bjascob committed Dec 6, 2020
1 parent d97ddf2 commit 95d0852
Showing 1 changed file with 6 additions and 1 deletion.
7 changes: 6 additions & 1 deletion amrlib/models/parse_gsii/modules/transformer.py
Original file line number Diff line number Diff line change
Expand Up @@ -194,7 +194,12 @@ def forward(self, query, key, value, key_padding_mask=None, attn_mask=None, need
return attn, attn_weights

def in_proj_qkv(self, query):
return self._in_proj(query).chunk(3, dim=-1)
#return self._in_proj(query).chunk(3, dim=-1) # original code
# Note: As of torch 1.7 this line is failing with... RuntimeError: one of the variables needed for
# gradient computation has been modified by an inplace operation: [torch.cuda.FloatTensor [24, 129, 1536]]
# See release notes for v1.7 (torch.chunk) for an explanation. A temporary fix is to use unsafe_chunk instead.
# See https://discuss.pytorch.org/t/runtimeerror-for-chunk-inplace-operation-new-with-torch-1-7/105334
return self._in_proj(query).unsafe_chunk(3, dim=-1)

def in_proj_kv(self, key):
return self._in_proj(key, start=self.embed_dim).chunk(2, dim=-1)
Expand Down

0 comments on commit 95d0852

Please sign in to comment.