From 99d11e1761e2eb6a19432a8e5b0e4439cd75c5a1 Mon Sep 17 00:00:00 2001 From: sid Date: Wed, 13 Jan 2021 23:18:26 +0100 Subject: [PATCH 01/43] add basic sampling code --- src/dalle_mtf/__init__.py | 3 +- src/dalle_mtf/models.py | 22 +---- src/dalle_mtf/sample.py | 176 ++++++++++++++++++++++++++++++++++++++ src/model_fns.py | 56 ++++++++++-- src/model_fns_tf.py | 3 +- 5 files changed, 232 insertions(+), 28 deletions(-) create mode 100644 src/dalle_mtf/sample.py diff --git a/src/dalle_mtf/__init__.py b/src/dalle_mtf/__init__.py index a53a710..d9d604b 100644 --- a/src/dalle_mtf/__init__.py +++ b/src/dalle_mtf/__init__.py @@ -1 +1,2 @@ -from .models import DALLE, DiscreteVAE \ No newline at end of file +from .models import DALLE, DiscreteVAE +from .sample import sample_autoregressive \ No newline at end of file diff --git a/src/dalle_mtf/models.py b/src/dalle_mtf/models.py index 7bc7474..59b83be 100644 --- a/src/dalle_mtf/models.py +++ b/src/dalle_mtf/models.py @@ -179,6 +179,7 @@ def __init__(self, n_embd, text_vocab_size=12800, image_vocab_size=512, text_seq self.activation_fn = activation_fn if self.is_incremental_inference: assert self.context is not None, "must have context in incremental inference" + assert self.context['mode'] == 'incremental' if params is None: # extra params params = {} self.params = defaultdict(lambda: None, params) @@ -254,25 +255,7 @@ def attention(self, x, n_state, mask, attention_type="global", name="attn"): self.context.record_new_states([k, v]) with tf.variable_scope("attention"): - if attention_type == "local": - # `local_attention_1d` has built in autoregressive masking, so we don't need mask_attn_weights. - radius = self.params.get("local_attention_radius", 256) - if self.is_incremental_inference: - q *= one_hot - a = mtf_transformer.attention.local_attention_1d( - q, k, v, - length_dim=k.shape[1], - key_dim=self.dimensions["kv_dim"], - value_dim=self.dimensions["kv_dim"], - radius=radius, - length_dim_num_splits=1, - fully_autoregressive=True, - attention_kwargs={}, - ) - if self.is_incremental_inference: - a = mtf.gather(a, self.context.position - 1, seq_dim) - - elif attention_type == "global": + if attention_type == "global": if exists(mask): if not self.is_incremental_inference: broadcasted_mask = mtf.broadcast(mask, @@ -402,6 +385,7 @@ def forward(self, features, return_loss=True, return_logits=False): out = self.transformer(tokens, mask=mask) logits = self.to_logits(out) if not return_loss: + logits = mtf.cast(logits, self.variable_dtype.master_dtype) return logits labels = pad(inputs, [0, 1], dim_name="total_seq_dim", pad_value=self.eos_token_id) diff --git a/src/dalle_mtf/sample.py b/src/dalle_mtf/sample.py new file mode 100644 index 0000000..e184ac8 --- /dev/null +++ b/src/dalle_mtf/sample.py @@ -0,0 +1,176 @@ +import mesh_tensorflow as mtf +import tensorflow.compat.v1 as tf +import mesh_tensorflow.transformer as mtf_transformer + + +def sample_autoregressive(inputs, + model, + params, + stop_at_token=50256, + max_steps=None, + temperature=0.9, + variable_dtype=mtf.VariableDType(tf.float32), + has_partial_sequences=True, + remove_partial_sequences=False, + sampling_keep_top_k=-1, + ): + """Sample randomly one token at a time. + + The partial_sequences represent partial sequences to be continued. The + first tokens of each sequence are nonzero representing the given partial + sequences and the last tokens of each sequence are zeros, representing what + needs to be filled in. + + If there are no partial sequences (you want to sample from the beginning), + then pass partial_sequences=mtf.zeros(mesh, shape, dtype=tf.int32) and + has_partial_sequences=False (so we can skip computation). + + Args: + inputs: an int32 Tensor with shape [, length_dim], + model: DALL-E model + params: model paramers. + stop_at_token: an optional integer eos id. Stop when we produce it. + max_steps: an optional integer, the max number of steps to decode. + temperature: an optional floating point value between 0.0 and 1.0 0.0 + means argmax, 1.0 means sample according to predicted distribution. + variable_dtype: a mtf.VariableDType + has_partial_sequences: a boolean + decoding, one per each input layer + the embedding layer + remove_partial_sequences: a boolean - whether to remove the partial + sequences from the output + sampling_keep_top_k: an integer - if not -1, only sample from the top k + logits. + + Returns: + a Tensor with shape [, length_dim] + """ + + # with dalle, inputs will be a text sequence of len 256, then the rest image tokens. + # the parts we want to fill in will be <|pad_token|>, which we should assign in the input + + batch_dims = inputs.shape.dims[:-1] + length_dim = inputs.shape.dims[-1] + padding_id = params.get("padding_id", 0) + + initial_position = mtf.reduce_sum( + mtf.to_int32(mtf.not_equal(inputs, padding_id)), + reduced_dim=length_dim) # Gets position where zero padding starts + + length_range = mtf.range(inputs.mesh, length_dim, tf.int32) + + # Builds context to pass around internally + # The 'first part' context records initial states of k / v / x + + context_first_part = mtf_transformer.transformer.Context( + model=None, + mesh=inputs.mesh, + batch_dims=batch_dims, + length_dim=length_dim, + variable_dtype=variable_dtype, + mode="first_part", + position=length_range, + position_is_default=True, + new_states=[], + initial_position=initial_position, + sequence_id=None, + constant_states=[], + inputs=inputs) + model.context = context_first_part + + with tf.variable_scope('dall-e'): + logits = model.forward({'tokens': inputs}, return_loss=False, return_logits=True) + del logits + + if not has_partial_sequences: + initial_states = [mtf.zeros_like(t) for t in context_first_part.new_states] + else: + initial_states = context_first_part.new_states + + if not has_partial_sequences: + partial_sequences_eos_count = 0 + + if stop_at_token is not None: + partial_sequences_eos_count = mtf.reduce_sum( + mtf.to_int32(mtf.equal(inputs, stop_at_token)), + reduced_dim=length_dim) + + def cond_fn(position, ids, *unused_states): + """Should we run another loop iteration?""" + past_end = mtf.greater_equal(position, length_dim.size) + if max_steps: + past_end = mtf.logical_or( + past_end, mtf.greater_equal(position - initial_position, max_steps)) + + is_done = past_end + if stop_at_token is not None: + eos_count = mtf.reduce_sum( + mtf.to_int32(mtf.equal(ids, stop_at_token)), + reduced_dim=length_dim) + has_additional_eos = mtf.greater(eos_count, partial_sequences_eos_count) + is_done = mtf.logical_or(is_done, has_additional_eos) + all_done = mtf.reduce_all(is_done) + return mtf.logical_not(all_done) + + def body_fn(position, ids, *states): + """One step in the decode loop.""" + nonlocal sampling_keep_top_k + + context = mtf_transformer.transformer.Context( + model=None, + mesh=inputs.mesh, + batch_dims=batch_dims, + length_dim=length_dim, + variable_dtype=variable_dtype, + mode="incremental", + position=position, + position_is_default=True, + states=states, + new_states=[], + initial_position=position, + sequence_id=None, + inputs=ids) + + model.is_incremental_inference = True + model.context = context + with tf.variable_scope("dall-e", reuse=tf.AUTO_REUSE): + logits = model.forward({'tokens': inputs}, return_loss=False, return_logits=True) + + # By default, do top_k sampling of 0.9 + if sampling_keep_top_k == -2: + sampling_keep_top_k = int(logits.shape[-1].size * 0.1) + + if sampling_keep_top_k != -1: + if sampling_keep_top_k <= 0: + raise ValueError("sampling_keep_top_k must either be -1 or positive.") + k_largest = mtf.nth_largest_element( + logits, n=sampling_keep_top_k, + reduced_dim=model.dimensions['final_vocab_dim']) + logits = mtf.where(mtf.less_equal(logits, k_largest), + mtf.ones_like(logits) * -1e6, logits) + + # temperature sampling + ids_this_step = mtf.sample_with_temperature( + logits, model.dimensions['final_vocab_dim'], temperature) + + # reshape & assign results + ids_this_step = mtf.reshape(ids_this_step, batch_dims) + one_hot = mtf.one_hot(position, length_dim, dtype=tf.int32) + one_new_id = ids_this_step * one_hot + new_ids = (1 - one_hot) * ids + one_new_id + new_position = position + 1 + ret = [new_position, new_ids] + ret += context.new_states + return ret + + while_loop_inputs = [initial_position, inputs] + initial_states + final_position, outputs = mtf.while_loop( + cond_fn, body_fn, while_loop_inputs)[:2] + del final_position + if has_partial_sequences and remove_partial_sequences: + # Remove partial sequences from outputs + partial_length = mtf.reduce_sum( + mtf.to_int32(mtf.not_equal(inputs, padding_id)), + reduced_dim=length_dim) + outputs = mtf.dynamic_shift( + outputs, -partial_length, length_dim, wrap=False) + return outputs diff --git a/src/model_fns.py b/src/model_fns.py index f89ebb8..2c38147 100644 --- a/src/model_fns.py +++ b/src/model_fns.py @@ -4,8 +4,9 @@ import mesh_tensorflow.transformer as mtf_transformer from .optimizers import get_optimizer from .utils import mode_to_str, get_graph_info, create_host_call, simd_mesh_setup, scalar_summary -from .dalle_mtf import DALLE +from .dalle_mtf import DALLE, sample_autoregressive from .vae_tf import DiscreteVAE +from tensorflow.python.ops import resources def initialize_vae_weights(checkpoint_path, scope="vae"): @@ -16,7 +17,7 @@ def initialize_vae_weights(checkpoint_path, scope="vae"): vars_to_restore = tf.get_collection( tf.GraphKeys.GLOBAL_VARIABLES, scope=scope) ckpt_vars = [ - name for name, _ in tf.train.list_variables(checkpoint_path)] + name for name, _ in tf.train.list_variables(checkpoint_path)] tf.logging.info(f"RESTORING {len(vars_to_restore)} VAE VARS FROM CHECKPOINT: ") tf.logging.info(f"CHECKPOINT PATH: {checkpoint_path}") tf.logging.info(f"CHECKPOINT VARS:") @@ -132,7 +133,48 @@ def dalle_model_fn(features, labels, mode, params): scalar_summary("input_image", mtf_features["image_inputs"]) if mode == tf.estimator.ModeKeys.PREDICT: - raise NotImplementedError + # Set up the model for prediction + inputs = mtf_features["tokens"] + + mtf_samples = sample_autoregressive(inputs, + model, + params, + stop_at_token=model.eos_token_id, + max_steps=None, + temperature=0.9, + variable_dtype=model.variable_dtype, + has_partial_sequences=True, + remove_partial_sequences=True, + sampling_keep_top_k=-1, + ) + + mtf_samples = mtf.anonymize(mtf_samples) + inputs = mtf.anonymize(inputs) + lowering = mtf.Lowering(graph, {mesh: mesh_impl}, autostack=True) + inputs = lowering.export_to_tf_tensor(inputs) + outputs = lowering.export_to_tf_tensor(mtf_samples) + # predictions_decoded = vae.decode(outputs) + predictions = { + "inputs": inputs, + "outputs": outputs} + + def scaffold_fn(): + return tf.train.Scaffold( + local_init_op=tf.group( + tf.train.Scaffold.default_local_init_op(), + lowering.copy_masters_to_slices(), + name="mtf_local_init_op"), + ready_op=tf.concat( + [tf.report_uninitialized_variables(), + resources.report_uninitialized_resources()], + axis=0, + name="mtf_ready_op")) + + return tpu_estimator.TPUEstimatorSpec( + mode=tf.estimator.ModeKeys.PREDICT, + predictions=predictions, + scaffold_fn=scaffold_fn, + prediction_hooks=[mtf.MtfRestoreHook(lowering)]) # We're not predicting, so we better be training or evaluating assert (mode == tf.estimator.ModeKeys.TRAIN or mode == tf.estimator.ModeKeys.EVAL) @@ -155,8 +197,9 @@ def dalle_model_fn(features, labels, mode, params): if num_microbatches > 1: # For serialize_training_step we need to modify the model to output results in a dict def serialized_fn(mtf_features): - loss, loss_batch = model.forward(mtf_features, return_loss=True) - return {"loss": loss, "loss_batch": loss_batch} + with tf.variable_scope('dall-e'): + loss, loss_batch = model.forward(mtf_features, return_loss=True) + return {"loss": loss, "loss_batch": loss_batch} # Serialize the training step - Gradients are accumulated locally and reduced once. var_grads, output_dict = mtf.serialize_training_step(mtf_features, serialized_fn, model.dimensions["batch_dim"], @@ -164,7 +207,8 @@ def serialized_fn(mtf_features): loss = output_dict["loss"] loss_batch = output_dict["loss_batch"] else: - loss, loss_batch = model.forward(mtf_features, return_loss=True) + with tf.variable_scope('dall-e'): + loss, loss_batch = model.forward(mtf_features, return_loss=True) del loss_batch # TODO: may need this for some metrics - otherwise, remove from output diff --git a/src/model_fns_tf.py b/src/model_fns_tf.py index 784a702..937dbdb 100644 --- a/src/model_fns_tf.py +++ b/src/model_fns_tf.py @@ -1,9 +1,8 @@ import tensorflow.compat.v1 as tf import tensorflow.compat.v2 as tf2 from tensorflow.python.tpu import tpu_estimator -from .optimizers import get_optimizer from .vae_tf import DiscreteVAE -from .utils import scalar_summary, mode_to_str, create_host_call +from .utils import mode_to_str def vae_model_fn(features, labels, mode, params): From b9fd039cc6639feb4348a14566fe3652942ebc23 Mon Sep 17 00:00:00 2001 From: sid Date: Thu, 14 Jan 2021 00:15:49 +0100 Subject: [PATCH 02/43] add prediction input / output fns --- src/input_fns.py | 29 +++++++++++++++++- src/model_fns.py | 75 ++++++++++++++++++++++++++-------------------- src/utils/utils.py | 6 +++- train_dalle.py | 13 +++++++- 4 files changed, 88 insertions(+), 35 deletions(-) diff --git a/src/input_fns.py b/src/input_fns.py index ee35bfc..d503647 100644 --- a/src/input_fns.py +++ b/src/input_fns.py @@ -38,6 +38,31 @@ def truncate_or_pad_label(label, params): return label +def pred_input(params, tokenizer, prompt='a cat in a hat'): + tokens = tokenizer.encode(prompt).ids + if len(tokens) > params["total_seq_len"]: + tf.logging.info("The length of your input prompt is longer than the model's text context length - truncating " + "input.") + tokens = tokens[len(tokens) - params["total_seq_len"]:] # TODO: left or right truncate here? + if len(tokens) < params["total_seq_len"]: + tokens = tf.pad(tokens, [[0, params["total_seq_len"] - len(tokens)]], constant_values=params["padding_id"]) + t = tf.broadcast_to(tokens, [params["batch_size"], params["total_seq_len"]]) + dataset = tf.data.Dataset.from_tensors(t) + + def _dummy_labels(x): + return x, x + + dataset = dataset.map(_dummy_labels) + return dataset + + +def pred_output(predictions, out_name='test'): + with tf.gfile.Open(f"{out_name}.txt", "w") as f: + for i, p in enumerate(predictions): + p = p["outputs"] + f.write(str(p["outputs"])) + + def read_labeled_tfrecord(params): def read_fn(example): features = { @@ -103,6 +128,7 @@ def _process_path(file_path): dataset = configure_for_performance(dataset, params, eval) return dataset.repeat() + def dalle_input_fn(params, eval=False): path = params["dataset"]["train_path"] if not eval else params["dataset"]["eval_path"] files = tf.io.gfile.glob(path) @@ -113,7 +139,8 @@ def dalle_input_fn(params, eval=False): if not eval: dataset = dataset.shuffle(file_count, reshuffle_each_iteration=False) - dataset = dataset.apply(tf.data.experimental.parallel_interleave(tf.data.TFRecordDataset, cycle_length=4, sloppy=False)) + dataset = dataset.apply( + tf.data.experimental.parallel_interleave(tf.data.TFRecordDataset, cycle_length=4, sloppy=False)) parse_fn = read_labeled_tfrecord(params) dataset = dataset.map(parse_fn, num_parallel_calls=tf.data.experimental.AUTOTUNE) dataset = configure_for_performance(dataset, params, eval) diff --git a/src/model_fns.py b/src/model_fns.py index 53a8975..776b09a 100644 --- a/src/model_fns.py +++ b/src/model_fns.py @@ -64,18 +64,17 @@ def dalle_model_fn(features, labels, mode, params): vae, vae_checkpoint_path = load_vae_model(params, mode_str) initialize_vae_weights(vae_checkpoint_path) - H = W = params["dataset"]["image_size"] - image_seq_len = (vae.H // (2 ** len(vae.convblocks))) ** 2 // (vae.stack_factor ** 2) # TODO: check this is correct batch_size = params[f"{mode_str}_batch_size"] n_channels = params.get("input_channels", 3) + if mode in [tf.estimator.ModeKeys.TRAIN, tf.estimator.ModeKeys.EVAL]: - with tf.variable_scope("vae"): - vae_logits = vae.forward(features, return_logits=True) + with tf.variable_scope("vae"): + vae_logits = vae.forward(features, return_logits=True) - # TODO: using argmax sampling for now, but is that optimal? - tokens = tf.math.argmax(vae_logits, -1) - img_tokens_reshaped = tf.cast(tf.reshape(tokens, (batch_size, image_seq_len)), tf.int32) + # TODO: using argmax sampling for now, but is that optimal? + tokens = tf.math.argmax(vae_logits, -1) + img_tokens_reshaped = tf.cast(tf.reshape(tokens, (batch_size, params['image_seq_len'])), tf.int32) # Construct mtf graph + mesh from params graph = mtf.Graph() @@ -99,7 +98,7 @@ def dalle_model_fn(features, labels, mode, params): text_vocab_size=params["text_vocab_size"], image_vocab_size=params["image_vocab_size"], text_seq_len=params["text_seq_len"], - image_seq_len=image_seq_len, + image_seq_len=params['image_seq_len'], n_layers=params["n_layers"], n_heads=params["n_heads"], batch_size=batch_size, @@ -110,29 +109,41 @@ def dalle_model_fn(features, labels, mode, params): # Build mtf_features & seq length dict for getting number of microbatches # We need to pack inputs into a dict to pass into serialize_training_step - features_dict = {"image_inputs": features, - "text_inputs": labels} - mtf_features = {} - for key, x in features_dict.items(): - if x is not None: - if key == "text_inputs": - text_tokens = tf.reshape(x, [batch_size, params["text_seq_len"]]) - x = tf.concat((text_tokens, img_tokens_reshaped + model.text_vocab_size), axis=1) - mtf_shape = mtf.Shape([model.dimensions["batch_dim"], model.dimensions["total_seq_dim"]]) - - mtf_features["tokens"] = mtf.import_fully_replicated(mesh, x, mtf_shape, name=key) - - if key == "image_inputs": - mtf_shape = mtf.Shape([ - model.dimensions["batch_dim"], - mtf.Dimension("img_height_dim", vae.H), - mtf.Dimension("img_width_dim", vae.W), - mtf.Dimension("img_channel_dim", vae.num_ch), - ]) - x = tf.reshape(x, [batch_size, H, W, n_channels]) # NHWC - mtf_features["image_inputs"] = mtf.import_fully_replicated(mesh, x, mtf_shape, name=key) - - scalar_summary("input_image", mtf_features["image_inputs"]) + if mode in [tf.estimator.ModeKeys.TRAIN, tf.estimator.ModeKeys.EVAL]: + features_dict = {"image_inputs": features, + "text_inputs": labels} + mtf_features = {} + for key, x in features_dict.items(): + if x is not None: + if key == "text_inputs": + text_tokens = tf.reshape(x, [batch_size, params["text_seq_len"]]) + x = tf.concat((text_tokens, img_tokens_reshaped + model.text_vocab_size), axis=1) + mtf_shape = mtf.Shape([model.dimensions["batch_dim"], model.dimensions["total_seq_dim"]]) + + mtf_features["tokens"] = mtf.import_fully_replicated(mesh, x, mtf_shape, name=key) + + if key == "image_inputs": + mtf_shape = mtf.Shape([ + model.dimensions["batch_dim"], + mtf.Dimension("img_height_dim", vae.H), + mtf.Dimension("img_width_dim", vae.W), + mtf.Dimension("img_channel_dim", vae.num_ch), + ]) + x = tf.reshape(x, [batch_size, H, W, n_channels]) # NHWC + mtf_features["image_inputs"] = mtf.import_fully_replicated(mesh, x, mtf_shape, name=key) + scalar_summary("input_image", mtf_features["image_inputs"]) + else: + features_dict = {"text_inputs": labels} + mtf_features = {} + for key, x in features_dict.items(): + if x is not None: + if key == "text_inputs": + text_tokens = tf.reshape(x, [batch_size, params["text_seq_len"]]) + x = tf.concat((text_tokens, img_tokens_reshaped + model.text_vocab_size), axis=1) + mtf_shape = mtf.Shape([model.dimensions["batch_dim"], model.dimensions["total_seq_dim"]]) + + mtf_features["tokens"] = mtf.import_fully_replicated(mesh, x, mtf_shape, name=key) + if mode == tf.estimator.ModeKeys.PREDICT: # Set up the model for prediction inputs = mtf_features["tokens"] @@ -151,7 +162,7 @@ def dalle_model_fn(features, labels, mode, params): mtf_samples = mtf.anonymize(mtf_samples) inputs = mtf.anonymize(inputs) - lowering = mtf.Lowering(graph, {mesh: mesh_impl}, autostack=True) + lowering = mtf.Lowering(graph, {mesh: mesh_impl}, autostack=False) inputs = lowering.export_to_tf_tensor(inputs) outputs = lowering.export_to_tf_tensor(mtf_samples) # predictions_decoded = vae.decode(outputs) diff --git a/src/utils/utils.py b/src/utils/utils.py index 45178d6..95e8751 100644 --- a/src/utils/utils.py +++ b/src/utils/utils.py @@ -224,4 +224,8 @@ def scalar_summary(name, x): Returns: a Tensor which is identical in value to x """ - return ScalarSummaryOperation(name, x) \ No newline at end of file + return ScalarSummaryOperation(name, x) + +def get_image_seq_len(dalle_params): + return (dalle_params["vae_params"]['dataset']['image_size'] // (2 ** len(dalle_params["vae_params"]['convblocks']))) ** 2 // ( + dalle_params.get("vae_params").get("stack_factor", 1) ** 2) \ No newline at end of file diff --git a/train_dalle.py b/train_dalle.py index e5c4428..9432604 100644 --- a/train_dalle.py +++ b/train_dalle.py @@ -6,7 +6,7 @@ import argparse from src.utils import * from src.model_fns import dalle_model_fn -from src.input_fns import dalle_input_fn +from src.input_fns import dalle_input_fn, pred_input, pred_output from src.data import get_tokenizer def parse_args(): @@ -18,6 +18,8 @@ def parse_args(): parser.add_argument("--model", type=str, default=None, help="JSON file that contains model parameters.") parser.add_argument("--new", action="store_true", help="If set, deletes previous checkpoint, if it exists, and " "starts a new training run") + parser.add_argument('--predict', action='store_true', help='run model in predict mode') + parser.add_argument('--prompt', type=str, default='a cat in a hat') args = parser.parse_args() assert args.model is not None, "Model must be set" return args @@ -46,6 +48,8 @@ def main(): params["gpu_ids"] = args.gpu_ids tokenizer = get_tokenizer(params["tokenizer"]) assert len(tokenizer) == params["text_vocab_size"], f"tokenizer vocab size {len(tokenizer)} must equal model vocab size {params['text_vocab_size']}" + params['image_seq_len'] = get_image_seq_len(params) + params['total_seq_len'] = params['image_seq_len'] + params['text_seq_len'] params["padding_id"] = tokenizer.encode(tokenizer.pad_token)[0] # Set up TPUs and Estimator if args.tpu == "colab": @@ -76,6 +80,13 @@ def main(): eval_batch_size=params["eval_batch_size"], predict_batch_size=params["predict_batch_size"], params=params) + if args.predict: + # Predict + pred_input_fn = partial(pred_input, params, tokenizer, args.prompt) + predictions = estimator.predict(input_fn=pred_input_fn) + logging.info("Predictions generated") + pred_output(predictions, 'test') + return has_predict_or_eval_steps = params["predict_steps"] > 0 or params["eval_steps"] > 0 if has_predict_or_eval_steps: From cf76c6c3d5734e9b7b98c2ef338680845571a9d3 Mon Sep 17 00:00:00 2001 From: connor Date: Thu, 14 Jan 2021 00:16:16 +0000 Subject: [PATCH 03/43] get sample_autoregressive working --- configs/dalle_coco.json | 8 ++++---- src/dalle_mtf/models.py | 14 ++++++++++++-- src/input_fns.py | 3 +-- src/model_fns.py | 6 ++---- train_dalle.py | 2 +- 5 files changed, 20 insertions(+), 13 deletions(-) diff --git a/configs/dalle_coco.json b/configs/dalle_coco.json index 7d2c768..d71e6a6 100644 --- a/configs/dalle_coco.json +++ b/configs/dalle_coco.json @@ -7,19 +7,19 @@ }, "train_batch_size": 128, "eval_batch_size": 128, - "predict_batch_size": 128, + "predict_batch_size": 16, "steps_per_checkpoint": 5000, "iterations": 1000, "train_steps": 100000, "predict_steps": 0, "eval_steps": 0, "n_channels": 3, - "bf_16": false, + "bf_16": true, "recompute_grad": true, "lr": 0.0001, - "model_path": "gs://neo-models/dalle_coco/", + "model_path": "gs://neo-models/dalle_coco_sample/", "mesh_shape": "data:16,model:2", - "layout": "batch_dim:data", + "layout": "batch_dim:data,embed_dim:model", "n_embd": 1024, "text_vocab_size": 50258, "image_vocab_size": 512, diff --git a/src/dalle_mtf/models.py b/src/dalle_mtf/models.py index 59b83be..f55ea54 100644 --- a/src/dalle_mtf/models.py +++ b/src/dalle_mtf/models.py @@ -228,8 +228,13 @@ def get_attn_mask(self, mesh, nd, ns): return self.attn_mask def attention(self, x, n_state, mask, attention_type="global", name="attn"): - # x :: [batch, seq, n_embd] - batch_dim, seq_dim, embd_dim = x_shape = x.shape + if not self.is_incremental_inference: + # x :: [batch, seq, n_embd] + batch_dim, seq_dim, embd_dim = x_shape = x.shape + else: + batch_dim, embd_dim = x_shape = x.shape + seq_dim = self.dimensions['total_seq_dim'] + assert n_state.size % self.n_heads == 0, "n_state must be divisible by n_heads" with tf.variable_scope(name): # Compute attention inputs @@ -379,6 +384,11 @@ def to_logits(self, x): def forward(self, features, return_loss=True, return_logits=False): inputs = features["tokens"] + if self.is_incremental_inference: + # reshape inputs if in inference mode + inputs = mtf.gather(inputs, self.context.position - 1, self.dimensions['total_seq_dim']) + inputs = mtf.reshape(inputs, [self.dimensions['batch_dim']]) + tokens = self.positional_embedding(self.embedding(inputs, "embedding"), "positional_embedding") mask = self.get_attn_mask(tokens.mesh, tokens.shape[1], self.dimensions["memory_len_dim"]) diff --git a/src/input_fns.py b/src/input_fns.py index d503647..74e6ab8 100644 --- a/src/input_fns.py +++ b/src/input_fns.py @@ -39,7 +39,7 @@ def truncate_or_pad_label(label, params): def pred_input(params, tokenizer, prompt='a cat in a hat'): - tokens = tokenizer.encode(prompt).ids + tokens = tokenizer.encode(prompt) if len(tokens) > params["total_seq_len"]: tf.logging.info("The length of your input prompt is longer than the model's text context length - truncating " "input.") @@ -59,7 +59,6 @@ def _dummy_labels(x): def pred_output(predictions, out_name='test'): with tf.gfile.Open(f"{out_name}.txt", "w") as f: for i, p in enumerate(predictions): - p = p["outputs"] f.write(str(p["outputs"])) diff --git a/src/model_fns.py b/src/model_fns.py index 776b09a..4aabd5f 100644 --- a/src/model_fns.py +++ b/src/model_fns.py @@ -138,11 +138,9 @@ def dalle_model_fn(features, labels, mode, params): for key, x in features_dict.items(): if x is not None: if key == "text_inputs": - text_tokens = tf.reshape(x, [batch_size, params["text_seq_len"]]) - x = tf.concat((text_tokens, img_tokens_reshaped + model.text_vocab_size), axis=1) + text_tokens = tf.reshape(x, [batch_size, params["total_seq_len"]]) mtf_shape = mtf.Shape([model.dimensions["batch_dim"], model.dimensions["total_seq_dim"]]) - - mtf_features["tokens"] = mtf.import_fully_replicated(mesh, x, mtf_shape, name=key) + mtf_features["tokens"] = mtf.import_fully_replicated(mesh, text_tokens, mtf_shape, name=key) if mode == tf.estimator.ModeKeys.PREDICT: # Set up the model for prediction diff --git a/train_dalle.py b/train_dalle.py index 9432604..6ad8fe6 100644 --- a/train_dalle.py +++ b/train_dalle.py @@ -82,7 +82,7 @@ def main(): params=params) if args.predict: # Predict - pred_input_fn = partial(pred_input, params, tokenizer, args.prompt) + pred_input_fn = partial(pred_input, tokenizer=tokenizer, prompt=args.prompt) predictions = estimator.predict(input_fn=pred_input_fn) logging.info("Predictions generated") pred_output(predictions, 'test') From c7ff6c4549bf25a3c5a9721e345f09ea18ecdf09 Mon Sep 17 00:00:00 2001 From: connor Date: Thu, 14 Jan 2021 00:21:22 +0000 Subject: [PATCH 04/43] truncate text tokens properly --- src/input_fns.py | 6 +++--- 1 file changed, 3 insertions(+), 3 deletions(-) diff --git a/src/input_fns.py b/src/input_fns.py index 74e6ab8..0862147 100644 --- a/src/input_fns.py +++ b/src/input_fns.py @@ -40,10 +40,10 @@ def truncate_or_pad_label(label, params): def pred_input(params, tokenizer, prompt='a cat in a hat'): tokens = tokenizer.encode(prompt) - if len(tokens) > params["total_seq_len"]: + if len(tokens) > params["text_seq_len"]: tf.logging.info("The length of your input prompt is longer than the model's text context length - truncating " "input.") - tokens = tokens[len(tokens) - params["total_seq_len"]:] # TODO: left or right truncate here? + tokens = tokens[len(tokens) - params["text_seq_len"]:] # TODO: left or right truncate here? if len(tokens) < params["total_seq_len"]: tokens = tf.pad(tokens, [[0, params["total_seq_len"] - len(tokens)]], constant_values=params["padding_id"]) t = tf.broadcast_to(tokens, [params["batch_size"], params["total_seq_len"]]) @@ -59,7 +59,7 @@ def _dummy_labels(x): def pred_output(predictions, out_name='test'): with tf.gfile.Open(f"{out_name}.txt", "w") as f: for i, p in enumerate(predictions): - f.write(str(p["outputs"])) + f.write(str(p["outputs"].tolist())) def read_labeled_tfrecord(params): From 4d51fd94bf74ad1424445da12bcfb9ffba0e5d85 Mon Sep 17 00:00:00 2001 From: connor Date: Thu, 14 Jan 2021 00:59:10 +0000 Subject: [PATCH 05/43] log model params to tensorboard --- src/utils/utils.py | 33 ++++++++++++++++++++++++++++++++- train_dalle.py | 1 + train_vae.py | 1 + 3 files changed, 34 insertions(+), 1 deletion(-) diff --git a/src/utils/utils.py b/src/utils/utils.py index 95e8751..e1ea317 100644 --- a/src/utils/utils.py +++ b/src/utils/utils.py @@ -9,6 +9,8 @@ import logging import sys from mesh_tensorflow.ops import Operation, Tensor +import re + def fetch_model_params(model): model_path = model if model.endswith(".json") else f"./configs/{model}.json" @@ -226,6 +228,35 @@ def scalar_summary(name, x): """ return ScalarSummaryOperation(name, x) + def get_image_seq_len(dalle_params): return (dalle_params["vae_params"]['dataset']['image_size'] // (2 ** len(dalle_params["vae_params"]['convblocks']))) ** 2 // ( - dalle_params.get("vae_params").get("stack_factor", 1) ** 2) \ No newline at end of file + dalle_params.get("vae_params").get("stack_factor", 1) ** 2) + +def save_config(params_dict, logdir): + tf.logging.info(f"Saving config to {logdir}") + text = "{\n\n" + total_params = len(params_dict) + for count, key in enumerate(params_dict): + config_value = str(params_dict[key]) + if re.search('[a-zA-Z]', config_value): + if config_value.lower() != 'true': + if config_value.lower() != 'false': + if config_value[0] != '[': + # TODO: Making a manual exception for parsing epsilon right now since it's the only number in + # scientific notation. Should fix this. + if key != "epsilon": + config_value = f'"{config_value}"' + if count == total_params - 1: + text += f'"{str(key)}"' + ' : ' + config_value + '\n\n' + else: + text += f'"{str(key)}"' + ' : ' + config_value + ',\n\n' + text += '\n\n}' + sess = tf.InteractiveSession() + summary_op = tf.summary.text("run_config", tf.convert_to_tensor(text)) + summary_writer = tf.summary.FileWriter(f"{logdir}/config", sess.graph) + text = sess.run(summary_op) + summary_writer.add_summary(text, 0) + summary_writer.flush() + summary_writer.close() + tf.reset_default_graph() \ No newline at end of file diff --git a/train_dalle.py b/train_dalle.py index 6ad8fe6..8ad9a8e 100644 --- a/train_dalle.py +++ b/train_dalle.py @@ -31,6 +31,7 @@ def main(): logging = setup_logging(args) params = fetch_model_params(args.model) params["vae_params"] = fetch_model_params(params["vae_model"]) + save_config(params, params['model_dir']) assert params["model_type"].lower() == "dalle", f'model_type {params["model_type"]} not recognized' # Confirm deletion of checkpoint files if --new flag is set diff --git a/train_vae.py b/train_vae.py index 835cbbc..053cc79 100644 --- a/train_vae.py +++ b/train_vae.py @@ -28,6 +28,7 @@ def main(): args = parse_args() logging = setup_logging(args) params = fetch_model_params(args.model) + save_config(params, params['model_dir']) assert params["model_type"].lower() == "vae", f'model_type {params["model_type"]} not recognized' # get current step From 346871f5b2f0dad5019785881f2cbcdf62555a8d Mon Sep 17 00:00:00 2001 From: Ben Wang Date: Thu, 14 Jan 2021 15:49:50 +1100 Subject: [PATCH 06/43] add vae decoding and write to jpeg --- src/input_fns.py | 8 +++++--- src/model_fns.py | 9 +++++++-- src/vae_tf/models.py | 22 +++++++++++++++++++--- 3 files changed, 31 insertions(+), 8 deletions(-) diff --git a/src/input_fns.py b/src/input_fns.py index 0862147..3559227 100644 --- a/src/input_fns.py +++ b/src/input_fns.py @@ -1,3 +1,5 @@ +import imageio +import numpy as np import tensorflow.compat.v1 as tf @@ -57,9 +59,9 @@ def _dummy_labels(x): def pred_output(predictions, out_name='test'): - with tf.gfile.Open(f"{out_name}.txt", "w") as f: - for i, p in enumerate(predictions): - f.write(str(p["outputs"].tolist())) + for i, p in enumerate(predictions): + denormalize = lambda x: (((x + 1) / 2) * 255.0).astype(np.uint8) + imageio.imwrite(f"{out_name}_{i}.jpeg", denormalize(p["predictions_decoded"])) def read_labeled_tfrecord(params): diff --git a/src/model_fns.py b/src/model_fns.py index 4aabd5f..ef022e6 100644 --- a/src/model_fns.py +++ b/src/model_fns.py @@ -163,10 +163,15 @@ def dalle_model_fn(features, labels, mode, params): lowering = mtf.Lowering(graph, {mesh: mesh_impl}, autostack=False) inputs = lowering.export_to_tf_tensor(inputs) outputs = lowering.export_to_tf_tensor(mtf_samples) - # predictions_decoded = vae.decode(outputs) + + img_outputs = outputs[:, -model.image_seq_len:] + predictions_decoded = vae.decode(img_outputs) + predictions = { "inputs": inputs, - "outputs": outputs} + "outputs": outputs, + "predictions_decoded": predictions_decoded + } def scaffold_fn(): return tf.train.Scaffold( diff --git a/src/vae_tf/models.py b/src/vae_tf/models.py index d7dd073..6c8e8f6 100644 --- a/src/vae_tf/models.py +++ b/src/vae_tf/models.py @@ -75,6 +75,8 @@ def __init__(self, self.recompute_grad = recompute_grad self.bf16 = use_bf16 + self.n_hid = convblocks[-1][1] + assert math.log2(stack_factor).is_integer() # maybe you don't actually need this? self.stack_factor = stack_factor @@ -109,7 +111,6 @@ def encoder_block(x, channels=channels): x = x + res_out with tf.variable_scope(f"codebook"): - self.n_hid = x.shape[-1] embedding = tf.get_variable("codebook", shape=[self.n_hid, self.num_tokens], dtype=tf.float32) if self.bf16: @@ -119,9 +120,8 @@ def encoder_block(x, channels=channels): return output - def decoder(self, x): - with tf.variable_scope(f"codebook", reuse=True): + with tf.variable_scope(f"codebook", reuse=tf.AUTO_REUSE): embedding = tf.get_variable("codebook", shape=[self.n_hid, self.num_tokens], dtype=tf.float32) x = tf.matmul(x, embedding, transpose_b=True) @@ -162,6 +162,22 @@ def decoder_block(x, channels=channels): return x + def decode(self, input_indices): + batch, seqlen = input_indices.shape + + print(f"seqlen {seqlen}") + print(f"side expected {self.W // (2 ** len(self.convblocks))}") + + assert seqlen == (self.W // (2 ** len(self.convblocks))) * (self.H // (2 ** len(self.convblocks))) + + input_onehot = tf.one_hot(input_indices, self.num_tokens) + input_reshaped = tf.reshape(input_onehot, [batch, + self.H // (2 ** len(self.convblocks)), + self.W // (2 ** len(self.convblocks)), + self.num_tokens]) # NHWC + + return self.decoder(input_reshaped) + def forward(self, features, return_recon_loss=False, return_logits=False, hard_gumbel=True, temperature=1.): if isinstance(features, dict): img = features["inputs"] From d13c3309ac2124be2a9d3fe68e4e927e420a5883 Mon Sep 17 00:00:00 2001 From: Ben Wang Date: Thu, 14 Jan 2021 16:38:17 +1100 Subject: [PATCH 07/43] unshift image outputs at decode time --- src/model_fns.py | 2 +- 1 file changed, 1 insertion(+), 1 deletion(-) diff --git a/src/model_fns.py b/src/model_fns.py index ef022e6..a51e4dc 100644 --- a/src/model_fns.py +++ b/src/model_fns.py @@ -164,7 +164,7 @@ def dalle_model_fn(features, labels, mode, params): inputs = lowering.export_to_tf_tensor(inputs) outputs = lowering.export_to_tf_tensor(mtf_samples) - img_outputs = outputs[:, -model.image_seq_len:] + img_outputs = outputs[:, -model.image_seq_len:] - model.text_vocab_size predictions_decoded = vae.decode(img_outputs) predictions = { From f8a744929bb4189cf3efd0378daf027996ce7c74 Mon Sep 17 00:00:00 2001 From: Ben Wang Date: Thu, 14 Jan 2021 19:42:46 +1100 Subject: [PATCH 08/43] dirty hack to use vae decoder params when training dalle --- src/model_fns.py | 7 ++++++- 1 file changed, 6 insertions(+), 1 deletion(-) diff --git a/src/model_fns.py b/src/model_fns.py index a51e4dc..4214274 100644 --- a/src/model_fns.py +++ b/src/model_fns.py @@ -76,6 +76,10 @@ def dalle_model_fn(features, labels, mode, params): tokens = tf.math.argmax(vae_logits, -1) img_tokens_reshaped = tf.cast(tf.reshape(tokens, (batch_size, params['image_seq_len'])), tf.int32) + # TODO: get rid of this ugly hack, its just to pull the decoder parameters in during training + with tf.variable_scope('vae'): + vae.decoder(tf.zeros_like(vae_logits)) + # Construct mtf graph + mesh from params graph = mtf.Graph() mesh_shape = mtf.convert_to_shape(params["mesh_shape"]) @@ -165,7 +169,8 @@ def dalle_model_fn(features, labels, mode, params): outputs = lowering.export_to_tf_tensor(mtf_samples) img_outputs = outputs[:, -model.image_seq_len:] - model.text_vocab_size - predictions_decoded = vae.decode(img_outputs) + with tf.variable_scope('vae'): + predictions_decoded = vae.decode(img_outputs) predictions = { "inputs": inputs, From ff56d1265482f65afd57df25b9bb1dbdbccab373 Mon Sep 17 00:00:00 2001 From: Leo Gao <54557097+leogao2@users.noreply.github.com> Date: Sat, 16 Jan 2021 19:13:35 -0700 Subject: [PATCH 09/43] Move initialize_vae_weights to after lowering --- src/model_fns.py | 3 ++- 1 file changed, 2 insertions(+), 1 deletion(-) diff --git a/src/model_fns.py b/src/model_fns.py index 4214274..b9731bd 100644 --- a/src/model_fns.py +++ b/src/model_fns.py @@ -63,7 +63,6 @@ def dalle_model_fn(features, labels, mode, params): # load vae in tensorflow graph before mtf vae, vae_checkpoint_path = load_vae_model(params, mode_str) - initialize_vae_weights(vae_checkpoint_path) H = W = params["dataset"]["image_size"] batch_size = params[f"{mode_str}_batch_size"] n_channels = params.get("input_channels", 3) @@ -168,6 +167,8 @@ def dalle_model_fn(features, labels, mode, params): inputs = lowering.export_to_tf_tensor(inputs) outputs = lowering.export_to_tf_tensor(mtf_samples) + initialize_vae_weights(vae_checkpoint_path) + img_outputs = outputs[:, -model.image_seq_len:] - model.text_vocab_size with tf.variable_scope('vae'): predictions_decoded = vae.decode(img_outputs) From 4c4e0e06fa211036cf4b41d4942cba467b07a529 Mon Sep 17 00:00:00 2001 From: connor Date: Sun, 17 Jan 2021 14:13:55 +0000 Subject: [PATCH 10/43] fix vae checkpoint load in training --- src/input_fns.py | 3 +++ src/model_fns.py | 18 +++++++++++++----- 2 files changed, 16 insertions(+), 5 deletions(-) diff --git a/src/input_fns.py b/src/input_fns.py index 3559227..b07098a 100644 --- a/src/input_fns.py +++ b/src/input_fns.py @@ -61,6 +61,9 @@ def _dummy_labels(x): def pred_output(predictions, out_name='test'): for i, p in enumerate(predictions): denormalize = lambda x: (((x + 1) / 2) * 255.0).astype(np.uint8) + # to debug: + # with open(f"{out_name}_{i}.txt", 'w') as f: + # f.write(str(p["outputs"].tolist())) imageio.imwrite(f"{out_name}_{i}.jpeg", denormalize(p["predictions_decoded"])) diff --git a/src/model_fns.py b/src/model_fns.py index b9731bd..52cf870 100644 --- a/src/model_fns.py +++ b/src/model_fns.py @@ -134,7 +134,8 @@ def dalle_model_fn(features, labels, mode, params): ]) x = tf.reshape(x, [batch_size, H, W, n_channels]) # NHWC mtf_features["image_inputs"] = mtf.import_fully_replicated(mesh, x, mtf_shape, name=key) - scalar_summary("input_image", mtf_features["image_inputs"]) + denormalize = lambda x: (x + 1) / 2 + scalar_summary("input_image", denormalize(mtf_features["image_inputs"])) else: features_dict = {"text_inputs": labels} mtf_features = {} @@ -163,19 +164,21 @@ def dalle_model_fn(features, labels, mode, params): mtf_samples = mtf.anonymize(mtf_samples) inputs = mtf.anonymize(inputs) - lowering = mtf.Lowering(graph, {mesh: mesh_impl}, autostack=False) + lowering = mtf.Lowering(graph, {mesh: mesh_impl}, autostack=params.get('autostack', True)) + inputs = lowering.export_to_tf_tensor(inputs) outputs = lowering.export_to_tf_tensor(mtf_samples) initialize_vae_weights(vae_checkpoint_path) - + img_outputs = outputs[:, -model.image_seq_len:] - model.text_vocab_size + with tf.variable_scope('vae'): predictions_decoded = vae.decode(img_outputs) predictions = { "inputs": inputs, - "outputs": outputs, + "outputs": img_outputs, "predictions_decoded": predictions_decoded } @@ -250,11 +253,12 @@ def serialized_fn(mtf_features): get_graph_info(graph) # 'lowers' mtf tensors into a tf graph - this enables us to export results as tf tensors - lowering = mtf.Lowering(graph, {mesh: mesh_impl}, autostack=False) + lowering = mtf.Lowering(graph, {mesh: mesh_impl}, autostack=params.get('autostack', True)) tf_loss = lowering.export_to_tf_tensor(loss) tf_loss = tf.cast(tf_loss, tf.float32) + if mode == tf.estimator.ModeKeys.TRAIN: # Use our patched version until mtf updates theirs host_call = create_host_call(params['model_path']) @@ -264,8 +268,12 @@ def serialized_fn(mtf_features): tf_update_ops = [lowering.lowered_operation(op) for op in update_ops] tf_update_ops.append(tf.assign_add(global_step, 1)) # Need to manually increment global_step train_op = tf.group(tf_update_ops) + with mtf.utils.outside_all_rewrites(): + # only *now* can we initialize vae weights (stupid tensorflow) + initialize_vae_weights(vae_checkpoint_path) + # Copy master variables to slices. Must be called first. restore_hook = mtf.MtfRestoreHook(lowering) if mode == tf.estimator.ModeKeys.TRAIN: From 2c14bde296835ae7c16b4dfca6da712de9fb4c84 Mon Sep 17 00:00:00 2001 From: connor Date: Mon, 18 Jan 2021 19:45:25 +0000 Subject: [PATCH 11/43] fix parameter count logging --- src/utils/utils.py | 8 ++++---- 1 file changed, 4 insertions(+), 4 deletions(-) diff --git a/src/utils/utils.py b/src/utils/utils.py index e1ea317..70869d5 100644 --- a/src/utils/utils.py +++ b/src/utils/utils.py @@ -68,7 +68,7 @@ def get_n_trainable_vars(graph): for dim in shape: variable_parameters *= dim.size total_parameters += variable_parameters - print(f"\n\nN PARAMS:\n{total_parameters:,}\n\n") + tf.logging.info(f"\n\nN PARAMS:\n{total_parameters:,}\n\n") def print_dim_names(graph): @@ -85,10 +85,10 @@ def print_dim_names(graph): # Print all dim names in graph & write to file all_dim_names = [item for sublist in all_dim_names for item in sublist] # Flatten all dims unique_dims = list(set(all_dim_names)) - print("ALL DIM NAMES:") + tf.logging.info("ALL DIM NAMES:") for dim_name in unique_dims: - print(dim_name) - print('\n') + tf.logging.info(dim_name) + tf.logging.info('\n') def get_graph_info(graph): From 130c26e3126659f1724d8c1bafad85b85b7c09dc Mon Sep 17 00:00:00 2001 From: connor Date: Mon, 18 Jan 2021 19:46:22 +0000 Subject: [PATCH 12/43] fix image vocab size --- configs/dalle_coco.json | 2 +- 1 file changed, 1 insertion(+), 1 deletion(-) diff --git a/configs/dalle_coco.json b/configs/dalle_coco.json index d71e6a6..0454b98 100644 --- a/configs/dalle_coco.json +++ b/configs/dalle_coco.json @@ -22,7 +22,7 @@ "layout": "batch_dim:data,embed_dim:model", "n_embd": 1024, "text_vocab_size": 50258, - "image_vocab_size": 512, + "image_vocab_size": 2048, "text_seq_len": 256, "n_layers": 12, "n_heads": 8, From 4652ef24f26b6893c5671ff7f70f3fc1096c40c2 Mon Sep 17 00:00:00 2001 From: connor Date: Tue, 19 Jan 2021 15:30:47 +0000 Subject: [PATCH 13/43] revert to separate embeddings for image and text --- src/dalle_mtf/models.py | 74 ++++++++++++++++++++++++++++++++--------- src/dalle_mtf/ops.py | 6 ++++ src/dalle_mtf/sample.py | 58 +++++++++++++++++--------------- src/input_fns.py | 10 +++--- src/model_fns.py | 55 ++++++++++++++---------------- 5 files changed, 126 insertions(+), 77 deletions(-) diff --git a/src/dalle_mtf/models.py b/src/dalle_mtf/models.py index f55ea54..20841cd 100644 --- a/src/dalle_mtf/models.py +++ b/src/dalle_mtf/models.py @@ -5,7 +5,7 @@ from collections import defaultdict import math -from .ops import pad, exists, get_variable_dtype +from .ops import pad, exists, get_variable_dtype, expand_tile from .layers import gumbel_softmax, mse_loss, norm @@ -154,13 +154,16 @@ def __init__(self, n_embd, text_vocab_size=12800, image_vocab_size=512, text_seq self.n_layers = n_layers self.n_heads = n_heads self.attn_mask = attn_mask + self.logits_mask = None self.total_tokens = text_vocab_size + image_vocab_size + 1 # extra for EOS self.eos_token_id = self.total_tokens - 1 if eos_token_id is None else eos_token_id self.dimensions = {"embed_dim": mtf.Dimension("embed_dim", n_embd), "text_vocab_dim": mtf.Dimension("vocab_dim", text_vocab_size), "image_vocab_dim": mtf.Dimension("vocab_dim", image_vocab_size), "final_vocab_dim": mtf.Dimension("vocab_dim", self.total_tokens), - "total_seq_dim": mtf.Dimension("total_seq_dim", self.total_seq_dim), + "text_sequence_dim": mtf.Dimension("sequence_dim", text_seq_len), + "image_sequence_dim": mtf.Dimension("sequence_dim", image_seq_len), + "total_seq_dim": mtf.Dimension("sequence_dim", self.total_seq_dim), "embed_seq_dim": mtf.Dimension("embed_seq_dim", self.total_seq_dim), "memory_len_dim": mtf.Dimension("memory_len_dim", self.total_seq_dim), "heads_dim": mtf.Dimension("heads", n_heads), @@ -186,7 +189,10 @@ def __init__(self, n_embd, text_vocab_size=12800, image_vocab_size=512, text_seq def embedding(self, x, name): embd_dim = self.dimensions["embed_dim"] - vocab_dim = self.dimensions["final_vocab_dim"] + if "text" in name: + vocab_dim = self.dimensions["text_vocab_dim"] + else: + vocab_dim = self.dimensions["image_vocab_dim"] with tf.variable_scope(name): wte = mtf.get_variable(x.mesh, "wte", mtf.Shape([vocab_dim, embd_dim]), @@ -202,6 +208,10 @@ def embedding(self, x, name): return x def positional_embedding(self, x, name): + if "text" in name: + sequence_dim = self.dimensions["text_sequence_dim"] + else: + sequence_dim = self.dimensions["image_sequence_dim"] with tf.variable_scope(name): # Positional embedding wpe = mtf.get_variable(x.mesh, "wpe", @@ -210,7 +220,7 @@ def positional_embedding(self, x, name): master_dtype=self.variable_dtype.master_dtype, slice_dtype=self.variable_dtype.slice_dtype, activation_dtype=self.variable_dtype.activation_dtype) - position_indices = mtf.range(x.mesh, self.dimensions["total_seq_dim"], tf.int64) if not \ + position_indices = mtf.range(x.mesh, sequence_dim, tf.int64) if not \ self.is_incremental_inference else (self.context.position - 1) pos_emb = mtf.gather(wpe, position_indices, wpe.shape[0]) embed_dropout = self.params.get("embed_dropout", 0) @@ -226,10 +236,25 @@ def get_attn_mask(self, mesh, nd, ns): i, j = map(lambda t: mtf.broadcast(t, [nd, ns]), (i, j)) self.attn_mask = mtf.cast(mtf.less(i, j), self.variable_dtype.activation_dtype) * -1e10 return self.attn_mask + + def get_logits_mask(self, mesh): + if not exists(self.logits_mask): + t = mtf.ones(mesh, mtf.Shape([self.dimensions['text_vocab_dim']]), tf.int32) + i = mtf.zeros(mesh, mtf.Shape([self.dimensions['image_vocab_dim']]), tf.int32) + eos = mtf.ones(mesh, mtf.Shape([mtf.Dimension(self.dimensions['image_vocab_dim'].name, 1)]), tf.int32) + logits_mask = mtf.concat([t,i], self.dimensions['image_vocab_dim'].name) + logits_mask = mtf.concat([logits_mask, eos], self.dimensions['image_vocab_dim'].name) + new_shape = mtf.Shape([self.dimensions['batch_dim'], self.dimensions['total_seq_dim'], logits_mask.shape.dims[-1]]) + logits_mask = mtf.broadcast(logits_mask, new_shape) + logits_mask = mtf.cast(mtf.equal(logits_mask, 1), tf.float32) * -1e10 + logits_mask += 1 + self.logits_mask = logits_mask + return self.logits_mask def attention(self, x, n_state, mask, attention_type="global", name="attn"): if not self.is_incremental_inference: # x :: [batch, seq, n_embd] + print(x.shape) batch_dim, seq_dim, embd_dim = x_shape = x.shape else: batch_dim, embd_dim = x_shape = x.shape @@ -380,28 +405,47 @@ def to_logits(self, x): with tf.variable_scope("to_logits"): logits = self.linear(self.layer_norm(x), self.dimensions["final_vocab_dim"], name="linear_out") # Go to full precision for the logits + if self.is_incremental_inference: + # add seq dim in inference mode + logits = expand_tile(logits, mtf.Dimension("sequence_dim", 1), axis=1) return mtf.cast(logits, tf.float32) + def shift_labels(self, labels): + print(labels.shape) + labels = pad(labels, [0, 1], dim_name="sequence_dim", pad_value=self.eos_token_id) + indices = mtf.range(labels.mesh, mtf.Dimension("range", labels.shape[1].size - 1), tf.int32, name="labels_indices") + 1 + labels = mtf.gather(labels, indices, dim=labels.shape[1]) + labels = mtf.rename_dimension(labels, "range", "sequence_dim") + return labels + def forward(self, features, return_loss=True, return_logits=False): - inputs = features["tokens"] - if self.is_incremental_inference: + if features.get('text_inputs') is not None: + text = features["text_inputs"] + text_emb = self.positional_embedding(self.embedding(text, "text_embd"), "text_pos_emb") + else: + assert self.is_incremental_inference + image = features.get("image_inputs", None) + if not self.is_incremental_inference: + image_emb = self.positional_embedding(self.embedding(image, "image_embd"), "image_pos_emb") + tokens = mtf.concat([text_emb, image_emb], concat_dim_name="sequence_dim") # [batch, seq, n_embd] + else: # reshape inputs if in inference mode - inputs = mtf.gather(inputs, self.context.position - 1, self.dimensions['total_seq_dim']) - inputs = mtf.reshape(inputs, [self.dimensions['batch_dim']]) + image = mtf.gather(image, self.context.position - 1, self.dimensions["image_sequence_dim"]) + image = mtf.reshape(image, [self.dimensions["batch_dim"]]) + tokens = self.positional_embedding(self.embedding(image, "image_embd"), "image_pos_emb") - tokens = self.positional_embedding(self.embedding(inputs, "embedding"), "positional_embedding") - - mask = self.get_attn_mask(tokens.mesh, tokens.shape[1], self.dimensions["memory_len_dim"]) + mask = self.get_attn_mask(tokens.mesh, self.dimensions["total_seq_dim"], self.dimensions["memory_len_dim"]) out = self.transformer(tokens, mask=mask) logits = self.to_logits(out) + logits *= self.get_logits_mask(tokens.mesh) if not return_loss: logits = mtf.cast(logits, self.variable_dtype.master_dtype) return logits - labels = pad(inputs, [0, 1], dim_name="total_seq_dim", pad_value=self.eos_token_id) - indices = mtf.range(labels.mesh, mtf.Dimension("range", labels.shape[1].size - 1), tf.int32, name="labels_indices") + 1 - labels = mtf.gather(labels, indices, dim=labels.shape[1]) - labels = mtf.rename_dimension(labels, "range", "total_seq_dim") + assert exists(image), 'when training, image must be supplied' + offset_image = image + self.text_vocab_size + labels = mtf.concat([text, offset_image], concat_dim_name="sequence_dim") + labels = self.shift_labels(labels) loss, loss_batch = self._loss(logits, labels) if return_logits and return_loss: # Cast back to checkpoint dtype diff --git a/src/dalle_mtf/ops.py b/src/dalle_mtf/ops.py index f679170..e62eff6 100644 --- a/src/dalle_mtf/ops.py +++ b/src/dalle_mtf/ops.py @@ -80,3 +80,9 @@ def get_variable_dtype(bf_16=True): return mtf.VariableDType(master_dtype=tf.bfloat16, slice_dtype=tf.float32, activation_dtype=tf.bfloat16) else: return mtf.VariableDType(master_dtype=tf.float32, slice_dtype=tf.float32, activation_dtype=tf.float32) + +def expand_tile(value, newdim, axis=0): + """Add a new axis of given size.""" + new_shape = value.shape.dims + new_shape.insert(axis, newdim) + return mtf.broadcast(value, new_shape) # shape.dims gets us a list which we need in order to concat diff --git a/src/dalle_mtf/sample.py b/src/dalle_mtf/sample.py index e184ac8..331c90f 100644 --- a/src/dalle_mtf/sample.py +++ b/src/dalle_mtf/sample.py @@ -26,7 +26,7 @@ def sample_autoregressive(inputs, has_partial_sequences=False (so we can skip computation). Args: - inputs: an int32 Tensor with shape [, length_dim], + inputs: an input dictionary containing 'text_inputs' and 'image_inputs', model: DALL-E model params: model paramers. stop_at_token: an optional integer eos id. Stop when we produce it. @@ -48,24 +48,29 @@ def sample_autoregressive(inputs, # with dalle, inputs will be a text sequence of len 256, then the rest image tokens. # the parts we want to fill in will be <|pad_token|>, which we should assign in the input - batch_dims = inputs.shape.dims[:-1] - length_dim = inputs.shape.dims[-1] + batch_dims = model.dimensions["batch_dim"] + length_dim = model.dimensions["total_seq_dim"] + image_seq_dim = model.dimensions['image_sequence_dim'] padding_id = params.get("padding_id", 0) + image_inputs = inputs['image_inputs'] + text_inputs = inputs['text_inputs'] + # Gets position (in image inputs) where zero padding starts initial_position = mtf.reduce_sum( - mtf.to_int32(mtf.not_equal(inputs, padding_id)), - reduced_dim=length_dim) # Gets position where zero padding starts + mtf.to_int32(mtf.not_equal(image_inputs, padding_id)), + reduced_dim=image_seq_dim) + # initial_position += model.dimensions['text_seq_dim'].size - length_range = mtf.range(inputs.mesh, length_dim, tf.int32) + length_range = mtf.range(image_inputs.mesh, image_seq_dim, tf.int32) # Builds context to pass around internally # The 'first part' context records initial states of k / v / x context_first_part = mtf_transformer.transformer.Context( model=None, - mesh=inputs.mesh, + mesh=image_inputs.mesh, batch_dims=batch_dims, - length_dim=length_dim, + length_dim=image_seq_dim, variable_dtype=variable_dtype, mode="first_part", position=length_range, @@ -78,7 +83,7 @@ def sample_autoregressive(inputs, model.context = context_first_part with tf.variable_scope('dall-e'): - logits = model.forward({'tokens': inputs}, return_loss=False, return_logits=True) + logits = model.forward(inputs, return_loss=False, return_logits=True) del logits if not has_partial_sequences: @@ -91,12 +96,12 @@ def sample_autoregressive(inputs, if stop_at_token is not None: partial_sequences_eos_count = mtf.reduce_sum( - mtf.to_int32(mtf.equal(inputs, stop_at_token)), - reduced_dim=length_dim) + mtf.to_int32(mtf.equal(image_inputs, stop_at_token)), + reduced_dim=image_seq_dim) def cond_fn(position, ids, *unused_states): """Should we run another loop iteration?""" - past_end = mtf.greater_equal(position, length_dim.size) + past_end = mtf.greater_equal(position, image_seq_dim.size) if max_steps: past_end = mtf.logical_or( past_end, mtf.greater_equal(position - initial_position, max_steps)) @@ -105,7 +110,7 @@ def cond_fn(position, ids, *unused_states): if stop_at_token is not None: eos_count = mtf.reduce_sum( mtf.to_int32(mtf.equal(ids, stop_at_token)), - reduced_dim=length_dim) + reduced_dim=image_seq_dim) has_additional_eos = mtf.greater(eos_count, partial_sequences_eos_count) is_done = mtf.logical_or(is_done, has_additional_eos) all_done = mtf.reduce_all(is_done) @@ -117,9 +122,9 @@ def body_fn(position, ids, *states): context = mtf_transformer.transformer.Context( model=None, - mesh=inputs.mesh, + mesh=image_inputs.mesh, batch_dims=batch_dims, - length_dim=length_dim, + length_dim=image_seq_dim, variable_dtype=variable_dtype, mode="incremental", position=position, @@ -133,7 +138,7 @@ def body_fn(position, ids, *states): model.is_incremental_inference = True model.context = context with tf.variable_scope("dall-e", reuse=tf.AUTO_REUSE): - logits = model.forward({'tokens': inputs}, return_loss=False, return_logits=True) + logits = model.forward({'image_inputs': image_inputs}, return_loss=False, return_logits=True) # By default, do top_k sampling of 0.9 if sampling_keep_top_k == -2: @@ -151,10 +156,9 @@ def body_fn(position, ids, *states): # temperature sampling ids_this_step = mtf.sample_with_temperature( logits, model.dimensions['final_vocab_dim'], temperature) - # reshape & assign results - ids_this_step = mtf.reshape(ids_this_step, batch_dims) - one_hot = mtf.one_hot(position, length_dim, dtype=tf.int32) + ids_this_step = mtf.reshape(ids_this_step, ([batch_dims])) + one_hot = mtf.one_hot(position, image_seq_dim, dtype=tf.int32) one_new_id = ids_this_step * one_hot new_ids = (1 - one_hot) * ids + one_new_id new_position = position + 1 @@ -162,15 +166,15 @@ def body_fn(position, ids, *states): ret += context.new_states return ret - while_loop_inputs = [initial_position, inputs] + initial_states + while_loop_inputs = [initial_position, image_inputs] + initial_states final_position, outputs = mtf.while_loop( cond_fn, body_fn, while_loop_inputs)[:2] del final_position - if has_partial_sequences and remove_partial_sequences: - # Remove partial sequences from outputs - partial_length = mtf.reduce_sum( - mtf.to_int32(mtf.not_equal(inputs, padding_id)), - reduced_dim=length_dim) - outputs = mtf.dynamic_shift( - outputs, -partial_length, length_dim, wrap=False) + # if has_partial_sequences and remove_partial_sequences: + # # Remove partial sequences from outputs + # partial_length = mtf.reduce_sum( + # mtf.to_int32(mtf.not_equal(image_inputs, padding_id)), + # reduced_dim=image_seq_dim) + # outputs = mtf.dynamic_shift( + # outputs, -partial_length, image_seq_dim, wrap=False) return outputs diff --git a/src/input_fns.py b/src/input_fns.py index b07098a..ea19896 100644 --- a/src/input_fns.py +++ b/src/input_fns.py @@ -46,9 +46,9 @@ def pred_input(params, tokenizer, prompt='a cat in a hat'): tf.logging.info("The length of your input prompt is longer than the model's text context length - truncating " "input.") tokens = tokens[len(tokens) - params["text_seq_len"]:] # TODO: left or right truncate here? - if len(tokens) < params["total_seq_len"]: - tokens = tf.pad(tokens, [[0, params["total_seq_len"] - len(tokens)]], constant_values=params["padding_id"]) - t = tf.broadcast_to(tokens, [params["batch_size"], params["total_seq_len"]]) + if len(tokens) < params["text_seq_len"]: + tokens = tf.pad(tokens, [[0, params["text_seq_len"] - len(tokens)]], constant_values=params["padding_id"]) + t = tf.broadcast_to(tokens, [params["batch_size"], params["text_seq_len"]]) dataset = tf.data.Dataset.from_tensors(t) def _dummy_labels(x): @@ -62,8 +62,8 @@ def pred_output(predictions, out_name='test'): for i, p in enumerate(predictions): denormalize = lambda x: (((x + 1) / 2) * 255.0).astype(np.uint8) # to debug: - # with open(f"{out_name}_{i}.txt", 'w') as f: - # f.write(str(p["outputs"].tolist())) + with open(f"{out_name}_{i}.txt", 'w') as f: + f.write(str(p["outputs"].tolist())) imageio.imwrite(f"{out_name}_{i}.jpeg", denormalize(p["predictions_decoded"])) diff --git a/src/model_fns.py b/src/model_fns.py index 52cf870..4541150 100644 --- a/src/model_fns.py +++ b/src/model_fns.py @@ -110,47 +110,45 @@ def dalle_model_fn(features, labels, mode, params): params=params, ) - # Build mtf_features & seq length dict for getting number of microbatches - # We need to pack inputs into a dict to pass into serialize_training_step if mode in [tf.estimator.ModeKeys.TRAIN, tf.estimator.ModeKeys.EVAL]: - features_dict = {"image_inputs": features, - "text_inputs": labels} + # Build mtf_features & seq length dict for getting number of microbatches + # We need to pack inputs into a dict to pass into serialize_training_step + features_dict = {"image_inputs": img_tokens_reshaped, + "text_inputs": labels} mtf_features = {} for key, x in features_dict.items(): if x is not None: if key == "text_inputs": - text_tokens = tf.reshape(x, [batch_size, params["text_seq_len"]]) - x = tf.concat((text_tokens, img_tokens_reshaped + model.text_vocab_size), axis=1) - mtf_shape = mtf.Shape([model.dimensions["batch_dim"], model.dimensions["total_seq_dim"]]) - + x = tf.reshape(x, [batch_size, params["text_seq_len"]]) + mtf_shape = mtf.Shape([model.dimensions["batch_dim"], model.dimensions["text_sequence_dim"]]) mtf_features["tokens"] = mtf.import_fully_replicated(mesh, x, mtf_shape, name=key) - if key == "image_inputs": mtf_shape = mtf.Shape([ model.dimensions["batch_dim"], - mtf.Dimension("img_height_dim", vae.H), - mtf.Dimension("img_width_dim", vae.W), - mtf.Dimension("img_channel_dim", vae.num_ch), + model.dimensions["image_sequence_dim"], ]) - x = tf.reshape(x, [batch_size, H, W, n_channels]) # NHWC - mtf_features["image_inputs"] = mtf.import_fully_replicated(mesh, x, mtf_shape, name=key) - denormalize = lambda x: (x + 1) / 2 - scalar_summary("input_image", denormalize(mtf_features["image_inputs"])) + mtf_features[key] = mtf.import_fully_replicated(mesh, x, mtf_shape, name=key) else: - features_dict = {"text_inputs": labels} + # Build mtf_features & seq length dict for getting number of microbatches + # We need to pack inputs into a dict to pass into serialize_training_step + features_dict = {"text_inputs": labels, 'image_inputs': 'None'} mtf_features = {} for key, x in features_dict.items(): if x is not None: if key == "text_inputs": - text_tokens = tf.reshape(x, [batch_size, params["total_seq_len"]]) - mtf_shape = mtf.Shape([model.dimensions["batch_dim"], model.dimensions["total_seq_dim"]]) - mtf_features["tokens"] = mtf.import_fully_replicated(mesh, text_tokens, mtf_shape, name=key) + x = tf.reshape(x, [batch_size, params["text_seq_len"]]) + mtf_shape = mtf.Shape([model.dimensions["batch_dim"], model.dimensions["text_sequence_dim"]]) + mtf_features["tokens"] = mtf.import_fully_replicated(mesh, x, mtf_shape, name=key) + mtf_features[key] = mtf.import_fully_replicated(mesh, x, mtf_shape, name=key) + if key == "image_inputs": + mtf_shape = mtf.Shape([ + model.dimensions["batch_dim"], + model.dimensions["image_sequence_dim"], + ]) + mtf_features[key] = mtf.zeros(mesh, mtf_shape, tf.int32) + params['padding_id'] - if mode == tf.estimator.ModeKeys.PREDICT: # Set up the model for prediction - inputs = mtf_features["tokens"] - - mtf_samples = sample_autoregressive(inputs, + mtf_samples = sample_autoregressive(mtf_features, model, params, stop_at_token=model.eos_token_id, @@ -163,22 +161,19 @@ def dalle_model_fn(features, labels, mode, params): ) mtf_samples = mtf.anonymize(mtf_samples) - inputs = mtf.anonymize(inputs) lowering = mtf.Lowering(graph, {mesh: mesh_impl}, autostack=params.get('autostack', True)) - inputs = lowering.export_to_tf_tensor(inputs) outputs = lowering.export_to_tf_tensor(mtf_samples) initialize_vae_weights(vae_checkpoint_path) - img_outputs = outputs[:, -model.image_seq_len:] - model.text_vocab_size + outputs -= model.text_vocab_size with tf.variable_scope('vae'): - predictions_decoded = vae.decode(img_outputs) + predictions_decoded = vae.decode(outputs) predictions = { - "inputs": inputs, - "outputs": img_outputs, + "outputs": outputs, "predictions_decoded": predictions_decoded } From 67247cfe84de21bec6cae3d2bf219c34c91c770f Mon Sep 17 00:00:00 2001 From: Leo Gao Date: Tue, 19 Jan 2021 16:59:12 -0700 Subject: [PATCH 14/43] Fix masking --- src/dalle_mtf/models.py | 9 +++++++-- 1 file changed, 7 insertions(+), 2 deletions(-) diff --git a/src/dalle_mtf/models.py b/src/dalle_mtf/models.py index 20841cd..31b5f77 100644 --- a/src/dalle_mtf/models.py +++ b/src/dalle_mtf/models.py @@ -239,15 +239,20 @@ def get_attn_mask(self, mesh, nd, ns): def get_logits_mask(self, mesh): if not exists(self.logits_mask): + + # mask for image section: 1 means *masked* t = mtf.ones(mesh, mtf.Shape([self.dimensions['text_vocab_dim']]), tf.int32) i = mtf.zeros(mesh, mtf.Shape([self.dimensions['image_vocab_dim']]), tf.int32) eos = mtf.ones(mesh, mtf.Shape([mtf.Dimension(self.dimensions['image_vocab_dim'].name, 1)]), tf.int32) logits_mask = mtf.concat([t,i], self.dimensions['image_vocab_dim'].name) logits_mask = mtf.concat([logits_mask, eos], self.dimensions['image_vocab_dim'].name) new_shape = mtf.Shape([self.dimensions['batch_dim'], self.dimensions['total_seq_dim'], logits_mask.shape.dims[-1]]) + + + logits_mask = mtf.broadcast(logits_mask, new_shape) logits_mask = mtf.cast(mtf.equal(logits_mask, 1), tf.float32) * -1e10 - logits_mask += 1 + self.logits_mask = logits_mask return self.logits_mask @@ -437,7 +442,7 @@ def forward(self, features, return_loss=True, return_logits=False): mask = self.get_attn_mask(tokens.mesh, self.dimensions["total_seq_dim"], self.dimensions["memory_len_dim"]) out = self.transformer(tokens, mask=mask) logits = self.to_logits(out) - logits *= self.get_logits_mask(tokens.mesh) + logits += self.get_logits_mask(tokens.mesh) if not return_loss: logits = mtf.cast(logits, self.variable_dtype.master_dtype) return logits From a0a28284114555a3bc9d8943085a0ee2d9dcec1c Mon Sep 17 00:00:00 2001 From: Leo Gao Date: Tue, 19 Jan 2021 17:10:21 -0700 Subject: [PATCH 15/43] Ignore text tokens in loss computation --- src/dalle_mtf/models.py | 2 +- 1 file changed, 1 insertion(+), 1 deletion(-) diff --git a/src/dalle_mtf/models.py b/src/dalle_mtf/models.py index 31b5f77..e2c1693 100644 --- a/src/dalle_mtf/models.py +++ b/src/dalle_mtf/models.py @@ -365,7 +365,7 @@ def transformer(self, x, mask): def _loss(self, logits, labels): with tf.variable_scope("loss_final"): - loss_batch = self.loss_fn(logits=logits, targets=labels, + loss_batch = self.loss_fn(logits=logits[:, self.text_seq_len:], targets=labels[:, self.text_seq_len:], vocab_dim=logits.shape[-1], z_loss=0.0) with tf.variable_scope("reduce_mean_final"): From ca23b855610d77cf9fdc0301ed03023ddc6dea53 Mon Sep 17 00:00:00 2001 From: Leo Gao Date: Tue, 19 Jan 2021 18:00:29 -0700 Subject: [PATCH 16/43] Fix slicing for mtf --- src/dalle_mtf/models.py | 3 ++- 1 file changed, 2 insertions(+), 1 deletion(-) diff --git a/src/dalle_mtf/models.py b/src/dalle_mtf/models.py index e2c1693..f12c2b7 100644 --- a/src/dalle_mtf/models.py +++ b/src/dalle_mtf/models.py @@ -365,7 +365,8 @@ def transformer(self, x, mask): def _loss(self, logits, labels): with tf.variable_scope("loss_final"): - loss_batch = self.loss_fn(logits=logits[:, self.text_seq_len:], targets=labels[:, self.text_seq_len:], + loss_batch = self.loss_fn(logits =mtf.slice(logits, begin=self.text_seq_len, size=self.image_seq_len, slice_dim_name="sequence_dim"), + targets=mtf.slice(labels, begin=self.text_seq_len, size=self.image_seq_len, slice_dim_name="sequence_dim"), vocab_dim=logits.shape[-1], z_loss=0.0) with tf.variable_scope("reduce_mean_final"): From e126b7995b54685410bc7bfba9f6e16e9efe7234 Mon Sep 17 00:00:00 2001 From: Leo Gao Date: Tue, 19 Jan 2021 20:09:49 -0700 Subject: [PATCH 17/43] Fix sampling Because of the mismatch between overall length and image length, it would end up broadcasting into a weird shape. This fixes that by slicing the text off first. --- src/dalle_mtf/sample.py | 3 ++- 1 file changed, 2 insertions(+), 1 deletion(-) diff --git a/src/dalle_mtf/sample.py b/src/dalle_mtf/sample.py index 331c90f..c7bf624 100644 --- a/src/dalle_mtf/sample.py +++ b/src/dalle_mtf/sample.py @@ -157,7 +157,8 @@ def body_fn(position, ids, *states): ids_this_step = mtf.sample_with_temperature( logits, model.dimensions['final_vocab_dim'], temperature) # reshape & assign results - ids_this_step = mtf.reshape(ids_this_step, ([batch_dims])) + ids_this_step = mtf.reshape(ids_this_step, ([batch_dims, ids_this_step.shape[-1]])) + ids_this_step = mtf.slice(logits, begin=model.text_seq_len, size=model.image_seq_len, slice_dim_name="sequence_dim") one_hot = mtf.one_hot(position, image_seq_dim, dtype=tf.int32) one_new_id = ids_this_step * one_hot new_ids = (1 - one_hot) * ids + one_new_id From ddeb74dca065149b5e23745c40e0e41aceaac5de Mon Sep 17 00:00:00 2001 From: Leo Gao <54557097+leogao2@users.noreply.github.com> Date: Wed, 20 Jan 2021 09:58:17 -0700 Subject: [PATCH 18/43] Implement incremental logits mask --- src/dalle_mtf/models.py | 19 +++++++++++++------ 1 file changed, 13 insertions(+), 6 deletions(-) diff --git a/src/dalle_mtf/models.py b/src/dalle_mtf/models.py index f12c2b7..b96f656 100644 --- a/src/dalle_mtf/models.py +++ b/src/dalle_mtf/models.py @@ -247,14 +247,21 @@ def get_logits_mask(self, mesh): logits_mask = mtf.concat([t,i], self.dimensions['image_vocab_dim'].name) logits_mask = mtf.concat([logits_mask, eos], self.dimensions['image_vocab_dim'].name) new_shape = mtf.Shape([self.dimensions['batch_dim'], self.dimensions['total_seq_dim'], logits_mask.shape.dims[-1]]) + new_shape_incremental = mtf.Shape([self.dimensions['batch_dim'], mtf.Dimension(self.dimensions['total_seq_dim'].name, 1), logits_mask.shape.dims[-1]]) + logits_mask_full = mtf.broadcast(logits_mask, new_shape) + logits_mask_full = mtf.cast(mtf.equal(logits_mask, 1), tf.float32) * -1e10 + + logits_mask_incremental = mtf.broadcast(logits_mask, new_shape_incremental) + logits_mask_incremental = mtf.cast(mtf.equal(logits_mask, 1), tf.float32) * -1e10 - - logits_mask = mtf.broadcast(logits_mask, new_shape) - logits_mask = mtf.cast(mtf.equal(logits_mask, 1), tf.float32) * -1e10 - - self.logits_mask = logits_mask - return self.logits_mask + self.logits_mask = logits_mask_full + self.logits_mask_incremental = logits_mask_incremental + + if self.is_incremental_inference: + return self.logits_mask_incremental + else: + return self.logits_mask def attention(self, x, n_state, mask, attention_type="global", name="attn"): if not self.is_incremental_inference: From e7eb45957e3344318f8b0653e048e3957bedfd9e Mon Sep 17 00:00:00 2001 From: Leo Gao <54557097+leogao2@users.noreply.github.com> Date: Wed, 20 Jan 2021 10:04:35 -0700 Subject: [PATCH 19/43] Fix typo --- src/dalle_mtf/models.py | 4 ++-- 1 file changed, 2 insertions(+), 2 deletions(-) diff --git a/src/dalle_mtf/models.py b/src/dalle_mtf/models.py index b96f656..7fb5084 100644 --- a/src/dalle_mtf/models.py +++ b/src/dalle_mtf/models.py @@ -250,10 +250,10 @@ def get_logits_mask(self, mesh): new_shape_incremental = mtf.Shape([self.dimensions['batch_dim'], mtf.Dimension(self.dimensions['total_seq_dim'].name, 1), logits_mask.shape.dims[-1]]) logits_mask_full = mtf.broadcast(logits_mask, new_shape) - logits_mask_full = mtf.cast(mtf.equal(logits_mask, 1), tf.float32) * -1e10 + logits_mask_full = mtf.cast(mtf.equal(logits_mask_incremental, 1), tf.float32) * -1e10 logits_mask_incremental = mtf.broadcast(logits_mask, new_shape_incremental) - logits_mask_incremental = mtf.cast(mtf.equal(logits_mask, 1), tf.float32) * -1e10 + logits_mask_incremental = mtf.cast(mtf.equal(logits_mask_incremental, 1), tf.float32) * -1e10 self.logits_mask = logits_mask_full self.logits_mask_incremental = logits_mask_incremental From 29f30067ddf57bff6d729c33eb171482f9667003 Mon Sep 17 00:00:00 2001 From: connor Date: Thu, 21 Jan 2021 01:40:41 +0000 Subject: [PATCH 20/43] revert changes to sample.py --- src/dalle_mtf/sample.py | 3 +-- 1 file changed, 1 insertion(+), 2 deletions(-) diff --git a/src/dalle_mtf/sample.py b/src/dalle_mtf/sample.py index c7bf624..331c90f 100644 --- a/src/dalle_mtf/sample.py +++ b/src/dalle_mtf/sample.py @@ -157,8 +157,7 @@ def body_fn(position, ids, *states): ids_this_step = mtf.sample_with_temperature( logits, model.dimensions['final_vocab_dim'], temperature) # reshape & assign results - ids_this_step = mtf.reshape(ids_this_step, ([batch_dims, ids_this_step.shape[-1]])) - ids_this_step = mtf.slice(logits, begin=model.text_seq_len, size=model.image_seq_len, slice_dim_name="sequence_dim") + ids_this_step = mtf.reshape(ids_this_step, ([batch_dims])) one_hot = mtf.one_hot(position, image_seq_dim, dtype=tf.int32) one_new_id = ids_this_step * one_hot new_ids = (1 - one_hot) * ids + one_new_id From b69ff7203276fb43cb571ac3619cf62db3a794ff Mon Sep 17 00:00:00 2001 From: connor Date: Thu, 21 Jan 2021 03:15:11 +0000 Subject: [PATCH 21/43] add mask to bias op --- src/dalle_mtf/ops.py | 12 ++++++++++++ 1 file changed, 12 insertions(+) diff --git a/src/dalle_mtf/ops.py b/src/dalle_mtf/ops.py index e62eff6..121b7f3 100644 --- a/src/dalle_mtf/ops.py +++ b/src/dalle_mtf/ops.py @@ -86,3 +86,15 @@ def expand_tile(value, newdim, axis=0): new_shape = value.shape.dims new_shape.insert(axis, newdim) return mtf.broadcast(value, new_shape) # shape.dims gets us a list which we need in order to concat + +def mask_to_bias(visible, dtype): + """Convert a boolean visibility mask to an attention bias. + The returned Tensor has large negative values in positions where + visible=False. + Args: + visible: a boolean Tensor + dtype: a dtype + Returns: + a Tensor with the given dtype and the same shape as "visible" + """ + return mtf.cast(mtf.logical_not(visible), dtype) * -1e9 From 7e3b6ff98b81ff0f0c88b8284a8808a5962e056e Mon Sep 17 00:00:00 2001 From: connor Date: Thu, 21 Jan 2021 03:18:08 +0000 Subject: [PATCH 22/43] update mask. (still not working :( ) --- src/dalle_mtf/models.py | 42 +++++++++++++---------------------------- src/model_fns.py | 12 ++++++++++++ 2 files changed, 25 insertions(+), 29 deletions(-) diff --git a/src/dalle_mtf/models.py b/src/dalle_mtf/models.py index 7fb5084..9e3d25a 100644 --- a/src/dalle_mtf/models.py +++ b/src/dalle_mtf/models.py @@ -5,7 +5,7 @@ from collections import defaultdict import math -from .ops import pad, exists, get_variable_dtype, expand_tile +from .ops import pad, exists, get_variable_dtype, expand_tile, mask_to_bias from .layers import gumbel_softmax, mse_loss, norm @@ -140,11 +140,12 @@ def forward(self, features, return_recon_loss=False, return_logits=False, hard_g class DALLE: - def __init__(self, n_embd, text_vocab_size=12800, image_vocab_size=512, text_seq_len=256, image_seq_len=1024, + def __init__(self, mesh, n_embd, text_vocab_size=12800, image_vocab_size=512, text_seq_len=256, image_seq_len=1024, n_layers=6, n_heads=8, batch_size=32, bf_16=True, attn_mask=None, mode="train", is_incremental_inference=False, context=None, loss_fn=None, params=None, eos_token_id=None, activation_fn=None): + self.mesh = mesh self.n_embd = n_embd self.text_vocab_size = text_vocab_size self.image_vocab_size = image_vocab_size @@ -156,7 +157,7 @@ def __init__(self, n_embd, text_vocab_size=12800, image_vocab_size=512, text_seq self.attn_mask = attn_mask self.logits_mask = None self.total_tokens = text_vocab_size + image_vocab_size + 1 # extra for EOS - self.eos_token_id = self.total_tokens - 1 if eos_token_id is None else eos_token_id + self.eos_token_id = self.total_tokens - 1 if eos_token_id is None else eos_token_id self.dimensions = {"embed_dim": mtf.Dimension("embed_dim", n_embd), "text_vocab_dim": mtf.Dimension("vocab_dim", text_vocab_size), "image_vocab_dim": mtf.Dimension("vocab_dim", image_vocab_size), @@ -237,31 +238,10 @@ def get_attn_mask(self, mesh, nd, ns): self.attn_mask = mtf.cast(mtf.less(i, j), self.variable_dtype.activation_dtype) * -1e10 return self.attn_mask - def get_logits_mask(self, mesh): - if not exists(self.logits_mask): - - # mask for image section: 1 means *masked* - t = mtf.ones(mesh, mtf.Shape([self.dimensions['text_vocab_dim']]), tf.int32) - i = mtf.zeros(mesh, mtf.Shape([self.dimensions['image_vocab_dim']]), tf.int32) - eos = mtf.ones(mesh, mtf.Shape([mtf.Dimension(self.dimensions['image_vocab_dim'].name, 1)]), tf.int32) - logits_mask = mtf.concat([t,i], self.dimensions['image_vocab_dim'].name) - logits_mask = mtf.concat([logits_mask, eos], self.dimensions['image_vocab_dim'].name) - new_shape = mtf.Shape([self.dimensions['batch_dim'], self.dimensions['total_seq_dim'], logits_mask.shape.dims[-1]]) - new_shape_incremental = mtf.Shape([self.dimensions['batch_dim'], mtf.Dimension(self.dimensions['total_seq_dim'].name, 1), logits_mask.shape.dims[-1]]) - - logits_mask_full = mtf.broadcast(logits_mask, new_shape) - logits_mask_full = mtf.cast(mtf.equal(logits_mask_incremental, 1), tf.float32) * -1e10 - - logits_mask_incremental = mtf.broadcast(logits_mask, new_shape_incremental) - logits_mask_incremental = mtf.cast(mtf.equal(logits_mask_incremental, 1), tf.float32) * -1e10 - - self.logits_mask = logits_mask_full - self.logits_mask_incremental = logits_mask_incremental - - if self.is_incremental_inference: - return self.logits_mask_incremental - else: - return self.logits_mask + def set_logits_mask(self, tf_mask): + mask_shape = mtf.Shape([self.dimensions['total_seq_dim'], self.dimensions['final_vocab_dim']]) + mtf_mask = mask_to_bias(mtf.import_fully_replicated(self.mesh, tf_mask, mask_shape), tf.float32) + self.logits_mask = mtf_mask def attention(self, x, n_state, mask, attention_type="global", name="attn"): if not self.is_incremental_inference: @@ -423,6 +403,7 @@ def to_logits(self, x): logits = expand_tile(logits, mtf.Dimension("sequence_dim", 1), axis=1) return mtf.cast(logits, tf.float32) + def shift_labels(self, labels): print(labels.shape) labels = pad(labels, [0, 1], dim_name="sequence_dim", pad_value=self.eos_token_id) @@ -431,6 +412,7 @@ def shift_labels(self, labels): labels = mtf.rename_dimension(labels, "range", "sequence_dim") return labels + def forward(self, features, return_loss=True, return_logits=False): if features.get('text_inputs') is not None: text = features["text_inputs"] @@ -450,7 +432,9 @@ def forward(self, features, return_loss=True, return_logits=False): mask = self.get_attn_mask(tokens.mesh, self.dimensions["total_seq_dim"], self.dimensions["memory_len_dim"]) out = self.transformer(tokens, mask=mask) logits = self.to_logits(out) - logits += self.get_logits_mask(tokens.mesh) + if not self.is_incremental_inference: + logits += mtf.cast(self.logits_mask, logits.dtype) + if not return_loss: logits = mtf.cast(logits, self.variable_dtype.master_dtype) return logits diff --git a/src/model_fns.py b/src/model_fns.py index 4541150..a77fce1 100644 --- a/src/model_fns.py +++ b/src/model_fns.py @@ -6,8 +6,15 @@ from .utils import mode_to_str, get_graph_info, create_host_call, simd_mesh_setup, scalar_summary from .dalle_mtf import DALLE, sample_autoregressive from .vae_tf import DiscreteVAE +from .dalle_mtf.ops import mask_to_bias from tensorflow.python.ops import resources +def get_tf_logits_mask(text_vocab_size, total_vocab_size, text_seq_len, image_seq_len): + mask_inp = [text_vocab_size for _ in range(text_seq_len)] + x = tf.logical_not(tf.sequence_mask(mask_inp, total_vocab_size)) + mask_inp = [text_vocab_size for _ in range(image_seq_len)] + y = tf.sequence_mask(mask_inp, total_vocab_size) + return tf.concat([x, y], 0) def initialize_vae_weights(checkpoint_path, scope="vae"): """ @@ -97,6 +104,7 @@ def dalle_model_fn(features, labels, mode, params): mesh = mtf.Mesh(graph, "my_mesh", var_placer) model = DALLE( + mesh=mesh, n_embd=params["n_embd"], text_vocab_size=params["text_vocab_size"], image_vocab_size=params["image_vocab_size"], @@ -110,6 +118,10 @@ def dalle_model_fn(features, labels, mode, params): params=params, ) + tf_logits_mask = get_tf_logits_mask(params["text_vocab_size"], params["text_vocab_size"]+params["image_vocab_size"], + params["text_seq_len"], params['image_seq_len']) + model.set_logits_mask(tf_logits_mask) + if mode in [tf.estimator.ModeKeys.TRAIN, tf.estimator.ModeKeys.EVAL]: # Build mtf_features & seq length dict for getting number of microbatches # We need to pack inputs into a dict to pass into serialize_training_step From 34d53269afa48fbac981dfb336f7323a6e0fd54d Mon Sep 17 00:00:00 2001 From: connor Date: Sun, 31 Jan 2021 02:33:59 +0000 Subject: [PATCH 23/43] mask changes --- src/dalle_mtf/models.py | 13 +++++++++---- src/model_fns.py | 22 +++++++++++++++------- 2 files changed, 24 insertions(+), 11 deletions(-) diff --git a/src/dalle_mtf/models.py b/src/dalle_mtf/models.py index 9e3d25a..4a7cd37 100644 --- a/src/dalle_mtf/models.py +++ b/src/dalle_mtf/models.py @@ -240,7 +240,9 @@ def get_attn_mask(self, mesh, nd, ns): def set_logits_mask(self, tf_mask): mask_shape = mtf.Shape([self.dimensions['total_seq_dim'], self.dimensions['final_vocab_dim']]) - mtf_mask = mask_to_bias(mtf.import_fully_replicated(self.mesh, tf_mask, mask_shape), tf.float32) + mtf_mask = mtf.import_fully_replicated(self.mesh, tf_mask, mask_shape) + new_shape = mtf.Shape([self.dimensions['batch_dim'], self.dimensions['total_seq_dim'], self.dimensions['final_vocab_dim']]) + mtf_mask = mtf.broadcast(mtf_mask, new_shape) self.logits_mask = mtf_mask def attention(self, x, n_state, mask, attention_type="global", name="attn"): @@ -405,7 +407,6 @@ def to_logits(self, x): def shift_labels(self, labels): - print(labels.shape) labels = pad(labels, [0, 1], dim_name="sequence_dim", pad_value=self.eos_token_id) indices = mtf.range(labels.mesh, mtf.Dimension("range", labels.shape[1].size - 1), tf.int32, name="labels_indices") + 1 labels = mtf.gather(labels, indices, dim=labels.shape[1]) @@ -432,8 +433,12 @@ def forward(self, features, return_loss=True, return_logits=False): mask = self.get_attn_mask(tokens.mesh, self.dimensions["total_seq_dim"], self.dimensions["memory_len_dim"]) out = self.transformer(tokens, mask=mask) logits = self.to_logits(out) - if not self.is_incremental_inference: - logits += mtf.cast(self.logits_mask, logits.dtype) + if self.is_incremental_inference: + logits_mask = mtf.gather(self.logits_mask, self.context.position + self.text_seq_len - 1, self.logits_mask.shape[1]) + logits_mask = expand_tile(logits_mask, mtf.Dimension("sequence_dim", 1), axis=1) + else: + logits_mask = self.logits_mask + logits += mtf.cast(logits_mask, logits.dtype) if not return_loss: logits = mtf.cast(logits, self.variable_dtype.master_dtype) diff --git a/src/model_fns.py b/src/model_fns.py index a77fce1..590fd90 100644 --- a/src/model_fns.py +++ b/src/model_fns.py @@ -8,13 +8,21 @@ from .vae_tf import DiscreteVAE from .dalle_mtf.ops import mask_to_bias from tensorflow.python.ops import resources +import numpy as np -def get_tf_logits_mask(text_vocab_size, total_vocab_size, text_seq_len, image_seq_len): - mask_inp = [text_vocab_size for _ in range(text_seq_len)] - x = tf.logical_not(tf.sequence_mask(mask_inp, total_vocab_size)) - mask_inp = [text_vocab_size for _ in range(image_seq_len)] - y = tf.sequence_mask(mask_inp, total_vocab_size) - return tf.concat([x, y], 0) +def get_tf_logits_mask(num_text_tokens, total_tokens, text_seq_len, image_seq_len): + seq_len = text_seq_len + image_seq_len + + seq_range = np.arange(seq_len).reshape((1,-1,1)) + logits_range = np.arange(total_tokens).reshape((1,1,-1)) + + logits_mask = (((seq_range >= (text_seq_len - 1)) & (logits_range < num_text_tokens)) | + ((seq_range < (text_seq_len - 1)) & (logits_range >= num_text_tokens)) | + ((seq_range != (seq_len - 1)) & (logits_range >= (total_tokens - 1))) ) + logits_mask = np.squeeze(logits_mask) + logits_mask = logits_mask * -1e9 + logits_mask = logits_mask.astype(np.float32) + return tf.constant(logits_mask) def initialize_vae_weights(checkpoint_path, scope="vae"): """ @@ -118,7 +126,7 @@ def dalle_model_fn(features, labels, mode, params): params=params, ) - tf_logits_mask = get_tf_logits_mask(params["text_vocab_size"], params["text_vocab_size"]+params["image_vocab_size"], + tf_logits_mask = get_tf_logits_mask(params["text_vocab_size"], model.total_tokens, params["text_seq_len"], params['image_seq_len']) model.set_logits_mask(tf_logits_mask) From 05dec2685e6521ed8219517273a0730b24604941 Mon Sep 17 00:00:00 2001 From: connor Date: Mon, 1 Feb 2021 11:16:13 +0000 Subject: [PATCH 24/43] fix label shifting --- src/dalle_mtf/models.py | 13 +++++-------- src/input_fns.py | 11 +++++------ src/model_fns.py | 8 ++++---- 3 files changed, 14 insertions(+), 18 deletions(-) diff --git a/src/dalle_mtf/models.py b/src/dalle_mtf/models.py index 4a7cd37..b0265f1 100644 --- a/src/dalle_mtf/models.py +++ b/src/dalle_mtf/models.py @@ -151,7 +151,7 @@ def __init__(self, mesh, n_embd, text_vocab_size=12800, image_vocab_size=512, te self.image_vocab_size = image_vocab_size self.text_seq_len = text_seq_len self.image_seq_len = image_seq_len - self.total_seq_dim = text_seq_len + image_seq_len + self.total_seq_len = text_seq_len + image_seq_len self.n_layers = n_layers self.n_heads = n_heads self.attn_mask = attn_mask @@ -164,9 +164,9 @@ def __init__(self, mesh, n_embd, text_vocab_size=12800, image_vocab_size=512, te "final_vocab_dim": mtf.Dimension("vocab_dim", self.total_tokens), "text_sequence_dim": mtf.Dimension("sequence_dim", text_seq_len), "image_sequence_dim": mtf.Dimension("sequence_dim", image_seq_len), - "total_seq_dim": mtf.Dimension("sequence_dim", self.total_seq_dim), - "embed_seq_dim": mtf.Dimension("embed_seq_dim", self.total_seq_dim), - "memory_len_dim": mtf.Dimension("memory_len_dim", self.total_seq_dim), + "total_seq_dim": mtf.Dimension("sequence_dim", self.total_seq_len), + "embed_seq_dim": mtf.Dimension("embed_seq_dim", self.total_seq_len), + "memory_len_dim": mtf.Dimension("memory_len_dim", self.total_seq_len), "heads_dim": mtf.Dimension("heads", n_heads), "kv_dim": mtf.Dimension("kv_dim", n_embd // n_heads), "batch_dim": mtf.Dimension("batch_dim", batch_size)} @@ -248,7 +248,6 @@ def set_logits_mask(self, tf_mask): def attention(self, x, n_state, mask, attention_type="global", name="attn"): if not self.is_incremental_inference: # x :: [batch, seq, n_embd] - print(x.shape) batch_dim, seq_dim, embd_dim = x_shape = x.shape else: batch_dim, embd_dim = x_shape = x.shape @@ -408,9 +407,7 @@ def to_logits(self, x): def shift_labels(self, labels): labels = pad(labels, [0, 1], dim_name="sequence_dim", pad_value=self.eos_token_id) - indices = mtf.range(labels.mesh, mtf.Dimension("range", labels.shape[1].size - 1), tf.int32, name="labels_indices") + 1 - labels = mtf.gather(labels, indices, dim=labels.shape[1]) - labels = mtf.rename_dimension(labels, "range", "sequence_dim") + labels = mtf.slice(labels, 1, self.total_seq_len, "sequence_dim") return labels diff --git a/src/input_fns.py b/src/input_fns.py index ea19896..b1c07a8 100644 --- a/src/input_fns.py +++ b/src/input_fns.py @@ -1,7 +1,7 @@ import imageio import numpy as np import tensorflow.compat.v1 as tf - +import os def crop_center_and_resize(img, size): s = tf.shape(img) @@ -58,13 +58,12 @@ def _dummy_labels(x): return dataset -def pred_output(predictions, out_name='test'): +def pred_output(predictions, out_name='test', output_dir='outputs'): + if not os.path.isdir(output_dir): + os.makedirs(output_dir) for i, p in enumerate(predictions): denormalize = lambda x: (((x + 1) / 2) * 255.0).astype(np.uint8) - # to debug: - with open(f"{out_name}_{i}.txt", 'w') as f: - f.write(str(p["outputs"].tolist())) - imageio.imwrite(f"{out_name}_{i}.jpeg", denormalize(p["predictions_decoded"])) + imageio.imwrite(f"outputs/{out_name}_{i}.jpeg", denormalize(p["predictions_decoded"])) def read_labeled_tfrecord(params): diff --git a/src/model_fns.py b/src/model_fns.py index 590fd90..20d2074 100644 --- a/src/model_fns.py +++ b/src/model_fns.py @@ -1,5 +1,6 @@ import mesh_tensorflow as mtf import tensorflow.compat.v1 as tf + from tensorflow.python.tpu import tpu_estimator import mesh_tensorflow.transformer as mtf_transformer from .optimizers import get_optimizer @@ -177,7 +178,7 @@ def dalle_model_fn(features, labels, mode, params): variable_dtype=model.variable_dtype, has_partial_sequences=True, remove_partial_sequences=True, - sampling_keep_top_k=-1, + sampling_keep_top_k=-2, ) mtf_samples = mtf.anonymize(mtf_samples) @@ -188,7 +189,6 @@ def dalle_model_fn(features, labels, mode, params): initialize_vae_weights(vae_checkpoint_path) outputs -= model.text_vocab_size - with tf.variable_scope('vae'): predictions_decoded = vae.decode(outputs) @@ -196,7 +196,7 @@ def dalle_model_fn(features, labels, mode, params): "outputs": outputs, "predictions_decoded": predictions_decoded } - + denormalize = lambda x: (((x + 1) / 2) * 255.0) def scaffold_fn(): return tf.train.Scaffold( local_init_op=tf.group( @@ -222,7 +222,7 @@ def scaffold_fn(): # Gets number of microbatches per batch for serialized training # if param tokens_per_mb_per_replica = None, this defaults to 1 and no microbatching is performed num_microbatches = int(mtf_transformer.utils.serialize_num_microbatches(batch_dim=model.dimensions["batch_dim"], - sequence_length=model.total_seq_dim, + sequence_length=model.total_seq_len, mesh_shape=mesh_shape, layout_rules=layout_rules, tokens_per_microbatch_per_replica= From 6c39bdf01d04fc3e93cc531633f1ddde0f2ddd6a Mon Sep 17 00:00:00 2001 From: connor Date: Mon, 1 Feb 2021 11:16:34 +0000 Subject: [PATCH 25/43] add weight decay Adam --- src/optimizers.py | 166 +++++++++++++++++++++++----------------------- 1 file changed, 83 insertions(+), 83 deletions(-) diff --git a/src/optimizers.py b/src/optimizers.py index 7f77c04..42a627c 100644 --- a/src/optimizers.py +++ b/src/optimizers.py @@ -79,7 +79,7 @@ def get_optimizer(mesh, loss, params, variable_dtype, inp_var_grads=None): scalar_summary("lr", learning_rate) if optimizer_name.lower() == "adam": - optimizer = mtf.optimize.AdamWeightDecayOptimizer( + optimizer = AdamWeightDecayOptimizer( learning_rate=learning_rate, weight_decay_rate=params.get("weight_decay", 0.0), beta_1=params.get("beta_1", 0.9), @@ -104,85 +104,85 @@ def get_optimizer(mesh, loss, params, variable_dtype, inp_var_grads=None): return learning_rate, update_ops, var_grads_fp -# class AdamWeightDecayOptimizer(mtf.optimize.Optimizer): -# """A basic Adam optimizer that includes "correct" L2 weight decay.""" - -# def __init__(self, -# learning_rate, -# weight_decay_rate=0.0, -# beta_1=0.9, -# beta_2=0.999, -# epsilon=1e-6, -# exclude_from_weight_decay=None, -# variable_dtype=None): -# """Constructs a AdamWeightDecayOptimizer.""" - -# self.learning_rate = learning_rate -# self.weight_decay_rate = weight_decay_rate -# self.beta_1 = beta_1 -# self.beta_2 = beta_2 -# self.epsilon = epsilon -# self.exclude_from_weight_decay = exclude_from_weight_decay -# self.variable_dtype = variable_dtype - -# def apply_grad(self, grad, var): -# """See base class.""" -# if grad is None: -# tf.logging.warning("Gradient is None for variable %s" % var.name) -# return [] - -# grad = mtf.to_float(grad) - -# assignments = [] - -# m = mtf.get_variable( -# var.mesh, var.name + "/adam_m", var.shape, -# initializer=tf.zeros_initializer(), -# # master_dtype=self.variable_dtype.master_dtype, -# # slice_dtype=self.variable_dtype.slice_dtype, -# # activation_dtype=self.variable_dtype.activation_dtype, -# trainable=False) - -# v = mtf.get_variable( -# var.mesh, var.name + "/adam_v", var.shape, -# initializer=tf.zeros_initializer(), -# # master_dtype=self.variable_dtype.master_dtype, -# # slice_dtype=self.variable_dtype.slice_dtype, -# # activation_dtype=self.variable_dtype.activation_dtype, -# trainable=False) - -# # Standard Adam update. -# next_m = self.beta_1 * m + (1.0 - self.beta_1) * grad -# next_v = self.beta_2 * v + (1.0 - self.beta_2) * mtf.square(grad) - -# update = next_m / (mtf.sqrt(next_v) + self.epsilon) - -# # Just adding the square of the weights to the loss function is *not* -# # the correct way of using L2 regularization/weight decay with Adam, -# # since that will interact with the m and v parameters in strange ways. -# # -# # Instead we want to decay the weights in a manner that doesn't interact -# # with the m/v parameters. This is equivalent to adding the square -# # of the weights to the loss with plain (non-momentum) SGD. -# if self._do_use_weight_decay(var.name): -# update += mtf.to_float(var.value) * self.weight_decay_rate - -# update_with_lr = self.learning_rate * update - -# var_update = mtf.assign_sub(var, update_with_lr) - -# assignments.extend( -# [var_update, -# mtf.assign(m, next_m), -# mtf.assign(v, next_v)]) -# return assignments - -# def _do_use_weight_decay(self, param_name): -# """Whether to use L2 weight decay for `param_name`.""" -# if not self.weight_decay_rate: -# return False -# if self.exclude_from_weight_decay: -# for r in self.exclude_from_weight_decay: -# if re.search(r, param_name) is not None: -# return False -# return True +class AdamWeightDecayOptimizer(mtf.optimize.Optimizer): + """A basic Adam optimizer that includes "correct" L2 weight decay.""" + + def __init__(self, + learning_rate, + weight_decay_rate=0.0, + beta_1=0.9, + beta_2=0.999, + epsilon=1e-6, + exclude_from_weight_decay=None, + variable_dtype=None): + """Constructs a AdamWeightDecayOptimizer.""" + + self.learning_rate = learning_rate + self.weight_decay_rate = weight_decay_rate + self.beta_1 = beta_1 + self.beta_2 = beta_2 + self.epsilon = epsilon + self.exclude_from_weight_decay = exclude_from_weight_decay + self.variable_dtype = variable_dtype + + def apply_grad(self, grad, var): + """See base class.""" + if grad is None: + tf.logging.warning("Gradient is None for variable %s" % var.name) + return [] + + grad = mtf.to_float(grad) + + assignments = [] + + m = mtf.get_variable( + var.mesh, var.name + "/adam_m", var.shape, + initializer=tf.zeros_initializer(), + # master_dtype=self.variable_dtype.master_dtype, + # slice_dtype=self.variable_dtype.slice_dtype, + # activation_dtype=self.variable_dtype.activation_dtype, + trainable=False) + + v = mtf.get_variable( + var.mesh, var.name + "/adam_v", var.shape, + initializer=tf.zeros_initializer(), + # master_dtype=self.variable_dtype.master_dtype, + # slice_dtype=self.variable_dtype.slice_dtype, + # activation_dtype=self.variable_dtype.activation_dtype, + trainable=False) + + # Standard Adam update. + next_m = self.beta_1 * m + (1.0 - self.beta_1) * grad + next_v = self.beta_2 * v + (1.0 - self.beta_2) * mtf.square(grad) + + update = next_m / (mtf.sqrt(next_v) + self.epsilon) + + # Just adding the square of the weights to the loss function is *not* + # the correct way of using L2 regularization/weight decay with Adam, + # since that will interact with the m and v parameters in strange ways. + # + # Instead we want to decay the weights in a manner that doesn't interact + # with the m/v parameters. This is equivalent to adding the square + # of the weights to the loss with plain (non-momentum) SGD. + if self._do_use_weight_decay(var.name): + update += mtf.to_float(var.value) * self.weight_decay_rate + + update_with_lr = self.learning_rate * update + + var_update = mtf.assign_sub(var, update_with_lr) + + assignments.extend( + [var_update, + mtf.assign(m, next_m), + mtf.assign(v, next_v)]) + return assignments + + def _do_use_weight_decay(self, param_name): + """Whether to use L2 weight decay for `param_name`.""" + if not self.weight_decay_rate: + return False + if self.exclude_from_weight_decay: + for r in self.exclude_from_weight_decay: + if re.search(r, param_name) is not None: + return False + return True From 69d7ef9780731f56af5e0bb2a38b5ed71a2528c3 Mon Sep 17 00:00:00 2001 From: connor Date: Mon, 1 Feb 2021 11:16:47 +0000 Subject: [PATCH 26/43] add eval steps --- train_dalle.py | 8 +++++--- 1 file changed, 5 insertions(+), 3 deletions(-) diff --git a/train_dalle.py b/train_dalle.py index 8ad9a8e..1a3ae10 100644 --- a/train_dalle.py +++ b/train_dalle.py @@ -19,7 +19,7 @@ def parse_args(): parser.add_argument("--new", action="store_true", help="If set, deletes previous checkpoint, if it exists, and " "starts a new training run") parser.add_argument('--predict', action='store_true', help='run model in predict mode') - parser.add_argument('--prompt', type=str, default='a cat in a hat') + parser.add_argument('--prompt', type=str, default='face') args = parser.parse_args() assert args.model is not None, "Model must be set" return args @@ -93,14 +93,15 @@ def main(): if has_predict_or_eval_steps: # Eval and train - stop and predict and/or eval every checkpoint while current_step < params["train_steps"]: - next_checkpoint = min(current_step + args.steps_per_checkpoint, params["train_steps"]) + next_checkpoint = min(current_step + params["steps_per_checkpoint"], params["train_steps"]) estimator.train(input_fn=partial(dalle_input_fn, eval=False), max_steps=next_checkpoint) current_step = next_checkpoint if params["predict_steps"] > 0: raise NotImplementedError if params["eval_steps"] > 0: - raise NotImplementedError + estimator.evaluate(input_fn=partial(dalle_input_fn, eval=True), + steps=params["eval_steps"]) return else: # Else, just train @@ -110,6 +111,7 @@ def main(): max_steps=params["train_steps"]) + if __name__ == "__main__": tf.disable_v2_behavior() main() From 2ec46beba0c43cd262fee2aff787287c1016d6e5 Mon Sep 17 00:00:00 2001 From: sid Date: Sun, 4 Apr 2021 22:36:20 +0200 Subject: [PATCH 27/43] add slow sampling --- src/dalle_mtf/sample.py | 63 +++++++++++++++++++++++------------------ 1 file changed, 35 insertions(+), 28 deletions(-) diff --git a/src/dalle_mtf/sample.py b/src/dalle_mtf/sample.py index 331c90f..1022f38 100644 --- a/src/dalle_mtf/sample.py +++ b/src/dalle_mtf/sample.py @@ -13,6 +13,7 @@ def sample_autoregressive(inputs, has_partial_sequences=True, remove_partial_sequences=False, sampling_keep_top_k=-1, + cached=True ): """Sample randomly one token at a time. @@ -59,37 +60,38 @@ def sample_autoregressive(inputs, initial_position = mtf.reduce_sum( mtf.to_int32(mtf.not_equal(image_inputs, padding_id)), reduced_dim=image_seq_dim) - # initial_position += model.dimensions['text_seq_dim'].size length_range = mtf.range(image_inputs.mesh, image_seq_dim, tf.int32) # Builds context to pass around internally # The 'first part' context records initial states of k / v / x - - context_first_part = mtf_transformer.transformer.Context( - model=None, - mesh=image_inputs.mesh, - batch_dims=batch_dims, - length_dim=image_seq_dim, - variable_dtype=variable_dtype, - mode="first_part", - position=length_range, - position_is_default=True, - new_states=[], - initial_position=initial_position, - sequence_id=None, - constant_states=[], - inputs=inputs) - model.context = context_first_part - - with tf.variable_scope('dall-e'): - logits = model.forward(inputs, return_loss=False, return_logits=True) - del logits - - if not has_partial_sequences: - initial_states = [mtf.zeros_like(t) for t in context_first_part.new_states] + if cached: + context_first_part = mtf_transformer.transformer.Context( + model=None, + mesh=image_inputs.mesh, + batch_dims=batch_dims, + length_dim=image_seq_dim, + variable_dtype=variable_dtype, + mode="first_part", + position=length_range, + position_is_default=True, + new_states=[], + initial_position=initial_position, + sequence_id=None, + constant_states=[], + inputs=inputs) + model.context = context_first_part + + with tf.variable_scope('dall-e'): + logits = model.forward(inputs, return_loss=False, return_logits=True) + del logits + + if not has_partial_sequences: + initial_states = [mtf.zeros_like(t) for t in context_first_part.new_states] + else: + initial_states = context_first_part.new_states else: - initial_states = context_first_part.new_states + initial_states = [] if not has_partial_sequences: partial_sequences_eos_count = 0 @@ -133,9 +135,9 @@ def body_fn(position, ids, *states): new_states=[], initial_position=position, sequence_id=None, - inputs=ids) + inputs=ids) if cached else None - model.is_incremental_inference = True + model.is_incremental_inference = True if cached else False model.context = context with tf.variable_scope("dall-e", reuse=tf.AUTO_REUSE): logits = model.forward({'image_inputs': image_inputs}, return_loss=False, return_logits=True) @@ -156,8 +158,13 @@ def body_fn(position, ids, *states): # temperature sampling ids_this_step = mtf.sample_with_temperature( logits, model.dimensions['final_vocab_dim'], temperature) + # reshape & assign results - ids_this_step = mtf.reshape(ids_this_step, ([batch_dims])) + if cached: + ids_this_step = mtf.reshape(ids_this_step, ([batch_dims])) + else: + ids_this_step = mtf.shift(ids_this_step, offset=1, dim=length_dim, wrap=False) + one_hot = mtf.one_hot(position, image_seq_dim, dtype=tf.int32) one_new_id = ids_this_step * one_hot new_ids = (1 - one_hot) * ids + one_new_id From 321b718dd0b08814abdd944b521907e954c935fb Mon Sep 17 00:00:00 2001 From: sid Date: Sun, 4 Apr 2021 22:37:16 +0200 Subject: [PATCH 28/43] add slow sampling --- src/dalle_mtf/sample.py | 4 +++- 1 file changed, 3 insertions(+), 1 deletion(-) diff --git a/src/dalle_mtf/sample.py b/src/dalle_mtf/sample.py index 1022f38..c39a7eb 100644 --- a/src/dalle_mtf/sample.py +++ b/src/dalle_mtf/sample.py @@ -59,7 +59,7 @@ def sample_autoregressive(inputs, # Gets position (in image inputs) where zero padding starts initial_position = mtf.reduce_sum( mtf.to_int32(mtf.not_equal(image_inputs, padding_id)), - reduced_dim=image_seq_dim) + reduced_dim=image_seq_dim) length_range = mtf.range(image_inputs.mesh, image_seq_dim, tf.int32) @@ -163,7 +163,9 @@ def body_fn(position, ids, *states): if cached: ids_this_step = mtf.reshape(ids_this_step, ([batch_dims])) else: + print('*' * 100, '\nIDS THIS STEP SLOW') ids_this_step = mtf.shift(ids_this_step, offset=1, dim=length_dim, wrap=False) + print('*' * 100) one_hot = mtf.one_hot(position, image_seq_dim, dtype=tf.int32) one_new_id = ids_this_step * one_hot From d564853882060c9ff2cd4c8500720697104d8f4d Mon Sep 17 00:00:00 2001 From: Phil Wang Date: Sun, 4 Apr 2021 14:13:30 -0700 Subject: [PATCH 29/43] add tests --- .github/workflows/tests.yml | 33 +++++++++++++ src/dalle_mtf/models.py | 6 --- src/dalle_mtf/sample.py | 12 +++-- test.py | 95 +++++++++++++++++++++++++++++++++++++ 4 files changed, 137 insertions(+), 9 deletions(-) create mode 100644 .github/workflows/tests.yml create mode 100644 test.py diff --git a/.github/workflows/tests.yml b/.github/workflows/tests.yml new file mode 100644 index 0000000..d83b8aa --- /dev/null +++ b/.github/workflows/tests.yml @@ -0,0 +1,33 @@ +# This workflow will install Python dependencies, run tests and lint with a variety of Python versions +# For more information see: https://help.github.com/actions/language-and-framework-guides/using-python-with-github-actions + +name: Tests + +on: + push: + branches: [ main ] + pull_request: + branches: [ main ] + +jobs: + build: + + runs-on: ubuntu-latest + strategy: + matrix: + python-version: [3.7, 3.8] + + steps: + - uses: actions/checkout@v2 + - name: Set up Python ${{ matrix.python-version }} + uses: actions/setup-python@v2 + with: + python-version: ${{ matrix.python-version }} + - name: Install dependencies + run: | + python -m pip install --upgrade pip + python -m pip install pytest + if [ -f requirements.txt ]; then pip install -r requirements.txt; fi + - name: Test with pytest + run: | + pytest -s test.py diff --git a/src/dalle_mtf/models.py b/src/dalle_mtf/models.py index b0265f1..c879cf5 100644 --- a/src/dalle_mtf/models.py +++ b/src/dalle_mtf/models.py @@ -430,12 +430,6 @@ def forward(self, features, return_loss=True, return_logits=False): mask = self.get_attn_mask(tokens.mesh, self.dimensions["total_seq_dim"], self.dimensions["memory_len_dim"]) out = self.transformer(tokens, mask=mask) logits = self.to_logits(out) - if self.is_incremental_inference: - logits_mask = mtf.gather(self.logits_mask, self.context.position + self.text_seq_len - 1, self.logits_mask.shape[1]) - logits_mask = expand_tile(logits_mask, mtf.Dimension("sequence_dim", 1), axis=1) - else: - logits_mask = self.logits_mask - logits += mtf.cast(logits_mask, logits.dtype) if not return_loss: logits = mtf.cast(logits, self.variable_dtype.master_dtype) diff --git a/src/dalle_mtf/sample.py b/src/dalle_mtf/sample.py index 331c90f..9ddf6f6 100644 --- a/src/dalle_mtf/sample.py +++ b/src/dalle_mtf/sample.py @@ -5,14 +5,15 @@ def sample_autoregressive(inputs, model, - params, stop_at_token=50256, max_steps=None, temperature=0.9, + padding_id = 0, variable_dtype=mtf.VariableDType(tf.float32), has_partial_sequences=True, remove_partial_sequences=False, sampling_keep_top_k=-1, + min_start_pos=None ): """Sample randomly one token at a time. @@ -28,7 +29,6 @@ def sample_autoregressive(inputs, Args: inputs: an input dictionary containing 'text_inputs' and 'image_inputs', model: DALL-E model - params: model paramers. stop_at_token: an optional integer eos id. Stop when we produce it. max_steps: an optional integer, the max number of steps to decode. temperature: an optional floating point value between 0.0 and 1.0 0.0 @@ -51,7 +51,7 @@ def sample_autoregressive(inputs, batch_dims = model.dimensions["batch_dim"] length_dim = model.dimensions["total_seq_dim"] image_seq_dim = model.dimensions['image_sequence_dim'] - padding_id = params.get("padding_id", 0) + image_inputs = inputs['image_inputs'] text_inputs = inputs['text_inputs'] @@ -59,6 +59,12 @@ def sample_autoregressive(inputs, initial_position = mtf.reduce_sum( mtf.to_int32(mtf.not_equal(image_inputs, padding_id)), reduced_dim=image_seq_dim) + + if min_start_pos is not None: + # force the sampling to never start below a minimum starting position, say the text length. + # this will also be useful for image completion, where you can start sampling from half the image tokens + initial_position = mtf.maximum(initial_position, min_start_pos) + # initial_position += model.dimensions['text_seq_dim'].size length_range = mtf.range(image_inputs.mesh, image_seq_dim, tf.int32) diff --git a/test.py b/test.py new file mode 100644 index 0000000..943d770 --- /dev/null +++ b/test.py @@ -0,0 +1,95 @@ +import pytest +import traceback +import logging +from collections import defaultdict +from contextlib import contextmanager + +import tensorflow as tf +tf.compat.v1.enable_eager_execution() +import mesh_tensorflow as mtf +from mesh_tensorflow import placement_mesh_impl + +from src.dalle_mtf.models import DALLE +from src.dalle_mtf.sample import sample_autoregressive + +# helper functions + +@contextmanager +def not_raises(exception): + try: + yield + except exception: + logging.error(traceback.format_exc()) + raise pytest.fail("DID RAISE {0}".format(exception)) + +# tests + +def test_model(): + graph = mtf.Graph() + mesh = mtf.Mesh(graph, "my_mesh") + + model = DALLE( + mesh = mesh, + batch_size = 1, + n_embd = 16, + n_heads = 2, + bf_16 = False + ) + + batch_dim = model.dimensions["batch_dim"] + sequence_dim = model.dimensions["total_seq_dim"] + + text_inputs = mtf.ones(mesh, mtf.Shape((batch_dim, sequence_dim)), tf.int32) + image_inputs = mtf.ones(mesh, mtf.Shape((batch_dim, sequence_dim)), tf.int32) + + features = { + 'text_inputs': mtf.slice(text_inputs, 0, model.text_seq_len, sequence_dim.name), + 'image_inputs': mtf.slice(image_inputs, 0, model.image_seq_len, sequence_dim.name) + } + + with not_raises(Exception): + loss, loss_batch, logits = model.forward(features, return_loss = True, return_logits = True) + + mesh_impl = placement_mesh_impl.PlacementMeshImpl(shape=[], layout={}, devices=[""]) + lowering = mtf.Lowering(graph, {mesh: mesh_impl}) + logits = lowering.export_to_tf_tensor(logits) + +def test_sampling(): + graph = mtf.Graph() + mesh = mtf.Mesh(graph, "my_mesh") + + model = DALLE( + mesh = mesh, + batch_size = 1, + text_seq_len = 1, + image_seq_len = 4, + n_embd = 16, + n_heads = 2, + bf_16 = False + ) + + batch_dim = model.dimensions["batch_dim"] + sequence_dim = model.dimensions["total_seq_dim"] + + text_inputs = mtf.ones(mesh, mtf.Shape((batch_dim, sequence_dim)), tf.int32) + image_inputs = mtf.ones(mesh, mtf.Shape((batch_dim, sequence_dim)), tf.int32) + + inputs = { + 'text_inputs': mtf.slice(text_inputs, 0, model.text_seq_len, sequence_dim.name), + 'image_inputs': mtf.slice(image_inputs, 0, model.image_seq_len, sequence_dim.name) + } + + with not_raises(Exception): + samples = sample_autoregressive( + inputs, + model, + variable_dtype=mtf.VariableDType(), + max_steps = sequence_dim.size, + remove_partial_sequences=False, + stop_at_token=None, + min_start_pos=model.text_seq_len + ) + + mesh_impl = placement_mesh_impl.PlacementMeshImpl(shape=[], layout={}, devices=[""]) + lowering = mtf.Lowering(graph, {mesh: mesh_impl}) + samples = lowering.export_to_tf_tensor(samples) \ No newline at end of file From a0d01b99171f01221bd4a1ffadd75c1d4b92140d Mon Sep 17 00:00:00 2001 From: Phil Wang Date: Sun, 4 Apr 2021 14:41:22 -0700 Subject: [PATCH 30/43] switch to using instead of --- src/dalle_mtf/models.py | 104 +++++++++++++++++++++++++++++++++++----- src/model_fns.py | 4 +- 2 files changed, 95 insertions(+), 13 deletions(-) diff --git a/src/dalle_mtf/models.py b/src/dalle_mtf/models.py index c879cf5..8576945 100644 --- a/src/dalle_mtf/models.py +++ b/src/dalle_mtf/models.py @@ -142,7 +142,7 @@ class DALLE: def __init__(self, mesh, n_embd, text_vocab_size=12800, image_vocab_size=512, text_seq_len=256, image_seq_len=1024, n_layers=6, n_heads=8, batch_size=32, bf_16=True, attn_mask=None, mode="train", - is_incremental_inference=False, context=None, loss_fn=None, params=None, eos_token_id=None, + is_incremental_inference=False, context=None, loss_fn=None, params=None, padding_id=None, activation_fn=None): self.mesh = mesh @@ -157,7 +157,7 @@ def __init__(self, mesh, n_embd, text_vocab_size=12800, image_vocab_size=512, te self.attn_mask = attn_mask self.logits_mask = None self.total_tokens = text_vocab_size + image_vocab_size + 1 # extra for EOS - self.eos_token_id = self.total_tokens - 1 if eos_token_id is None else eos_token_id + self.padding_id = 0 if padding_id is None else padding_id self.dimensions = {"embed_dim": mtf.Dimension("embed_dim", n_embd), "text_vocab_dim": mtf.Dimension("vocab_dim", text_vocab_size), "image_vocab_dim": mtf.Dimension("vocab_dim", image_vocab_size), @@ -377,6 +377,58 @@ def linear(self, x, new_dim, w_init_stdev=0.02, params=None, scale=False, name=" kernel_initializer=tf.random_normal_initializer(stddev=w_init_stdev), variable_dtype=self.variable_dtype) + def axial_positional_embedding(self, mesh, name): + with tf.variable_scope(name): + axial_dim_side = int(sqrt(self.image_seq_len)) + + embd_dim = self.dimensions["embed_dim"] + axial_dim = mtf.Dimension("axial_dim", self.image_seq_len) + + dim_axials = [mtf.Dimension(f"axial_dim_{i}", t) for i, t in enumerate((axial_dim_side, axial_dim_side))] + + axial_wpe_1 = mtf.get_variable(mesh, "axial_wpe_1", mtf.Shape([dim_axials[0], embd_dim]), + initializer=tf.random_normal_initializer(stddev=0.01), + master_dtype=self.variable_dtype.master_dtype, + slice_dtype=self.variable_dtype.slice_dtype, + activation_dtype=self.variable_dtype.activation_dtype) + + axial_wpe_2 = mtf.get_variable(mesh, "axial_wpe_2", mtf.Shape([dim_axials[1], embd_dim]), + initializer=tf.random_normal_initializer(stddev=0.01), + master_dtype=self.variable_dtype.master_dtype, + slice_dtype=self.variable_dtype.slice_dtype, + activation_dtype=self.variable_dtype.activation_dtype) + + axial_wpe_1, axial_wpe_2 = map(lambda t: mtf.broadcast(t, [dim_axials[0], dim_axials[1], embd_dim]), + (axial_wpe_1, axial_wpe_2)) + wpe = (axial_wpe_1 + axial_wpe_2) / 2 + + wpe = mtf.reshape(wpe, [axial_dim, embd_dim]) + wpe = pad(wpe, [self.text_seq_len, 0], axial_dim.name) + wpe = mtf.replace_dimensions(wpe, wpe.shape[0], self.dimensions["embed_seq_dim"]) + return wpe + + + def absolute_positional_embedding(self, mesh, name): + with tf.variable_scope(name): + # Positional embedding + wpe = mtf.get_variable(mesh, "wpe", + mtf.Shape([self.dimensions["embed_seq_dim"], self.dimensions["embed_dim"]]), + initializer=tf.random_normal_initializer(stddev=0.01), + master_dtype=self.variable_dtype.master_dtype, + slice_dtype=self.variable_dtype.slice_dtype, + activation_dtype=self.variable_dtype.activation_dtype) + return wpe + + def apply_positional_embedding(self, x, wpe): + position_indices = mtf.range(x.mesh, self.dimensions["total_seq_dim"], tf.int64) if not \ + self.is_incremental_inference else (self.context.position - 1) + pos_emb = mtf.gather(wpe, position_indices, wpe.shape[0]) + embed_dropout = self.params.get("embed_dropout", 0) + if embed_dropout > 0 and self.mode == "train": + pos_emb = mtf.dropout(pos_emb, rate=embed_dropout, name="wte_dropout") + x += pos_emb + return x + def layer_norm(self, x, name="layer_norm", axis=None, epsilon=1e-5): """Normalize to mean = 0, std = 1, then do a diagonal affine transform.""" if axis is None: @@ -404,31 +456,61 @@ def to_logits(self, x): logits = expand_tile(logits, mtf.Dimension("sequence_dim", 1), axis=1) return mtf.cast(logits, tf.float32) - - def shift_labels(self, labels): - labels = pad(labels, [0, 1], dim_name="sequence_dim", pad_value=self.eos_token_id) - labels = mtf.slice(labels, 1, self.total_seq_len, "sequence_dim") - return labels + def to_image_logits(self, x): + with tf.variable_scope("to_logits"): + if not self.is_incremental_inference: + x = mtf.slice(x, begin = self.text_seq_len, size = self.image_seq_len, slice_dim_name = x.shape[1].name) + + image_logits = self.linear(x, self.dimensions["image_vocab_dim"], name="linear_image_out") + + # Go to full precision for the logits + image_logits = mtf.cast(image_logits, tf.float32) + return image_logits + def to_text_logits(self, x): + with tf.variable_scope("to_logits"): + text_tokens = mtf.slice(x, begin = 0, size = self.text_seq_len, slice_dim_name = x.shape[1].name) + text_logits = self.linear(text_tokens, self.dimensions["text_vocab_dim"], name="linear_text_out") + + # Go to full precision for the logits + text_logits = mtf.cast(text_logits, tf.float32) + return text_logits def forward(self, features, return_loss=True, return_logits=False): if features.get('text_inputs') is not None: text = features["text_inputs"] - text_emb = self.positional_embedding(self.embedding(text, "text_embd"), "text_pos_emb") + text_with_bos = pad(text, [1, 0], dim_name = text.shape[1].name, pad_value = self.padding_id) + text_emb = self.embedding(text_with_bos, "text_embd") else: assert self.is_incremental_inference + image = features.get("image_inputs", None) + if not self.is_incremental_inference: - image_emb = self.positional_embedding(self.embedding(image, "image_embd"), "image_pos_emb") + image_input = mtf.slice(image, 0, self.image_seq_len - 1, image.shape[1].name) + image_emb = self.embedding(image_input, "image_embd") tokens = mtf.concat([text_emb, image_emb], concat_dim_name="sequence_dim") # [batch, seq, n_embd] else: # reshape inputs if in inference mode image = mtf.gather(image, self.context.position - 1, self.dimensions["image_sequence_dim"]) image = mtf.reshape(image, [self.dimensions["batch_dim"]]) - tokens = self.positional_embedding(self.embedding(image, "image_embd"), "image_pos_emb") + tokens = self.embedding(image, "image_embd") + + # positional embedding + + abs_pos_emb = self.absolute_positional_embedding(tokens.mesh, "positional_embedding") + axial_pos_emb = self.axial_positional_embedding(tokens.mesh, "axial_positional_embedding") + + tokens = self.apply_positional_embedding(tokens, abs_pos_emb) + tokens = self.apply_positional_embedding(tokens, axial_pos_emb) + + # attention mask = self.get_attn_mask(tokens.mesh, self.dimensions["total_seq_dim"], self.dimensions["memory_len_dim"]) out = self.transformer(tokens, mask=mask) + + # to logits + logits = self.to_logits(out) if not return_loss: @@ -438,7 +520,7 @@ def forward(self, features, return_loss=True, return_logits=False): assert exists(image), 'when training, image must be supplied' offset_image = image + self.text_vocab_size labels = mtf.concat([text, offset_image], concat_dim_name="sequence_dim") - labels = self.shift_labels(labels) + loss, loss_batch = self._loss(logits, labels) if return_logits and return_loss: # Cast back to checkpoint dtype diff --git a/src/model_fns.py b/src/model_fns.py index 20d2074..da6e233 100644 --- a/src/model_fns.py +++ b/src/model_fns.py @@ -172,8 +172,8 @@ def dalle_model_fn(features, labels, mode, params): mtf_samples = sample_autoregressive(mtf_features, model, params, - stop_at_token=model.eos_token_id, - max_steps=None, + stop_at_token=None, + max_steps=model.total_seq_len, temperature=0.9, variable_dtype=model.variable_dtype, has_partial_sequences=True, From 58c4d4e7a65f0e72920e9312f46bb254736013f5 Mon Sep 17 00:00:00 2001 From: Phil Wang Date: Sun, 4 Apr 2021 14:48:59 -0700 Subject: [PATCH 31/43] separate logits for text and images --- src/dalle_mtf/models.py | 37 +++++++++++++++++++++++++++++++------ src/dalle_mtf/sample.py | 4 ++-- 2 files changed, 33 insertions(+), 8 deletions(-) diff --git a/src/dalle_mtf/models.py b/src/dalle_mtf/models.py index 8576945..0c74131 100644 --- a/src/dalle_mtf/models.py +++ b/src/dalle_mtf/models.py @@ -143,7 +143,7 @@ class DALLE: def __init__(self, mesh, n_embd, text_vocab_size=12800, image_vocab_size=512, text_seq_len=256, image_seq_len=1024, n_layers=6, n_heads=8, batch_size=32, bf_16=True, attn_mask=None, mode="train", is_incremental_inference=False, context=None, loss_fn=None, params=None, padding_id=None, - activation_fn=None): + activation_fn=None, text_loss_weight=0.15): self.mesh = mesh self.n_embd = n_embd @@ -156,6 +156,7 @@ def __init__(self, mesh, n_embd, text_vocab_size=12800, image_vocab_size=512, te self.n_heads = n_heads self.attn_mask = attn_mask self.logits_mask = None + self.text_loss_weight = text_loss_weight self.total_tokens = text_vocab_size + image_vocab_size + 1 # extra for EOS self.padding_id = 0 if padding_id is None else padding_id self.dimensions = {"embed_dim": mtf.Dimension("embed_dim", n_embd), @@ -475,7 +476,25 @@ def to_text_logits(self, x): # Go to full precision for the logits text_logits = mtf.cast(text_logits, tf.float32) return text_logits - + + def _loss(self, text_logits, image_logits, text_labels, image_labels): + with tf.variable_scope("loss_final"): + text_loss_batch = self.loss_fn(logits=text_logits, targets=text_labels, + vocab_dim=text_logits.shape[-1], z_loss=0.0) + + image_loss_batch = self.loss_fn(logits=image_logits, targets=image_labels, + vocab_dim=image_logits.shape[-1], z_loss=0.0) + + loss_batch = text_loss_batch * self.text_loss_weight + image_loss_batch + + with tf.variable_scope("reduce_mean_final"): + loss = mtf.reduce_mean(loss_batch) + + loss /= self.params.get("num_microbatches", 1) + # Convert to train dtype + loss = mtf.cast(loss, self.variable_dtype.slice_dtype) + return loss, loss_batch # loss batch must be returned for metric fns + def forward(self, features, return_loss=True, return_logits=False): if features.get('text_inputs') is not None: text = features["text_inputs"] @@ -511,19 +530,25 @@ def forward(self, features, return_loss=True, return_logits=False): # to logits - logits = self.to_logits(out) + image_logits = self.to_image_logits(out) if not return_loss: - logits = mtf.cast(logits, self.variable_dtype.master_dtype) + logits = mtf.cast(image_logits, self.variable_dtype.master_dtype) return logits assert exists(image), 'when training, image must be supplied' offset_image = image + self.text_vocab_size labels = mtf.concat([text, offset_image], concat_dim_name="sequence_dim") - loss, loss_batch = self._loss(logits, labels) + text_logits = self.to_text_logits(out) + + text_labels = mtf.slice(labels, begin = 0, size = self.text_seq_len, slice_dim_name = labels.shape[1].name) + image_labels = mtf.slice(labels, begin = self.text_seq_len, size = self.image_seq_len, slice_dim_name = labels.shape[1].name) + + loss, loss_batch = self._loss(text_logits, image_logits, text_labels, image_labels) + if return_logits and return_loss: # Cast back to checkpoint dtype - logits = mtf.cast(logits, self.variable_dtype.master_dtype) + logits = mtf.cast(image_logits, self.variable_dtype.master_dtype) return loss, loss_batch, logits return loss, loss_batch diff --git a/src/dalle_mtf/sample.py b/src/dalle_mtf/sample.py index 9ddf6f6..21d272e 100644 --- a/src/dalle_mtf/sample.py +++ b/src/dalle_mtf/sample.py @@ -155,13 +155,13 @@ def body_fn(position, ids, *states): raise ValueError("sampling_keep_top_k must either be -1 or positive.") k_largest = mtf.nth_largest_element( logits, n=sampling_keep_top_k, - reduced_dim=model.dimensions['final_vocab_dim']) + reduced_dim=model.dimensions['image_vocab_dim']) logits = mtf.where(mtf.less_equal(logits, k_largest), mtf.ones_like(logits) * -1e6, logits) # temperature sampling ids_this_step = mtf.sample_with_temperature( - logits, model.dimensions['final_vocab_dim'], temperature) + logits, model.dimensions['image_vocab_dim'], temperature) # reshape & assign results ids_this_step = mtf.reshape(ids_this_step, ([batch_dims])) one_hot = mtf.one_hot(position, image_seq_dim, dtype=tf.int32) From 9410ba9350e6cf214d70f2dee820978f0411f071 Mon Sep 17 00:00:00 2001 From: Phil Wang Date: Sun, 4 Apr 2021 15:02:41 -0700 Subject: [PATCH 32/43] remove offsetted image, since it is no longer needed --- src/dalle_mtf/models.py | 3 +-- 1 file changed, 1 insertion(+), 2 deletions(-) diff --git a/src/dalle_mtf/models.py b/src/dalle_mtf/models.py index 0c74131..409fbb0 100644 --- a/src/dalle_mtf/models.py +++ b/src/dalle_mtf/models.py @@ -537,8 +537,7 @@ def forward(self, features, return_loss=True, return_logits=False): return logits assert exists(image), 'when training, image must be supplied' - offset_image = image + self.text_vocab_size - labels = mtf.concat([text, offset_image], concat_dim_name="sequence_dim") + labels = mtf.concat([text, image], concat_dim_name="sequence_dim") text_logits = self.to_text_logits(out) From f13a567d830a4fde271200c2aed0892134469744 Mon Sep 17 00:00:00 2001 From: Phil Wang Date: Sun, 4 Apr 2021 15:12:05 -0700 Subject: [PATCH 33/43] add unique pad tokens feature, hidden behind a feature flag --- src/dalle_mtf/models.py | 15 +++++++++++---- test.py | 3 ++- 2 files changed, 13 insertions(+), 5 deletions(-) diff --git a/src/dalle_mtf/models.py b/src/dalle_mtf/models.py index 409fbb0..40fd197 100644 --- a/src/dalle_mtf/models.py +++ b/src/dalle_mtf/models.py @@ -143,11 +143,12 @@ class DALLE: def __init__(self, mesh, n_embd, text_vocab_size=12800, image_vocab_size=512, text_seq_len=256, image_seq_len=1024, n_layers=6, n_heads=8, batch_size=32, bf_16=True, attn_mask=None, mode="train", is_incremental_inference=False, context=None, loss_fn=None, params=None, padding_id=None, - activation_fn=None, text_loss_weight=0.15): + activation_fn=None, text_loss_weight=0.15, unique_pad_tokens = False): self.mesh = mesh self.n_embd = n_embd - self.text_vocab_size = text_vocab_size + self.unique_pad_tokens = unique_pad_tokens + self.text_vocab_size = text_vocab_size + (0 if not unique_pad_tokens else text_seq_len) self.image_vocab_size = image_vocab_size self.text_seq_len = text_seq_len self.image_seq_len = image_seq_len @@ -156,13 +157,12 @@ def __init__(self, mesh, n_embd, text_vocab_size=12800, image_vocab_size=512, te self.n_heads = n_heads self.attn_mask = attn_mask self.logits_mask = None + self.text_loss_weight = text_loss_weight - self.total_tokens = text_vocab_size + image_vocab_size + 1 # extra for EOS self.padding_id = 0 if padding_id is None else padding_id self.dimensions = {"embed_dim": mtf.Dimension("embed_dim", n_embd), "text_vocab_dim": mtf.Dimension("vocab_dim", text_vocab_size), "image_vocab_dim": mtf.Dimension("vocab_dim", image_vocab_size), - "final_vocab_dim": mtf.Dimension("vocab_dim", self.total_tokens), "text_sequence_dim": mtf.Dimension("sequence_dim", text_seq_len), "image_sequence_dim": mtf.Dimension("sequence_dim", image_seq_len), "total_seq_dim": mtf.Dimension("sequence_dim", self.total_seq_len), @@ -498,6 +498,13 @@ def _loss(self, text_logits, image_logits, text_labels, image_labels): def forward(self, features, return_loss=True, return_logits=False): if features.get('text_inputs') is not None: text = features["text_inputs"] + + if self.unique_pad_tokens: + input_range = mtf.range(text.mesh, text.shape[1], tf.int32) + pad_mask = mtf.equal(text, 0) + pad_token_ids = input_range + self.text_seq_len # shift to the range of pad token ids, which come after text token ids, and before image token ids + text = mtf.where(pad_mask, pad_token_ids, text) + text_with_bos = pad(text, [1, 0], dim_name = text.shape[1].name, pad_value = self.padding_id) text_emb = self.embedding(text_with_bos, "text_embd") else: diff --git a/test.py b/test.py index 943d770..856cf8e 100644 --- a/test.py +++ b/test.py @@ -65,7 +65,8 @@ def test_sampling(): image_seq_len = 4, n_embd = 16, n_heads = 2, - bf_16 = False + bf_16 = False, + unique_pad_tokens = True ) batch_dim = model.dimensions["batch_dim"] From 0f774e17a3f19be46a512b03c7a94196229e680b Mon Sep 17 00:00:00 2001 From: Phil Wang Date: Sun, 4 Apr 2021 15:21:06 -0700 Subject: [PATCH 34/43] remove logits mask code --- src/dalle_mtf/models.py | 7 ------- src/model_fns.py | 18 ------------------ 2 files changed, 25 deletions(-) diff --git a/src/dalle_mtf/models.py b/src/dalle_mtf/models.py index 40fd197..08d78bd 100644 --- a/src/dalle_mtf/models.py +++ b/src/dalle_mtf/models.py @@ -238,13 +238,6 @@ def get_attn_mask(self, mesh, nd, ns): i, j = map(lambda t: mtf.broadcast(t, [nd, ns]), (i, j)) self.attn_mask = mtf.cast(mtf.less(i, j), self.variable_dtype.activation_dtype) * -1e10 return self.attn_mask - - def set_logits_mask(self, tf_mask): - mask_shape = mtf.Shape([self.dimensions['total_seq_dim'], self.dimensions['final_vocab_dim']]) - mtf_mask = mtf.import_fully_replicated(self.mesh, tf_mask, mask_shape) - new_shape = mtf.Shape([self.dimensions['batch_dim'], self.dimensions['total_seq_dim'], self.dimensions['final_vocab_dim']]) - mtf_mask = mtf.broadcast(mtf_mask, new_shape) - self.logits_mask = mtf_mask def attention(self, x, n_state, mask, attention_type="global", name="attn"): if not self.is_incremental_inference: diff --git a/src/model_fns.py b/src/model_fns.py index da6e233..a1aa42b 100644 --- a/src/model_fns.py +++ b/src/model_fns.py @@ -11,20 +11,6 @@ from tensorflow.python.ops import resources import numpy as np -def get_tf_logits_mask(num_text_tokens, total_tokens, text_seq_len, image_seq_len): - seq_len = text_seq_len + image_seq_len - - seq_range = np.arange(seq_len).reshape((1,-1,1)) - logits_range = np.arange(total_tokens).reshape((1,1,-1)) - - logits_mask = (((seq_range >= (text_seq_len - 1)) & (logits_range < num_text_tokens)) | - ((seq_range < (text_seq_len - 1)) & (logits_range >= num_text_tokens)) | - ((seq_range != (seq_len - 1)) & (logits_range >= (total_tokens - 1))) ) - logits_mask = np.squeeze(logits_mask) - logits_mask = logits_mask * -1e9 - logits_mask = logits_mask.astype(np.float32) - return tf.constant(logits_mask) - def initialize_vae_weights(checkpoint_path, scope="vae"): """ Initialize the vae model from the checkpoint. @@ -127,10 +113,6 @@ def dalle_model_fn(features, labels, mode, params): params=params, ) - tf_logits_mask = get_tf_logits_mask(params["text_vocab_size"], model.total_tokens, - params["text_seq_len"], params['image_seq_len']) - model.set_logits_mask(tf_logits_mask) - if mode in [tf.estimator.ModeKeys.TRAIN, tf.estimator.ModeKeys.EVAL]: # Build mtf_features & seq length dict for getting number of microbatches # We need to pack inputs into a dict to pass into serialize_training_step From 2972665ec9ee93c8f71844014543bd901059d180 Mon Sep 17 00:00:00 2001 From: Phil Wang Date: Sun, 4 Apr 2021 15:22:11 -0700 Subject: [PATCH 35/43] fix syntax --- src/dalle_mtf/sample.py | 2 +- 1 file changed, 1 insertion(+), 1 deletion(-) diff --git a/src/dalle_mtf/sample.py b/src/dalle_mtf/sample.py index 8db423b..9bc58e0 100644 --- a/src/dalle_mtf/sample.py +++ b/src/dalle_mtf/sample.py @@ -13,7 +13,7 @@ def sample_autoregressive(inputs, has_partial_sequences=True, remove_partial_sequences=False, sampling_keep_top_k=-1, - cached=True + cached=True, min_start_pos=None ): """Sample randomly one token at a time. From 2c081f263c69e787053870e920795f409630d8fa Mon Sep 17 00:00:00 2001 From: Phil Wang Date: Sun, 4 Apr 2021 15:47:25 -0700 Subject: [PATCH 36/43] variable scope --- src/dalle_mtf/models.py | 4 ++-- 1 file changed, 2 insertions(+), 2 deletions(-) diff --git a/src/dalle_mtf/models.py b/src/dalle_mtf/models.py index 08d78bd..6b8cc9f 100644 --- a/src/dalle_mtf/models.py +++ b/src/dalle_mtf/models.py @@ -451,7 +451,7 @@ def to_logits(self, x): return mtf.cast(logits, tf.float32) def to_image_logits(self, x): - with tf.variable_scope("to_logits"): + with tf.variable_scope("to_image_logits"): if not self.is_incremental_inference: x = mtf.slice(x, begin = self.text_seq_len, size = self.image_seq_len, slice_dim_name = x.shape[1].name) @@ -462,7 +462,7 @@ def to_image_logits(self, x): return image_logits def to_text_logits(self, x): - with tf.variable_scope("to_logits"): + with tf.variable_scope("to_text_logits"): text_tokens = mtf.slice(x, begin = 0, size = self.text_seq_len, slice_dim_name = x.shape[1].name) text_logits = self.linear(text_tokens, self.dimensions["text_vocab_dim"], name="linear_text_out") From 6a37d0f3abbea3bc77457b0e79f3820e3ce54e86 Mon Sep 17 00:00:00 2001 From: Phil Wang Date: Sun, 4 Apr 2021 16:23:39 -0700 Subject: [PATCH 37/43] remove stop_at_tokens, since it wont be used, to avoid confusion --- src/dalle_mtf/sample.py | 12 ------------ src/model_fns.py | 2 -- test.py | 5 ++--- 3 files changed, 2 insertions(+), 17 deletions(-) diff --git a/src/dalle_mtf/sample.py b/src/dalle_mtf/sample.py index 9bc58e0..7a86ec1 100644 --- a/src/dalle_mtf/sample.py +++ b/src/dalle_mtf/sample.py @@ -5,7 +5,6 @@ def sample_autoregressive(inputs, model, - stop_at_token=50256, max_steps=None, temperature=0.9, padding_id = 0, @@ -103,11 +102,6 @@ def sample_autoregressive(inputs, if not has_partial_sequences: partial_sequences_eos_count = 0 - if stop_at_token is not None: - partial_sequences_eos_count = mtf.reduce_sum( - mtf.to_int32(mtf.equal(image_inputs, stop_at_token)), - reduced_dim=image_seq_dim) - def cond_fn(position, ids, *unused_states): """Should we run another loop iteration?""" past_end = mtf.greater_equal(position, image_seq_dim.size) @@ -116,12 +110,6 @@ def cond_fn(position, ids, *unused_states): past_end, mtf.greater_equal(position - initial_position, max_steps)) is_done = past_end - if stop_at_token is not None: - eos_count = mtf.reduce_sum( - mtf.to_int32(mtf.equal(ids, stop_at_token)), - reduced_dim=image_seq_dim) - has_additional_eos = mtf.greater(eos_count, partial_sequences_eos_count) - is_done = mtf.logical_or(is_done, has_additional_eos) all_done = mtf.reduce_all(is_done) return mtf.logical_not(all_done) diff --git a/src/model_fns.py b/src/model_fns.py index a1aa42b..383f87a 100644 --- a/src/model_fns.py +++ b/src/model_fns.py @@ -153,8 +153,6 @@ def dalle_model_fn(features, labels, mode, params): # Set up the model for prediction mtf_samples = sample_autoregressive(mtf_features, model, - params, - stop_at_token=None, max_steps=model.total_seq_len, temperature=0.9, variable_dtype=model.variable_dtype, diff --git a/test.py b/test.py index 856cf8e..f4bbf68 100644 --- a/test.py +++ b/test.py @@ -73,7 +73,7 @@ def test_sampling(): sequence_dim = model.dimensions["total_seq_dim"] text_inputs = mtf.ones(mesh, mtf.Shape((batch_dim, sequence_dim)), tf.int32) - image_inputs = mtf.ones(mesh, mtf.Shape((batch_dim, sequence_dim)), tf.int32) + image_inputs = mtf.zeros(mesh, mtf.Shape((batch_dim, sequence_dim)), tf.int32) inputs = { 'text_inputs': mtf.slice(text_inputs, 0, model.text_seq_len, sequence_dim.name), @@ -87,10 +87,9 @@ def test_sampling(): variable_dtype=mtf.VariableDType(), max_steps = sequence_dim.size, remove_partial_sequences=False, - stop_at_token=None, min_start_pos=model.text_seq_len ) mesh_impl = placement_mesh_impl.PlacementMeshImpl(shape=[], layout={}, devices=[""]) lowering = mtf.Lowering(graph, {mesh: mesh_impl}) - samples = lowering.export_to_tf_tensor(samples) \ No newline at end of file + samples = lowering.export_to_tf_tensor(samples) From c60bf543ecb3873053beaca7b8f61c132808000e Mon Sep 17 00:00:00 2001 From: Phil Wang Date: Sun, 4 Apr 2021 16:29:06 -0700 Subject: [PATCH 38/43] more sample cleanup --- src/dalle_mtf/sample.py | 20 +------------------- src/model_fns.py | 2 -- test.py | 1 - 3 files changed, 1 insertion(+), 22 deletions(-) diff --git a/src/dalle_mtf/sample.py b/src/dalle_mtf/sample.py index 7a86ec1..2fbecf1 100644 --- a/src/dalle_mtf/sample.py +++ b/src/dalle_mtf/sample.py @@ -9,8 +9,6 @@ def sample_autoregressive(inputs, temperature=0.9, padding_id = 0, variable_dtype=mtf.VariableDType(tf.float32), - has_partial_sequences=True, - remove_partial_sequences=False, sampling_keep_top_k=-1, cached=True, min_start_pos=None @@ -34,10 +32,7 @@ def sample_autoregressive(inputs, temperature: an optional floating point value between 0.0 and 1.0 0.0 means argmax, 1.0 means sample according to predicted distribution. variable_dtype: a mtf.VariableDType - has_partial_sequences: a boolean decoding, one per each input layer + the embedding layer - remove_partial_sequences: a boolean - whether to remove the partial - sequences from the output sampling_keep_top_k: an integer - if not -1, only sample from the top k logits. @@ -92,16 +87,10 @@ def sample_autoregressive(inputs, logits = model.forward(inputs, return_loss=False, return_logits=True) del logits - if not has_partial_sequences: - initial_states = [mtf.zeros_like(t) for t in context_first_part.new_states] - else: - initial_states = context_first_part.new_states + initial_states = context_first_part.new_states else: initial_states = [] - if not has_partial_sequences: - partial_sequences_eos_count = 0 - def cond_fn(position, ids, *unused_states): """Should we run another loop iteration?""" past_end = mtf.greater_equal(position, image_seq_dim.size) @@ -174,11 +163,4 @@ def body_fn(position, ids, *states): final_position, outputs = mtf.while_loop( cond_fn, body_fn, while_loop_inputs)[:2] del final_position - # if has_partial_sequences and remove_partial_sequences: - # # Remove partial sequences from outputs - # partial_length = mtf.reduce_sum( - # mtf.to_int32(mtf.not_equal(image_inputs, padding_id)), - # reduced_dim=image_seq_dim) - # outputs = mtf.dynamic_shift( - # outputs, -partial_length, image_seq_dim, wrap=False) return outputs diff --git a/src/model_fns.py b/src/model_fns.py index 383f87a..618a978 100644 --- a/src/model_fns.py +++ b/src/model_fns.py @@ -156,8 +156,6 @@ def dalle_model_fn(features, labels, mode, params): max_steps=model.total_seq_len, temperature=0.9, variable_dtype=model.variable_dtype, - has_partial_sequences=True, - remove_partial_sequences=True, sampling_keep_top_k=-2, ) diff --git a/test.py b/test.py index f4bbf68..ee66e9f 100644 --- a/test.py +++ b/test.py @@ -86,7 +86,6 @@ def test_sampling(): model, variable_dtype=mtf.VariableDType(), max_steps = sequence_dim.size, - remove_partial_sequences=False, min_start_pos=model.text_seq_len ) From 02a51f92cc2b07708cb42e834c6876def97a7333 Mon Sep 17 00:00:00 2001 From: Phil Wang Date: Sun, 4 Apr 2021 16:39:45 -0700 Subject: [PATCH 39/43] make sure one can sample the non-cached way --- src/dalle_mtf/sample.py | 11 +++++------ test.py | 20 +++++++++++++++++--- 2 files changed, 22 insertions(+), 9 deletions(-) diff --git a/src/dalle_mtf/sample.py b/src/dalle_mtf/sample.py index 2fbecf1..cdb45c1 100644 --- a/src/dalle_mtf/sample.py +++ b/src/dalle_mtf/sample.py @@ -22,7 +22,6 @@ def sample_autoregressive(inputs, If there are no partial sequences (you want to sample from the beginning), then pass partial_sequences=mtf.zeros(mesh, shape, dtype=tf.int32) and - has_partial_sequences=False (so we can skip computation). Args: inputs: an input dictionary containing 'text_inputs' and 'image_inputs', @@ -124,7 +123,7 @@ def body_fn(position, ids, *states): model.is_incremental_inference = True if cached else False model.context = context with tf.variable_scope("dall-e", reuse=tf.AUTO_REUSE): - logits = model.forward({'image_inputs': image_inputs}, return_loss=False, return_logits=True) + logits = model.forward({'image_inputs': image_inputs, 'text_inputs': (text_inputs if not cached else None)}, return_loss=False, return_logits=True) # By default, do top_k sampling of 0.9 if sampling_keep_top_k == -2: @@ -147,16 +146,16 @@ def body_fn(position, ids, *states): if cached: ids_this_step = mtf.reshape(ids_this_step, ([batch_dims])) else: - print('*' * 100, '\nIDS THIS STEP SLOW') - ids_this_step = mtf.shift(ids_this_step, offset=1, dim=length_dim, wrap=False) - print('*' * 100) + ids_this_step = mtf.shift(ids_this_step, offset=1, dim=image_seq_dim, wrap=False) one_hot = mtf.one_hot(position, image_seq_dim, dtype=tf.int32) one_new_id = ids_this_step * one_hot new_ids = (1 - one_hot) * ids + one_new_id new_position = position + 1 ret = [new_position, new_ids] - ret += context.new_states + + if cached: + ret += context.new_states return ret while_loop_inputs = [initial_position, image_inputs] + initial_states diff --git a/test.py b/test.py index ee66e9f..44fe72d 100644 --- a/test.py +++ b/test.py @@ -81,14 +81,28 @@ def test_sampling(): } with not_raises(Exception): - samples = sample_autoregressive( + cached_samples = sample_autoregressive( inputs, model, variable_dtype=mtf.VariableDType(), max_steps = sequence_dim.size, - min_start_pos=model.text_seq_len + min_start_pos=model.text_seq_len, + cached = True ) mesh_impl = placement_mesh_impl.PlacementMeshImpl(shape=[], layout={}, devices=[""]) lowering = mtf.Lowering(graph, {mesh: mesh_impl}) - samples = lowering.export_to_tf_tensor(samples) + cached_samples = lowering.export_to_tf_tensor(cached_samples) + + noncached_samples = sample_autoregressive( + inputs, + model, + variable_dtype=mtf.VariableDType(), + max_steps = sequence_dim.size, + min_start_pos=model.text_seq_len, + cached = False + ) + + mesh_impl = placement_mesh_impl.PlacementMeshImpl(shape=[], layout={}, devices=[""]) + lowering = mtf.Lowering(graph, {mesh: mesh_impl}) + noncached_samples = lowering.export_to_tf_tensor(noncached_samples) From beb1c609769c46c50d795811282c30cb690aab0a Mon Sep 17 00:00:00 2001 From: Phil Wang Date: Sun, 4 Apr 2021 16:47:23 -0700 Subject: [PATCH 40/43] fix sampling for cached version (potentially) --- src/dalle_mtf/sample.py | 70 +++++++++++++++++++++++------------------ 1 file changed, 40 insertions(+), 30 deletions(-) diff --git a/src/dalle_mtf/sample.py b/src/dalle_mtf/sample.py index cdb45c1..0b3c1f8 100644 --- a/src/dalle_mtf/sample.py +++ b/src/dalle_mtf/sample.py @@ -63,6 +63,39 @@ def sample_autoregressive(inputs, length_range = mtf.range(image_inputs.mesh, image_seq_dim, tf.int32) + # one step of sampling fn + + def sample_step(logits, ids, position, incremental): + nonlocal sampling_keep_top_k + # By default, do top_k sampling of 0.9 + if sampling_keep_top_k == -2: + sampling_keep_top_k = int(logits.shape[-1].size * 0.1) + + if sampling_keep_top_k != -1: + if sampling_keep_top_k <= 0: + raise ValueError("sampling_keep_top_k must either be -1 or positive.") + k_largest = mtf.nth_largest_element( + logits, n=sampling_keep_top_k, + reduced_dim=model.dimensions['image_vocab_dim']) + logits = mtf.where(mtf.less_equal(logits, k_largest), + mtf.ones_like(logits) * -1e6, logits) + + # temperature sampling + ids_this_step = mtf.sample_with_temperature( + logits, model.dimensions['image_vocab_dim'], temperature) + + # reshape & assign results + if incremental: + ids_this_step = mtf.reshape(ids_this_step, ([batch_dims])) + else: + ids_this_step = mtf.shift(ids_this_step, offset=1, dim=image_seq_dim, wrap=False) + + one_hot = mtf.one_hot(position, image_seq_dim, dtype=tf.int32) + one_new_id = ids_this_step * one_hot + new_ids = (1 - one_hot) * ids + one_new_id + new_position = position + 1 + return [new_position, new_ids] + # Builds context to pass around internally # The 'first part' context records initial states of k / v / x if cached: @@ -84,9 +117,14 @@ def sample_autoregressive(inputs, with tf.variable_scope('dall-e'): logits = model.forward(inputs, return_loss=False, return_logits=True) - del logits initial_states = context_first_part.new_states + + # sample one step to get first image token and then delete logits + + initial_position, image_inputs = sample_step(logits, image_inputs, initial_position, incremental = False) + + del logits else: initial_states = [] @@ -103,7 +141,6 @@ def cond_fn(position, ids, *unused_states): def body_fn(position, ids, *states): """One step in the decode loop.""" - nonlocal sampling_keep_top_k context = mtf_transformer.transformer.Context( model=None, @@ -125,34 +162,7 @@ def body_fn(position, ids, *states): with tf.variable_scope("dall-e", reuse=tf.AUTO_REUSE): logits = model.forward({'image_inputs': image_inputs, 'text_inputs': (text_inputs if not cached else None)}, return_loss=False, return_logits=True) - # By default, do top_k sampling of 0.9 - if sampling_keep_top_k == -2: - sampling_keep_top_k = int(logits.shape[-1].size * 0.1) - - if sampling_keep_top_k != -1: - if sampling_keep_top_k <= 0: - raise ValueError("sampling_keep_top_k must either be -1 or positive.") - k_largest = mtf.nth_largest_element( - logits, n=sampling_keep_top_k, - reduced_dim=model.dimensions['image_vocab_dim']) - logits = mtf.where(mtf.less_equal(logits, k_largest), - mtf.ones_like(logits) * -1e6, logits) - - # temperature sampling - ids_this_step = mtf.sample_with_temperature( - logits, model.dimensions['image_vocab_dim'], temperature) - - # reshape & assign results - if cached: - ids_this_step = mtf.reshape(ids_this_step, ([batch_dims])) - else: - ids_this_step = mtf.shift(ids_this_step, offset=1, dim=image_seq_dim, wrap=False) - - one_hot = mtf.one_hot(position, image_seq_dim, dtype=tf.int32) - one_new_id = ids_this_step * one_hot - new_ids = (1 - one_hot) * ids + one_new_id - new_position = position + 1 - ret = [new_position, new_ids] + ret = sample_step(logits, ids, position, cached) if cached: ret += context.new_states From c085faf345b2fb0a6d7df505019cca71a484e722 Mon Sep 17 00:00:00 2001 From: Phil Wang Date: Mon, 5 Apr 2021 08:30:03 -0700 Subject: [PATCH 41/43] change max steps to image seq len, since it is counting from the start of the image --- src/model_fns.py | 2 +- test.py | 2 +- 2 files changed, 2 insertions(+), 2 deletions(-) diff --git a/src/model_fns.py b/src/model_fns.py index 618a978..a2b2841 100644 --- a/src/model_fns.py +++ b/src/model_fns.py @@ -153,7 +153,7 @@ def dalle_model_fn(features, labels, mode, params): # Set up the model for prediction mtf_samples = sample_autoregressive(mtf_features, model, - max_steps=model.total_seq_len, + max_steps=model.image_seq_len, temperature=0.9, variable_dtype=model.variable_dtype, sampling_keep_top_k=-2, diff --git a/test.py b/test.py index 44fe72d..1651cd4 100644 --- a/test.py +++ b/test.py @@ -98,7 +98,7 @@ def test_sampling(): inputs, model, variable_dtype=mtf.VariableDType(), - max_steps = sequence_dim.size, + max_steps = model.image_seq_len, min_start_pos=model.text_seq_len, cached = False ) From 3ffc5d10183d83e518097b0756786ab4bba1ca1b Mon Sep 17 00:00:00 2001 From: Phil Wang Date: Mon, 5 Apr 2021 12:45:52 -0700 Subject: [PATCH 42/43] fix initial position at 0 --- src/dalle_mtf/sample.py | 4 +--- 1 file changed, 1 insertion(+), 3 deletions(-) diff --git a/src/dalle_mtf/sample.py b/src/dalle_mtf/sample.py index 0b3c1f8..3ee7f8f 100644 --- a/src/dalle_mtf/sample.py +++ b/src/dalle_mtf/sample.py @@ -50,9 +50,7 @@ def sample_autoregressive(inputs, text_inputs = inputs['text_inputs'] # Gets position (in image inputs) where zero padding starts - initial_position = mtf.reduce_sum( - mtf.to_int32(mtf.not_equal(image_inputs, padding_id)), - reduced_dim=image_seq_dim) + initial_position = mtf.zeros(text_inputs.mesh, mtf.Shape((batch_dims,)), dtype=tf.int32) if min_start_pos is not None: # force the sampling to never start below a minimum starting position, say the text length. From ef7864ceb5b65c2091cdd1fbe720342336a482e2 Mon Sep 17 00:00:00 2001 From: Phil Wang Date: Mon, 5 Apr 2021 22:45:30 -0700 Subject: [PATCH 43/43] make sure axial positional embedding is shifted over by one due to --- src/dalle_mtf/models.py | 3 ++- src/dalle_mtf/sample.py | 8 +------- test.py | 2 -- 3 files changed, 3 insertions(+), 10 deletions(-) diff --git a/src/dalle_mtf/models.py b/src/dalle_mtf/models.py index 6b8cc9f..be258d7 100644 --- a/src/dalle_mtf/models.py +++ b/src/dalle_mtf/models.py @@ -397,7 +397,8 @@ def axial_positional_embedding(self, mesh, name): wpe = (axial_wpe_1 + axial_wpe_2) / 2 wpe = mtf.reshape(wpe, [axial_dim, embd_dim]) - wpe = pad(wpe, [self.text_seq_len, 0], axial_dim.name) + wpe = pad(wpe, [self.text_seq_len + 1, 0], axial_dim.name) + wpe = mtf.slice(wpe, 0, self.total_seq_len, axial_dim.name) wpe = mtf.replace_dimensions(wpe, wpe.shape[0], self.dimensions["embed_seq_dim"]) return wpe diff --git a/src/dalle_mtf/sample.py b/src/dalle_mtf/sample.py index 3ee7f8f..39abc12 100644 --- a/src/dalle_mtf/sample.py +++ b/src/dalle_mtf/sample.py @@ -10,8 +10,7 @@ def sample_autoregressive(inputs, padding_id = 0, variable_dtype=mtf.VariableDType(tf.float32), sampling_keep_top_k=-1, - cached=True, - min_start_pos=None + cached=True ): """Sample randomly one token at a time. @@ -52,11 +51,6 @@ def sample_autoregressive(inputs, # Gets position (in image inputs) where zero padding starts initial_position = mtf.zeros(text_inputs.mesh, mtf.Shape((batch_dims,)), dtype=tf.int32) - if min_start_pos is not None: - # force the sampling to never start below a minimum starting position, say the text length. - # this will also be useful for image completion, where you can start sampling from half the image tokens - initial_position = mtf.maximum(initial_position, min_start_pos) - # initial_position += model.dimensions['text_seq_dim'].size length_range = mtf.range(image_inputs.mesh, image_seq_dim, tf.int32) diff --git a/test.py b/test.py index 1651cd4..90efc0e 100644 --- a/test.py +++ b/test.py @@ -86,7 +86,6 @@ def test_sampling(): model, variable_dtype=mtf.VariableDType(), max_steps = sequence_dim.size, - min_start_pos=model.text_seq_len, cached = True ) @@ -99,7 +98,6 @@ def test_sampling(): model, variable_dtype=mtf.VariableDType(), max_steps = model.image_seq_len, - min_start_pos=model.text_seq_len, cached = False )