diff --git a/tutorials/W2D5_Mysteries/W2D5_Tutorial3.ipynb b/tutorials/W2D5_Mysteries/W2D5_Tutorial3.ipynb index a826d7e75..2dee6504b 100644 --- a/tutorials/W2D5_Mysteries/W2D5_Tutorial3.ipynb +++ b/tutorials/W2D5_Mysteries/W2D5_Tutorial3.ipynb @@ -37,8 +37,918 @@ "metadata": {}, "outputs": [], "source": [ - "# Imports & content variables for the day\n", - "assert 1 == 0, \"Please run this script with the correct day number.\"" + "# @title Install and import feedback gadget\n", + "\n", + "!pip install vibecheck numpy matplotlib Pillow torch torchvision transformers ipywidgets gradio trdg scikit-learn networkx pickleshare seaborn tabulate --quiet\n", + "\n", + "from vibecheck import DatatopsContentReviewContainer\n", + "def content_review(notebook_section: str):\n", + " return DatatopsContentReviewContainer(\n", + " \"\", # No text prompt\n", + " notebook_section,\n", + " {\n", + " \"url\": \"https://pmyvdlilci.execute-api.us-east-1.amazonaws.com/klab\",\n", + " \"name\": \"neuromatch_neuroai\",\n", + " \"user_key\": \"wb2cxze8\",\n", + " },\n", + " ).render()\n", + "\n", + "feedback_prefix = \"W2D5_T3\"" + ] + }, + { + "cell_type": "code", + "execution_count": null, + "id": "4c4e3a7d", + "metadata": {}, + "outputs": [], + "source": [ + "# @title Import dependencies\n", + "# @markdown\n", + "\n", + "import contextlib\n", + "import io\n", + "\n", + "with contextlib.redirect_stdout(io.StringIO()):\n", + " # Standard Libraries\n", + " import copy\n", + " import logging\n", + " import os\n", + " import random\n", + " import requests\n", + "\n", + " # Data Handling and Visualization Libraries\n", + " import numpy as np\n", + " import pandas as pd\n", + " import matplotlib.pyplot as plt\n", + " import seaborn as sns\n", + " from sklearn.metrics import precision_score, recall_score, fbeta_score\n", + " from sklearn.linear_model import LinearRegression\n", + " from tabulate import tabulate\n", + "\n", + " # Scientific Computing and Statistical Libraries\n", + " from numpy.linalg import inv\n", + " from scipy.special import logsumexp\n", + " from scipy.stats import multivariate_normal\n", + "\n", + " # Deep Learning Libraries\n", + " import torch\n", + " from torch import nn, optim, save, load\n", + " from torch.nn import functional as F\n", + " from torch.utils.data import DataLoader\n", + " import torch.nn.init as init\n", + " from torch.optim.lr_scheduler import StepLR\n", + "\n", + " # Image Processing Libraries\n", + " from PIL import Image\n", + " from matplotlib.patches import Patch\n", + " from mpl_toolkits.mplot3d import Axes3D\n", + "\n", + " # Interactive Elements and Web Applications\n", + " from IPython.display import IFrame\n", + " from IPython.display import Image as IMG\n", + " import gradio as gr\n", + " import ipywidgets as widgets\n", + " from ipywidgets import interact, IntSlider\n", + "\n", + " # Graph Analysis Libraries\n", + " import networkx as nx\n", + "\n", + " # Progress Monitoring Libraries\n", + " from tqdm import tqdm\n", + "\n", + " # Utilities and Miscellaneous Libraries\n", + " from itertools import product\n", + "\n", + " import math\n", + " !pip install torch_optimizer\n", + " import torch_optimizer as optim2" + ] + }, + { + "cell_type": "code", + "execution_count": null, + "id": "98ca7c55", + "metadata": {}, + "outputs": [], + "source": [ + "# @title Set device (GPU or CPU)\n", + "\n", + "def set_device():\n", + " \"\"\"\n", + " Determines and sets the computational device for PyTorch operations based on the availability of a CUDA-capable GPU.\n", + "\n", + " Outputs:\n", + " - device (str): The device that PyTorch will use for computations ('cuda' or 'cpu'). This string can be directly used\n", + " in PyTorch operations to specify the device.\n", + " \"\"\"\n", + "\n", + " device = \"cuda\" if torch.cuda.is_available() else \"cpu\"\n", + " if device != \"cuda\":\n", + " print(\"GPU is not enabled in this notebook. \\n\"\n", + " \"If you want to enable it, in the menu under `Runtime` -> \\n\"\n", + " \"`Hardware accelerator.` and select `GPU` from the dropdown menu\")\n", + " else:\n", + " print(\"GPU is enabled in this notebook. \\n\"\n", + " \"If you want to disable it, in the menu under `Runtime` -> \\n\"\n", + " \"`Hardware accelerator.` and select `None` from the dropdown menu\")\n", + "\n", + " return device\n", + "\n", + "device = set_device()" + ] + }, + { + "cell_type": "code", + "execution_count": null, + "id": "4e5761b4", + "metadata": {}, + "outputs": [], + "source": [ + "# @title Plotting functions\n", + "# @markdown\n", + "\n", + "def plot_testing(results_seed, discrimination_seed, seeds, title):\n", + " print(results_seed)\n", + " print(discrimination_seed)\n", + "\n", + " Testing_graph_names = [\"Suprathreshold stimulus\", \"Subthreshold stimulus\", \"Low Vision\"]\n", + "\n", + " fig, ax = plt.subplots(figsize=(14, len(results_seed[0]) * 2 + 2)) # Adjusted for added header space\n", + " ax.axis('off')\n", + " ax.axis('tight')\n", + "\n", + " # Define column labels\n", + " col_labels = [\"Scenario\", \"F1 SCORE\\n(2nd order network)\", \"RECALL\\n(2nd order network)\", \"PRECISION\\n(2nd order network)\", \"Discrimination Performance\\n(1st order network)\", \"ACCURACY\\n(2nd order network)\"]\n", + "\n", + " # Initialize list to hold all rows of data including headers\n", + " full_data = []\n", + "\n", + " # Calculate averages and standard deviations\n", + " for i in range(len(results_seed[0])):\n", + " metrics_list = [result[i][\"metrics\"][0] for result in results_seed] # Collect metrics for each seed\n", + " discrimination_list = [discrimination_seed[j][i] for j in range(seeds)]\n", + "\n", + " # Calculate averages and standard deviations for metrics\n", + " avg_metrics = np.mean(metrics_list, axis=0).tolist()\n", + " std_metrics = np.std(metrics_list, axis=0).tolist()\n", + "\n", + " # Calculate average and standard deviation for discrimination performance\n", + " avg_discrimination = np.mean(discrimination_list)\n", + " std_discrimination = np.std(discrimination_list)\n", + "\n", + " # Format the row with averages and standard deviations\n", + " row = [\n", + " Testing_graph_names[i],\n", + " f\"{avg_metrics[2]:.2f} ± {std_metrics[2]:.2f}\", # F1 SCORE\n", + " f\"{avg_metrics[1]:.2f} ± {std_metrics[1]:.2f}\", # RECALL\n", + " f\"{avg_metrics[0]:.2f} ± {std_metrics[0]:.2f}\", # PRECISION\n", + " f\"{avg_discrimination:.2f} ± {std_discrimination:.2f}\", # Discrimination Performance\n", + " f\"{avg_metrics[3]:.2f} ± {std_metrics[3]:.2f}\" # ACCURACY\n", + " ]\n", + " full_data.append(row)\n", + "\n", + " # Extract metric values for color scaling (excluding the first and last columns which are text)\n", + " metric_values = np.array([[float(x.split(\" ± \")[0]) for x in row[1:]] for row in full_data]) # Convert to float for color scaling\n", + " max_value = np.max(metric_values)\n", + " colors = metric_values / max_value # Normalize for color mapping\n", + "\n", + " # Prepare colors for all cells, defaulting to white for non-metric cells\n", + " cell_colors = [[\"white\"] * len(col_labels) for _ in range(len(full_data))]\n", + " for i, row in enumerate(colors):\n", + " cell_colors[i][1] = plt.cm.RdYlGn(row[0])\n", + " cell_colors[i][2] = plt.cm.RdYlGn(row[1])\n", + " cell_colors[i][3] = plt.cm.RdYlGn(row[2])\n", + " cell_colors[i][5] = plt.cm.RdYlGn(row[3]) # Adding color for accuracy\n", + "\n", + " # Adding color for discrimination performance\n", + " discrimination_colors = colors[:, 3]\n", + " for i, dp_color in enumerate(discrimination_colors):\n", + " cell_colors[i][4] = plt.cm.RdYlGn(dp_color)\n", + "\n", + " # Create the main table with cell colors\n", + " table = ax.table(cellText=full_data, colLabels=col_labels, loc='center', cellLoc='center', cellColours=cell_colors)\n", + " table.auto_set_font_size(False)\n", + " table.set_fontsize(10)\n", + " table.scale(1.5, 1.5)\n", + "\n", + " # Set the height of the header row to be double that of the other rows\n", + " for j, col_label in enumerate(col_labels):\n", + " cell = table[(0, j)]\n", + " cell.set_height(cell.get_height() * 2)\n", + "\n", + " # Add chance level table\n", + " chance_level_data = [[\"Chance Level\\nDiscrimination(1st)\", \"Chance Level\\nAccuracy(2nd)\"],\n", + " [\"0.010\", \"0.50\"]]\n", + "\n", + " chance_table = ax.table(cellText=chance_level_data, bbox=[1.0, 0.8, 0.3, 0.1], cellLoc='center', colWidths=[0.1, 0.1])\n", + " chance_table.auto_set_font_size(False)\n", + " chance_table.set_fontsize(10)\n", + " chance_table.scale(1.2, 1.2)\n", + "\n", + " # Set the height of the header row to be double that of the other rows in the chance level table\n", + " for j in range(len(chance_level_data[0])):\n", + " cell = chance_table[(0, j)]\n", + " cell.set_height(cell.get_height() * 2)\n", + "\n", + " plt.title(title, pad=20, fontsize=16)\n", + " plt.show()\n", + " plt.close(fig)\n", + "\n", + "\n", + "def plot_signal_max_and_indicator(patterns_tensor, plot_title=\"Training Signals\"):\n", + " \"\"\"\n", + " Plots the maximum values of signal units and a binary indicator for max values greater than 0.5.\n", + "\n", + " Parameters:\n", + " - patterns_tensor: A tensor containing signals, where each signal is expected to have multiple units.\n", + " \"\"\"\n", + " with plt.xkcd():\n", + "\n", + " # Calculate the maximum value of units for each signal within the patterns tensor\n", + " max_values_of_units = patterns_tensor.max(dim=1).values.cpu().numpy() # Ensure it's on CPU and in NumPy format for plotting\n", + "\n", + " # Determine the binary indicators based on the max value being greater than 0.5\n", + " binary_indicators = (max_values_of_units > 0.5).astype(int)\n", + "\n", + " # Create a figure with 2 subplots (2 rows, 1 column)\n", + " fig, axs = plt.subplots(2, 1, figsize=(8, 8))\n", + "\n", + " fig.suptitle(plot_title, fontsize=16) # Set the overall title for the plot\n", + "\n", + " # First subplot for the maximum values of each signal\n", + " axs[0].plot(range(patterns_tensor.size(0)), max_values_of_units, drawstyle='steps-mid')\n", + " axs[0].set_xlabel('Pattern Number')\n", + " axs[0].set_ylabel('Max Value of Signal Units')\n", + " axs[0].set_ylim(-0.1, 1.1) # Adjust y-axis limits for clarity\n", + " axs[0].grid(True)\n", + "\n", + " # Second subplot for the binary indicators\n", + " axs[1].plot(range(patterns_tensor.size(0)), binary_indicators, drawstyle='steps-mid', color='red')\n", + " axs[1].set_xlabel('Pattern Number')\n", + " axs[1].set_ylabel('Indicator (Max > 0.5) in each signal')\n", + " axs[1].set_ylim(-0.1, 1.1) # Adjust y-axis limits for clarity\n", + " axs[1].grid(True)\n", + "\n", + " plt.tight_layout()\n", + " plt.show()\n", + "\n", + "\n", + "def perform_quadratic_regression(epoch_list, values):\n", + " # Perform quadratic regression\n", + " coeffs = np.polyfit(epoch_list, values, 2) # Coefficients of the polynomial\n", + " y_pred = np.polyval(coeffs, epoch_list) # Evaluate the polynomial at the given x values\n", + " return y_pred\n", + "\n", + "\n", + "def pre_train_plots(epoch_1_order, epoch_2_order, title, max_values_indices):\n", + " \"\"\"\n", + " Plots the training progress with regression lines and scatter plots of indices and values of max elements.\n", + "\n", + " Parameters:\n", + " - epoch_list (list): List of epoch numbers.\n", + " - epoch_1_order (list): Loss values for the first-order network over epochs.\n", + " - epoch_2_order (list): Loss values for the second-order network over epochs.\n", + " - title (str): Title for the plots.\n", + " - max_values_indices (tuple): Tuple containing lists of max values and indices for both tensors.\n", + " \"\"\"\n", + " (max_values_output_first_order,\n", + " max_indices_output_first_order,\n", + " max_values_patterns_tensor,\n", + " max_indices_patterns_tensor) = max_values_indices\n", + "\n", + " # Perform quadratic regression for the loss plots\n", + " epoch_list = list(range(len(epoch_1_order)))\n", + " y_pred1 = perform_quadratic_regression(epoch_list, epoch_1_order)\n", + " y_pred2 = perform_quadratic_regression(epoch_list, epoch_2_order)\n", + "\n", + " # Set up the plot with 2 rows and 2 columns\n", + " fig, axs = plt.subplots(2, 2, figsize=(15, 10))\n", + "\n", + " # First graph for 1st Order Network\n", + " axs[0, 0].plot(epoch_list, epoch_1_order, linestyle='--', marker='o', color='g')\n", + " axs[0, 0].plot(epoch_list, y_pred1, linestyle='-', color='r', label='Quadratic Fit')\n", + " axs[0, 0].legend(['1st Order Network', 'Quadratic Fit'])\n", + " axs[0, 0].set_title('1st Order Network Loss')\n", + " axs[0, 0].set_xlabel('Epochs - Pretraining Phase')\n", + " axs[0, 0].set_ylabel('Loss')\n", + "\n", + " # Second graph for 2nd Order Network\n", + " axs[0, 1].plot(epoch_list, epoch_2_order, linestyle='--', marker='o', color='b')\n", + " axs[0, 1].plot(epoch_list, y_pred2, linestyle='-', color='r', label='Quadratic Fit')\n", + " axs[0, 1].legend(['2nd Order Network', 'Quadratic Fit'])\n", + " axs[0, 1].set_title('2nd Order Network Loss')\n", + " axs[0, 1].set_xlabel('Epochs - Pretraining Phase')\n", + " axs[0, 1].set_ylabel('Loss')\n", + "\n", + " # Scatter plot of indices: patterns_tensor vs. output_first_order\n", + " axs[1, 0].scatter(max_indices_patterns_tensor, max_indices_output_first_order, alpha=0.5)\n", + "\n", + " # Add quadratic regression line\n", + " indices_regression = perform_quadratic_regression(max_indices_patterns_tensor, max_indices_output_first_order)\n", + " axs[1, 0].plot(max_indices_patterns_tensor, indices_regression, color='skyblue', linestyle='--', label='Quadratic Fit')\n", + "\n", + " axs[1, 0].set_title('Stimuli location: First Order Input vs. First Order Output')\n", + " axs[1, 0].set_xlabel('First Order Input Indices')\n", + " axs[1, 0].set_ylabel('First Order Output Indices')\n", + " axs[1, 0].legend()\n", + "\n", + " # Scatter plot of values: patterns_tensor vs. output_first_order\n", + " axs[1, 1].scatter(max_values_patterns_tensor, max_values_output_first_order, alpha=0.5)\n", + "\n", + " # Add quadratic regression line\n", + " values_regression = perform_quadratic_regression(max_values_patterns_tensor, max_values_output_first_order)\n", + " axs[1, 1].plot(max_values_patterns_tensor, values_regression, color='skyblue', linestyle='--', label='Quadratic Fit')\n", + "\n", + " axs[1, 1].set_title('Stimuli Values: First Order Input vs. First Order Output')\n", + " axs[1, 1].set_xlabel('First Order Input Values')\n", + " axs[1, 1].set_ylabel('First Order Output Values')\n", + " axs[1, 1].legend()\n", + "\n", + " plt.suptitle(title, fontsize=16, y=1.02)\n", + "\n", + " # Display the plots in a 2x2 grid\n", + " plt.tight_layout()\n", + " plt.savefig('Blindsight_Pre_training_Loss_{}.png'.format(title.replace(\" \", \"_\").replace(\"/\", \"_\")), bbox_inches='tight')\n", + " plt.show()\n", + " plt.close(fig)\n", + "\n", + "# Function to configure the training environment and load the models\n", + "def config_training(first_order_network, second_order_network, hidden, factor, gelu):\n", + " \"\"\"\n", + " Configures the training environment by saving the state of the given models and loading them back.\n", + " Initializes testing patterns for evaluation.\n", + "\n", + " Parameters:\n", + " - first_order_network: The first order network instance.\n", + " - second_order_network: The second order network instance.\n", + " - hidden: Number of hidden units in the first order network.\n", + " - factor: Factor influencing the network's architecture.\n", + " - gelu: Activation function to be used in the network.\n", + "\n", + " Returns:\n", + " - Tuple of testing patterns, number of samples in the testing patterns, and the loaded model instances.\n", + " \"\"\"\n", + " # Paths where the models' states will be saved\n", + " PATH = './cnn1.pth'\n", + " PATH_2 = './cnn2.pth'\n", + "\n", + " # Save the weights of the pretrained networks to the specified paths\n", + " torch.save(first_order_network.state_dict(), PATH)\n", + " torch.save(second_order_network.state_dict(), PATH_2)\n", + "\n", + " # Generating testing patterns for three different sets\n", + " First_set, First_set_targets = create_patterns(0,factor)\n", + " Second_set, Second_set_targets = create_patterns(1,factor)\n", + " Third_set, Third_set_targets = create_patterns(2,factor)\n", + "\n", + " # Aggregate testing patterns and their targets for ease of access\n", + " Testing_patterns = [[First_set, First_set_targets], [Second_set, Second_set_targets], [Third_set, Third_set_targets]]\n", + "\n", + " # Determine the number of samples from the first set (assumed consistent across all sets)\n", + " n_samples = len(Testing_patterns[0][0])\n", + "\n", + " # Initialize and load the saved states into model instances\n", + " loaded_model = FirstOrderNetwork(hidden, factor, gelu)\n", + " loaded_model_2 = SecondOrderNetwork(gelu)\n", + "\n", + " loaded_model.load_state_dict(torch.load(PATH))\n", + " loaded_model_2.load_state_dict(torch.load(PATH_2))\n", + "\n", + " # Ensure the models are moved to the appropriate device (CPU/GPU) and set to evaluation mode\n", + " loaded_model.to(device)\n", + " loaded_model_2.to(device)\n", + "\n", + " loaded_model.eval()\n", + " loaded_model_2.eval()\n", + "\n", + " return Testing_patterns, n_samples, loaded_model, loaded_model_2" + ] + }, + { + "cell_type": "code", + "execution_count": null, + "id": "ea1584b0", + "metadata": {}, + "outputs": [], + "source": [ + "# @title Figure settings\n", + "# @markdown\n", + "\n", + "logging.getLogger('matplotlib.font_manager').disabled = True\n", + "\n", + "%matplotlib inline\n", + "%config InlineBackend.figure_format = 'retina' # perfrom high definition rendering for images and plots\n", + "plt.style.use(\"https://raw.githubusercontent.com/NeuromatchAcademy/course-content/main/nma.mplstyle\")" + ] + }, + { + "cell_type": "code", + "execution_count": null, + "id": "02ac58d0", + "metadata": {}, + "outputs": [], + "source": [ + "# @title Helper functions\n", + "\n", + "mse_loss = nn.BCELoss(size_average = False)\n", + "\n", + "lam = 1e-4\n", + "\n", + "from torch.autograd import Variable\n", + "\n", + "def CAE_loss(W, x, recons_x, h, lam):\n", + " \"\"\"Compute the Contractive AutoEncoder Loss\n", + "\n", + " Evalutes the CAE loss, which is composed as the summation of a Mean\n", + " Squared Error and the weighted l2-norm of the Jacobian of the hidden\n", + " units with respect to the inputs.\n", + "\n", + "\n", + " See reference below for an in-depth discussion:\n", + " #1: http://wiseodd.github.io/techblog/2016/12/05/contractive-autoencoder\n", + "\n", + " Args:\n", + " `W` (FloatTensor): (N_hidden x N), where N_hidden and N are the\n", + " dimensions of the hidden units and input respectively.\n", + " `x` (Variable): the input to the network, with dims (N_batch x N)\n", + " recons_x (Variable): the reconstruction of the input, with dims\n", + " N_batch x N.\n", + " `h` (Variable): the hidden units of the network, with dims\n", + " batch_size x N_hidden\n", + " `lam` (float): the weight given to the jacobian regulariser term\n", + "\n", + " Returns:\n", + " Variable: the (scalar) CAE loss\n", + " \"\"\"\n", + " mse = mse_loss(recons_x, x)\n", + " # Since: W is shape of N_hidden x N. So, we do not need to transpose it as\n", + " # opposed to #1\n", + " dh = h * (1 - h) # Hadamard product produces size N_batch x N_hidden\n", + " # Sum through the input dimension to improve efficiency, as suggested in #1\n", + " w_sum = torch.sum(Variable(W)**2, dim=1)\n", + " # unsqueeze to avoid issues with torch.mv\n", + " w_sum = w_sum.unsqueeze(1) # shape N_hidden x 1\n", + " contractive_loss = torch.sum(torch.mm(dh**2, w_sum), 0)\n", + " return mse + contractive_loss.mul_(lam)\n", + "\n", + "class FirstOrderNetwork(nn.Module):\n", + " def __init__(self, hidden_units, data_factor, use_gelu):\n", + " \"\"\"\n", + " Initializes the FirstOrderNetwork with specific configurations.\n", + "\n", + " Parameters:\n", + " - hidden_units (int): The number of units in the hidden layer.\n", + " - data_factor (int): Factor to scale the amount of data processed.\n", + " A factor of 1 indicates the default data amount,\n", + " while 10 indicates 10 times the default amount.\n", + " - use_gelu (bool): Flag to use GELU (True) or ReLU (False) as the activation function.\n", + " \"\"\"\n", + " super(FirstOrderNetwork, self).__init__()\n", + "\n", + " # Define the encoder, hidden, and decoder layers with specified units\n", + "\n", + " self.fc1 = nn.Linear(100, hidden_units, bias = False) # Encoder\n", + " self.hidden= nn.Linear(hidden_units, hidden_units, bias = False) # Hidden\n", + " self.fc2 = nn.Linear(hidden_units, 100, bias = False) # Decoder\n", + "\n", + " self.relu = nn.ReLU()\n", + " self.sigmoid = nn.Sigmoid()\n", + "\n", + "\n", + " # Dropout layer to prevent overfitting\n", + " self.dropout = nn.Dropout(0.1)\n", + "\n", + " # Set the data factor\n", + " self.data_factor = data_factor\n", + "\n", + " # Other activation functions for various purposes\n", + " self.softmax = nn.Softmax()\n", + "\n", + " # Initialize network weights\n", + " self.initialize_weights()\n", + "\n", + " def initialize_weights(self):\n", + " \"\"\"Initializes weights of the encoder, hidden, and decoder layers uniformly.\"\"\"\n", + " init.uniform_(self.fc1.weight, -1.0, 1.0)\n", + " init.uniform_(self.fc2.weight, -1.0, 1.0)\n", + " init.uniform_(self.hidden.weight, -1.0, 1.0)\n", + "\n", + " def encoder(self, x):\n", + " h1 = self.dropout(self.relu(self.fc1(x.view(-1, 100))))\n", + " return h1\n", + "\n", + " def decoder(self,z):\n", + " #h2 = self.relu(self.hidden(z))\n", + " h2 = self.sigmoid(self.fc2(z))\n", + " return h2\n", + "\n", + "\n", + " def forward(self, x):\n", + " \"\"\"\n", + " Defines the forward pass through the network.\n", + "\n", + " Parameters:\n", + " - x (Tensor): The input tensor to the network.\n", + "\n", + " Returns:\n", + " - Tensor: The output of the network after passing through the layers and activations.\n", + " \"\"\"\n", + " h1 = self.encoder(x)\n", + " h2 = self.decoder(h1)\n", + "\n", + " return h1 , h2\n", + "\n", + "def initialize_global():\n", + " global Input_Size_1, Hidden_Size_1, Output_Size_1, Input_Size_2\n", + " global num_units, patterns_number\n", + " global learning_rate_2, momentum, temperature , Threshold\n", + " global First_set, Second_set, Third_set\n", + " global First_set_targets, Second_set_targets, Third_set_targets\n", + " global epoch_list, epoch_1_order, epoch_2_order, patterns_matrix1\n", + " global testing_graph_names\n", + "\n", + " global optimizer ,n_epochs , learning_rate_1\n", + " learning_rate_1 = 0.5\n", + " n_epochs = 100\n", + " optimizer=\"ADAMAX\"\n", + "\n", + " # Network sizes\n", + " Input_Size_1 = 100\n", + " Hidden_Size_1 = 60\n", + " Output_Size_1 = 100\n", + " Input_Size_2 = 100\n", + "\n", + " # Patterns\n", + " num_units = 100\n", + " patterns_number = 200\n", + "\n", + " # Pre-training and hyperparameters\n", + " learning_rate_2 = 0.1\n", + " momentum = 0.9\n", + " temperature = 1.0\n", + " Threshold=0.5\n", + "\n", + " # Testing\n", + " First_set = []\n", + " Second_set = []\n", + " Third_set = []\n", + " First_set_targets = []\n", + " Second_set_targets = []\n", + " Third_set_targets = []\n", + "\n", + " # Graphic of pretraining\n", + " epoch_list = list(range(1, n_epochs + 1))\n", + " epoch_1_order = np.zeros(n_epochs)\n", + " epoch_2_order = np.zeros(n_epochs)\n", + " patterns_matrix1 = torch.zeros((n_epochs, patterns_number), device=device) # Initialize patterns_matrix as a PyTorch tensor on the GPU\n", + "\n", + "\n", + "\n", + "def compute_metrics(TP, TN, FP, FN):\n", + " \"\"\"Compute precision, recall, F1 score, and accuracy.\"\"\"\n", + " precision = round(TP / (TP + FP), 2) if (TP + FP) > 0 else 0\n", + " recall = round(TP / (TP + FN), 2) if (TP + FN) > 0 else 0\n", + " f1_score = round(2 * (precision * recall) / (precision + recall), 2) if (precision + recall) > 0 else 0\n", + " accuracy = round((TP + TN) / (TP + TN + FP + FN), 2) if (TP + TN + FP + FN) > 0 else 0\n", + " return precision, recall, f1_score, accuracy\n", + "\n", + "# define the architecture, optimizers, loss functions, and schedulers for pre training\n", + "def prepare_pre_training(hidden,factor,gelu,stepsize, gam):\n", + "\n", + " first_order_network = FirstOrderNetwork(hidden, factor, gelu).to(device)\n", + " second_order_network = SecondOrderNetwork(gelu).to(device)\n", + "\n", + " criterion_1 = CAE_loss\n", + " criterion_2 = nn.BCELoss(size_average = False)\n", + "\n", + "\n", + " if optimizer == \"ADAM\":\n", + " optimizer_1 = optim.Adam(first_order_network.parameters(), lr=learning_rate_1)\n", + " optimizer_2 = optim.Adam(second_order_network.parameters(), lr=learning_rate_2)\n", + "\n", + " elif optimizer == \"SGD\":\n", + " optimizer_1 = optim.SGD(first_order_network.parameters(), lr=learning_rate_1)\n", + " optimizer_2 = optim.SGD(second_order_network.parameters(), lr=learning_rate_2)\n", + "\n", + " elif optimizer == \"SWATS\":\n", + " optimizer_1 = optim2.SWATS(first_order_network.parameters(), lr=learning_rate_1)\n", + " optimizer_2 = optim2.SWATS(second_order_network.parameters(), lr=learning_rate_2)\n", + "\n", + " elif optimizer == \"ADAMW\":\n", + " optimizer_1 = optim.AdamW(first_order_network.parameters(), lr=learning_rate_1)\n", + " optimizer_2 = optim.AdamW(second_order_network.parameters(), lr=learning_rate_2)\n", + "\n", + " elif optimizer == \"RMS\":\n", + " optimizer_1 = optim.RMSprop(first_order_network.parameters(), lr=learning_rate_1)\n", + " optimizer_2 = optim.RMSprop(second_order_network.parameters(), lr=learning_rate_2)\n", + "\n", + " elif optimizer == \"ADAMAX\":\n", + " optimizer_1 = optim.Adamax(first_order_network.parameters(), lr=learning_rate_1)\n", + " optimizer_2 = optim.Adamax(second_order_network.parameters(), lr=learning_rate_2)\n", + "\n", + " # Learning rate schedulers\n", + " scheduler_1 = StepLR(optimizer_1, step_size=stepsize, gamma=gam)\n", + " scheduler_2 = StepLR(optimizer_2, step_size=stepsize, gamma=gam)\n", + "\n", + " return first_order_network, second_order_network, criterion_1 , criterion_2, optimizer_1, optimizer_2, scheduler_1, scheduler_2\n", + "\n", + "def title(string):\n", + " # Enable XKCD plot styling\n", + " with plt.xkcd():\n", + " # Create a figure and an axes.\n", + " fig, ax = plt.subplots()\n", + "\n", + " # Create a rectangle patch with specified dimensions and styles\n", + " rectangle = patches.Rectangle((0.05, 0.1), 0.9, 0.4, linewidth=1, edgecolor='r', facecolor='blue', alpha=0.5)\n", + " ax.add_patch(rectangle)\n", + "\n", + " # Place text inside the rectangle, centered\n", + " plt.text(0.5, 0.3, string, horizontalalignment='center', verticalalignment='center', fontsize=26, color='white')\n", + "\n", + " # Set plot limits\n", + " ax.set_xlim(0, 1)\n", + " ax.set_ylim(0, 1)\n", + "\n", + " # Disable axis display\n", + " ax.axis('off')\n", + "\n", + " # Display the plot\n", + " plt.show()\n", + "\n", + " # Close the figure to free up memory\n", + " plt.close(fig)\n", + "# Function to configure the training environment and load the models\n", + "def get_test_patterns(factor):\n", + " \"\"\"\n", + " Configures the training environment by saving the state of the given models and loading them back.\n", + " Initializes testing patterns for evaluation.\n", + "\n", + " Returns:\n", + " - Tuple of testing patterns, number of samples in the testing patterns\n", + " \"\"\"\n", + " # Generating testing patterns for three different sets\n", + " first_set, first_set_targets = create_patterns(0,factor)\n", + " second_set, second_set_targets = create_patterns(1,factor)\n", + " third_set, third_set_targets = create_patterns(2,factor)\n", + "\n", + " # Aggregate testing patterns and their targets for ease of access\n", + " testing_patterns = [[first_set, first_set_targets], [second_set, second_set_targets], [third_set, third_set_targets]]\n", + "\n", + " # Determine the number of samples from the first set (assumed consistent across all sets)\n", + " n_samples = len(testing_patterns[0][0])\n", + "\n", + " return testing_patterns, n_samples\n", + "\n", + "# Function to test the model using the configured testing patterns\n", + "def plot_input_output(input_data, output_data, index):\n", + " fig, axes = plt.subplots(1, 2, figsize=(10, 6))\n", + "\n", + " # Plot input data\n", + " im1 = axes[0].imshow(input_data.cpu().numpy(), aspect='auto', cmap='viridis')\n", + " axes[0].set_title('Input')\n", + " fig.colorbar(im1, ax=axes[0])\n", + "\n", + " # Plot output data\n", + " im2 = axes[1].imshow(output_data.cpu().numpy(), aspect='auto', cmap='viridis')\n", + " axes[1].set_title('Output')\n", + " fig.colorbar(im2, ax=axes[1])\n", + "\n", + " plt.suptitle(f'Testing Pattern {index+1}')\n", + " plt.show()\n", + "\n", + "# Function to test the model using the configured testing patterns\n", + "# Function to test the model using the configured testing patterns\n", + "def testing(testing_patterns, n_samples, loaded_model, loaded_model_2,factor):\n", + "\n", + " def generate_chance_level(shape):\n", + " chance_level = np.random.rand(*shape).tolist()\n", + " return chance_level\n", + "\n", + " results_for_plotting = []\n", + " max_values_output_first_order = []\n", + " max_indices_output_first_order = []\n", + " max_values_patterns_tensor = []\n", + " max_indices_patterns_tensor = []\n", + " f1_scores_wager = []\n", + "\n", + " mse_losses_indices = []\n", + " mse_losses_values = []\n", + " discrimination_performances = []\n", + "\n", + "\n", + "\n", + " # Iterate through each set of testing patterns and targets\n", + " for i in range(len(testing_patterns)):\n", + " with torch.no_grad(): # Ensure no gradients are computed during testing\n", + "\n", + " #For low vision the stimulus threshold was set to 0.3 as can seen in the generate_patters function\n", + " threshold=0.5\n", + " if i==2:\n", + " threshold=0.15\n", + "\n", + " # Obtain output from the first order model\n", + " input_data = testing_patterns[i][0]\n", + " hidden_representation, output_first_order = loaded_model(input_data)\n", + " output_second_order = loaded_model_2(input_data, output_first_order)\n", + "\n", + " delta=100*factor\n", + "\n", + " print(\"driscriminator\")\n", + " print((output_first_order[delta:].argmax(dim=1) == input_data[delta:].argmax(dim=1)).to(float).mean())\n", + " discrimination_performance = round((output_first_order[delta:].argmax(dim=1) == input_data[delta:].argmax(dim=1)).to(float).mean().item(), 2)\n", + " discrimination_performances.append(discrimination_performance)\n", + "\n", + "\n", + " chance_level = torch.Tensor( generate_chance_level((200*factor,100)))\n", + " discrimination_random= round((chance_level[delta:].argmax(dim=1) == input_data[delta:].argmax(dim=1)).to(float).mean().item(), 2)\n", + " print(\"chance level\" , discrimination_random)\n", + "\n", + "\n", + "\n", + " #count all patterns in the dataset\n", + " wagers = output_second_order[delta:].cpu()\n", + "\n", + " _, targets_2 = torch.max(testing_patterns[i][1], 1)\n", + " targets_2 = targets_2[delta:].cpu()\n", + "\n", + " # Convert targets to binary classification for wagering scenario\n", + " targets_2 = (targets_2 > 0).int()\n", + "\n", + " # Convert tensors to NumPy arrays for metric calculations\n", + " predicted_np = wagers.numpy().flatten()\n", + " targets_2_np = targets_2.numpy()\n", + "\n", + " #print(\"number of targets,\" , len(targets_2_np))\n", + "\n", + " print(predicted_np)\n", + " print(targets_2_np)\n", + "\n", + " # Calculate True Positives, True Negatives, False Positives, and False Negatives\n", + " TP = np.sum((predicted_np > threshold) & (targets_2_np > threshold))\n", + " TN = np.sum((predicted_np < threshold ) & (targets_2_np < threshold))\n", + " FP = np.sum((predicted_np > threshold) & (targets_2_np < threshold))\n", + " FN = np.sum((predicted_np < threshold) & (targets_2_np > threshold))\n", + "\n", + " # Compute precision, recall, F1 score, and accuracy for both high and low wager scenarios\n", + " precision_h, recall_h, f1_score_h, accuracy_h = compute_metrics(TP, TN, FP, FN)\n", + "\n", + " f1_scores_wager.append(f1_score_h)\n", + "\n", + " # Collect results for plotting\n", + " results_for_plotting.append({\n", + " \"counts\": [[TP, FP, TP + FP]],\n", + " \"metrics\": [[precision_h, recall_h, f1_score_h, accuracy_h]],\n", + " \"title_results\": f\"Results Table - Set {i+1}\",\n", + " \"title_metrics\": f\"Metrics Table - Set {i+1}\"\n", + " })\n", + "\n", + " # Plot input and output of the first-order network\n", + " plot_input_output(input_data, output_first_order, i)\n", + "\n", + " max_vals_out, max_inds_out = torch.max(output_first_order[100:], dim=1)\n", + " max_inds_out[max_vals_out == 0] = 0\n", + " max_values_output_first_order.append(max_vals_out.tolist())\n", + " max_indices_output_first_order.append(max_inds_out.tolist())\n", + "\n", + " max_vals_pat, max_inds_pat = torch.max(input_data[100:], dim=1)\n", + " max_inds_pat[max_vals_pat == 0] = 0\n", + " max_values_patterns_tensor.append(max_vals_pat.tolist())\n", + " max_indices_patterns_tensor.append(max_inds_pat.tolist())\n", + "\n", + " fig, axs = plt.subplots(1, 2, figsize=(15, 5))\n", + "\n", + " # Scatter plot of indices: patterns_tensor vs. output_first_order\n", + " axs[0].scatter(max_indices_patterns_tensor[i], max_indices_output_first_order[i], alpha=0.5)\n", + " axs[0].set_title(f'Stimuli location: Condition {i+1} - First Order Input vs. First Order Output')\n", + " axs[0].set_xlabel('First Order Input Indices')\n", + " axs[0].set_ylabel('First Order Output Indices')\n", + "\n", + " # Add quadratic fit to scatter plot\n", + " x_indices = max_indices_patterns_tensor[i]\n", + " y_indices = max_indices_output_first_order[i]\n", + " y_pred_indices = perform_quadratic_regression(x_indices, y_indices)\n", + " axs[0].plot(x_indices, y_pred_indices, color='skyblue')\n", + "\n", + "\n", + " # Calculate MSE loss for indices\n", + " mse_loss_indices = np.mean((np.array(x_indices) - np.array(y_indices)) ** 2)\n", + " mse_losses_indices.append(mse_loss_indices)\n", + "\n", + " # Scatter plot of values: patterns_tensor vs. output_first_order\n", + " axs[1].scatter(max_values_patterns_tensor[i], max_values_output_first_order[i], alpha=0.5)\n", + " axs[1].set_title(f'Stimuli Values: Condition {i+1} - First Order Input vs. First Order Output')\n", + " axs[1].set_xlabel('First Order Input Values')\n", + " axs[1].set_ylabel('First Order Output Values')\n", + "\n", + " # Add quadratic fit to scatter plot\n", + " x_values = max_values_patterns_tensor[i]\n", + " y_values = max_values_output_first_order[i]\n", + " y_pred_values = perform_quadratic_regression(x_values, y_values)\n", + " axs[1].plot(x_values, y_pred_values, color='skyblue')\n", + "\n", + " # Calculate MSE loss for values\n", + " mse_loss_values = np.mean((np.array(x_values) - np.array(y_values)) ** 2)\n", + " mse_losses_values.append(mse_loss_values)\n", + "\n", + " plt.tight_layout()\n", + " plt.show()\n", + "\n", + " return f1_scores_wager, mse_losses_indices , mse_losses_values, discrimination_performances, results_for_plotting\n", + "\n", + "def generate_patterns(patterns_number, num_units, factor, condition = 0):\n", + " \"\"\"\n", + " Generates patterns and targets for training the networks\n", + "\n", + " # patterns_number: Number of patterns to generate\n", + " # num_units: Number of units in each pattern\n", + " # pattern: 0: superthreshold, 1: subthreshold, 2: low vision\n", + " # Returns lists of patterns, stimulus present/absent indicators, and second order targets\n", + " \"\"\"\n", + "\n", + " patterns_number= patterns_number*factor\n", + "\n", + " patterns = [] # Store generated patterns\n", + " stim_present = [] # Indicators for when a stimulus is present in the pattern\n", + " stim_absent = [] # Indicators for when no stimulus is present\n", + " order_2_pr = [] # Second order network targets based on the presence or absence of stimulus\n", + "\n", + " if condition == 0:\n", + " random_limit= 0.0\n", + " baseline = 0\n", + " multiplier = 1\n", + "\n", + " if condition == 1:\n", + " random_limit= 0.02\n", + " baseline = 0.0012\n", + " multiplier = 1\n", + "\n", + " if condition == 2:\n", + " random_limit= 0.02\n", + " baseline = 0.0012\n", + " multiplier = 0.3\n", + "\n", + " # Generate patterns, half noise and half potential stimuli\n", + " for i in range(patterns_number):\n", + "\n", + " # First half: Noise patterns\n", + " if i < patterns_number // 2:\n", + "\n", + " pattern = multiplier * np.random.uniform(0.0, random_limit, num_units) + baseline # Generate a noise pattern\n", + " patterns.append(pattern)\n", + " stim_present.append(np.zeros(num_units)) # Stimulus absent\n", + " order_2_pr.append([0.0 , 1.0]) # No stimulus, low wager\n", + "\n", + " # Second half: Stimulus patterns\n", + " else:\n", + " stimulus_number = random.randint(0, num_units - 1) # Choose a unit for potential stimulus\n", + " pattern = np.random.uniform(0.0, random_limit, num_units) + baseline\n", + " pattern[stimulus_number] = np.random.uniform(0.0, 1.0) * multiplier # Set stimulus intensity\n", + "\n", + " patterns.append(pattern)\n", + " present = np.zeros(num_units)\n", + " # Determine if stimulus is above discrimination threshold\n", + " if pattern[stimulus_number] >= multiplier/2:\n", + " order_2_pr.append([1.0 , 0.0]) # Stimulus detected, high wager\n", + " present[stimulus_number] = 1.0\n", + " else:\n", + " order_2_pr.append([0.0 , 1.0]) # Stimulus not detected, low wager\n", + " present[stimulus_number] = 0.0\n", + "\n", + " stim_present.append(present)\n", + "\n", + "\n", + " patterns_tensor = torch.Tensor(patterns).to(device).requires_grad_(True)\n", + " stim_present_tensor = torch.Tensor(stim_present).to(device).requires_grad_(True)\n", + " stim_absent_tensor= torch.Tensor(stim_absent).to(device).requires_grad_(True)\n", + " order_2_tensor = torch.Tensor(order_2_pr).to(device).requires_grad_(True)\n", + "\n", + " return patterns_tensor, stim_present_tensor, stim_absent_tensor, order_2_tensor\n", + "\n", + "def create_patterns(stimulus,factor):\n", + " \"\"\"\n", + " Generates neural network input patterns based on specified stimulus conditions.\n", + "\n", + " Parameters:\n", + " - stimulus (int): Determines the type of patterns to generate.\n", + " Acceptable values:\n", + " - 0: Suprathreshold stimulus\n", + " - 1: Subthreshold stimulus\n", + " - 2: Low vision condition\n", + "\n", + " Returns:\n", + " - torch.Tensor: Tensor of generated patterns.\n", + " - torch.Tensor: Tensor of target values corresponding to the generated patterns.\n", + " \"\"\"\n", + "\n", + " # Generate initial patterns and target tensors for base condition.\n", + "\n", + " patterns_tensor, stim_present_tensor, _, _ = generate_patterns(patterns_number, num_units ,factor, stimulus)\n", + " # Convert pattern tensors for processing on specified device (CPU/GPU).\n", + " patterns = torch.Tensor(patterns_tensor).to(device)\n", + " targets = torch.Tensor(stim_present_tensor).to(device)\n", + "\n", + " return patterns, targets" ] }, {