From 682ec5c6215f3451d46481ff03379e6cb3b3f17f Mon Sep 17 00:00:00 2001 From: 29riyasaxena <29riyasaxena@gmail.com> Date: Fri, 31 Mar 2023 01:12:09 +0530 Subject: [PATCH] Add GAT in JAX --- GAT/GAT_JAX.ipynb | 536 ++++++++++++++++++++++++++++++++++++++++++++++ 1 file changed, 536 insertions(+) create mode 100644 GAT/GAT_JAX.ipynb diff --git a/GAT/GAT_JAX.ipynb b/GAT/GAT_JAX.ipynb new file mode 100644 index 0000000..a8276dc --- /dev/null +++ b/GAT/GAT_JAX.ipynb @@ -0,0 +1,536 @@ +{ + "nbformat": 4, + "nbformat_minor": 0, + "metadata": { + "colab": { + "provenance": [] + }, + "kernelspec": { + "name": "python3", + "display_name": "Python 3" + }, + "language_info": { + "name": "python" + }, + "accelerator": "GPU", + "gpuClass": "standard" + }, + "cells": [ + { + "cell_type": "code", + "execution_count": 1, + "metadata": { + "id": "slj-obYkPXRG" + }, + "outputs": [], + "source": [ + "import jax.numpy as jnp\n", + "from jax import lax, random\n", + "from jax.nn.initializers import glorot_normal, glorot_uniform\n", + "import jax.nn as nn\n", + "import jax\n", + "from jax import jit\n", + "import numpy.random as npr\n", + "from jax import jit, grad, random\n", + "from jax.example_libraries import optimizers\n", + "from typing import List\n", + "import numpy as np\n", + "import scipy.sparse as sp\n", + "from tensorflow import keras \n", + "import os\n", + "import jax\n", + "import jax.numpy as np\n", + "from jax import jit, grad, random\n", + "from jax.example_libraries import optimizers" + ] + }, + { + "cell_type": "code", + "source": [ + "jax.default_backend()" + ], + "metadata": { + "colab": { + "base_uri": "https://localhost:8080/", + "height": 35 + }, + "id": "3dbRfzdXP_Ii", + "outputId": "88365b72-dc08-4232-fab4-cba6ae0dbde6" + }, + "execution_count": 2, + "outputs": [ + { + "output_type": "execute_result", + "data": { + "text/plain": [ + "'gpu'" + ], + "application/vnd.google.colaboratory.intrinsic+json": { + "type": "string" + } + }, + "metadata": {}, + "execution_count": 2 + } + ] + }, + { + "cell_type": "code", + "source": [ + "#function to speed up the creation of random nunmbers\n", + "@jit \n", + "def create_random():\n", + " return random.split(random.PRNGKey(npr.randint(0,100)), 4)" + ], + "metadata": { + "id": "bY_UK0ABQP4R" + }, + "execution_count": 3, + "outputs": [] + }, + { + "cell_type": "code", + "source": [ + "# Layer construction function for a dropout layer with given rate.\n", + "def Dropout(rate):\n", + " #Constructor function\n", + " def init_fun(input_shape):\n", + " return input_shape, ()\n", + " #Function to compute dropout\n", + " def apply_fun(inputs, is_training, **kwargs):\n", + " # generate a random number generate a bernoulli prob\n", + " rng, rng2, rng3, rng4 = create_random()\n", + " # keep rate\n", + " keep = random.bernoulli(rng, 1.0 - rate, inputs.shape)\n", + " # output that is kept from input features\n", + " outs = keep*inputs/(1.0 -rate) \n", + " # if not training, just return inputs and discard any computation done\n", + " out = lax.cond(is_training, outs, lambda x: x, inputs, lambda x: x)\n", + " return out\n", + "\n", + " return init_fun, apply_fun" + ], + "metadata": { + "id": "KFoHZqW_cY5p" + }, + "execution_count": 4, + "outputs": [] + }, + { + "cell_type": "code", + "source": [ + "def GraphAttentionLayer(out_dim, dropout):\n", + " #Main layer for graph attention \n", + " _, drop_fun = Dropout(dropout)\n", + "\n", + " def init_fun(input_shape):\n", + " # Constructor, generate input weights\n", + " output_shape = input_shape[:-1] + (out_dim,)\n", + " k1, k2, k3, k4 = create_random()\n", + " # initialize weight\n", + " W = glorot_uniform()(k1, (input_shape[-1], out_dim))\n", + " # initialize nn weight\n", + " a_init = glorot_uniform()\n", + " a1 = a_init(k2, (out_dim, 1))\n", + " a2 = a_init(k3, (out_dim, 1))\n", + "\n", + " return output_shape, (W, a1, a2)\n", + " \n", + " def apply_fun(params, x, adj, activation=nn.elu, is_training=False, **kwargs):\n", + " #Apply function, compute the attention \n", + " W, a1, a2 = params\n", + " # initial dropout\n", + " x = drop_fun(x, is_training=is_training)\n", + " # weights matmult\n", + " x = np.dot(x, W)\n", + " # neural netw + alignment score\n", + " f_1 = np.dot(x, a1) \n", + " f_2 = np.dot(x, a2)\n", + " logits = f_1 + f_2.T\n", + " # softmax of leakyReLu for e\n", + " coefs = nn.softmax( nn.leaky_relu(logits, negative_slope=0.2) + np.where(adj, 0., -1e9))\n", + " # final dropout\n", + " coefs = drop_fun(coefs, is_training=is_training)\n", + " x = drop_fun(x, is_training=is_training)\n", + "\n", + " ret = np.matmul(coefs, x)\n", + "\n", + " return activation(ret)\n", + " return init_fun, apply_fun\n", + "\n", + "def MultiHeadLayer(nheads: int, nhid: int, dropout: float,last_layer: bool=False):\n", + " #Multi head attention layer\n", + " \n", + " layer_funs, layer_inits = [], []\n", + " # define the heads layers\n", + " for head_i in range(nheads):\n", + " att_init, att_fun = GraphAttentionLayer(nhid, dropout=dropout)\n", + " # initialize layers of attention\n", + " layer_inits.append(att_init)\n", + " # grab the functions for running attentions\n", + " layer_funs.append(att_fun)\n", + " \n", + " def init_fun(input_shape):\n", + " #Initialize each attention head\n", + " params = []\n", + " # for each head initialize parameters\n", + " for att_init_fun in layer_inits:\n", + " #rng, layer_rng = random.split(rng)\n", + " layer_shape, param = att_init_fun(input_shape)\n", + " params.append(param)\n", + "\n", + " input_shape = layer_shape\n", + " if not last_layer:\n", + " # multiply by the number of heads\n", + " input_shape = input_shape[:-1] + (input_shape[-1]*len(layer_inits),)\n", + " return input_shape, params\n", + " \n", + " def apply_fun(params, x, adj, is_training=False, **kwargs):\n", + " #Function to apply parameters to head \n", + "\n", + " layer_outs = []\n", + " assert len(params) == nheads\n", + " for head_i in range(nheads):\n", + " layer_params = params[head_i]\n", + " layer_outs.append(layer_funs[head_i](layer_params, x, adj, is_training=is_training))\n", + " # concatenate or average\n", + " if not last_layer:\n", + " x = np.concatenate(layer_outs, axis=1)\n", + " else:\n", + " # average last layer heads\n", + " x = np.mean(np.stack(layer_outs), axis=0)\n", + "\n", + " return x\n", + "\n", + " return init_fun, apply_fun" + ], + "metadata": { + "id": "1MinpDwdcdY1" + }, + "execution_count": 5, + "outputs": [] + }, + { + "cell_type": "code", + "source": [ + "def GAT(nheads: List[int], nhid: List[int], nclass: int, dropout: float):\n", + " # Graph Attention Network model definition.\n", + " init_funs = []\n", + " attn_funs = []\n", + "\n", + " nhid += [nclass]\n", + " for layer_i in range(len(nhid)):\n", + " last = layer_i == len(nhid) - 1\n", + " layer_init, layer_fun = MultiHeadLayer(nheads[layer_i], nhid[layer_i],dropout=dropout,last_layer=last)\n", + " attn_funs.append(layer_fun)\n", + " init_funs.append(layer_init)\n", + "\n", + " def init_fun(input_shape):\n", + " params = []\n", + " for i, init_fun in enumerate(init_funs):\n", + " layer_shape, param = init_fun(input_shape)\n", + " params.append(param)\n", + " input_shape = layer_shape\n", + " return input_shape, params\n", + "\n", + " def apply_fun(params, x, adj, is_training=False, **kwargs):\n", + "\n", + " for i, layer_fun in enumerate(attn_funs):\n", + " x = layer_fun(params[i], x, adj, is_training=is_training)\n", + " \n", + " return nn.log_softmax(x)\n", + "\n", + " return init_fun, apply_fun" + ], + "metadata": { + "id": "MXjbliaQcqVX" + }, + "execution_count": 6, + "outputs": [] + }, + { + "cell_type": "code", + "source": [ + "def encode_onehot(labels):\n", + " #Transform labels into a one hot encoded vector\n", + " classes = set(labels)\n", + " classes_dict = {c: np.identity(len(classes))[i, :] for i, c in\n", + " enumerate(classes)}\n", + " print(classes_dict)\n", + " labels_onehot = np.array(list(map(classes_dict.get, labels)),\n", + " dtype=np.int32)\n", + " return labels_onehot\n", + "def normalize(mx):\n", + " # Function to normalize values of a given sparse array mx\n", + " rowsum = np.array(mx.sum(1))\n", + " r_inv = np.power(rowsum, -1).flatten()\n", + " r_inv[np.isinf(r_inv)] = 0.\n", + " r_mat_inv = sp.diags(r_inv)\n", + " mx = r_mat_inv.dot(mx)\n", + " return mx\n", + "def load_data():\n", + " #Function to load the Cora dataset. \n", + " # Download file\n", + " zip_file = keras.utils.get_file(\n", + " fname=\"cora.tgz\",\n", + " origin=\"https://linqs-data.soe.ucsc.edu/public/lbc/cora.tgz\",\n", + " extract=True,\n", + " )\n", + " # create the path\n", + " data_dir = os.path.join(os.path.dirname(zip_file), \"cora\")\n", + "\n", + " # content data is converted to numpy vector\n", + " idx_features_labels = np.genfromtxt(f\"{data_dir}/cora.content\", dtype=np.dtype(str))\n", + " \n", + " # Take the bag-of-words vector of each paper as the feature vector of each article and store it in a sparse matrix format\n", + " features = sp.csr_matrix(idx_features_labels[:, 1:-1], dtype=np.float32)\n", + " \n", + " # Take the type of each paper as a label and convert it into a one hot vector\n", + " labels = encode_onehot(idx_features_labels[:, -1])\n", + "\n", + " # Take out the id of each paper\n", + " idx = np.array(idx_features_labels[:, 0], dtype=np.int32)\n", + " idx_map = {j: i for i, j in enumerate(idx)}\n", + " \n", + " # cites data is converted to numpy vector\n", + " edges_unordered = np.genfromtxt(f\"{data_dir}/cora.cites\",dtype=np.int32)\n", + " \n", + " # Map the id in the cites data to the interval [0, 2708]\n", + " edges = np.array(list(map(idx_map.get, edges_unordered.flatten())),\n", + " dtype=np.int32).reshape(edges_unordered.shape)\n", + " \n", + " # Store the citation relationship between papers in a sparse matrix format\n", + " adj = sp.coo_matrix((np.ones(edges.shape[0]), (edges[:, 0], edges[:, 1])),\n", + " shape=(labels.shape[0], labels.shape[0]),\n", + " dtype=np.float32)\n", + " \n", + " # build symmetric adjacency matrix\n", + " adj = adj + adj.T.multiply(adj.T > adj) - adj.multiply(adj.T > adj)\n", + " \n", + " # Normalize the characteristics of the article\n", + " features = normalize(features)\n", + " adj = normalize(adj + sp.eye(adj.shape[0]))\n", + " # Produce the final vector\n", + " idx_train = range(140)\n", + " idx_val = range(200, 500)\n", + " idx_test = range(500, 1500)\n", + "\n", + " features = np.array(features.todense())\n", + "\n", + " # JAX doesn't support sparse matrices yet\n", + " adj = np.asarray(adj.todense())\n", + "\n", + " return adj, features, labels, np.array(idx_train), np.array(idx_val), np.array(idx_test)" + ], + "metadata": { + "id": "68r99bZccdgs" + }, + "execution_count": 7, + "outputs": [] + }, + { + "cell_type": "code", + "source": [ + "adj, features, labels, idx_train, idx_val, idx_test = load_data()" + ], + "metadata": { + "colab": { + "base_uri": "https://localhost:8080/" + }, + "id": "FhklgkROc2SK", + "outputId": "1d63eb67-0d33-4241-b6f5-09d7769c2445" + }, + "execution_count": 8, + "outputs": [ + { + "output_type": "stream", + "name": "stdout", + "text": [ + "Downloading data from https://linqs-data.soe.ucsc.edu/public/lbc/cora.tgz\n", + "168052/168052 [==============================] - 0s 2us/step\n", + "{'Probabilistic_Methods': array([1., 0., 0., 0., 0., 0., 0.]), 'Rule_Learning': array([0., 1., 0., 0., 0., 0., 0.]), 'Reinforcement_Learning': array([0., 0., 1., 0., 0., 0., 0.]), 'Neural_Networks': array([0., 0., 0., 1., 0., 0., 0.]), 'Theory': array([0., 0., 0., 0., 1., 0., 0.]), 'Case_Based': array([0., 0., 0., 0., 0., 1., 0.]), 'Genetic_Algorithms': array([0., 0., 0., 0., 0., 0., 1.])}\n" + ] + } + ] + }, + { + "cell_type": "code", + "source": [ + "\n", + "\n", + "@jit\n", + "def loss(params, batch):\n", + " # The indexes of the batch indicate which nodes are used to compute the loss.\n", + " inputs, targets, adj, is_training, idx = batch\n", + " preds = predict_fun(params, inputs, adj, is_training=is_training)\n", + " ce_loss = -np.mean(np.sum(preds[idx] * targets[idx], axis=1))\n", + " l2_loss = 5e-4 * optimizers.l2_norm(params)**2 \n", + " return ce_loss + l2_loss\n", + "\n", + "\n", + "@jit\n", + "def accuracy(params, batch):\n", + " inputs, targets, adj, is_training, idx = batch\n", + " target_class = np.argmax(targets, axis=1)\n", + " predicted_class = np.argmax(predict_fun(params, inputs, adj, \n", + " is_training=is_training), axis=1)\n", + " return np.mean(predicted_class[idx] == target_class[idx])\n", + "\n", + "\n", + "@jit\n", + "def loss_accuracy(params, batch):\n", + " inputs, targets, adj, is_training, idx = batch\n", + " preds = predict_fun(params, inputs, adj, is_training=is_training)\n", + " target_class = np.argmax(targets, axis=1)\n", + " predicted_class = np.argmax(preds, axis=1)\n", + " ce_loss = -np.mean(np.sum(preds[idx] * targets[idx], axis=1))\n", + " acc = np.mean(predicted_class[idx] == target_class[idx])\n", + " return ce_loss, acc" + ], + "metadata": { + "id": "vRmJxzGNc2YK" + }, + "execution_count": 9, + "outputs": [] + }, + { + "cell_type": "code", + "source": [ + "%%time\n", + "lr = 0.05\n", + "num_epochs = 100\n", + "n_nodes = adj.shape[0]\n", + "n_feats = features.shape[1]\n", + "\n", + "# GAT params\n", + "nheads = [8, 1]\n", + "nhid = [8]\n", + "dropout = 0.6 # probability of keeping\n", + "residual = False\n", + "\n", + "init_fun, predict_fun = GAT(nheads=nheads,\n", + " nhid=nhid,\n", + " nclass=7,\n", + " dropout=dropout,\n", + " )\n", + "\n", + "input_shape = (-1, n_nodes, n_feats)\n", + "_, init_params = init_fun(input_shape)\n", + "\n", + "opt_init, opt_update, get_params = optimizers.sgd(lr)" + ], + "metadata": { + "colab": { + "base_uri": "https://localhost:8080/" + }, + "id": "JrDZoo8Lc2fs", + "outputId": "eb0d8b75-2428-432d-9666-03426dddc56a" + }, + "execution_count": 10, + "outputs": [ + { + "output_type": "stream", + "name": "stdout", + "text": [ + "CPU times: user 2.54 s, sys: 1.02 s, total: 3.57 s\n", + "Wall time: 11.2 s\n" + ] + } + ] + }, + { + "cell_type": "code", + "source": [ + "@jit\n", + "def update(i, opt_state, batch):\n", + " params = get_params(opt_state)\n", + " return opt_update(i, grad(loss)(params, batch), opt_state)\n", + "\n", + "opt_state = opt_init(init_params)" + ], + "metadata": { + "id": "vfl4hUJKdY87" + }, + "execution_count": 11, + "outputs": [] + }, + { + "cell_type": "code", + "source": [ + "print(\"\\nStarting training...\")\n", + "for epoch in range(num_epochs):\n", + " \n", + " batch = (features, labels, adj, True, idx_train)\n", + " opt_state = update(epoch, opt_state, batch)\n", + "\n", + " params = get_params(opt_state)\n", + " eval_batch = (features, labels, adj, False, idx_val)\n", + " train_batch = (features, labels, adj, False, idx_train)\n", + " # additional step, everything can be loaded onto the GPU:\n", + " train_batch = jax.device_put(train_batch)\n", + " eval_batch = jax.device_put(eval_batch)\n", + " # without that we take about 1 min\n", + " train_loss, train_acc = loss_accuracy(params, train_batch)\n", + " val_loss, val_acc = loss_accuracy(params, eval_batch)\n", + " if epoch%10==0:\n", + " print((f\"Iter {epoch}/{num_epochs} train_loss:\"+\n", + " f\"{train_loss:.4f}, train_acc: {train_acc:.4f}, val_loss:\"+\n", + " f\"{val_loss:.4f}, val_acc: {val_acc:.4f}\"))" + ], + "metadata": { + "colab": { + "base_uri": "https://localhost:8080/" + }, + "id": "5eSYEYLVddpe", + "outputId": "5bc08a58-c80c-4e4e-803a-75a2bb97780c" + }, + "execution_count": 12, + "outputs": [ + { + "output_type": "stream", + "name": "stdout", + "text": [ + "\n", + "Starting training...\n", + "Iter 0/100 train_loss:1.9459, train_acc: 0.1143, val_loss:1.9459, val_acc: 0.1500\n", + "Iter 10/100 train_loss:1.9456, train_acc: 0.1214, val_loss:1.9456, val_acc: 0.1567\n", + "Iter 20/100 train_loss:1.9453, train_acc: 0.1357, val_loss:1.9453, val_acc: 0.1567\n", + "Iter 30/100 train_loss:1.9450, train_acc: 0.1286, val_loss:1.9451, val_acc: 0.1567\n", + "Iter 40/100 train_loss:1.9447, train_acc: 0.1357, val_loss:1.9448, val_acc: 0.1800\n", + "Iter 50/100 train_loss:1.9444, train_acc: 0.1571, val_loss:1.9446, val_acc: 0.1867\n", + "Iter 60/100 train_loss:1.9441, train_acc: 0.1714, val_loss:1.9443, val_acc: 0.2000\n", + "Iter 70/100 train_loss:1.9438, train_acc: 0.1786, val_loss:1.9441, val_acc: 0.2133\n", + "Iter 80/100 train_loss:1.9435, train_acc: 0.2143, val_loss:1.9439, val_acc: 0.2367\n", + "Iter 90/100 train_loss:1.9432, train_acc: 0.2357, val_loss:1.9436, val_acc: 0.2567\n" + ] + } + ] + }, + { + "cell_type": "code", + "source": [ + "# now run on the test set\n", + "test_batch = (features, labels, adj, False, idx_test)\n", + "test_acc = accuracy(params, test_batch)\n", + "print(f'Test set acc: {test_acc}')" + ], + "metadata": { + "colab": { + "base_uri": "https://localhost:8080/" + }, + "id": "TxPdf3KBds77", + "outputId": "3f7bea19-7901-452a-8718-5581b47635f8" + }, + "execution_count": 13, + "outputs": [ + { + "output_type": "stream", + "name": "stdout", + "text": [ + "Test set acc: 0.19700001180171967\n" + ] + } + ] + } + ] +} \ No newline at end of file