diff --git a/torch-neuronx/inference/hf_prertained_bart_inference_on_inf2.ipynb b/torch-neuronx/inference/hf_prertained_bart_inference_on_inf2.ipynb new file mode 100644 index 0000000..239350e --- /dev/null +++ b/torch-neuronx/inference/hf_prertained_bart_inference_on_inf2.ipynb @@ -0,0 +1,437 @@ +{ + "cells": [ + { + "cell_type": "markdown", + "metadata": {}, + "source": [ + "# 0. Import libraries / Setup" + ] + }, + { + "cell_type": "markdown", + "metadata": {}, + "source": [ + "This notebook was tested with the neuron **sdk 2.21.1** (in Python 3.10.12).\n", + "It requires the following packages:\n", + "```\n", + "torch==2.5.1\n", + "torch-neuronx==2.5.1.2.4.0\n", + "torch-xla==2.5.1\n", + "torchvision==0.20.1\n", + "libneuronxla==2.1.714.0\n", + "neuronx-cc==2.16.372.0+4a9b2326\n", + "```\n", + "Normally those should be already installed when you setup the system for said sdk version.\n", + "Then you need to install the following:\n", + "```\n", + "huggingface-hub==0.28.1\n", + "transformers==4.48.2\n", + "```" + ] + }, + { + "cell_type": "code", + "execution_count": null, + "metadata": { + "editable": true, + "scrolled": true, + "slideshow": { + "slide_type": "" + }, + "tags": [] + }, + "outputs": [], + "source": [ + "import sys\n", + "\n", + "!{sys.executable} -m pip install transformers==4.48.2 huggingface-hub==0.28.1" + ] + }, + { + "cell_type": "code", + "execution_count": null, + "metadata": {}, + "outputs": [], + "source": [ + "import transformers\n", + "import torch_neuronx\n", + "import os\n", + "os.environ[\"TOKENIZERS_PARALLELISM\"] = \"false\"" + ] + }, + { + "cell_type": "markdown", + "metadata": {}, + "source": [ + "# 1. Load model pretrained on MNLI" + ] + }, + { + "cell_type": "code", + "execution_count": null, + "metadata": {}, + "outputs": [], + "source": [ + "from transformers import BartForSequenceClassification, BartTokenizer\n", + "tokenizer = BartTokenizer.from_pretrained('facebook/bart-large-mnli', export=True)\n", + "model = BartForSequenceClassification.from_pretrained('facebook/bart-large-mnli', export=True)\n", + "model_cpu = BartForSequenceClassification.from_pretrained('facebook/bart-large-mnli')\n", + "model_dir = \"Bart\"\n", + "os.makedirs(model_dir, exist_ok=True)" + ] + }, + { + "cell_type": "markdown", + "metadata": {}, + "source": [ + "## 1.1 Test loaded model" + ] + }, + { + "cell_type": "code", + "execution_count": null, + "metadata": {}, + "outputs": [], + "source": [ + "# pose sequence as a NLI premise and label (politics) as a hypothesis\n", + "premise = 'What is your favorite team, Madrid or Barca?'\n", + "hypothesis = 'This text is about sports.'\n", + "max_length = 128\n", + "\n", + "# run through model pre-trained on MNLI\n", + "encoded_input = tokenizer.encode_plus(premise, hypothesis, return_tensors='pt', truncation='only_first', padding=\"max_length\", max_length=max_length)\n", + "logits = model(encoded_input[\"input_ids\"], encoded_input[\"attention_mask\"], use_cache=False)[0]\n", + "\n", + "# we throw away \"neutral\" (dim 1) and take the probability of\n", + "# \"entailment\" (2) as the probability of the label being true \n", + "entail_contradiction_logits = logits[:,[0,2]]\n", + "probs = entail_contradiction_logits.softmax(dim=1)\n", + "true_prob = probs[:,1].item() * 100\n", + "print(f'Probability that the label is true: {true_prob:0.2f}%')" + ] + }, + { + "cell_type": "markdown", + "metadata": {}, + "source": [ + "## 1.2 Test tracing the model as it comes" + ] + }, + { + "cell_type": "code", + "execution_count": null, + "metadata": { + "scrolled": true + }, + "outputs": [], + "source": [ + "neuron_encoder = torch_neuronx.trace(\n", + " model, \n", + " encoded_input[\"input_ids\"],\n", + " compiler_args='--target inf2 --model-type transformer --auto-cast all',\n", + " compiler_workdir='./enc_dir')" + ] + }, + { + "cell_type": "markdown", + "metadata": {}, + "source": [ + "The step above fails because between the encoder and decoder, the arguments are passes as a dictionary with tuples as values. The compiler doesn't work well with this setup, so the idea is to split the model in two parts, encoder and decoder compile them independently and then put them back into the original model structure.\n", + "\n", + "Given this model is around 400M params (1.5GB), it fits into just 1 core when quantized to bf16. After that, both encoder and decoder will be accelerated on inf2." + ] + }, + { + "cell_type": "markdown", + "metadata": {}, + "source": [ + "# 2. Prepare model for compilation" + ] + }, + { + "cell_type": "code", + "execution_count": null, + "metadata": {}, + "outputs": [], + "source": [ + "dim_enc=model.config.max_position_embeddings\n", + "dim_dec=model.config.d_model\n", + "print(f'Dim enc: {dim_enc}; Dim dec: {dim_dec}')\n", + "max_dec_len = 1024" + ] + }, + { + "cell_type": "code", + "execution_count": null, + "metadata": {}, + "outputs": [], + "source": [] + }, + { + "cell_type": "code", + "execution_count": null, + "metadata": {}, + "outputs": [], + "source": [ + "import torch\n", + "import torch.nn.functional as F\n", + "from transformers.modeling_outputs import BaseModelOutput, BaseModelOutputWithPastAndCrossAttentions\n", + "\n", + "# Define one function for the encoder part\n", + "def enc_f(self, input_ids, attention_mask, **kwargs):\n", + " if hasattr(self, 'forward_neuron'):\n", + " out = self.forward_neuron(input_ids, attention_mask)\n", + " else:\n", + " out = self.forward_(input_ids, attention_mask=attention_mask, return_dict=True)\n", + " return BaseModelOutput(**out)\n", + "\n", + "\n", + "# Define one function for the decoder part\n", + "def dec_f(self, input_ids, encoder_hidden_states, encoder_attention_mask, **kwargs): \n", + " out = None\n", + " \n", + " if input_ids.shape[1] > self.max_length:\n", + " raise Exception(f\"The decoded sequence is not supported. Max: {self.max_length}\")\n", + "\n", + " if hasattr(self, 'forward_neuron'):\n", + " out = self.forward_neuron(input_ids,\n", + " encoder_hidden_states,\n", + " encoder_attention_mask)\n", + " else:\n", + " out = self.forward_(input_ids=input_ids,\n", + " encoder_hidden_states=encoder_hidden_states,\n", + " encoder_attention_mask=encoder_attention_mask,\n", + " return_dict=True,\n", + " use_cache=False,\n", + " output_attentions=False)\n", + " \n", + " # Ensure the output is compatible with BaseModelOutputWithPastAndCrossAttentions\n", + " if 'cross_attentions' not in out:\n", + " out['cross_attentions'] = None\n", + " if 'hidden_states' not in out:\n", + " out['hidden_states'] = None\n", + " if 'attentions' not in out:\n", + " out['attentions'] = None\n", + " \n", + " return BaseModelOutputWithPastAndCrossAttentions(**out)" + ] + }, + { + "cell_type": "code", + "execution_count": null, + "metadata": {}, + "outputs": [], + "source": [ + "import types\n", + "\n", + "# Backup the original forward methods\n", + "if not hasattr(model.model.encoder, 'forward_'): \n", + " model.model.encoder.forward_ = model.model.encoder.forward\n", + "if not hasattr(model.model.decoder, 'forward_'): \n", + " model.model.decoder.forward_ = model.model.decoder.forward\n", + "\n", + "# Replace the forward methods with the custom ones\n", + "model.model.encoder.forward = types.MethodType(enc_f, model.model.encoder)\n", + "model.model.decoder.forward = types.MethodType(dec_f, model.model.decoder)\n", + "\n", + "# Set the max_length attribute for the decoder\n", + "model.model.decoder.max_length = max_dec_len # or any other appropriate value" + ] + }, + { + "cell_type": "code", + "execution_count": null, + "metadata": {}, + "outputs": [], + "source": [ + "# Run only the encoder to prepare the sample input for the decoder\n", + "encoder_inputs = encoded_input[\"input_ids\"], encoded_input[\"attention_mask\"]\n", + "encoder_outputs = model.model.encoder(encoded_input[\"input_ids\"], encoded_input[\"attention_mask\"])" + ] + }, + { + "cell_type": "markdown", + "metadata": {}, + "source": [ + "## 2.1 Trace Encoder" + ] + }, + { + "cell_type": "code", + "execution_count": null, + "metadata": {}, + "outputs": [], + "source": [ + "import os\n", + "import torch\n", + "\n", + "model_filename=f\"{model_dir}/BART-large-nli-encoder.pt\"\n", + "\n", + "if not os.path.isfile(model_filename):\n", + " if hasattr(model.model.encoder, 'forward_neuron'): del model.model.encoder.forward_neuron\n", + " neuron_encoder = torch_neuronx.trace(\n", + " model.model.encoder, \n", + " encoder_inputs,\n", + " compiler_args='--target inf2 --model-type transformer --auto-cast all',\n", + " compiler_workdir='./enc_dir')\n", + " # neuron_encoder_dynamic_batch = torch_neuronx.dynamic_batch(neuron_encoder)\n", + " neuron_encoder.save(model_filename)\n", + " model.model.encoder.forward_neuron = neuron_encoder\n", + "else:\n", + " model.model.encoder.forward_neuron = torch.jit.load(model_filename)\n", + "\n" + ] + }, + { + "cell_type": "markdown", + "metadata": {}, + "source": [ + "## 2.2 Trace Decoder" + ] + }, + { + "cell_type": "code", + "execution_count": null, + "metadata": {}, + "outputs": [], + "source": [ + "model_filename=f\"{model_dir}/BART-large-nli-decoder.pt\"\n", + "\n", + "if not os.path.isfile(model_filename):\n", + " inp = encoded_input[\"input_ids\"], encoder_outputs[0], encoded_input[\"attention_mask\"]\n", + " if hasattr(model.model.decoder, 'forward_neuron'): del model.model.decoder.forward_neuron\n", + " neuron_decoder = torch_neuronx.trace(\n", + " model.model.decoder,\n", + " inp,\n", + " compiler_args='--target inf2 --model-type transformer --auto-cast all',\n", + " compiler_workdir='./dec_dir')\n", + " # neuron_decoder_dynamic_batch = torch_neuronx.dynamic_batch(neuron_decoder)\n", + " neuron_decoder.save(model_filename)\n", + " model.model.decoder.forward_neuron = neuron_decoder\n", + "else:\n", + " model.model.decoder.forward_neuron = torch.jit.load(model_filename)" + ] + }, + { + "cell_type": "markdown", + "metadata": { + "scrolled": true + }, + "source": [ + "# 3. Test" + ] + }, + { + "cell_type": "code", + "execution_count": null, + "metadata": { + "scrolled": true + }, + "outputs": [], + "source": [ + "# pass sequence as a NLI premise and label (politics) as a hypothesis\n", + "premise = 'how do you like the potatoes?'\n", + "hypothesis = 'This text is about cooking.'\n", + "\n", + "# run through model pre-trained on MNLI\n", + "max_length=128\n", + "x = tokenizer.encode_plus(premise, hypothesis, return_tensors='pt', truncation='only_first', padding=\"max_length\", max_length=max_length, return_attention_mask=True)\n", + "y = model(x[\"input_ids\"],x[\"attention_mask\"])\n", + "logits = y[0]\n", + "\n", + "# we throw away \"neutral\" (dim 1) and take the probability of\n", + "# \"entailment\" (2) as the probability of the label being true \n", + "entail_contradiction_logits = logits[:,[0,2]]\n", + "probs = entail_contradiction_logits.softmax(dim=1)\n", + "true_prob = probs[:,1].item() * 100\n", + "print(f'Probability that the label is true: {true_prob:0.2f}%')\n" + ] + }, + { + "cell_type": "markdown", + "metadata": {}, + "source": [ + "### Now we can test the inference latency in the Inf2 chips:" + ] + }, + { + "cell_type": "code", + "execution_count": null, + "metadata": {}, + "outputs": [], + "source": [ + "%%timeit -r 10\n", + "\n", + "model(x[\"input_ids\"], x[\"attention_mask\"])" + ] + }, + { + "cell_type": "markdown", + "metadata": {}, + "source": [ + "### And compare it with the model hosted in the CPU:" + ] + }, + { + "cell_type": "code", + "execution_count": null, + "metadata": {}, + "outputs": [], + "source": [ + "%%timeit -r 10\n", + "model_cpu(x[\"input_ids\"], x[\"attention_mask\"])" + ] + }, + { + "cell_type": "markdown", + "metadata": {}, + "source": [ + "### Finally we can compare the output of CPU model vs the Inf2" + ] + }, + { + "cell_type": "code", + "execution_count": null, + "metadata": {}, + "outputs": [], + "source": [ + "y = model_cpu(x[\"input_ids\"],x[\"attention_mask\"])\n", + "logits = y[0]\n", + "# we throw away \"neutral\" (dim 1) and take the probability of\n", + "# \"entailment\" (2) as the probability of the label being true \n", + "entail_contradiction_logits = logits[:,[0,2]]\n", + "probs = entail_contradiction_logits.softmax(dim=1)\n", + "true_prob = probs[:,1].item() * 100\n", + "print(f'Probability that the label is true: {true_prob:0.2f}%')\n" + ] + }, + { + "cell_type": "markdown", + "metadata": {}, + "source": [ + "the value should be very similar to the one 3 cells above." + ] + } + ], + "metadata": { + "kernelspec": { + "display_name": "Python 3 (ipykernel)", + "language": "python", + "name": "python3" + }, + "language_info": { + "codemirror_mode": { + "name": "ipython", + "version": 3 + }, + "file_extension": ".py", + "mimetype": "text/x-python", + "name": "python", + "nbconvert_exporter": "python", + "pygments_lexer": "ipython3", + "version": "3.10.12" + } + }, + "nbformat": 4, + "nbformat_minor": 4 +}