From 95d0852a3e887049e10fa3566ca3eb408496fcd3 Mon Sep 17 00:00:00 2001 From: bjascob Date: Sun, 6 Dec 2020 14:35:17 -0700 Subject: [PATCH] temp fix for parse model train with torch 1.7 --- amrlib/models/parse_gsii/modules/transformer.py | 7 ++++++- 1 file changed, 6 insertions(+), 1 deletion(-) diff --git a/amrlib/models/parse_gsii/modules/transformer.py b/amrlib/models/parse_gsii/modules/transformer.py index 8fec1fe..f6fc1c4 100644 --- a/amrlib/models/parse_gsii/modules/transformer.py +++ b/amrlib/models/parse_gsii/modules/transformer.py @@ -194,7 +194,12 @@ def forward(self, query, key, value, key_padding_mask=None, attn_mask=None, need return attn, attn_weights def in_proj_qkv(self, query): - return self._in_proj(query).chunk(3, dim=-1) + #return self._in_proj(query).chunk(3, dim=-1) # original code + # Note: As of torch 1.7 this line is failing with... RuntimeError: one of the variables needed for + # gradient computation has been modified by an inplace operation: [torch.cuda.FloatTensor [24, 129, 1536]] + # See release notes for v1.7 (torch.chunk) for an explanation. A temporary fix is to use unsafe_chunk instead. + # See https://discuss.pytorch.org/t/runtimeerror-for-chunk-inplace-operation-new-with-torch-1-7/105334 + return self._in_proj(query).unsafe_chunk(3, dim=-1) def in_proj_kv(self, key): return self._in_proj(key, start=self.embed_dim).chunk(2, dim=-1)