-
Notifications
You must be signed in to change notification settings - Fork 10
/
recurrent_fast.py
3103 lines (2773 loc) · 123 KB
/
recurrent_fast.py
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
83
84
85
86
87
88
89
90
91
92
93
94
95
96
97
98
99
100
101
102
103
104
105
106
107
108
109
110
111
112
113
114
115
116
117
118
119
120
121
122
123
124
125
126
127
128
129
130
131
132
133
134
135
136
137
138
139
140
141
142
143
144
145
146
147
148
149
150
151
152
153
154
155
156
157
158
159
160
161
162
163
164
165
166
167
168
169
170
171
172
173
174
175
176
177
178
179
180
181
182
183
184
185
186
187
188
189
190
191
192
193
194
195
196
197
198
199
200
201
202
203
204
205
206
207
208
209
210
211
212
213
214
215
216
217
218
219
220
221
222
223
224
225
226
227
228
229
230
231
232
233
234
235
236
237
238
239
240
241
242
243
244
245
246
247
248
249
250
251
252
253
254
255
256
257
258
259
260
261
262
263
264
265
266
267
268
269
270
271
272
273
274
275
276
277
278
279
280
281
282
283
284
285
286
287
288
289
290
291
292
293
294
295
296
297
298
299
300
301
302
303
304
305
306
307
308
309
310
311
312
313
314
315
316
317
318
319
320
321
322
323
324
325
326
327
328
329
330
331
332
333
334
335
336
337
338
339
340
341
342
343
344
345
346
347
348
349
350
351
352
353
354
355
356
357
358
359
360
361
362
363
364
365
366
367
368
369
370
371
372
373
374
375
376
377
378
379
380
381
382
383
384
385
386
387
388
389
390
391
392
393
394
395
396
397
398
399
400
401
402
403
404
405
406
407
408
409
410
411
412
413
414
415
416
417
418
419
420
421
422
423
424
425
426
427
428
429
430
431
432
433
434
435
436
437
438
439
440
441
442
443
444
445
446
447
448
449
450
451
452
453
454
455
456
457
458
459
460
461
462
463
464
465
466
467
468
469
470
471
472
473
474
475
476
477
478
479
480
481
482
483
484
485
486
487
488
489
490
491
492
493
494
495
496
497
498
499
500
501
502
503
504
505
506
507
508
509
510
511
512
513
514
515
516
517
518
519
520
521
522
523
524
525
526
527
528
529
530
531
532
533
534
535
536
537
538
539
540
541
542
543
544
545
546
547
548
549
550
551
552
553
554
555
556
557
558
559
560
561
562
563
564
565
566
567
568
569
570
571
572
573
574
575
576
577
578
579
580
581
582
583
584
585
586
587
588
589
590
591
592
593
594
595
596
597
598
599
600
601
602
603
604
605
606
607
608
609
610
611
612
613
614
615
616
617
618
619
620
621
622
623
624
625
626
627
628
629
630
631
632
633
634
635
636
637
638
639
640
641
642
643
644
645
646
647
648
649
650
651
652
653
654
655
656
657
658
659
660
661
662
663
664
665
666
667
668
669
670
671
672
673
674
675
676
677
678
679
680
681
682
683
684
685
686
687
688
689
690
691
692
693
694
695
696
697
698
699
700
701
702
703
704
705
706
707
708
709
710
711
712
713
714
715
716
717
718
719
720
721
722
723
724
725
726
727
728
729
730
731
732
733
734
735
736
737
738
739
740
741
742
743
744
745
746
747
748
749
750
751
752
753
754
755
756
757
758
759
760
761
762
763
764
765
766
767
768
769
770
771
772
773
774
775
776
777
778
779
780
781
782
783
784
785
786
787
788
789
790
791
792
793
794
795
796
797
798
799
800
801
802
803
804
805
806
807
808
809
810
811
812
813
814
815
816
817
818
819
820
821
822
823
824
825
826
827
828
829
830
831
832
833
834
835
836
837
838
839
840
841
842
843
844
845
846
847
848
849
850
851
852
853
854
855
856
857
858
859
860
861
862
863
864
865
866
867
868
869
870
871
872
873
874
875
876
877
878
879
880
881
882
883
884
885
886
887
888
889
890
891
892
893
894
895
896
897
898
899
900
901
902
903
904
905
906
907
908
909
910
911
912
913
914
915
916
917
918
919
920
921
922
923
924
925
926
927
928
929
930
931
932
933
934
935
936
937
938
939
940
941
942
943
944
945
946
947
948
949
950
951
952
953
954
955
956
957
958
959
960
961
962
963
964
965
966
967
968
969
970
971
972
973
974
975
976
977
978
979
980
981
982
983
984
985
986
987
988
989
990
991
992
993
994
995
996
997
998
999
1000
# Copyright 2015 The TensorFlow Authors. All Rights Reserved.
#
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at
#
# http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.
# ==============================================================================
# pylint: disable=protected-access
"""Recurrent layers and their base classes.
"""
from __future__ import absolute_import
from __future__ import division
from __future__ import print_function
from tensorflow.python.keras.layers.edgeml.graph.rnn import FastRNNCell, FastGRNNCell
import uuid
import numpy as np
from tensorflow.python.eager import context
from tensorflow.python.eager import function
from tensorflow.python.framework import constant_op
from tensorflow.python.framework import dtypes
from tensorflow.python.framework import ops
from tensorflow.python.framework import tensor_shape
from tensorflow.python.keras import activations
from tensorflow.python.keras import backend as K
from tensorflow.python.keras import constraints
from tensorflow.python.keras import initializers
from tensorflow.python.keras import regularizers
from tensorflow.python.keras.engine.base_layer import Layer
from tensorflow.python.keras.engine.input_spec import InputSpec
from tensorflow.python.keras.utils import generic_utils
from tensorflow.python.keras.utils import tf_utils
from tensorflow.python.ops import array_ops
from tensorflow.python.ops import gen_cudnn_rnn_ops
from tensorflow.python.ops import state_ops
from tensorflow.python.platform import tf_logging as logging
from tensorflow.python.training.checkpointable import base as checkpointable
from tensorflow.python.util import nest
from tensorflow.python.util.tf_export import tf_export
@tf_export('keras.layers.StackedRNNCells')
class StackedRNNCells(Layer):
"""Wrapper allowing a stack of RNN cells to behave as a single cell.
Used to implement efficient stacked RNNs.
Arguments:
cells: List of RNN cell instances.
Examples:
```python
cells = [
keras.layers.LSTMCell(output_dim),
keras.layers.LSTMCell(output_dim),
keras.layers.LSTMCell(output_dim),
]
inputs = keras.Input((timesteps, input_dim))
x = keras.layers.RNN(cells)(inputs)
```
"""
@checkpointable.no_automatic_dependency_tracking
def __init__(self, cells, **kwargs):
for cell in cells:
if not hasattr(cell, 'call'):
raise ValueError('All cells must have a `call` method. '
'received cells:', cells)
if not hasattr(cell, 'state_size'):
raise ValueError('All cells must have a '
'`state_size` attribute. '
'received cells:', cells)
self.cells = cells
# reverse_state_order determines whether the state size will be in a reverse
# order of the cells' state. User might want to set this to True to keep the
# existing behavior. This is only useful when use RNN(return_state=True)
# since the state will be returned as the same order of state_size.
self.reverse_state_order = kwargs.pop('reverse_state_order', False)
if self.reverse_state_order:
logging.warning('reverse_state_order=True in StackedRNNCells will soon '
'be deprecated. Please update the code to work with the '
'natural order of states if you reply on the RNN states, '
'eg RNN(return_state=True).')
super(StackedRNNCells, self).__init__(**kwargs)
@property
def state_size(self):
return tuple(c.state_size for c in
(self.cells[::-1] if self.reverse_state_order else self.cells))
@property
def output_size(self):
if getattr(self.cells[-1], 'output_size', None) is not None:
return self.cells[-1].output_size
elif _is_multiple_state(self.cells[-1].state_size):
return self.cells[-1].state_size[0]
else:
return self.cells[-1].state_size
def get_initial_state(self, inputs=None, batch_size=None, dtype=None):
initial_states = []
for cell in self.cells[::-1] if self.reverse_state_order else self.cells:
get_initial_state_fn = getattr(cell, 'get_initial_state', None)
if get_initial_state_fn:
initial_states.append(get_initial_state_fn(
inputs=inputs, batch_size=batch_size, dtype=dtype))
else:
initial_states.append(_generate_zero_filled_state_for_cell(
cell, inputs, batch_size, dtype))
return tuple(initial_states)
def call(self, inputs, states, constants=None, **kwargs):
# Recover per-cell states.
state_size = (self.state_size[::-1]
if self.reverse_state_order else self.state_size)
nested_states = nest.pack_sequence_as(state_size, nest.flatten(states))
# Call the cells in order and store the returned states.
new_nested_states = []
for cell, states in zip(self.cells, nested_states):
states = states if nest.is_sequence(states) else [states]
if generic_utils.has_arg(cell.call, 'constants'):
inputs, states = cell.call(inputs, states, constants=constants,
**kwargs)
else:
inputs, states = cell.call(inputs, states, **kwargs)
new_nested_states.append(states)
return inputs, nest.pack_sequence_as(state_size,
nest.flatten(new_nested_states))
@tf_utils.shape_type_conversion
def build(self, input_shape):
if isinstance(input_shape, list):
constants_shape = input_shape[1:]
input_shape = input_shape[0]
for cell in self.cells:
if isinstance(cell, Layer):
if generic_utils.has_arg(cell.call, 'constants'):
cell.build([input_shape] + constants_shape)
else:
cell.build(input_shape)
if getattr(cell, 'output_size', None) is not None:
output_dim = cell.output_size
elif _is_multiple_state(cell.state_size):
output_dim = cell.state_size[0]
else:
output_dim = cell.state_size
input_shape = tuple([input_shape[0]] +
tensor_shape.as_shape(output_dim).as_list())
self.built = True
def get_config(self):
cells = []
for cell in self.cells:
cells.append({
'class_name': cell.__class__.__name__,
'config': cell.get_config()
})
config = {'cells': cells}
base_config = super(StackedRNNCells, self).get_config()
return dict(list(base_config.items()) + list(config.items()))
@classmethod
def from_config(cls, config, custom_objects=None):
from tensorflow.python.keras.layers import deserialize as deserialize_layer # pylint: disable=g-import-not-at-top
cells = []
for cell_config in config.pop('cells'):
cells.append(
deserialize_layer(cell_config, custom_objects=custom_objects))
return cls(cells, **config)
@property
def trainable_weights(self):
if not self.trainable:
return []
weights = []
for cell in self.cells:
if isinstance(cell, Layer):
weights += cell.trainable_weights
return weights
@property
def non_trainable_weights(self):
weights = []
for cell in self.cells:
if isinstance(cell, Layer):
weights += cell.non_trainable_weights
if not self.trainable:
trainable_weights = []
for cell in self.cells:
if isinstance(cell, Layer):
trainable_weights += cell.trainable_weights
return trainable_weights + weights
return weights
def get_weights(self):
"""Retrieves the weights of the model.
Returns:
A flat list of Numpy arrays.
"""
weights = []
for cell in self.cells:
if isinstance(cell, Layer):
weights += cell.weights
return K.batch_get_value(weights)
def set_weights(self, weights):
"""Sets the weights of the model.
Arguments:
weights: A list of Numpy arrays with shapes and types matching
the output of `model.get_weights()`.
"""
tuples = []
for cell in self.cells:
if isinstance(cell, Layer):
num_param = len(cell.weights)
weights = weights[:num_param]
for sw, w in zip(cell.weights, weights):
tuples.append((sw, w))
weights = weights[num_param:]
K.batch_set_value(tuples)
@property
def losses(self):
losses = []
for cell in self.cells:
if isinstance(cell, Layer):
losses += cell.losses
return losses + self._losses
@property
def updates(self):
updates = []
for cell in self.cells:
if isinstance(cell, Layer):
updates += cell.updates
return updates + self._updates
@tf_export('keras.layers.RNN')
class RNN(Layer):
"""Base class for recurrent layers.
Arguments:
cell: A RNN cell instance or a list of RNN cell instances.
A RNN cell is a class that has:
- a `call(input_at_t, states_at_t)` method, returning
`(output_at_t, states_at_t_plus_1)`. The call method of the
cell can also take the optional argument `constants`, see
section "Note on passing external constants" below.
- a `state_size` attribute. This can be a single integer
(single state) in which case it is the size of the recurrent
state. This can also be a list/tuple of integers (one size per
state).
The `state_size` can also be TensorShape or tuple/list of
TensorShape, to represent high dimension state.
- a `output_size` attribute. This can be a single integer or a
TensorShape, which represent the shape of the output. For backward
compatible reason, if this attribute is not available for the
cell, the value will be inferred by the first element of the
`state_size`.
- a `get_initial_state(inputs=None, batch_size=None, dtype=None)`
method that creates a tensor meant to be fed to `call()` as the
initial state, if user didn't specify any initial state via other
means. The returned initial state should be in shape of
[batch, cell.state_size]. Cell might choose to create zero filled
tensor, or with other values based on the cell implementations.
`inputs` is the input tensor to the RNN layer, which should
contain the batch size as its shape[0], and also dtype. Note that
the shape[0] might be None during the graph construction. Either
the `inputs` or the pair of `batch` and `dtype `are provided.
`batch` is a scalar tensor that represent the batch size
of the input. `dtype` is `tf.dtype` that represent the dtype of
the input.
For backward compatible reason, if this method is not implemented
by the cell, RNN layer will create a zero filled tensors with the
size of [batch, cell.state_size].
In the case that `cell` is a list of RNN cell instances, the cells
will be stacked on after the other in the RNN, implementing an
efficient stacked RNN.
return_sequences: Boolean. Whether to return the last output
in the output sequence, or the full sequence.
return_state: Boolean. Whether to return the last state
in addition to the output.
go_backwards: Boolean (default False).
If True, process the input sequence backwards and return the
reversed sequence.
stateful: Boolean (default False). If True, the last state
for each sample at index i in a batch will be used as initial
state for the sample of index i in the following batch.
unroll: Boolean (default False).
If True, the network will be unrolled,
else a symbolic loop will be used.
Unrolling can speed-up a RNN,
although it tends to be more memory-intensive.
Unrolling is only suitable for short sequences.
input_dim: dimensionality of the input (integer or tuple of integers).
This argument (or alternatively, the keyword argument `input_shape`)
is required when using this layer as the first layer in a model.
input_length: Length of input sequences, to be specified
when it is constant.
This argument is required if you are going to connect
`Flatten` then `Dense` layers upstream
(without it, the shape of the dense outputs cannot be computed).
Note that if the recurrent layer is not the first layer
in your model, you would need to specify the input length
at the level of the first layer
(e.g. via the `input_shape` argument)
time_major: The shape format of the `inputs` and `outputs` tensors.
If True, the inputs and outputs will be in shape
`(timesteps, batch, ...)`, whereas in the False case, it will be
`(batch, timesteps, ...)`. Using `time_major = True` is a bit more
efficient because it avoids transposes at the beginning and end of the
RNN calculation. However, most TensorFlow data is batch-major, so by
default this function accepts input and emits output in batch-major
form.
Input shape:
N-D tensor with shape `(batch_size, timesteps, ...)` or
`(timesteps, batch_size, ...)` when time_major is True.
Output shape:
- if `return_state`: a list of tensors. The first tensor is
the output. The remaining tensors are the last states,
each with shape `(batch_size, state_size)`, where `state_size` could
be a high dimension tensor shape.
- if `return_sequences`: N-D tensor with shape
`(batch_size, timesteps, output_size)`, where `output_size` could
be a high dimension tensor shape, or
`(timesteps, batch_size, output_size)` when `time_major` is True.
- else, N-D tensor with shape `(batch_size, output_size)`, where
`output_size` could be a high dimension tensor shape.
# Masking
This layer supports masking for input data with a variable number
of timesteps. To introduce masks to your data,
use an [Embedding](embeddings.md) layer with the `mask_zero` parameter
set to `True`.
# Note on using statefulness in RNNs
You can set RNN layers to be 'stateful', which means that the states
computed for the samples in one batch will be reused as initial states
for the samples in the next batch. This assumes a one-to-one mapping
between samples in different successive batches.
To enable statefulness:
- specify `stateful=True` in the layer constructor.
- specify a fixed batch size for your model, by passing
if sequential model:
`batch_input_shape=(...)` to the first layer in your model.
else for functional model with 1 or more Input layers:
`batch_shape=(...)` to all the first layers in your model.
This is the expected shape of your inputs
*including the batch size*.
It should be a tuple of integers, e.g. `(32, 10, 100)`.
- specify `shuffle=False` when calling fit().
To reset the states of your model, call `.reset_states()` on either
a specific layer, or on your entire model.
# Note on specifying the initial state of RNNs
You can specify the initial state of RNN layers symbolically by
calling them with the keyword argument `initial_state`. The value of
`initial_state` should be a tensor or list of tensors representing
the initial state of the RNN layer.
You can specify the initial state of RNN layers numerically by
calling `reset_states` with the keyword argument `states`. The value of
`states` should be a numpy array or list of numpy arrays representing
the initial state of the RNN layer.
# Note on passing external constants to RNNs
You can pass "external" constants to the cell using the `constants`
keyword argument of `RNN.__call__` (as well as `RNN.call`) method. This
requires that the `cell.call` method accepts the same keyword argument
`constants`. Such constants can be used to condition the cell
transformation on additional static inputs (not changing over time),
a.k.a. an attention mechanism.
Examples:
```python
# First, let's define a RNN Cell, as a layer subclass.
class MinimalRNNCell(keras.layers.Layer):
def __init__(self, units, **kwargs):
self.units = units
self.state_size = units
super(MinimalRNNCell, self).__init__(**kwargs)
def build(self, input_shape):
self.kernel = self.add_weight(shape=(input_shape[-1], self.units),
initializer='uniform',
name='kernel')
self.recurrent_kernel = self.add_weight(
shape=(self.units, self.units),
initializer='uniform',
name='recurrent_kernel')
self.built = True
def call(self, inputs, states):
prev_output = states[0]
h = K.dot(inputs, self.kernel)
output = h + K.dot(prev_output, self.recurrent_kernel)
return output, [output]
# Let's use this cell in a RNN layer:
cell = MinimalRNNCell(32)
x = keras.Input((None, 5))
layer = RNN(cell)
y = layer(x)
# Here's how to use the cell to build a stacked RNN:
cells = [MinimalRNNCell(32), MinimalRNNCell(64)]
x = keras.Input((None, 5))
layer = RNN(cells)
y = layer(x)
```
"""
@checkpointable.no_automatic_dependency_tracking
def __init__(self,
cell,
return_sequences=False,
return_state=False,
go_backwards=False,
stateful=False,
unroll=False,
time_major=False,
**kwargs):
if isinstance(cell, (list, tuple)):
cell = StackedRNNCells(cell)
if not hasattr(cell, 'call'):
raise ValueError('`cell` should have a `call` method. '
'The RNN was passed:', cell)
if not hasattr(cell, 'state_size'):
raise ValueError('The RNN cell should have '
'an attribute `state_size` '
'(tuple of integers, '
'one integer per RNN state).')
# If True, the output for masked timestep will be zeros, whereas in the
# False case, output from previous timestep is returned for masked timestep.
self.zero_output_for_mask = kwargs.pop('zero_output_for_mask', False)
super(RNN, self).__init__(**kwargs)
self.cell = cell
if isinstance(cell, checkpointable.CheckpointableBase):
self._track_checkpointable(self.cell, name='cell')
self.return_sequences = return_sequences
self.return_state = return_state
self.go_backwards = go_backwards
self.stateful = stateful
self.unroll = unroll
self.time_major = time_major
self.supports_masking = True
# The input shape is unknown yet, it could have nested tensor inputs, and
# the input spec will be the list of specs for flattened inputs.
self.input_spec = None
self.state_spec = None
self._states = None
self.constants_spec = None
self._num_constants = None
self._num_inputs = None
@property
def states(self):
if self._states is None:
state = nest.map_structure(lambda _: None, self.cell.state_size)
return state if nest.is_sequence(self.cell.state_size) else [state]
return self._states
@states.setter
def states(self, states):
self._states = states
def compute_output_shape(self, input_shape):
if isinstance(input_shape, list):
input_shape = input_shape[0]
# Check whether the input shape contains any nested shapes. It could be
# (tensor_shape(1, 2), tensor_shape(3, 4)) or (1, 2, 3) which is from numpy
# inputs.
try:
input_shape = tensor_shape.as_shape(input_shape)
except (ValueError, TypeError):
# A nested tensor input
input_shape = nest.flatten(input_shape)[0]
batch = input_shape[0]
time_step = input_shape[1]
if self.time_major:
batch, time_step = time_step, batch
if _is_multiple_state(self.cell.state_size):
state_size = self.cell.state_size
else:
state_size = [self.cell.state_size]
def _get_output_shape(flat_output_size):
output_dim = tensor_shape.as_shape(flat_output_size).as_list()
if self.return_sequences:
if self.time_major:
output_shape = tensor_shape.as_shape([time_step, batch] + output_dim)
else:
output_shape = tensor_shape.as_shape([batch, time_step] + output_dim)
else:
output_shape = tensor_shape.as_shape([batch] + output_dim)
return output_shape
if getattr(self.cell, 'output_size', None) is not None:
# cell.output_size could be nested structure.
output_shape = nest.flatten(nest.map_structure(
_get_output_shape, self.cell.output_size))
output_shape = output_shape[0] if len(output_shape) == 1 else output_shape
else:
# Note that state_size[0] could be a tensor_shape or int.
output_shape = _get_output_shape(state_size[0])
if self.return_state:
def _get_state_shape(flat_state):
state_shape = [batch] + tensor_shape.as_shape(flat_state).as_list()
return tensor_shape.as_shape(state_shape)
state_shape = nest.map_structure(_get_state_shape, state_size)
return generic_utils.to_list(output_shape) + nest.flatten(state_shape)
else:
return output_shape
def compute_mask(self, inputs, mask):
if isinstance(mask, list):
mask = mask[0]
output_mask = mask if self.return_sequences else None
if self.return_state:
state_mask = [None for _ in self.states]
return [output_mask] + state_mask
else:
return output_mask
def build(self, input_shape):
# Note input_shape will be list of shapes of initial states and
# constants if these are passed in __call__.
if self._num_constants is not None:
constants_shape = input_shape[-self._num_constants:] # pylint: disable=invalid-unary-operand-type
constants_shape = nest.map_structure(
lambda s: tuple(tensor_shape.TensorShape(s).as_list()),
constants_shape)
else:
constants_shape = None
if isinstance(input_shape, list):
input_shape = input_shape[0]
# The input_shape here could be a nest structure.
# do the tensor_shape to shapes here. The input could be single tensor, or a
# nested structure of tensors.
def get_input_spec(shape):
if isinstance(shape, tensor_shape.TensorShape):
input_spec_shape = shape.as_list()
else:
input_spec_shape = list(shape)
batch_index, time_step_index = (1, 0) if self.time_major else (0, 1)
if not self.stateful:
input_spec_shape[batch_index] = None
input_spec_shape[time_step_index] = None
return InputSpec(shape=tuple(input_spec_shape))
def get_step_input_shape(shape):
if isinstance(shape, tensor_shape.TensorShape):
shape = tuple(shape.as_list())
# remove the timestep from the input_shape
return shape[1:] if self.time_major else (shape[0],) + shape[2:]
# Check whether the input shape contains any nested shapes. It could be
# (tensor_shape(1, 2), tensor_shape(3, 4)) or (1, 2, 3) which is from numpy
# inputs.
try:
input_shape = tensor_shape.as_shape(input_shape)
except (ValueError, TypeError):
# A nested tensor input
pass
if not nest.is_sequence(input_shape):
# This indicates the there is only one input.
if self.input_spec is not None:
self.input_spec[0] = get_input_spec(input_shape)
else:
self.input_spec = [get_input_spec(input_shape)]
step_input_shape = get_step_input_shape(input_shape)
else:
flat_input_shapes = nest.flatten(input_shape)
flat_input_shapes = nest.map_structure(get_input_spec, flat_input_shapes)
assert len(flat_input_shapes) == self._num_inputs
if self.input_spec is not None:
self.input_spec[:self._num_inputs] = flat_input_shapes
else:
self.input_spec = flat_input_shapes
step_input_shape = nest.map_structure(get_step_input_shape, input_shape)
# allow cell (if layer) to build before we set or validate state_spec
if isinstance(self.cell, Layer):
if constants_shape is not None:
self.cell.build([step_input_shape] + constants_shape)
else:
self.cell.build(step_input_shape)
# set or validate state_spec
if _is_multiple_state(self.cell.state_size):
state_size = list(self.cell.state_size)
else:
state_size = [self.cell.state_size]
if self.state_spec is not None:
# initial_state was passed in call, check compatibility
self._validate_state_spec(state_size, self.state_spec)
else:
self.state_spec = [
InputSpec(shape=[None] + tensor_shape.as_shape(dim).as_list())
for dim in state_size
]
if self.stateful:
self.reset_states()
self.built = True
@staticmethod
def _validate_state_spec(cell_state_sizes, init_state_specs):
"""Validate the state spec between the initial_state and the state_size.
Args:
cell_state_sizes: list, the `state_size` attribute from the cell.
init_state_specs: list, the `state_spec` from the initial_state that is
passed in call()
Raises:
ValueError: When initial state spec is not compatible with the state size.
"""
validation_error = ValueError(
'An `initial_state` was passed that is not compatible with '
'`cell.state_size`. Received `state_spec`={}; '
'however `cell.state_size` is '
'{}'.format(init_state_specs, cell_state_sizes))
if len(cell_state_sizes) == len(init_state_specs):
for i in range(len(cell_state_sizes)):
if not tensor_shape.TensorShape(
# Ignore the first axis for init_state which is for batch
init_state_specs[i].shape[1:]).is_compatible_with(
tensor_shape.TensorShape(cell_state_sizes[i])):
raise validation_error
else:
raise validation_error
def get_initial_state(self, inputs):
get_initial_state_fn = getattr(self.cell, 'get_initial_state', None)
if nest.is_sequence(inputs):
# The input are nested sequences. Use the first element in the seq to get
# batch size and dtype.
inputs = nest.flatten(inputs)[0]
input_shape = array_ops.shape(inputs)
batch_size = input_shape[1] if self.time_major else input_shape[0]
dtype = inputs.dtype
if get_initial_state_fn:
init_state = get_initial_state_fn(
inputs=None, batch_size=batch_size, dtype=dtype)
else:
init_state = _generate_zero_filled_state(batch_size, self.cell.state_size,
dtype)
# Keras RNN expect the states in a list, even if it's a single state tensor.
if not nest.is_sequence(init_state):
init_state = [init_state]
# Force the state to be a list in case it is a namedtuple eg LSTMStateTuple.
return list(init_state)
def __call__(self, inputs, initial_state=None, constants=None, **kwargs):
inputs, initial_state, constants = _standardize_args(inputs,
initial_state,
constants,
self._num_constants,
self._num_inputs)
# in case the real inputs is a nested structure, set the size of flatten
# input so that we can distinguish between real inputs, initial_state and
# constants.
self._num_inputs = len(nest.flatten(inputs))
if initial_state is None and constants is None:
return super(RNN, self).__call__(inputs, **kwargs)
# If any of `initial_state` or `constants` are specified and are Keras
# tensors, then add them to the inputs and temporarily modify the
# input_spec to include them.
additional_inputs = []
additional_specs = []
if initial_state is not None:
additional_inputs += initial_state
self.state_spec = [
InputSpec(shape=K.int_shape(state)) for state in initial_state
]
additional_specs += self.state_spec
if constants is not None:
additional_inputs += constants
self.constants_spec = [
InputSpec(shape=K.int_shape(constant)) for constant in constants
]
self._num_constants = len(constants)
additional_specs += self.constants_spec
# at this point additional_inputs cannot be empty
is_keras_tensor = K.is_keras_tensor(additional_inputs[0])
for tensor in additional_inputs:
if K.is_keras_tensor(tensor) != is_keras_tensor:
raise ValueError('The initial state or constants of an RNN'
' layer cannot be specified with a mix of'
' Keras tensors and non-Keras tensors'
' (a "Keras tensor" is a tensor that was'
' returned by a Keras layer, or by `Input`)')
if is_keras_tensor:
# Compute the full input spec, including state and constants
full_input = [inputs] + additional_inputs
# The original input_spec is None since there could be a nested tensor
# input. Update the input_spec to match the inputs.
full_input_spec = [None for _ in range(len(nest.flatten(inputs)))
] + additional_specs
# Perform the call with temporarily replaced input_spec
original_input_spec = self.input_spec
self.input_spec = full_input_spec
output = super(RNN, self).__call__(full_input, **kwargs)
self.input_spec = original_input_spec
return output
else:
if initial_state is not None:
kwargs['initial_state'] = initial_state
if constants is not None:
kwargs['constants'] = constants
return super(RNN, self).__call__(inputs, **kwargs)
def call(self,
inputs,
mask=None,
training=None,
initial_state=None,
constants=None):
inputs, initial_state, constants = self._process_inputs(
inputs, initial_state, constants)
if isinstance(mask, list):
mask = mask[0]
if nest.is_sequence(inputs):
# In the case of nested input, use the first element for shape check.
input_shape = K.int_shape(nest.flatten(inputs)[0])
else:
input_shape = K.int_shape(inputs)
timesteps = input_shape[0] if self.time_major else input_shape[1]
if self.unroll and timesteps in [None, 1]:
raise ValueError('Cannot unroll a RNN if the '
'time dimension is undefined or equal to 1. \n'
'- If using a Sequential model, '
'specify the time dimension by passing '
'an `input_shape` or `batch_input_shape` '
'argument to your first layer. If your '
'first layer is an Embedding, you can '
'also use the `input_length` argument.\n'
'- If using the functional API, specify '
'the time dimension by passing a `shape` '
'or `batch_shape` argument to your Input layer.')
kwargs = {}
if generic_utils.has_arg(self.cell.call, 'training'):
kwargs['training'] = training
# TF RNN cells expect single tensor as state instead of list wrapped tensor.
is_tf_rnn_cell = getattr(self.cell, '_is_tf_rnn_cell', None) is not None
if constants:
if not generic_utils.has_arg(self.cell.call, 'constants'):
raise ValueError('RNN cell does not support constants')
def step(inputs, states):
constants = states[-self._num_constants:] # pylint: disable=invalid-unary-operand-type
states = states[:-self._num_constants] # pylint: disable=invalid-unary-operand-type
states = states[0] if len(states) == 1 and is_tf_rnn_cell else states
output, new_states = self.cell.call(
inputs, states, constants=constants, **kwargs)
if not nest.is_sequence(new_states):
new_states = [new_states]
return output, new_states
else:
def step(inputs, states):
states = states[0] if len(states) == 1 and is_tf_rnn_cell else states
output, new_states = self.cell.call(inputs, states, **kwargs)
if not nest.is_sequence(new_states):
new_states = [new_states]
return output, new_states
last_output, outputs, states = K.rnn(
step,
inputs,
initial_state,
constants=constants,
go_backwards=self.go_backwards,
mask=mask,
unroll=self.unroll,
input_length=timesteps,
time_major=self.time_major,
zero_output_for_mask=self.zero_output_for_mask)
if self.stateful:
updates = []
for i in range(len(states)):
updates.append(state_ops.assign(self.states[i], states[i]))
self.add_update(updates, inputs)
if self.return_sequences:
output = outputs
else:
output = last_output
if self.return_state:
if not isinstance(states, (list, tuple)):
states = [states]
else:
states = list(states)
return generic_utils.to_list(output) + states
else:
return output
def _process_inputs(self, inputs, initial_state, constants):
# input shape: `(samples, time (padded with zeros), input_dim)`
# note that the .build() method of subclasses MUST define
# self.input_spec and self.state_spec with complete input shapes.
if isinstance(inputs, list):
# get initial_state from full input spec
# as they could be copied to multiple GPU.
if self._num_constants is None:
initial_state = inputs[1:]
else:
initial_state = inputs[1:-self._num_constants]
constants = inputs[-self._num_constants:]
if len(initial_state) == 0:
initial_state = None
inputs = inputs[0]
if initial_state is not None:
pass
elif self.stateful:
initial_state = self.states
else:
initial_state = self.get_initial_state(inputs)
if len(initial_state) != len(self.states):
raise ValueError('Layer has ' + str(len(self.states)) +
' states but was passed ' + str(len(initial_state)) +
' initial states.')
return inputs, initial_state, constants
def reset_states(self, states=None):
if not self.stateful:
raise AttributeError('Layer must be stateful.')
if self.time_major:
batch_size = self.input_spec[0].shape[1]
else:
batch_size = self.input_spec[0].shape[0]
if not batch_size:
raise ValueError('If a RNN is stateful, it needs to know '
'its batch size. Specify the batch size '
'of your input tensors: \n'
'- If using a Sequential model, '
'specify the batch size by passing '
'a `batch_input_shape` '
'argument to your first layer.\n'
'- If using the functional API, specify '
'the batch size by passing a '
'`batch_shape` argument to your Input layer.')
# initialize state if None
if self.states[0] is None:
if _is_multiple_state(self.cell.state_size):
self.states = [
K.zeros([batch_size] + tensor_shape.as_shape(dim).as_list())
for dim in self.cell.state_size
]
else:
self.states = [
K.zeros([batch_size] +
tensor_shape.as_shape(self.cell.state_size).as_list())
]
elif states is None:
if _is_multiple_state(self.cell.state_size):
for state, dim in zip(self.states, self.cell.state_size):
K.set_value(state,
np.zeros([batch_size] +
tensor_shape.as_shape(dim).as_list()))
else:
K.set_value(self.states[0], np.zeros(
[batch_size] +
tensor_shape.as_shape(self.cell.state_size).as_list()))
else:
if not isinstance(states, (list, tuple)):
states = [states]
if len(states) != len(self.states):
raise ValueError('Layer ' + self.name + ' expects ' +
str(len(self.states)) + ' states, '
'but it received ' + str(len(states)) +
' state values. Input received: ' + str(states))
for index, (value, state) in enumerate(zip(states, self.states)):
if _is_multiple_state(self.cell.state_size):
dim = self.cell.state_size[index]
else:
dim = self.cell.state_size
if value.shape != tuple([batch_size] +
tensor_shape.as_shape(dim).as_list()):
raise ValueError(
'State ' + str(index) + ' is incompatible with layer ' +
self.name + ': expected shape=' + str(
(batch_size, dim)) + ', found shape=' + str(value.shape))
# TODO(fchollet): consider batch calls to `set_value`.
K.set_value(state, value)
def get_config(self):
config = {
'return_sequences': self.return_sequences,
'return_state': self.return_state,
'go_backwards': self.go_backwards,
'stateful': self.stateful,
'unroll': self.unroll,
'time_major': self.time_major
}
if self._num_constants is not None:
config['num_constants'] = self._num_constants
if self.zero_output_for_mask:
config['zero_output_for_mask'] = self.zero_output_for_mask
cell_config = self.cell.get_config()
config['cell'] = {
'class_name': self.cell.__class__.__name__,
'config': cell_config
}
base_config = super(RNN, self).get_config()
return dict(list(base_config.items()) + list(config.items()))
@classmethod
def from_config(cls, config, custom_objects=None):
from tensorflow.python.keras.layers import deserialize as deserialize_layer # pylint: disable=g-import-not-at-top
cell = deserialize_layer(config.pop('cell'), custom_objects=custom_objects)
num_constants = config.pop('num_constants', None)
layer = cls(cell, **config)
layer._num_constants = num_constants
return layer
@property
def trainable_weights(self):
if not self.trainable:
return []
if isinstance(self.cell, Layer):
return self.cell.trainable_weights
return []
@property
def non_trainable_weights(self):
if isinstance(self.cell, Layer):
if not self.trainable:
return self.cell.weights
return self.cell.non_trainable_weights
return []
@property
def losses(self):
layer_losses = super(RNN, self).losses
if isinstance(self.cell, Layer):
return self.cell.losses + layer_losses
return layer_losses
@property
def updates(self):
updates = []
if isinstance(self.cell, Layer):
updates += self.cell.updates
return updates + self._updates
@tf_export('keras.layers.SimpleRNNCell')
class SimpleRNNCell(Layer):
"""Cell class for SimpleRNN.
Arguments: