Skip to content

Commit f5bbf09

Browse files
lingvo-botcopybara-github
authored andcommitted
Use tf.slice instead of tf.strided_slice (via tf.Tensor.__getitem__ / indexing syntax) for state tensors.
This ensures that tensors in the output state have the same static shape as tensors in the input state. Before this change the equivalence is only guaranteed for the dynamic shape. PiperOrigin-RevId: 480168486
1 parent 2934dfe commit f5bbf09

File tree

2 files changed

+13
-8
lines changed

2 files changed

+13
-8
lines changed

lingvo/core/batch_major_attention.py

Lines changed: 9 additions & 6 deletions
Original file line numberDiff line numberDiff line change
@@ -2271,7 +2271,8 @@ def StreamStepAddSkipConnection(self, input_to_add, output, state0, state1):
22712271
concat_input_to_add = tf.concat([state0.skip_conn_input, input_to_add],
22722272
axis=1)
22732273
final_output = output + concat_input_to_add[:, :seqlen]
2274-
state1.skip_conn_input = concat_input_to_add[:, seqlen:]
2274+
state1.skip_conn_input = tf.slice(concat_input_to_add, [0, seqlen, 0],
2275+
tf.shape(state0.skip_conn_input))
22752276
return final_output, state1
22762277

22772278
def _StreamStepDimensions(self, inputs):
@@ -2664,12 +2665,14 @@ def _StreamStepDynamicLength(self, theta, query_vec, query_paddings, key_vec,
26642665
output = self.post.FProp(theta.post, output)
26652666

26662667
state1 = py_utils.NestedMap(
2667-
key=key[:, k:, :, :],
2668-
value=value[:, k:, :, :],
2669-
masks=state_masks[:, k:])
2668+
key=tf.slice(key, [0, k, 0, 0], tf.shape(state0.key)),
2669+
value=tf.slice(value, [0, k, 0, 0], tf.shape(state0.value)),
2670+
masks=tf.slice(state_masks, [0, k], tf.shape(state0.masks)))
26702671
if p.right_context > 0:
2671-
state1.query = concat_query[:, q:]
2672-
state1.out_masks = concat_out_masks[:, q:]
2672+
state1.query = tf.slice(concat_query, [0, q, 0, 0],
2673+
tf.shape(state0.query))
2674+
state1.out_masks = tf.slice(concat_out_masks, [0, q],
2675+
tf.shape(state0.out_masks))
26732676
return output, out_paddings, state1
26742677

26752678
@classmethod

lingvo/core/conv_layers_with_time_padding.py

Lines changed: 4 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -596,7 +596,8 @@ def StreamStep(self, theta, inputs, paddings, state0):
596596
padding='VALID')
597597
if p.bias:
598598
outputs = tf.nn.bias_add(outputs, theta.b)
599-
new_context = concat_inputs[:, q:]
599+
new_context = tf.slice(concat_inputs, [0, q, 0, 0],
600+
tf.shape(state0.context))
600601
return outputs, paddings, py_utils.NestedMap(context=new_context)
601602

602603

@@ -773,7 +774,8 @@ def StreamStep(self, theta, inputs, paddings, state0):
773774
padding='VALID')
774775
if p.bias:
775776
outputs = tf.nn.bias_add(outputs, theta.b)
776-
new_context = concat_inputs[:, q:]
777+
new_context = tf.slice(concat_inputs, [0, q, 0, 0],
778+
tf.shape(state0.context))
777779
return outputs, paddings, py_utils.NestedMap(context=new_context)
778780

779781

0 commit comments

Comments
 (0)