diff --git a/caffe2/operators/recurrent_op_cudnn.cc b/caffe2/operators/recurrent_op_cudnn.cc index 6951852db80..e36c2e3d29a 100644 --- a/caffe2/operators/recurrent_op_cudnn.cc +++ b/caffe2/operators/recurrent_op_cudnn.cc @@ -295,7 +295,7 @@ template bool RecurrentGradientOp::RunOnDevice() { const int seqLength = Input(INPUT).dim32(0); if (Input(INPUT).dims() != cachedInputDims_) { - initialize(Input(INPUT)); + initialize(Input(INPUT), Output(DROPOUT_STATES)); cachedInputDims_ = Input(INPUT).dims(); } CUDNN_ENFORCE(cudnnGetRNNTrainingReserveSize(