From e1bde16a3da212c3031688631d6db2441ea6f70d Mon Sep 17 00:00:00 2001 From: Valentin Berkes Date: Mon, 3 May 2021 16:57:12 +0000 Subject: [PATCH 1/5] [WIP] pseudo self attention for LM + load LM with PSA --- onmt/decoders/__init__.py | 37 ++++-- onmt/decoders/transformer.py | 209 +++++++++++++++++++++++++++++- onmt/modules/__init__.py | 4 +- onmt/modules/multi_headed_attn.py | 11 ++ onmt/opts.py | 3 +- 5 files changed, 250 insertions(+), 14 deletions(-) diff --git a/onmt/decoders/__init__.py b/onmt/decoders/__init__.py index 2b9a7acd34..1e50b7cd96 100644 --- a/onmt/decoders/__init__.py +++ b/onmt/decoders/__init__.py @@ -1,13 +1,32 @@ """Module defining decoders.""" -from onmt.decoders.decoder import DecoderBase, InputFeedRNNDecoder, \ - StdRNNDecoder -from onmt.decoders.transformer import TransformerDecoder, TransformerLMDecoder from onmt.decoders.cnn_decoder import CNNDecoder +from onmt.decoders.decoder import ( + DecoderBase, + InputFeedRNNDecoder, + StdRNNDecoder, +) +from onmt.decoders.transformer import ( + TransformerDecoder, + TransformerLMDecoder, + TransformerLMPseudoSelfAttentionDecoder, +) +str2dec = { + "rnn": StdRNNDecoder, + "ifrnn": InputFeedRNNDecoder, + "cnn": CNNDecoder, + "transformer": TransformerDecoder, + "transformer_lm": TransformerLMDecoder, + "transformer_lm_psa": TransformerLMPseudoSelfAttentionDecoder, +} -str2dec = {"rnn": StdRNNDecoder, "ifrnn": InputFeedRNNDecoder, - "cnn": CNNDecoder, "transformer": TransformerDecoder, - "transformer_lm": TransformerLMDecoder} - -__all__ = ["DecoderBase", "TransformerDecoder", "StdRNNDecoder", "CNNDecoder", - "InputFeedRNNDecoder", "str2dec", "TransformerLMDecoder"] +__all__ = [ + "DecoderBase", + "TransformerDecoder", + "StdRNNDecoder", + "CNNDecoder", + "InputFeedRNNDecoder", + "str2dec", + "TransformerLMDecoder", + "TransformerLMPseudoSelfAttentionDecoder", +] diff --git a/onmt/decoders/transformer.py b/onmt/decoders/transformer.py index a50e4a8e9c..c882a44802 100644 --- a/onmt/decoders/transformer.py +++ b/onmt/decoders/transformer.py @@ -7,7 +7,7 @@ import torch.nn as nn from onmt.decoders.decoder import DecoderBase -from onmt.modules import MultiHeadedAttention, AverageAttention +from onmt.modules import MultiHeadedAttention, AverageAttention, MultiHeadedPseudoSelfAttention from onmt.modules.position_ffn import PositionwiseFeedForward from onmt.modules.position_ffn import ActivationFunction from onmt.utils.misc import sequence_mask @@ -68,7 +68,13 @@ def __init__( self.self_attn = AverageAttention( d_model, dropout=attention_dropout, aan_useffn=aan_useffn ) - + elif self_attn_type == "pseudo-self": + self.self_attn = MultiHeadedPseudoSelfAttention( + heads, + d_model, + dropout=attention_dropout, + max_relative_positions=max_relative_positions, + ) self.feed_forward = PositionwiseFeedForward(d_model, d_ff, dropout, pos_ffn_activation_fn ) @@ -693,3 +699,202 @@ def _init_cache(self, memory_bank=None): if isinstance(layer.self_attn, AverageAttention): raise NotImplementedError self.state["cache"]["layer_{}".format(i)] = layer_cache + +class TransformerLMPseudoSelfAttentionDecoderLayer(TransformerDecoderLayerBase): + """Transformer Decoder only layer block in GPT style. + + .. mermaid:: + + graph LR + %% "*SubLayer" can be self-attn, src-attn or feed forward block + A(input) --> B[Norm] + B --> C["*SubLayer"] + C --> D[Drop] + D --> E((+)) + A --> E + E --> F(out) + + + Args: + See TransformerDecoderLayerBase + """ + + def _forward( + self, inputs, + src_memory_bank, + src_pad_mask, + tgt_pad_mask, layer_cache=None, step=None, future=False + ): + """A naive forward pass for transformer decoder. + + # T: could be 1 in the case of stepwise decoding or tgt_len + + Args: + inputs (FloatTensor): ``(batch_size, T, model_dim)`` + tgt_pad_mask (bool): ``(batch_size, 1, T)`` + layer_cache (dict or None): cached layer info when stepwise decode + step (int or None): stepwise decoding counter + future (bool): If set True, do not apply future_mask. + + Returns: + (FloatTensor, FloatTensor): + + * output ``(batch_size, T, model_dim)`` + * attns ``(batch_size, head, T, T)`` + + """ + dec_mask = None + + if inputs.size(1) > 1: + # masking is necessary when sequence length is greater than one + dec_mask = self._compute_dec_mask(tgt_pad_mask, future) + + inputs_norm = self.layer_norm_1(inputs) + + pseudo_key_value = torch.cat([src_memory_bank.transpose(0, 1), inputs], axis=1) + pseudo_mask = torch.cat([src_pad_mask.repeat(1, dec_mask.size(1), 1), dec_mask], axis=-1) + query, attns = self.self_attn( + pseudo_key_value, + pseudo_key_value, + inputs_norm, + mask=pseudo_mask, + layer_cache=layer_cache, + attn_type="self", + ) + + output = self.drop(query) + inputs + + output_feedforward = self.feed_forward(output) + + return output_feedforward, attns + +class TransformerLMPseudoSelfAttentionDecoder(TransformerDecoderBase): + """The Transformer decoder from GPT-2 with pseudo self attention + + .. mermaid:: + + graph BT + A[input] + B[multi-head self-attn] + C[feed forward] + O[output] + A --> B + B --> C + C --> O + + + Args: + num_layers (int): number of decoder layers. + d_model (int): size of the model + heads (int): number of heads + d_ff (int): size of the inner FF layer + copy_attn (bool): if using a separate copy attention + self_attn_type (str): type of self-attention scaled-dot, average + dropout (float): dropout in residual, self-attn(dot) and feed-forward + attention_dropout (float): dropout in context_attn (and self-attn(avg)) + embeddings (onmt.modules.Embeddings): + embeddings to use, should have positional encodings + max_relative_positions (int): + Max distance between inputs in relative positions representations + aan_useffn (bool): Turn on the FFN layer in the AAN decoder + """ + + def __init__( + self, + num_layers, + d_model, + heads, + d_ff, + copy_attn, + self_attn_type, + dropout, + attention_dropout, + embeddings, + max_relative_positions, + aan_useffn, + full_context_alignment=None, + alignment_layer=None, + alignment_heads=None, + pos_ffn_activation_fn=ActivationFunction.relu, + ): + super(TransformerLMPseudoSelfAttentionDecoder, self).__init__( + d_model, copy_attn, embeddings, None + ) + self.transformer_layers = nn.ModuleList( + [ + TransformerLMPseudoSelfAttentionDecoderLayer( + d_model, + heads, + d_ff, + dropout, + attention_dropout, + self_attn_type="pseudo-self", + max_relative_positions=max_relative_positions, + aan_useffn=aan_useffn, + full_context_alignment=None, + alignment_heads=None, + pos_ffn_activation_fn=pos_ffn_activation_fn, + ) + for i in range(num_layers) + ] + ) + + def detach_state(self): + pass + + def forward(self, tgt, memory_bank=None, step=None, **kwargs): + """Decode, possibly stepwise.""" + if step == 0: + self._init_cache() + + tgt_words = tgt[:, :, 0].transpose(0, 1) + + emb = self.embeddings(tgt, step=step) + assert emb.dim() == 3 # len x batch x embedding_dim + + output = emb.transpose(0, 1).contiguous() + + pad_idx = self.embeddings.word_padding_idx + src_lens = kwargs["memory_lengths"] + src_max_len = self.state["src"].shape[0] + src_pad_mask = ~sequence_mask(src_lens, src_max_len).unsqueeze(1) + tgt_pad_mask = tgt_words.data.eq(pad_idx).unsqueeze(1) # [B, 1, T_tgt] + + with_align = kwargs.pop("with_align", False) + assert not with_align, "TransformerLMDecoder does not support align" + + for i, layer in enumerate(self.transformer_layers): + layer_cache = ( + self.state["cache"]["layer_{}".format(i)] + if step is not None + else None + ) + output, attn, _ = layer( + output, + memory_bank, + src_pad_mask, + tgt_pad_mask, + layer_cache=layer_cache, + step=step, + with_align=with_align, + ) + + output = self.layer_norm(output) + dec_outs = output.transpose(0, 1).contiguous() + attn = attn.transpose(0, 1).contiguous() + + attns = {"std": attn} + if self._copy: + attns["copy"] = attn + + # TODO change the way attns is returned dict => list or tuple (onnx) + return dec_outs, attns + + def _init_cache(self, memory_bank=None): + self.state["cache"] = {} + + for i, layer in enumerate(self.transformer_layers): + layer_cache = {"self_keys": None, "self_values": None} + if isinstance(layer.self_attn, AverageAttention): + raise NotImplementedError + self.state["cache"]["layer_{}".format(i)] = layer_cache diff --git a/onmt/modules/__init__.py b/onmt/modules/__init__.py index 0e789e5774..83ccf0d203 100644 --- a/onmt/modules/__init__.py +++ b/onmt/modules/__init__.py @@ -5,7 +5,7 @@ from onmt.modules.conv_multi_step_attention import ConvMultiStepAttention from onmt.modules.copy_generator import CopyGenerator, CopyGeneratorLoss, \ CopyGeneratorLossCompute, CopyGeneratorLMLossCompute -from onmt.modules.multi_headed_attn import MultiHeadedAttention +from onmt.modules.multi_headed_attn import MultiHeadedAttention, MultiHeadedPseudoSelfAttention from onmt.modules.embeddings import Embeddings, PositionalEncoding from onmt.modules.weight_norm import WeightNormConv2d from onmt.modules.average_attn import AverageAttention @@ -15,4 +15,4 @@ "CopyGeneratorLoss", "CopyGeneratorLossCompute", "MultiHeadedAttention", "Embeddings", "PositionalEncoding", "WeightNormConv2d", "AverageAttention", - "CopyGeneratorLMLossCompute"] + "CopyGeneratorLMLossCompute", "MultiHeadedPseudoSelfAttention"] diff --git a/onmt/modules/multi_headed_attn.py b/onmt/modules/multi_headed_attn.py index a9b8b487e0..0887a329e9 100644 --- a/onmt/modules/multi_headed_attn.py +++ b/onmt/modules/multi_headed_attn.py @@ -228,3 +228,14 @@ def unshape(x): def update_dropout(self, dropout): self.dropout.p = dropout + + +class MultiHeadedPseudoSelfAttention(MultiHeadedAttention): + def __init__(self, head_count, model_dim, dropout=0.1, + max_relative_positions=0): + super().__init__(head_count, model_dim, dropout=0.1, + max_relative_positions=0) + self.linear_keys = nn.Linear(model_dim, + head_count * self.dim_per_head, 2) + self.linear_values = nn.Linear(model_dim, + head_count * self.dim_per_head, 2) diff --git a/onmt/opts.py b/onmt/opts.py index 6872e351a4..fd09ac4647 100644 --- a/onmt/opts.py +++ b/onmt/opts.py @@ -271,7 +271,8 @@ def model_opts(parser): "are experimental. Options are " "[rnn|brnn|ggnn|mean|transformer|cnn|transformer_lm].") group.add('--decoder_type', '-decoder_type', type=str, default='rnn', - choices=['rnn', 'transformer', 'cnn', 'transformer_lm'], + choices=['rnn', 'transformer', 'cnn', 'transformer_lm', + 'transformer_lm_psa'], help="Type of decoder layer to use. Non-RNN layers " "are experimental. Options are " "[rnn|transformer|cnn|transformer].") From 96cc29be6fad68ba91223924d21c3f84c705862b Mon Sep 17 00:00:00 2001 From: Valentin Berkes Date: Tue, 4 May 2021 11:01:34 +0000 Subject: [PATCH 2/5] implement tests --- .github/workflows/push.yml | 12 +++ onmt/decoders/transformer.py | 32 ++++-- onmt/modules/__init__.py | 36 +++++-- onmt/modules/multi_headed_attn.py | 114 +++++++++++--------- onmt/tests/pull_request_chk.sh | 14 +++ onmt/tests/test_pseudo_self_attention.py | 131 +++++++++++++++++++++++ 6 files changed, 272 insertions(+), 67 deletions(-) create mode 100644 onmt/tests/test_pseudo_self_attention.py diff --git a/.github/workflows/push.yml b/.github/workflows/push.yml index caa95631f7..96f2472eed 100644 --- a/.github/workflows/push.yml +++ b/.github/workflows/push.yml @@ -96,6 +96,18 @@ jobs: -word_vec_size 5 -report_every 5 \ -coverage_attn true -lambda_coverage 0.1 \ -rnn_size 10 -train_steps 10 + - name: Test Transformer training with pseudo self attention + run : | + python train.py \ + -config data/align_data.yaml \ + -src_vocab /tmp/onmt.vocab.src \ + -tgt_vocab /tmp/onmt.vocab.tgt \ + -src_vocab_size 1000 \ + -tgt_vocab_size 1000 \ + -max_generator_batches 0 \ + -encoder_type transformer -decoder_type transformer_lm_psa \ + -layers 4 -word_vec_size 16 -rnn_size 16 -heads 2 -transformer_ff 64 \ + -report_every 5 -train_steps 10 - name: Test Transformer training with align run: | python train.py \ diff --git a/onmt/decoders/transformer.py b/onmt/decoders/transformer.py index c882a44802..f930278305 100644 --- a/onmt/decoders/transformer.py +++ b/onmt/decoders/transformer.py @@ -7,7 +7,8 @@ import torch.nn as nn from onmt.decoders.decoder import DecoderBase -from onmt.modules import MultiHeadedAttention, AverageAttention, MultiHeadedPseudoSelfAttention +from onmt.modules import MultiHeadedAttention, AverageAttention +from onmt.modules import MultiHeadedPseudoSelfAttention from onmt.modules.position_ffn import PositionwiseFeedForward from onmt.modules.position_ffn import ActivationFunction from onmt.utils.misc import sequence_mask @@ -126,7 +127,8 @@ def update_dropout(self, dropout, attention_dropout): def _forward(self, *args, **kwargs): raise NotImplementedError - def _compute_dec_mask(self, tgt_pad_mask, future): + @staticmethod + def _compute_dec_mask(tgt_pad_mask, future): tgt_len = tgt_pad_mask.size(-1) if not future: # apply future_mask, result mask in (B, T, T) future_mask = torch.ones( @@ -700,7 +702,10 @@ def _init_cache(self, memory_bank=None): raise NotImplementedError self.state["cache"]["layer_{}".format(i)] = layer_cache -class TransformerLMPseudoSelfAttentionDecoderLayer(TransformerDecoderLayerBase): + +class TransformerLMPseudoSelfAttentionDecoderLayer( + TransformerDecoderLayerBase +): """Transformer Decoder only layer block in GPT style. .. mermaid:: @@ -720,10 +725,14 @@ class TransformerLMPseudoSelfAttentionDecoderLayer(TransformerDecoderLayerBase): """ def _forward( - self, inputs, - src_memory_bank, - src_pad_mask, - tgt_pad_mask, layer_cache=None, step=None, future=False + self, + inputs, + src_memory_bank, + src_pad_mask, + tgt_pad_mask, + layer_cache=None, + step=None, + future=False, ): """A naive forward pass for transformer decoder. @@ -751,8 +760,12 @@ def _forward( inputs_norm = self.layer_norm_1(inputs) - pseudo_key_value = torch.cat([src_memory_bank.transpose(0, 1), inputs], axis=1) - pseudo_mask = torch.cat([src_pad_mask.repeat(1, dec_mask.size(1), 1), dec_mask], axis=-1) + pseudo_key_value = torch.cat( + [src_memory_bank.transpose(0, 1), inputs], + axis=1, + ) + pseudo_mask = torch.cat([src_pad_mask.repeat(1, dec_mask.size(1), 1), + dec_mask], axis=-1) query, attns = self.self_attn( pseudo_key_value, pseudo_key_value, @@ -768,6 +781,7 @@ def _forward( return output_feedforward, attns + class TransformerLMPseudoSelfAttentionDecoder(TransformerDecoderBase): """The Transformer decoder from GPT-2 with pseudo self attention diff --git a/onmt/modules/__init__.py b/onmt/modules/__init__.py index 83ccf0d203..bbf835d46f 100644 --- a/onmt/modules/__init__.py +++ b/onmt/modules/__init__.py @@ -3,16 +3,34 @@ from onmt.modules.gate import context_gate_factory, ContextGate from onmt.modules.global_attention import GlobalAttention from onmt.modules.conv_multi_step_attention import ConvMultiStepAttention -from onmt.modules.copy_generator import CopyGenerator, CopyGeneratorLoss, \ - CopyGeneratorLossCompute, CopyGeneratorLMLossCompute -from onmt.modules.multi_headed_attn import MultiHeadedAttention, MultiHeadedPseudoSelfAttention +from onmt.modules.copy_generator import ( + CopyGenerator, + CopyGeneratorLoss, + CopyGeneratorLossCompute, + CopyGeneratorLMLossCompute, +) +from onmt.modules.multi_headed_attn import ( + MultiHeadedAttention, + MultiHeadedPseudoSelfAttention, +) from onmt.modules.embeddings import Embeddings, PositionalEncoding from onmt.modules.weight_norm import WeightNormConv2d from onmt.modules.average_attn import AverageAttention -__all__ = ["Elementwise", "context_gate_factory", "ContextGate", - "GlobalAttention", "ConvMultiStepAttention", "CopyGenerator", - "CopyGeneratorLoss", "CopyGeneratorLossCompute", - "MultiHeadedAttention", "Embeddings", "PositionalEncoding", - "WeightNormConv2d", "AverageAttention", - "CopyGeneratorLMLossCompute", "MultiHeadedPseudoSelfAttention"] +__all__ = [ + "Elementwise", + "context_gate_factory", + "ContextGate", + "GlobalAttention", + "ConvMultiStepAttention", + "CopyGenerator", + "CopyGeneratorLoss", + "CopyGeneratorLossCompute", + "MultiHeadedAttention", + "Embeddings", + "PositionalEncoding", + "WeightNormConv2d", + "AverageAttention", + "CopyGeneratorLMLossCompute", + "MultiHeadedPseudoSelfAttention", +] diff --git a/onmt/modules/multi_headed_attn.py b/onmt/modules/multi_headed_attn.py index 0887a329e9..09f04cf3e9 100644 --- a/onmt/modules/multi_headed_attn.py +++ b/onmt/modules/multi_headed_attn.py @@ -3,8 +3,8 @@ import torch import torch.nn as nn -from onmt.utils.misc import generate_relative_positions_matrix,\ - relative_matmul +from onmt.utils.misc import generate_relative_positions_matrix, relative_matmul + # from onmt.utils.misc import aeq @@ -48,8 +48,9 @@ class MultiHeadedAttention(nn.Module): dropout (float): dropout parameter """ - def __init__(self, head_count, model_dim, dropout=0.1, - max_relative_positions=0): + def __init__( + self, head_count, model_dim, dropout=0.1, max_relative_positions=0 + ): assert model_dim % head_count == 0 self.dim_per_head = model_dim // head_count self.model_dim = model_dim @@ -57,12 +58,13 @@ def __init__(self, head_count, model_dim, dropout=0.1, super(MultiHeadedAttention, self).__init__() self.head_count = head_count - self.linear_keys = nn.Linear(model_dim, - head_count * self.dim_per_head) - self.linear_values = nn.Linear(model_dim, - head_count * self.dim_per_head) - self.linear_query = nn.Linear(model_dim, - head_count * self.dim_per_head) + self.linear_keys = nn.Linear(model_dim, head_count * self.dim_per_head) + self.linear_values = nn.Linear( + model_dim, head_count * self.dim_per_head + ) + self.linear_query = nn.Linear( + model_dim, head_count * self.dim_per_head + ) self.softmax = nn.Softmax(dim=-1) self.dropout = nn.Dropout(dropout) self.final_linear = nn.Linear(model_dim, model_dim) @@ -72,10 +74,12 @@ def __init__(self, head_count, model_dim, dropout=0.1, if max_relative_positions > 0: vocab_size = max_relative_positions * 2 + 1 self.relative_positions_embeddings = nn.Embedding( - vocab_size, self.dim_per_head) + vocab_size, self.dim_per_head + ) - def forward(self, key, value, query, mask=None, - layer_cache=None, attn_type=None): + def forward( + self, key, value, query, mask=None, layer_cache=None, attn_type=None + ): """ Compute the context vector and the attention vectors. @@ -120,42 +124,49 @@ def forward(self, key, value, query, mask=None, def shape(x): """Projection.""" - return x.view(batch_size, -1, head_count, dim_per_head) \ - .transpose(1, 2) + return x.view(batch_size, -1, head_count, dim_per_head).transpose( + 1, 2 + ) def unshape(x): """Compute context.""" - return x.transpose(1, 2).contiguous() \ - .view(batch_size, -1, head_count * dim_per_head) + return ( + x.transpose(1, 2) + .contiguous() + .view(batch_size, -1, head_count * dim_per_head) + ) # 1) Project key, value, and query. if layer_cache is not None: if attn_type == "self": - query, key, value = self.linear_query(query),\ - self.linear_keys(query),\ - self.linear_values(query) + query, key, value = ( + self.linear_query(query), + self.linear_keys(query), + self.linear_values(query), + ) key = shape(key) value = shape(value) if layer_cache["self_keys"] is not None: - key = torch.cat( - (layer_cache["self_keys"], key), - dim=2) + key = torch.cat((layer_cache["self_keys"], key), dim=2) if layer_cache["self_values"] is not None: value = torch.cat( - (layer_cache["self_values"], value), - dim=2) + (layer_cache["self_values"], value), dim=2 + ) layer_cache["self_keys"] = key layer_cache["self_values"] = value elif attn_type == "context": query = self.linear_query(query) if layer_cache["memory_keys"] is None: - key, value = self.linear_keys(key),\ - self.linear_values(value) + key, value = self.linear_keys(key), self.linear_values( + value + ) key = shape(key) value = shape(value) else: - key, value = layer_cache["memory_keys"],\ - layer_cache["memory_values"] + key, value = ( + layer_cache["memory_keys"], + layer_cache["memory_values"], + ) layer_cache["memory_keys"] = key layer_cache["memory_values"] = value else: @@ -169,14 +180,18 @@ def unshape(x): key_len = key.size(2) # 1 or key_len x key_len relative_positions_matrix = generate_relative_positions_matrix( - key_len, self.max_relative_positions, - cache=True if layer_cache is not None else False) + key_len, + self.max_relative_positions, + cache=True if layer_cache is not None else False, + ) # 1 or key_len x key_len x dim_per_head relations_keys = self.relative_positions_embeddings( - relative_positions_matrix.to(key.device)) + relative_positions_matrix.to(key.device) + ) # 1 or key_len x key_len x dim_per_head relations_values = self.relative_positions_embeddings( - relative_positions_matrix.to(key.device)) + relative_positions_matrix.to(key.device) + ) query = shape(query) @@ -201,14 +216,13 @@ def unshape(x): # 3) Apply attention dropout and compute context vectors. attn = self.softmax(scores).to(query.dtype) drop_attn = self.dropout(attn) - context_original = torch.matmul(drop_attn, value) if self.max_relative_positions > 0 and attn_type == "self": - context = unshape(context_original - + relative_matmul(drop_attn, - relations_values, - False)) + context = unshape( + context_original + + relative_matmul(drop_attn, relations_values, False) + ) else: context = unshape(context_original) @@ -220,9 +234,7 @@ def unshape(x): # aeq(d, d_) # Return multi-head attn - attns = attn \ - .view(batch_size, head_count, - query_len, key_len) + attns = attn.view(batch_size, head_count, query_len, key_len) return output, attns @@ -231,11 +243,15 @@ def update_dropout(self, dropout): class MultiHeadedPseudoSelfAttention(MultiHeadedAttention): - def __init__(self, head_count, model_dim, dropout=0.1, - max_relative_positions=0): - super().__init__(head_count, model_dim, dropout=0.1, - max_relative_positions=0) - self.linear_keys = nn.Linear(model_dim, - head_count * self.dim_per_head, 2) - self.linear_values = nn.Linear(model_dim, - head_count * self.dim_per_head, 2) + def __init__( + self, head_count, model_dim, dropout=0.1, max_relative_positions=0 + ): + super().__init__( + head_count, model_dim, dropout, max_relative_positions + ) + self.linear_keys = nn.Linear( + model_dim, head_count * self.dim_per_head, 2 + ) + self.linear_values = nn.Linear( + model_dim, head_count * self.dim_per_head, 2 + ) diff --git a/onmt/tests/pull_request_chk.sh b/onmt/tests/pull_request_chk.sh index b282cc7f1e..83a65fd8e3 100755 --- a/onmt/tests/pull_request_chk.sh +++ b/onmt/tests/pull_request_chk.sh @@ -117,6 +117,20 @@ ${PYTHON} onmt/bin/train.py \ [ "$?" -eq 0 ] || error_exit echo "Succeeded" | tee -a ${LOG_FILE} +echo -n " [+] Testing NMT training w/ pseudo self attention..." +${PYTHON} onmt/bin/train.py \ + -config ${DATA_DIR}/align_data.yaml \ + -src_vocab $TMP_OUT_DIR/onmt.vocab.src \ + -tgt_vocab $TMP_OUT_DIR/onmt.vocab.tgt \ + -src_vocab_size 1000 \ + -tgt_vocab_size 1000 \ + -max_generator_batches 0 \ + -encoder_type transformer -decoder_type transformer_lm_psa \ + -layers 4 -word_vec_size 16 -rnn_size 16 -heads 2 -transformer_ff 64 \ + -report_every 5 -train_steps 10 >> ${LOG_FILE} 2>&1 +[ "$?" -eq 0 ] || error_exit +echo "Succeeded" | tee -a ${LOG_FILE} + echo -n " [+] Testing NMT training w/ align..." ${PYTHON} onmt/bin/train.py \ -config ${DATA_DIR}/align_data.yaml \ diff --git a/onmt/tests/test_pseudo_self_attention.py b/onmt/tests/test_pseudo_self_attention.py new file mode 100644 index 0000000000..c5581d49b5 --- /dev/null +++ b/onmt/tests/test_pseudo_self_attention.py @@ -0,0 +1,131 @@ +""" +Here come the tests for attention types and their compatibility +""" +import unittest +import torch + +from onmt.modules import ( + MultiHeadedAttention, + MultiHeadedPseudoSelfAttention, +) +from onmt.utils.misc import sequence_mask +from onmt.decoders.transformer import TransformerDecoderLayerBase + + +class TestPseudoSelfAttention(unittest.TestCase): + @classmethod + def setUpClass(cls): + max_relative_positions = 0 + heads = 2 + cls.d_model = 16 + cls.pseudo_self_attention = MultiHeadedPseudoSelfAttention( + heads, + cls.d_model, + dropout=0, + max_relative_positions=max_relative_positions, + ) + cls.self_attention = MultiHeadedAttention( + heads, + cls.d_model, + dropout=0, + max_relative_positions=max_relative_positions, + ) + torch.nn.init.constant_( + cls.pseudo_self_attention.linear_keys.weight, 1 + ) + torch.nn.init.constant_( + cls.pseudo_self_attention.linear_values.weight, 1 + ) + torch.nn.init.constant_( + cls.pseudo_self_attention.linear_query.weight, 1 + ) + torch.nn.init.constant_(cls.self_attention.linear_keys.weight, 1) + torch.nn.init.constant_(cls.self_attention.linear_values.weight, 1) + torch.nn.init.constant_(cls.self_attention.linear_query.weight, 1) + + torch.nn.init.constant_(cls.pseudo_self_attention.linear_keys.bias, 0) + torch.nn.init.constant_( + cls.pseudo_self_attention.linear_values.bias, 0 + ) + torch.nn.init.constant_(cls.pseudo_self_attention.linear_query.bias, 0) + torch.nn.init.constant_(cls.self_attention.linear_keys.bias, 0) + torch.nn.init.constant_(cls.self_attention.linear_values.bias, 0) + torch.nn.init.constant_(cls.self_attention.linear_query.bias, 0) + + torch.nn.init.constant_( + cls.pseudo_self_attention.final_linear.weight, 1 + ) + torch.nn.init.constant_(cls.pseudo_self_attention.final_linear.bias, 1) + torch.nn.init.constant_(cls.self_attention.final_linear.weight, 1) + torch.nn.init.constant_(cls.self_attention.final_linear.bias, 1) + + def test_pseudo_self_attention_is_self_attention_without_encoding(self): + X = torch.zeros( + (3, 5, self.d_model) + ) # (batch_size, seq_len, dim_model) + Y = torch.ones((3, 8, self.d_model)) + pseudo_key_value = torch.cat([X, Y], axis=1) + + output_self_attn, _ = self.self_attention(Y, Y, Y, attn_type="self") + output_pseudo_self_attn, _ = self.pseudo_self_attention( + pseudo_key_value, pseudo_key_value, Y, attn_type="self" + ) + self.assertTrue(output_self_attn.equal(output_pseudo_self_attn)) + + def test_masked_pseudo_self_attention_equals_premasked_encoder(self): + X = 0.3 * torch.ones( + (4, 5, self.d_model) + ) # (batch_size, seq_len, dim_model) + X[0, 4:, :] = 1000 + X[1, 3:, :] = 1000 + + X_premasked = 0.3 * torch.ones( + (4, 5, self.d_model) + ) # (batch_size, seq_len, dim_model) + X_premasked[0, 4:, :] = 0 + X_premasked[1, 3:, :] = 0 + + Y = torch.ones((4, 8, self.d_model)) + + pseudo_key_value = torch.cat([X, Y], axis=1) + masked_pseudo_key_value = torch.cat([X_premasked, Y], axis=1) + + src_pad_mask = ~sequence_mask(torch.tensor([4, 3, 1, 5]), 5).unsqueeze( + 1 + ) + no_mask_src_pad_mask = ~sequence_mask( + torch.tensor([5, 5, 5, 5]), 5 + ).unsqueeze(1) + tgt_pad_mask = ~sequence_mask(torch.tensor([8, 3, 8, 1]), 8).unsqueeze( + 1 + ) + + dec_mask = TransformerDecoderLayerBase._compute_dec_mask( + tgt_pad_mask, future=False + ) + + pseudo_mask = torch.cat( + [src_pad_mask.repeat(1, dec_mask.size(-1), 1), dec_mask], axis=-1 + ) + no_mask_pseudo_mask = torch.cat( + [no_mask_src_pad_mask.repeat(1, dec_mask.size(-1), 1), dec_mask], + axis=-1, + ) + + output, _ = self.pseudo_self_attention( + pseudo_key_value, + pseudo_key_value, + Y, + mask=pseudo_mask, + attn_type="self", + ) + + output_masked, _ = self.pseudo_self_attention( + masked_pseudo_key_value, + masked_pseudo_key_value, + Y, + mask=no_mask_pseudo_mask, + attn_type="self", + ) + + self.assertTrue(output.equal(output_masked)) From e0571302159e85505c4102f096b82ebb37a19a84 Mon Sep 17 00:00:00 2001 From: Valentin Berkes Date: Tue, 4 May 2021 15:43:12 +0000 Subject: [PATCH 3/5] fix pseudo self attention --- onmt/decoders/transformer.py | 26 ++-- onmt/modules/multi_headed_attn.py | 160 +++++++++++++++++++++-- onmt/tests/test_pseudo_self_attention.py | 16 +-- 3 files changed, 169 insertions(+), 33 deletions(-) diff --git a/onmt/decoders/transformer.py b/onmt/decoders/transformer.py index f930278305..96d9861eef 100644 --- a/onmt/decoders/transformer.py +++ b/onmt/decoders/transformer.py @@ -76,9 +76,9 @@ def __init__( dropout=attention_dropout, max_relative_positions=max_relative_positions, ) - self.feed_forward = PositionwiseFeedForward(d_model, d_ff, dropout, - pos_ffn_activation_fn - ) + self.feed_forward = PositionwiseFeedForward( + d_model, d_ff, dropout, pos_ffn_activation_fn + ) self.layer_norm_1 = nn.LayerNorm(d_model, eps=1e-6) self.drop = nn.Dropout(dropout) self.full_context_alignment = full_context_alignment @@ -760,20 +760,16 @@ def _forward( inputs_norm = self.layer_norm_1(inputs) - pseudo_key_value = torch.cat( - [src_memory_bank.transpose(0, 1), inputs], - axis=1, + pseudo_mask = torch.cat( + [src_pad_mask.repeat(1, dec_mask.size(1), 1), dec_mask], axis=-1 ) - pseudo_mask = torch.cat([src_pad_mask.repeat(1, dec_mask.size(1), 1), - dec_mask], axis=-1) query, attns = self.self_attn( - pseudo_key_value, - pseudo_key_value, - inputs_norm, - mask=pseudo_mask, - layer_cache=layer_cache, - attn_type="self", - ) + src_memory_bank.transpose(0, 1), + inputs_norm, + mask=pseudo_mask, + layer_cache=layer_cache, + attn_type="self", + ) output = self.drop(query) + inputs diff --git a/onmt/modules/multi_headed_attn.py b/onmt/modules/multi_headed_attn.py index 09f04cf3e9..59effad0d2 100644 --- a/onmt/modules/multi_headed_attn.py +++ b/onmt/modules/multi_headed_attn.py @@ -242,16 +242,160 @@ def update_dropout(self, dropout): self.dropout.p = dropout -class MultiHeadedPseudoSelfAttention(MultiHeadedAttention): +class MultiHeadedPseudoSelfAttention(nn.Module): def __init__( self, head_count, model_dim, dropout=0.1, max_relative_positions=0 ): - super().__init__( - head_count, model_dim, dropout, max_relative_positions - ) - self.linear_keys = nn.Linear( - model_dim, head_count * self.dim_per_head, 2 - ) + assert model_dim % head_count == 0 + self.dim_per_head = model_dim // head_count + self.model_dim = model_dim + + super(MultiHeadedPseudoSelfAttention, self).__init__() + self.head_count = head_count + + self.linear_keys = nn.Linear(model_dim, head_count * self.dim_per_head) self.linear_values = nn.Linear( - model_dim, head_count * self.dim_per_head, 2 + model_dim, head_count * self.dim_per_head ) + self.linear_keys_src = nn.Linear(model_dim, head_count * self.dim_per_head) + self.linear_values_src = nn.Linear( model_dim, head_count * self.dim_per_head) + self.linear_query = nn.Linear( + model_dim, head_count * self.dim_per_head + ) + self.softmax = nn.Softmax(dim=-1) + self.dropout = nn.Dropout(dropout) + self.final_linear = nn.Linear(model_dim, model_dim) + + self.max_relative_positions = max_relative_positions + + if max_relative_positions > 0: + vocab_size = max_relative_positions * 2 + 1 + self.relative_positions_embeddings = nn.Embedding( + vocab_size, self.dim_per_head + ) + + def forward(self, src, tgt, mask=None, layer_cache=None, attn_type=None): + batch_size = tgt.size(0) + dim_per_head = self.dim_per_head + head_count = self.head_count + key_len = tgt.size(1) + query_len = tgt.size(1) + + def shape(x): + """Projection.""" + return x.view(batch_size, -1, head_count, dim_per_head).transpose( + 1, 2 + ) + + def unshape(x): + """Compute context.""" + return ( + x.transpose(1, 2) + .contiguous() + .view(batch_size, -1, head_count * dim_per_head) + ) + + # 1) Project key, value, and query. + # if layer_cache is not None: + # if attn_type == "self": + # query, key, value = ( + # self.linear_query(query), + # self.linear_keys(query), + # self.linear_values(query), + # ) + # key = shape(key) + # value = shape(value) + # if layer_cache["self_keys"] is not None: + # key = torch.cat((layer_cache["self_keys"], key), dim=2) + # if layer_cache["self_values"] is not None: + # value = torch.cat( + # (layer_cache["self_values"], value), dim=2 + # ) + # layer_cache["self_keys"] = key + # layer_cache["self_values"] = value + # elif attn_type == "context": + # query = self.linear_query(query) + # if layer_cache["memory_keys"] is None: + # key, value = self.linear_keys(key), self.linear_values( + # value + # ) + # key = shape(key) + # value = shape(value) + # else: + # key, value = ( + # layer_cache["memory_keys"], + # layer_cache["memory_values"], + # ) + # layer_cache["memory_keys"] = key + # layer_cache["memory_values"] = value + # else: + key = torch.cat((self.linear_keys_src(src), self.linear_keys(tgt)),dim=1) + value = torch.cat((self.linear_values_src(src), self.linear_values(tgt)), dim=1) + query = self.linear_query(tgt) + key = shape(key) + value = shape(value) + + if self.max_relative_positions > 0 and attn_type == "self": + key_len = key.size(2) + # 1 or key_len x key_len + relative_positions_matrix = generate_relative_positions_matrix( + key_len, + self.max_relative_positions, + cache=True if layer_cache is not None else False, + ) + # 1 or key_len x key_len x dim_per_head + relations_keys = self.relative_positions_embeddings( + relative_positions_matrix.to(key.device) + ) + # 1 or key_len x key_len x dim_per_head + relations_values = self.relative_positions_embeddings( + relative_positions_matrix.to(key.device) + ) + + query = shape(query) + + key_len = key.size(2) + query_len = query.size(2) + + # 2) Calculate and scale scores. + query = query / math.sqrt(dim_per_head) + # batch x num_heads x query_len x key_len + query_key = torch.matmul(query, key.transpose(2, 3)) + + if self.max_relative_positions > 0 and attn_type == "self": + scores = query_key + relative_matmul(query, relations_keys, True) + else: + scores = query_key + scores = scores.float() + + if mask is not None: + mask = mask.unsqueeze(1) # [B, 1, 1, T_values] + scores = scores.masked_fill(mask, -1e18) + + # 3) Apply attention dropout and compute context vectors. + attn = self.softmax(scores).to(query.dtype) + drop_attn = self.dropout(attn) + context_original = torch.matmul(drop_attn, value) + + if self.max_relative_positions > 0 and attn_type == "self": + context = unshape( + context_original + + relative_matmul(drop_attn, relations_values, False) + ) + else: + context = unshape(context_original) + + output = self.final_linear(context) + # CHECK + # batch_, q_len_, d_ = output.size() + # aeq(q_len, q_len_) + # aeq(batch, batch_) + # aeq(d, d_) + + # Return multi-head attn + attns = attn.view(batch_size, head_count, query_len, key_len) + + return output, attns + + def update_dropout(self, dropout): + self.dropout.p = dropout diff --git a/onmt/tests/test_pseudo_self_attention.py b/onmt/tests/test_pseudo_self_attention.py index c5581d49b5..8bf7a22d13 100644 --- a/onmt/tests/test_pseudo_self_attention.py +++ b/onmt/tests/test_pseudo_self_attention.py @@ -59,17 +59,16 @@ def setUpClass(cls): torch.nn.init.constant_(cls.self_attention.final_linear.weight, 1) torch.nn.init.constant_(cls.self_attention.final_linear.bias, 1) - def test_pseudo_self_attention_is_self_attention_without_encoding(self): + def test_pseudo_self_attention_equals_self_attention_without_encoding( + self, + ): X = torch.zeros( (3, 5, self.d_model) ) # (batch_size, seq_len, dim_model) Y = torch.ones((3, 8, self.d_model)) - pseudo_key_value = torch.cat([X, Y], axis=1) output_self_attn, _ = self.self_attention(Y, Y, Y, attn_type="self") - output_pseudo_self_attn, _ = self.pseudo_self_attention( - pseudo_key_value, pseudo_key_value, Y, attn_type="self" - ) + output_pseudo_self_attn, _ = self.pseudo_self_attention(X, Y) self.assertTrue(output_self_attn.equal(output_pseudo_self_attn)) def test_masked_pseudo_self_attention_equals_premasked_encoder(self): @@ -87,7 +86,6 @@ def test_masked_pseudo_self_attention_equals_premasked_encoder(self): Y = torch.ones((4, 8, self.d_model)) - pseudo_key_value = torch.cat([X, Y], axis=1) masked_pseudo_key_value = torch.cat([X_premasked, Y], axis=1) src_pad_mask = ~sequence_mask(torch.tensor([4, 3, 1, 5]), 5).unsqueeze( @@ -113,16 +111,14 @@ def test_masked_pseudo_self_attention_equals_premasked_encoder(self): ) output, _ = self.pseudo_self_attention( - pseudo_key_value, - pseudo_key_value, + X, Y, mask=pseudo_mask, attn_type="self", ) output_masked, _ = self.pseudo_self_attention( - masked_pseudo_key_value, - masked_pseudo_key_value, + X_premasked, Y, mask=no_mask_pseudo_mask, attn_type="self", From 318ed17e4df6b222d5de019732809d1b5a6a61c8 Mon Sep 17 00:00:00 2001 From: Valentin Berkes Date: Mon, 10 May 2021 10:11:35 +0000 Subject: [PATCH 4/5] fix PSA + write tests --- onmt/decoders/transformer.py | 30 +++++--- onmt/modules/multi_headed_attn.py | 84 +++++++++++---------- onmt/tests/test_base_transformer.py | 83 +++++++++++++++++++++ onmt/tests/test_lm_transformer_decoder.py | 70 +++++++++++++++++ onmt/tests/test_psa_transformer_decoder.py | 87 ++++++++++++++++++++++ onmt/tests/test_pseudo_self_attention.py | 4 +- 6 files changed, 305 insertions(+), 53 deletions(-) create mode 100644 onmt/tests/test_base_transformer.py create mode 100644 onmt/tests/test_lm_transformer_decoder.py create mode 100644 onmt/tests/test_psa_transformer_decoder.py diff --git a/onmt/decoders/transformer.py b/onmt/decoders/transformer.py index 96d9861eef..b7c7f4ef90 100644 --- a/onmt/decoders/transformer.py +++ b/onmt/decoders/transformer.py @@ -261,7 +261,7 @@ def _forward( """ dec_mask = None - if inputs.size(1) > 1: + if step is None: # masking is necessary when sequence length is greater than one dec_mask = self._compute_dec_mask(tgt_pad_mask, future) @@ -556,7 +556,7 @@ def _forward( """ dec_mask = None - if inputs.size(1) > 1: + if step is None: # masking is necessary when sequence length is greater than one dec_mask = self._compute_dec_mask(tgt_pad_mask, future) @@ -753,16 +753,27 @@ def _forward( """ dec_mask = None - - if inputs.size(1) > 1: + pseudo_mask = None + if step is None: # masking is necessary when sequence length is greater than one dec_mask = self._compute_dec_mask(tgt_pad_mask, future) - + pseudo_mask = torch.cat( + [src_pad_mask.repeat(1, inputs.size(1), 1), dec_mask], axis=-1 + ) + else: + pseudo_mask = torch.cat( + ( + src_pad_mask.repeat(1, inputs.size(1), 1), + torch.zeros( + (inputs.size(0), inputs.size(1), step + 1), + dtype=torch.bool, + device=src_pad_mask.device, + ), + ), + axis=-1, + ) inputs_norm = self.layer_norm_1(inputs) - pseudo_mask = torch.cat( - [src_pad_mask.repeat(1, dec_mask.size(1), 1), dec_mask], axis=-1 - ) query, attns = self.self_attn( src_memory_bank.transpose(0, 1), inputs_norm, @@ -904,7 +915,8 @@ def _init_cache(self, memory_bank=None): self.state["cache"] = {} for i, layer in enumerate(self.transformer_layers): - layer_cache = {"self_keys": None, "self_values": None} + layer_cache = {"self_keys": None, "self_values": None, + "src_keys": None, "src_values": None} if isinstance(layer.self_attn, AverageAttention): raise NotImplementedError self.state["cache"]["layer_{}".format(i)] = layer_cache diff --git a/onmt/modules/multi_headed_attn.py b/onmt/modules/multi_headed_attn.py index 59effad0d2..0faf74ec0b 100644 --- a/onmt/modules/multi_headed_attn.py +++ b/onmt/modules/multi_headed_attn.py @@ -257,8 +257,12 @@ def __init__( self.linear_values = nn.Linear( model_dim, head_count * self.dim_per_head ) - self.linear_keys_src = nn.Linear(model_dim, head_count * self.dim_per_head) - self.linear_values_src = nn.Linear( model_dim, head_count * self.dim_per_head) + self.linear_keys_src = nn.Linear( + model_dim, head_count * self.dim_per_head + ) + self.linear_values_src = nn.Linear( + model_dim, head_count * self.dim_per_head + ) self.linear_query = nn.Linear( model_dim, head_count * self.dim_per_head ) @@ -295,45 +299,43 @@ def unshape(x): .view(batch_size, -1, head_count * dim_per_head) ) - # 1) Project key, value, and query. - # if layer_cache is not None: - # if attn_type == "self": - # query, key, value = ( - # self.linear_query(query), - # self.linear_keys(query), - # self.linear_values(query), - # ) - # key = shape(key) - # value = shape(value) - # if layer_cache["self_keys"] is not None: - # key = torch.cat((layer_cache["self_keys"], key), dim=2) - # if layer_cache["self_values"] is not None: - # value = torch.cat( - # (layer_cache["self_values"], value), dim=2 - # ) - # layer_cache["self_keys"] = key - # layer_cache["self_values"] = value - # elif attn_type == "context": - # query = self.linear_query(query) - # if layer_cache["memory_keys"] is None: - # key, value = self.linear_keys(key), self.linear_values( - # value - # ) - # key = shape(key) - # value = shape(value) - # else: - # key, value = ( - # layer_cache["memory_keys"], - # layer_cache["memory_values"], - # ) - # layer_cache["memory_keys"] = key - # layer_cache["memory_values"] = value - # else: - key = torch.cat((self.linear_keys_src(src), self.linear_keys(tgt)),dim=1) - value = torch.cat((self.linear_values_src(src), self.linear_values(tgt)), dim=1) - query = self.linear_query(tgt) - key = shape(key) - value = shape(value) + if layer_cache is not None: + query, self_key, self_value = ( + self.linear_query(tgt), + self.linear_keys(tgt), + self.linear_values(tgt), + ) + self_key = shape(self_key) + self_value = shape(self_value) + if layer_cache["self_keys"] is not None: + self_key = torch.cat( + (layer_cache["self_keys"], self_key), dim=2 + ) + if layer_cache["self_values"] is not None: + self_value = torch.cat( + (layer_cache["self_values"], self_value), dim=2 + ) + if layer_cache["src_keys"] is None: + layer_cache["src_keys"] = shape(self.linear_keys_src(src)) + layer_cache["src_values"] = shape(self.linear_values_src(src)) + layer_cache["self_keys"] = self_key + layer_cache["self_values"] = self_value + key = torch.cat( + (layer_cache["src_keys"], layer_cache["self_keys"]), dim=2 + ) + value = torch.cat( + (layer_cache["src_values"], layer_cache["self_values"]), dim=2 + ) + else: + key = torch.cat( + (self.linear_keys_src(src), self.linear_keys(tgt)), dim=1 + ) + value = torch.cat( + (self.linear_values_src(src), self.linear_values(tgt)), dim=1 + ) + query = self.linear_query(tgt) + key = shape(key) + value = shape(value) if self.max_relative_positions > 0 and attn_type == "self": key_len = key.size(2) diff --git a/onmt/tests/test_base_transformer.py b/onmt/tests/test_base_transformer.py new file mode 100644 index 0000000000..0c0ca87f73 --- /dev/null +++ b/onmt/tests/test_base_transformer.py @@ -0,0 +1,83 @@ +""" +Here come the tests for attention types and their compatibility +""" +import unittest +import torch + +from onmt.decoders.transformer import TransformerDecoder +from onmt.modules import Embeddings +from onmt.modules.position_ffn import ActivationFunction + + +class TestTransformerDecoder(unittest.TestCase): + @classmethod + def setUpClass(cls): + torch.manual_seed(42) + emb = Embeddings( + word_vec_size=100, + position_encoding=True, + feat_merge="concat", + feat_vec_exponent=0.7, + feat_vec_size=-1, + dropout=0, + word_padding_idx=1, + feat_padding_idx=[], + word_vocab_size=100, + feat_vocab_sizes=[], + sparse=False, + freeze_word_vecs=False, + ) + cls.transformer_decoder = TransformerDecoder( + num_layers=2, + d_model=100, + heads=2, + d_ff=100, + copy_attn=False, + self_attn_type="scaled-dot", + dropout=0, + attention_dropout=0, + embeddings=emb, + max_relative_positions=0, + aan_useffn=False, + full_context_alignment=None, + alignment_layer=None, + alignment_heads=None, + pos_ffn_activation_fn=ActivationFunction.relu, + ) + cls.memory_bank = torch.rand([58, 2, 100]) + cls.tgt = torch.randint(3, 99, [12, 2, 1]) + cls.src = torch.randint(3, 99, [58, 2, 1]) + cls.memory_lengths = torch.tensor([58, 58]) + cls.transformer_decoder.init_state( + cls.src, cls.memory_bank, cls.memory_bank + ) + + def test_transformer_caching_equals_no_caching( + self, + ): + dec_outs, _ = self.transformer_decoder( + self.tgt[1:3], + self.memory_bank, + memory_lengths=self.memory_lengths, + step=None, + ) + dec_outs_step_0, _ = self.transformer_decoder( + self.tgt[1:2], + self.memory_bank, + memory_lengths=self.memory_lengths, + step=0, + ) + dec_outs_step_1, _ = self.transformer_decoder( + self.tgt[2:3], + self.memory_bank, + memory_lengths=self.memory_lengths, + step=1, + ) + # randomness might cause failing (seed is set to avoid that) + # small differences are expected due to masking with huge negative + # float but not infinite + self.assertTrue(dec_outs_step_1.allclose(dec_outs[1:])) + + +if __name__ == "__main__": + unittest.main() diff --git a/onmt/tests/test_lm_transformer_decoder.py b/onmt/tests/test_lm_transformer_decoder.py new file mode 100644 index 0000000000..a8794006cc --- /dev/null +++ b/onmt/tests/test_lm_transformer_decoder.py @@ -0,0 +1,70 @@ +""" +Here come the tests for attention types and their compatibility +""" +import unittest +import torch + +from onmt.decoders.transformer import TransformerLMDecoder +from onmt.modules import Embeddings +from onmt.modules.position_ffn import ActivationFunction + + +class TestLMTransformerDecoder(unittest.TestCase): + @classmethod + def setUpClass(cls): + torch.manual_seed(42) + emb = Embeddings( + word_vec_size=100, + position_encoding=True, + feat_merge="concat", + feat_vec_exponent=0.7, + feat_vec_size=-1, + dropout=0, + word_padding_idx=1, + feat_padding_idx=[], + word_vocab_size=100, + feat_vocab_sizes=[], + sparse=False, + freeze_word_vecs=False, + ) + cls.lm_transformer_decoder = TransformerLMDecoder( + num_layers=2, + d_model=100, + heads=2, + d_ff=100, + copy_attn=False, + self_attn_type="scaled-dot", + dropout=0, + attention_dropout=0, + embeddings=emb, + max_relative_positions=0, + aan_useffn=False, + full_context_alignment=None, + alignment_layer=None, + alignment_heads=None, + pos_ffn_activation_fn=ActivationFunction.relu, + ) + cls.tgt = torch.randint(3, 99, [12, 3, 1]) + cls.lm_transformer_decoder.init_state(None, None, None) + + def test_lm_transformer_caching_equals_no_caching( + self, + ): + dec_outs, _ = self.lm_transformer_decoder( + self.tgt[1:3], None, memory_lengths=None, step=None + ) + dec_outs_step_0, _ = self.lm_transformer_decoder( + self.tgt[1:2], None, memory_lengths=None, step=0 + ) + dec_outs_step_1, _ = self.lm_transformer_decoder( + self.tgt[2:3], None, memory_lengths=None, step=1 + ) + + # randomness might cause failing (seed is set to avoid that) + # small differences are expected due to masking with huge negative + # float but not infinite + self.assertTrue(dec_outs_step_1.allclose(dec_outs[1:])) + + +if __name__ == "__main__": + unittest.main() diff --git a/onmt/tests/test_psa_transformer_decoder.py b/onmt/tests/test_psa_transformer_decoder.py new file mode 100644 index 0000000000..10c093936a --- /dev/null +++ b/onmt/tests/test_psa_transformer_decoder.py @@ -0,0 +1,87 @@ +""" +Here come the tests for attention types and their compatibility +""" +import unittest +import torch + +from onmt.decoders.transformer import TransformerLMPseudoSelfAttentionDecoder +from onmt.modules import Embeddings +from onmt.modules.position_ffn import ActivationFunction + + +class TestPSADecoder(unittest.TestCase): + @classmethod + def setUpClass(cls): + torch.manual_seed(42) + emb = Embeddings( + word_vec_size=100, + position_encoding=True, + feat_merge="concat", + feat_vec_exponent=0.7, + feat_vec_size=-1, + dropout=0, + word_padding_idx=1, + feat_padding_idx=[], + word_vocab_size=100, + feat_vocab_sizes=[], + sparse=False, + freeze_word_vecs=False, + ) + cls.psa_transformer_decoder = TransformerLMPseudoSelfAttentionDecoder( + num_layers=2, + d_model=100, + heads=2, + d_ff=100, + copy_attn=False, + self_attn_type="scaled-dot", + dropout=0, + attention_dropout=0, + embeddings=emb, + max_relative_positions=0, + aan_useffn=False, + full_context_alignment=None, + alignment_layer=None, + alignment_heads=None, + pos_ffn_activation_fn=ActivationFunction.relu, + ) + batch_size = 3 + src_len = 58 + tgt_len = 12 + cls.memory_bank = torch.rand([src_len, batch_size, 100]) + cls.tgt = torch.randint(3, 99, [tgt_len, batch_size, 1]) + cls.src = torch.randint(3, 99, [src_len, batch_size, 1]) + cls.memory_lengths = torch.tensor([src_len] * batch_size) + cls.memory_lengths[0] -= 3 + cls.psa_transformer_decoder.init_state( + cls.src, cls.memory_bank, cls.memory_bank + ) + + def test_psa_transformer_caching_equals_no_caching( + self, + ): + dec_outs, _ = self.psa_transformer_decoder( + self.tgt[1:3], + self.memory_bank, + memory_lengths=self.memory_lengths, + step=None, + ) + dec_outs_step_0, _ = self.psa_transformer_decoder( + self.tgt[1:2], + self.memory_bank, + memory_lengths=self.memory_lengths, + step=0, + ) + dec_outs_step_1, _ = self.psa_transformer_decoder( + self.tgt[2:3], + self.memory_bank, + memory_lengths=self.memory_lengths, + step=1, + ) + # randomness might cause failing (seed is set to avoid that) + # small differences are expected due to masking with huge negative + # float but not infinite + self.assertTrue(dec_outs_step_1.allclose(dec_outs[1:])) + + +if __name__ == "__main__": + unittest.main() diff --git a/onmt/tests/test_pseudo_self_attention.py b/onmt/tests/test_pseudo_self_attention.py index 8bf7a22d13..0aab256119 100644 --- a/onmt/tests/test_pseudo_self_attention.py +++ b/onmt/tests/test_pseudo_self_attention.py @@ -12,7 +12,7 @@ from onmt.decoders.transformer import TransformerDecoderLayerBase -class TestPseudoSelfAttention(unittest.TestCase): +class TestMultiHeadedPseudoSelfAttention(unittest.TestCase): @classmethod def setUpClass(cls): max_relative_positions = 0 @@ -86,8 +86,6 @@ def test_masked_pseudo_self_attention_equals_premasked_encoder(self): Y = torch.ones((4, 8, self.d_model)) - masked_pseudo_key_value = torch.cat([X_premasked, Y], axis=1) - src_pad_mask = ~sequence_mask(torch.tensor([4, 3, 1, 5]), 5).unsqueeze( 1 ) From 628b47261788591d197cf14bb0cbb717e4585ffc Mon Sep 17 00:00:00 2001 From: Valentin Berkes Date: Mon, 10 May 2021 15:51:59 +0000 Subject: [PATCH 5/5] fix lm cache decoding that can receive long input requiring masking at step 0 --- onmt/decoders/transformer.py | 2 +- 1 file changed, 1 insertion(+), 1 deletion(-) diff --git a/onmt/decoders/transformer.py b/onmt/decoders/transformer.py index b7c7f4ef90..48ba0fc0a1 100644 --- a/onmt/decoders/transformer.py +++ b/onmt/decoders/transformer.py @@ -556,7 +556,7 @@ def _forward( """ dec_mask = None - if step is None: + if step is None or inputs.size(1) > 1: # masking is necessary when sequence length is greater than one dec_mask = self._compute_dec_mask(tgt_pad_mask, future)