diff --git a/tutorials/W1D3_ComparingArtificialAndBiologicalNetworks/W1D3_Tutorial4.ipynb b/tutorials/W1D3_ComparingArtificialAndBiologicalNetworks/W1D3_Tutorial4.ipynb index f2a33e3b5..064782b50 100644 --- a/tutorials/W1D3_ComparingArtificialAndBiologicalNetworks/W1D3_Tutorial4.ipynb +++ b/tutorials/W1D3_ComparingArtificialAndBiologicalNetworks/W1D3_Tutorial4.ipynb @@ -17,7 +17,7 @@ "execution": {} }, "source": [ - "# Tutorial 4: Representational geometry & noise\n", + "# (Bonus) Tutorial 4: Representational geometry & noise\n", "\n", "**Week 1, Day 3: Comparing Artificial And Biological Networks**\n", "\n", @@ -25,9 +25,9 @@ "\n", "__Content creators:__ Wenxuan Guo, Heiko Schütt\n", "\n", - "__Content reviewers:__ Alish Dipani, Samuele Bolotta, Yizhou Chen, RyeongKyung Yoon, Ruiyi Zhang, Lily Chamakura, Hlib Solodzhuk\n", + "__Content reviewers:__ Alish Dipani, Samuele Bolotta, Yizhou Chen, RyeongKyung Yoon, Ruiyi Zhang, Lily Chamakura, Hlib Solodzhuk, Alex Murphy\n", "\n", - "__Production editors:__ Konstantine Tsafatinos, Ella Batty, Spiros Chavlis, Samuele Bolotta, Hlib Solodzhuk, Patrick Mineault\n", + "__Production editors:__ Konstantine Tsafatinos, Ella Batty, Spiros Chavlis, Samuele Bolotta, Hlib Solodzhuk, Patrick Mineault, Alex Murphy\n", "\n", "Acknowledgments: the tutorial outline was written by Heiko Schütt. The content was greatly improved by discussions with Heiko, Hlib, and Alish, and the insightful illustrations presented in the paper by Walther et al. (2016)\n" ] @@ -61,7 +61,7 @@ "\n", "5. Using random projections to estimate distances. This section introduces the Johnson–Lindenstrauss Lemma, which states that random projections maintain the integrity of distance estimates in a lower-dimensional space. This concept is crucial for reducing dimensionality while preserving the relational structure of the data.\n", "\n", - "We will adhere to the notational conventions established by [Walther et al. (2016)](https://pubmed.ncbi.nlm.nih.gov/26707889/) for all discussed distance measures. " + "We will adhere to the notational conventions established by [Walther et al. (2016)](https://pubmed.ncbi.nlm.nih.gov/26707889/) for all discussed distance measures." ] }, { @@ -644,6 +644,72 @@ "display(tabs)" ] }, + { + "cell_type": "markdown", + "id": "b64eaea5", + "metadata": { + "execution": {} + }, + "source": [ + "The video below is additional information in more detail which was previously part of the introductory video for this course day. It provides some useful further information on the technical details mentioned during these tutorials. Please feel free to check it out and use it as a resource if you want to learn more or if you want to get a deeper understanding on some of the important details." + ] + }, + { + "cell_type": "code", + "execution_count": null, + "id": "200235dc", + "metadata": { + "cellView": "form", + "execution": {} + }, + "outputs": [], + "source": [ + "# @title Video 2 (BONUS): Extended Intro Video\n", + "\n", + "from ipywidgets import widgets\n", + "from IPython.display import YouTubeVideo\n", + "from IPython.display import IFrame\n", + "from IPython.display import display\n", + "\n", + "class PlayVideo(IFrame):\n", + " def __init__(self, id, source, page=1, width=400, height=300, **kwargs):\n", + " self.id = id\n", + " if source == 'Bilibili':\n", + " src = f'https://player.bilibili.com/player.html?bvid={id}&page={page}'\n", + " elif source == 'Osf':\n", + " src = f'https://mfr.ca-1.osf.io/render?url=https://osf.io/download/{id}/?direct%26mode=render'\n", + " super(PlayVideo, self).__init__(src, width, height, **kwargs)\n", + "\n", + "def display_videos(video_ids, W=400, H=300, fs=1):\n", + " tab_contents = []\n", + " for i, video_id in enumerate(video_ids):\n", + " out = widgets.Output()\n", + " with out:\n", + " if video_ids[i][0] == 'Youtube':\n", + " video = YouTubeVideo(id=video_ids[i][1], width=W,\n", + " height=H, fs=fs, rel=0)\n", + " print(f'Video available at https://youtube.com/watch?v={video.id}')\n", + " else:\n", + " video = PlayVideo(id=video_ids[i][1], source=video_ids[i][0], width=W,\n", + " height=H, fs=fs, autoplay=False)\n", + " if video_ids[i][0] == 'Bilibili':\n", + " print(f'Video available at https://www.bilibili.com/video/{video.id}')\n", + " elif video_ids[i][0] == 'Osf':\n", + " print(f'Video available at https://osf.io/{video.id}')\n", + " display(video)\n", + " tab_contents.append(out)\n", + " return tab_contents\n", + "\n", + "#assert 1 == 0, \"Upload this video\"\n", + "video_ids = [('Youtube', 'm9srqTx5ci0'), ('Bilibili', 'BV1meVjz3Eeh')]\n", + "tab_contents = display_videos(video_ids, W=854, H=480)\n", + "tabs = widgets.Tab()\n", + "tabs.children = tab_contents\n", + "for i in range(len(tab_contents)):\n", + " tabs.set_title(i, video_ids[i][0])\n", + "display(tabs)" + ] + }, { "cell_type": "code", "execution_count": null, @@ -1503,6 +1569,20 @@ "3. Cross-validated distance estimators (cross-validated Euclidean or Mahalanobis distance) can remove the positive bias introduced by noise.\n", "4. The Johnson–Lindenstrauss Lemma shows that random projections preserve the Euclidean distance with some distortions. Crucially, the distortion does not depend on the dimensionality of the original space." ] + }, + { + "cell_type": "markdown", + "id": "40936ec4", + "metadata": { + "execution": {} + }, + "source": [ + "# The Big Picture\n", + "\n", + "The goal of this tutorial is to provide you with some mathematical tools for your NeuroAI researcher toolkit. What happens when you pull out the Euclidean metric from your toolkit and, while this has worked well in the past, suddenly in different scenarios it doesn't seem to perform so well. Aha, you spot the potential for correlated noise and you reach deeper into your toolkit and pull out the Mahalanobis metric, which implicitly undoes the correlated noise in the model. Perhaps you can't even tell if there is any correlated noise in your data and you try with both metrics, and Mahalanobis works well but Euclidean does not, that can be a hunch that leads you to confirm the presence of correlated noise. \n", + "\n", + "Sometimes you might be faced with dimensionalities that are just too high to practically deal with in your use case. Then, why not recall what you learned about how random projections can reduce the dimensionality of a feature space and be largely resistant to corrupting the applicability of distance metrics. These metrics also might work better in this lower dimensional space. If you apply this idea and need to justify it, just reach into your NeuroAI toolkit and pull out the Johnson-Lindenstrauss Lemma as your justification." + ] } ], "metadata": { @@ -1533,7 +1613,7 @@ "name": "python", "nbconvert_exporter": "python", "pygments_lexer": "ipython3", - "version": "3.9.19" + "version": "3.9.22" } }, "nbformat": 4, diff --git a/tutorials/W1D3_ComparingArtificialAndBiologicalNetworks/W1D3_Tutorial5.ipynb b/tutorials/W1D3_ComparingArtificialAndBiologicalNetworks/W1D3_Tutorial5.ipynb index 77480d80e..cf185d05d 100644 --- a/tutorials/W1D3_ComparingArtificialAndBiologicalNetworks/W1D3_Tutorial5.ipynb +++ b/tutorials/W1D3_ComparingArtificialAndBiologicalNetworks/W1D3_Tutorial5.ipynb @@ -17,17 +17,17 @@ "execution": {} }, "source": [ - "# Bonus Material: Dynamical similarity analysis (DSA)\n", + "# Tutorial 5: Dynamical Similarity Analysis (DSA)\n", "\n", "**Week 1, Day 3: Comparing Artificial And Biological Networks**\n", "\n", "**By Neuromatch Academy**\n", "\n", - "__Content creators:__ Mitchell Ostrow\n", + "__Content creators:__ Mitchell Ostrow, Alex Murphy\n", "\n", - "__Content reviewers:__ Xaq Pitkow, Hlib Solodzhuk\n", + "__Content reviewers:__ Xaq Pitkow, Hlib Solodzhuk, Alex Murphy\n", "\n", - "__Production editors:__ Konstantine Tsafatinos, Ella Batty, Spiros Chavlis, Samuele Bolotta, Hlib Solodzhuk, Patrick Mineault\n" + "__Production editors:__ Konstantine Tsafatinos, Ella Batty, Spiros Chavlis, Samuele Bolotta, Hlib Solodzhuk, Patrick Mineault, Alex Murphy\n" ] }, { @@ -52,7 +52,7 @@ "source": [ "# @title Install and import feedback gadget\n", "\n", - "!pip install vibecheck --quiet\n", + "!pip install vibecheck rsatoolbox --quiet\n", "\n", "from vibecheck import DatatopsContentReviewContainer\n", "def content_review(notebook_section: str):\n", @@ -67,7 +67,2087 @@ " ).render()\n", "\n", "\n", - "feedback_prefix = \"W1D3_Bonus\"" + "feedback_prefix = \"W1D5_DSA\"" + ] + }, + { + "cell_type": "code", + "execution_count": null, + "id": "ef9abaa3", + "metadata": { + "cellView": "form", + "execution": {} + }, + "outputs": [], + "source": [ + "# @title Helper functions\n", + "\n", + "import numpy as np\n", + "import matplotlib.pyplot as plt\n", + "import seaborn as sns\n", + "\n", + "def generate_2d_random_process(A, B, T=1000):\n", + " \"\"\"\n", + " Generates a 2D random process with the equation x(t+1) = A.x(t) + B.noise.\n", + "\n", + " Args:\n", + " A: 2x2 transition matrix.\n", + " B: 2x2 noise scaling matrix.\n", + " T: Number of time steps.\n", + "\n", + " Returns:\n", + " A NumPy array of shape (T+1, 2) representing the trajectory.\n", + " \"\"\"\n", + " # Assuming equilibrium distribution is zero mean and identity covariance for simplicity.\n", + " # You may adjust this according to your actual equilibrium distribution\n", + " x = np.zeros(2)\n", + "\n", + " trajectory = [x.copy()] # Initialize with x(0)\n", + " for t in range(T):\n", + " noise = np.random.normal(size=2) # Standard normal noise\n", + " x = np.dot(A, x) + np.dot(B, noise)\n", + " trajectory.append(x.copy())\n", + " return np.array(trajectory)\n", + "\n", + "\"\"\"This module computes the Havok DMD model for a given dataset.\"\"\"\n", + "import torch\n", + "\n", + "def embed_signal_torch(data, n_delays, delay_interval=1):\n", + " \"\"\"\n", + " Create a delay embedding from the provided tensor data.\n", + "\n", + " Parameters\n", + " ----------\n", + " data : torch.tensor\n", + " The data from which to create the delay embedding. Must be either: (1) a\n", + " 2-dimensional array/tensor of shape T x N where T is the number\n", + " of time points and N is the number of observed dimensions\n", + " at each time point, or (2) a 3-dimensional array/tensor of shape\n", + " K x T x N where K is the number of \"trials\" and T and N are\n", + " as defined above.\n", + "\n", + " n_delays : int\n", + " Parameter that controls the size of the delay embedding. Explicitly,\n", + " the number of delays to include.\n", + "\n", + " delay_interval : int\n", + " The number of time steps between each delay in the delay embedding. Defaults\n", + " to 1 time step.\n", + " \"\"\"\n", + " if isinstance(data, np.ndarray):\n", + " data = torch.from_numpy(data)\n", + " device = data.device\n", + "\n", + " if data.shape[int(data.ndim==3)] - (n_delays - 1)*delay_interval < 1:\n", + " raise ValueError(\"The number of delays is too large for the number of time points in the data!\")\n", + "\n", + " # initialize the embedding\n", + " if data.ndim == 3:\n", + " embedding = torch.zeros((data.shape[0], data.shape[1] - (n_delays - 1)*delay_interval, data.shape[2]*n_delays)).to(device)\n", + " else:\n", + " embedding = torch.zeros((data.shape[0] - (n_delays - 1)*delay_interval, data.shape[1]*n_delays)).to(device)\n", + "\n", + " for d in range(n_delays):\n", + " index = (n_delays - 1 - d)*delay_interval\n", + " ddelay = d*delay_interval\n", + "\n", + " if data.ndim == 3:\n", + " ddata = d*data.shape[2]\n", + " embedding[:,:, ddata: ddata + data.shape[2]] = data[:,index:data.shape[1] - ddelay]\n", + " else:\n", + " ddata = d*data.shape[1]\n", + " embedding[:, ddata:ddata + data.shape[1]] = data[index:data.shape[0] - ddelay]\n", + "\n", + " return embedding\n", + "\n", + "class DMD:\n", + " \"\"\"DMD class for computing and predicting with DMD models.\n", + " \"\"\"\n", + " def __init__(\n", + " self,\n", + " data,\n", + " n_delays,\n", + " delay_interval=1,\n", + " rank=None,\n", + " rank_thresh=None,\n", + " rank_explained_variance=None,\n", + " reduced_rank_reg=False,\n", + " lamb=0,\n", + " device='cpu',\n", + " verbose=False,\n", + " send_to_cpu=False,\n", + " steps_ahead=1\n", + " ):\n", + " \"\"\"\n", + " Parameters\n", + " ----------\n", + " data : np.ndarray or torch.tensor\n", + " The data to fit the DMD model to. Must be either: (1) a\n", + " 2-dimensional array/tensor of shape T x N where T is the number\n", + " of time points and N is the number of observed dimensions\n", + " at each time point, or (2) a 3-dimensional array/tensor of shape\n", + " K x T x N where K is the number of \"trials\" and T and N are\n", + " as defined above.\n", + "\n", + " n_delays : int\n", + " Parameter that controls the size of the delay embedding. Explicitly,\n", + " the number of delays to include.\n", + "\n", + " delay_interval : int\n", + " The number of time steps between each delay in the delay embedding. Defaults\n", + " to 1 time step.\n", + "\n", + " rank : int\n", + " The rank of V in fitting HAVOK DMD - i.e., the number of columns of V to\n", + " use to fit the DMD model. Defaults to None, in which case all columns of V\n", + " will be used.\n", + "\n", + " rank_thresh : float\n", + " Parameter that controls the rank of V in fitting HAVOK DMD by dictating a threshold\n", + " of singular values to use. Explicitly, the rank of V will be the number of singular\n", + " values greater than rank_thresh. Defaults to None.\n", + "\n", + " rank_explained_variance : float\n", + " Parameter that controls the rank of V in fitting HAVOK DMD by indicating the percentage of\n", + " cumulative explained variance that should be explained by the columns of V. Defaults to None.\n", + "\n", + " reduced_rank_reg : bool\n", + " Determines whether to use reduced rank regression (True) or principal component regression (False)\n", + "\n", + " lamb : float\n", + " Regularization parameter for ridge regression. Defaults to 0.\n", + "\n", + " device: string, int, or torch.device\n", + " A string, int or torch.device object to indicate the device to torch.\n", + "\n", + " verbose: bool\n", + " If True, print statements will be provided about the progress of the fitting procedure.\n", + "\n", + " send_to_cpu: bool\n", + " If True, will send all tensors in the object back to the cpu after everything is computed.\n", + " This is implemented to prevent gpu memory overload when computing multiple DMDs.\n", + "\n", + " steps_ahead: int\n", + " The number of time steps ahead to predict. Defaults to 1.\n", + " \"\"\"\n", + "\n", + " self.device = device\n", + " self._init_data(data)\n", + "\n", + " self.n_delays = n_delays\n", + " self.delay_interval = delay_interval\n", + " self.rank = rank\n", + " self.rank_thresh = rank_thresh\n", + " self.rank_explained_variance = rank_explained_variance\n", + " self.reduced_rank_reg = reduced_rank_reg\n", + " self.lamb = lamb\n", + " self.verbose = verbose\n", + " self.send_to_cpu = send_to_cpu\n", + " self.steps_ahead = steps_ahead\n", + "\n", + " # Hankel matrix\n", + " self.H = None\n", + "\n", + " # SVD attributes\n", + " self.U = None\n", + " self.S = None\n", + " self.V = None\n", + " self.S_mat = None\n", + " self.S_mat_inv = None\n", + "\n", + " # DMD attributes\n", + " self.A_v = None\n", + " self.A_havok_dmd = None\n", + "\n", + " def _init_data(self, data):\n", + " # check if the data is an np.ndarry - if so, convert it to Torch\n", + " if isinstance(data, np.ndarray):\n", + " data = torch.from_numpy(data)\n", + " self.data = data\n", + " # create attributes for the data dimensions\n", + " if self.data.ndim == 3:\n", + " self.ntrials = self.data.shape[0]\n", + " self.window = self.data.shape[1]\n", + " self.n = self.data.shape[2]\n", + " else:\n", + " self.window = self.data.shape[0]\n", + " self.n = self.data.shape[1]\n", + " self.ntrials = 1\n", + "\n", + " def compute_hankel(\n", + " self,\n", + " data=None,\n", + " n_delays=None,\n", + " delay_interval=None,\n", + " ):\n", + " \"\"\"\n", + " Computes the Hankel matrix from the provided data.\n", + "\n", + " Parameters\n", + " ----------\n", + " data : np.ndarray or torch.tensor\n", + " The data to fit the DMD model to. Must be either: (1) a\n", + " 2-dimensional array/tensor of shape T x N where T is the number\n", + " of time points and N is the number of observed dimensions\n", + " at each time point, or (2) a 3-dimensional array/tensor of shape\n", + " K x T x N where K is the number of \"trials\" and T and N are\n", + " as defined above.\n", + "\n", + " n_delays : int\n", + " Parameter that controls the size of the delay embedding. Explicitly,\n", + " the number of delays to include. Defaults to None - provide only if you want\n", + " to override the value of n_delays from the init.\n", + "\n", + " delay_interval : int\n", + " The number of time steps between each delay in the delay embedding. Defaults\n", + " to 1 time step. Defaults to None - provide only if you want\n", + " to override the value of n_delays from the init.\n", + " \"\"\"\n", + " if self.verbose:\n", + " print(\"Computing Hankel matrix ...\")\n", + "\n", + " # if parameters are provided, overwrite them from the init\n", + " self.data = self.data if data is None else self._init_data(data)\n", + " self.n_delays = self.n_delays if n_delays is None else n_delays\n", + " self.delay_interval = self.delay_interval if delay_interval is None else delay_interval\n", + " self.data = self.data.to(self.device)\n", + "\n", + " self.H = embed_signal_torch(self.data, self.n_delays, self.delay_interval)\n", + "\n", + " if self.verbose:\n", + " print(\"Hankel matrix computed!\")\n", + "\n", + " def compute_svd(self):\n", + " \"\"\"\n", + " Computes the SVD of the Hankel matrix.\n", + " \"\"\"\n", + "\n", + " if self.verbose:\n", + " print(\"Computing SVD on Hankel matrix ...\")\n", + " if self.H.ndim == 3: #flatten across trials for 3d\n", + " H = self.H.reshape(self.H.shape[0] * self.H.shape[1], self.H.shape[2])\n", + " else:\n", + " H = self.H\n", + " # compute the SVD\n", + " U, S, Vh = torch.linalg.svd(H.T, full_matrices=False)\n", + "\n", + " # update attributes\n", + " V = Vh.T\n", + " self.U = U\n", + " self.S = S\n", + " self.V = V\n", + "\n", + " # construct the singuar value matrix and its inverse\n", + " # dim = self.n_delays * self.n\n", + " # s = len(S)\n", + " # self.S_mat = torch.zeros(dim, dim,dtype=torch.float32).to(self.device)\n", + " # self.S_mat_inv = torch.zeros(dim, dim,dtype=torch.float32).to(self.device)\n", + " self.S_mat = torch.diag(S).to(self.device)\n", + " self.S_mat_inv= torch.diag(1 / S).to(self.device)\n", + "\n", + " # compute explained variance\n", + " exp_variance_inds = self.S**2 / ((self.S**2).sum())\n", + " cumulative_explained = torch.cumsum(exp_variance_inds, 0)\n", + " self.cumulative_explained_variance = cumulative_explained\n", + "\n", + " #make the X and Y components of the regression by staggering the hankel eigen-time delay coordinates by time\n", + " if self.reduced_rank_reg:\n", + " V = self.V\n", + " else:\n", + " V = self.V\n", + "\n", + " if self.ntrials > 1:\n", + " if V.numel() < self.H.numel():\n", + " raise ValueError(\"The dimension of the SVD of the Hankel matrix is smaller than the dimension of the Hankel matrix itself. \\n \\\n", + " This is likely due to the number of time points being smaller than the number of dimensions. \\n \\\n", + " Please reduce the number of delays.\")\n", + "\n", + " V = V.reshape(self.H.shape)\n", + "\n", + " #first reshape back into Hankel shape, separated by trials\n", + " newshape = (self.H.shape[0]*(self.H.shape[1]-self.steps_ahead),self.H.shape[2])\n", + " self.Vt_minus = V[:,:-self.steps_ahead].reshape(newshape)\n", + " self.Vt_plus = V[:,self.steps_ahead:].reshape(newshape)\n", + " else:\n", + " self.Vt_minus = V[:-self.steps_ahead]\n", + " self.Vt_plus = V[self.steps_ahead:]\n", + "\n", + "\n", + " if self.verbose:\n", + " print(\"SVD complete!\")\n", + "\n", + " def recalc_rank(self,rank,rank_thresh,rank_explained_variance):\n", + " '''\n", + " Parameters\n", + " ----------\n", + " rank : int\n", + " The rank of V in fitting HAVOK DMD - i.e., the number of columns of V to\n", + " use to fit the DMD model. Defaults to None, in which case all columns of V\n", + " will be used. Provide only if you want to override the value from the init.\n", + "\n", + " rank_thresh : float\n", + " Parameter that controls the rank of V in fitting HAVOK DMD by dictating a threshold\n", + " of singular values to use. Explicitly, the rank of V will be the number of singular\n", + " values greater than rank_thresh. Defaults to None - provide only if you want\n", + " to override the value from the init.\n", + "\n", + " rank_explained_variance : float\n", + " Parameter that controls the rank of V in fitting HAVOK DMD by indicating the percentage of\n", + " cumulative explained variance that should be explained by the columns of V. Defaults to None -\n", + " provide only if you want to overried the value from the init.\n", + " '''\n", + " # if an argument was provided, overwrite the stored rank information\n", + " none_vars = (rank is None) + (rank_thresh is None) + (rank_explained_variance is None)\n", + " if none_vars != 3:\n", + " self.rank = None\n", + " self.rank_thresh = None\n", + " self.rank_explained_variance = None\n", + "\n", + " self.rank = self.rank if rank is None else rank\n", + " self.rank_thresh = self.rank_thresh if rank_thresh is None else rank_thresh\n", + " self.rank_explained_variance = self.rank_explained_variance if rank_explained_variance is None else rank_explained_variance\n", + "\n", + " none_vars = (self.rank is None) + (self.rank_thresh is None) + (self.rank_explained_variance is None)\n", + " if none_vars < 2:\n", + " raise ValueError(\"More than one value was provided between rank, rank_thresh, and rank_explained_variance. Please provide only one of these, and ensure the others are None!\")\n", + " elif none_vars == 3:\n", + " self.rank = len(self.S)\n", + "\n", + " if self.reduced_rank_reg:\n", + " S = self.proj_mat_S\n", + " else:\n", + " S = self.S\n", + "\n", + " if rank_thresh is not None:\n", + " if S[-1] > rank_thresh:\n", + " self.rank = len(S)\n", + " else:\n", + " self.rank = torch.argmax(torch.arange(len(S), 0, -1).to(self.device)*(S < rank_thresh))\n", + "\n", + " if rank_explained_variance is not None:\n", + " self.rank = int(torch.argmax((self.cumulative_explained_variance > rank_explained_variance).type(torch.int)).cpu().numpy())\n", + "\n", + " if self.rank > self.H.shape[-1]:\n", + " self.rank = self.H.shape[-1]\n", + "\n", + " if self.rank is None:\n", + " if S[-1] > self.rank_thresh:\n", + " self.rank = len(S)\n", + " else:\n", + " self.rank = torch.argmax(torch.arange(len(S), 0, -1).to(self.device)*(S < self.rank_thresh))\n", + "\n", + " def compute_havok_dmd(self,lamb=None):\n", + " \"\"\"\n", + " Computes the Havok DMD matrix (Principal Component Regression)\n", + "\n", + " Parameters\n", + " ----------\n", + " lamb : float\n", + " Regularization parameter for ridge regression. Defaults to 0 - provide only if you want\n", + " to override the value of n_delays from the init.\n", + "\n", + " \"\"\"\n", + " if self.verbose:\n", + " print(\"Computing least squares fits to HAVOK DMD ...\")\n", + "\n", + " self.lamb = self.lamb if lamb is None else lamb\n", + "\n", + " A_v = (torch.linalg.inv(self.Vt_minus[:, :self.rank].T @ self.Vt_minus[:, :self.rank] + self.lamb*torch.eye(self.rank).to(self.device)) \\\n", + " @ self.Vt_minus[:, :self.rank].T @ self.Vt_plus[:, :self.rank]).T\n", + " self.A_v = A_v\n", + " self.A_havok_dmd = self.U @ self.S_mat[:self.U.shape[1], :self.rank] @ self.A_v @ self.S_mat_inv[:self.rank, :self.U.shape[1]] @ self.U.T\n", + "\n", + " if self.verbose:\n", + " print(\"Least squares complete! \\n\")\n", + "\n", + " def compute_proj_mat(self,lamb=None):\n", + " if self.verbose:\n", + " print(\"Computing Projector Matrix for Reduced Rank Regression\")\n", + "\n", + " self.lamb = self.lamb if lamb is None else lamb\n", + "\n", + " self.proj_mat = self.Vt_plus.T @ self.Vt_minus @ torch.linalg.inv(self.Vt_minus.T @ self.Vt_minus +\n", + " self.lamb*torch.eye(self.Vt_minus.shape[1]).to(self.device)) @ \\\n", + " self.Vt_minus.T @ self.Vt_plus\n", + "\n", + " self.proj_mat_S, self.proj_mat_V = torch.linalg.eigh(self.proj_mat)\n", + " #todo: more efficient to flip ranks (negative index) in compute_reduced_rank_regression but also less interpretable\n", + " self.proj_mat_S = torch.flip(self.proj_mat_S, dims=(0,))\n", + " self.proj_mat_V = torch.flip(self.proj_mat_V, dims=(1,))\n", + "\n", + " if self.verbose:\n", + " print(\"Projector Matrix computed! \\n\")\n", + "\n", + " def compute_reduced_rank_regression(self,lamb=None):\n", + " if self.verbose:\n", + " print(\"Computing Reduced Rank Regression ...\")\n", + "\n", + " self.lamb = self.lamb if lamb is None else lamb\n", + " proj_mat = self.proj_mat_V[:,:self.rank] @ self.proj_mat_V[:,:self.rank].T\n", + " B_ols = torch.linalg.inv(self.Vt_minus.T @ self.Vt_minus + self.lamb*torch.eye(self.Vt_minus.shape[1]).to(self.device)) @ self.Vt_minus.T @ self.Vt_plus\n", + "\n", + " self.A_v = B_ols @ proj_mat\n", + " self.A_havok_dmd = self.U @ self.S_mat[:self.U.shape[1],:self.A_v.shape[1]] @ self.A_v.T @ self.S_mat_inv[:self.A_v.shape[0], :self.U.shape[1]] @ self.U.T\n", + "\n", + "\n", + " if self.verbose:\n", + " print(\"Reduced Rank Regression complete! \\n\")\n", + "\n", + " def fit(\n", + " self,\n", + " data=None,\n", + " n_delays=None,\n", + " delay_interval=None,\n", + " rank=None,\n", + " rank_thresh=None,\n", + " rank_explained_variance=None,\n", + " lamb=None,\n", + " device=None,\n", + " verbose=None,\n", + " steps_ahead=None\n", + " ):\n", + " \"\"\"\n", + " Parameters\n", + " ----------\n", + " data : np.ndarray or torch.tensor\n", + " The data to fit the DMD model to. Must be either: (1) a\n", + " 2-dimensional array/tensor of shape T x N where T is the number\n", + " of time points and N is the number of observed dimensions\n", + " at each time point, or (2) a 3-dimensional array/tensor of shape\n", + " K x T x N where K is the number of \"trials\" and T and N are\n", + " as defined above. Defaults to None - provide only if you want to\n", + " override the value from the init.\n", + "\n", + " n_delays : int\n", + " Parameter that controls the size of the delay embedding. Explicitly,\n", + " the number of delays to include. Defaults to None - provide only if you want to\n", + " override the value from the init.\n", + "\n", + " delay_interval : int\n", + " The number of time steps between each delay in the delay embedding. Defaults to None -\n", + " provide only if you want to override the value from the init.\n", + "\n", + " rank : int\n", + " The rank of V in fitting HAVOK DMD - i.e., the number of columns of V to\n", + " use to fit the DMD model. Defaults to None, in which case all columns of V\n", + " will be used - provide only if you want to\n", + " override the value from the init.\n", + "\n", + " rank_thresh : int\n", + " Parameter that controls the rank of V in fitting HAVOK DMD by dictating a threshold\n", + " of singular values to use. Explicitly, the rank of V will be the number of singular\n", + " values greater than rank_thresh. Defaults to None - provide only if you want to\n", + " override the value from the init.\n", + "\n", + " rank_explained_variance : float\n", + " Parameter that controls the rank of V in fitting HAVOK DMD by indicating the percentage of\n", + " cumulative explained variance that should be explained by the columns of V. Defaults to None -\n", + " provide only if you want to overried the value from the init.\n", + "\n", + " lamb : float\n", + " Regularization parameter for ridge regression. Defaults to None - provide only if you want to\n", + " override the value from the init.\n", + "\n", + " device: string or int\n", + " A string or int to indicate the device to torch. For example, can be 'cpu' or 'cuda',\n", + " or alternatively 0 if the intenion is to use GPU device 0. Defaults to None - provide only\n", + " if you want to override the value from the init.\n", + "\n", + " verbose: bool\n", + " If True, print statements will be provided about the progress of the fitting procedure.\n", + " Defaults to None - provide only if you want to override the value from the init.\n", + "\n", + " steps_ahead: int\n", + " The number of time steps ahead to predict. Defaults to 1.\n", + "\n", + " \"\"\"\n", + " # if parameters are provided, overwrite them from the init\n", + " self.steps_ahead = self.steps_ahead if steps_ahead is None else steps_ahead\n", + " self.device = self.device if device is None else device\n", + " self.verbose = self.verbose if verbose is None else verbose\n", + "\n", + " self.compute_hankel(data, n_delays, delay_interval)\n", + " self.compute_svd()\n", + "\n", + " if self.reduced_rank_reg:\n", + " self.compute_proj_mat(lamb)\n", + " self.recalc_rank(rank,rank_thresh,rank_explained_variance)\n", + " self.compute_reduced_rank_regression(lamb)\n", + " else:\n", + " self.recalc_rank(rank,rank_thresh,rank_explained_variance)\n", + " self.compute_havok_dmd(lamb)\n", + "\n", + " if self.send_to_cpu:\n", + " self.all_to_device('cpu') #send back to the cpu to save memory\n", + "\n", + " def predict(\n", + " self,\n", + " test_data=None,\n", + " reseed=None,\n", + " full_return=False\n", + " ):\n", + " \"\"\"\n", + " Returns\n", + " -------\n", + " pred_data : torch.tensor\n", + " The predictions generated by the HAVOK model. Of the same shape as test_data. Note that the first\n", + " (self.n_delays - 1)*self.delay_interval + 1 time steps of the generated predictions are by construction\n", + " identical to the test_data.\n", + "\n", + " H_test_havok_dmd : torch.tensor (Optional)\n", + " Returned if full_return=True. The predicted Hankel matrix generated by the HAVOK model.\n", + " H_test : torch.tensor (Optional)\n", + " Returned if full_return=True. The true Hankel matrix\n", + " \"\"\"\n", + " # initialize test_data\n", + " if test_data is None:\n", + " test_data = self.data\n", + " if isinstance(test_data, np.ndarray):\n", + " test_data = torch.from_numpy(test_data).to(self.device)\n", + " ndim = test_data.ndim\n", + " if ndim == 2:\n", + " test_data = test_data.unsqueeze(0)\n", + " H_test = embed_signal_torch(test_data, self.n_delays, self.delay_interval)\n", + " steps_ahead = self.steps_ahead if self.steps_ahead is not None else 1\n", + "\n", + " if reseed is None:\n", + " reseed = 1\n", + "\n", + " H_test_havok_dmd = torch.zeros(H_test.shape).to(self.device)\n", + " H_test_havok_dmd[:, :steps_ahead] = H_test[:, :steps_ahead]\n", + "\n", + " A = self.A_havok_dmd.unsqueeze(0)\n", + " for t in range(steps_ahead, H_test.shape[1]):\n", + " if t % reseed == 0:\n", + " H_test_havok_dmd[:, t] = (A @ H_test[:, t - steps_ahead].transpose(-2, -1)).transpose(-2, -1)\n", + " else:\n", + " H_test_havok_dmd[:, t] = (A @ H_test_havok_dmd[:, t - steps_ahead].transpose(-2, -1)).transpose(-2, -1)\n", + " pred_data = torch.hstack([test_data[:, :(self.n_delays - 1)*self.delay_interval + steps_ahead], H_test_havok_dmd[:, steps_ahead:, :self.n]])\n", + "\n", + " if ndim == 2:\n", + " pred_data = pred_data[0]\n", + "\n", + " if full_return:\n", + " return pred_data, H_test_havok_dmd, H_test\n", + " else:\n", + " return pred_data\n", + "\n", + " def all_to_device(self,device='cpu'):\n", + " for k,v in self.__dict__.items():\n", + " if isinstance(v, torch.Tensor):\n", + " self.__dict__[k] = v.to(device)\n", + "\n", + "from typing import Literal\n", + "import torch.nn as nn\n", + "import torch.optim as optim\n", + "from typing import Literal\n", + "import torch.nn.utils.parametrize as parametrize\n", + "from scipy.stats import wasserstein_distance\n", + "\n", + "def pad_zeros(A,B,device):\n", + "\n", + " with torch.no_grad():\n", + " dim = max(A.shape[0],B.shape[0])\n", + " A1 = torch.zeros((dim,dim)).float()\n", + " A1[:A.shape[0],:A.shape[1]] += A\n", + " A = A1.float().to(device)\n", + "\n", + " B1 = torch.zeros((dim,dim)).float()\n", + " B1[:B.shape[0],:B.shape[1]] += B\n", + " B = B1.float().to(device)\n", + "\n", + " return A,B\n", + "\n", + "class LearnableSimilarityTransform(nn.Module):\n", + " \"\"\"\n", + " Computes the similarity transform for a learnable orthonormal matrix C\n", + " \"\"\"\n", + " def __init__(self, n,orthog=True):\n", + " \"\"\"\n", + " Parameters\n", + " __________\n", + " n : int\n", + " dimension of the C matrix\n", + " \"\"\"\n", + " super(LearnableSimilarityTransform, self).__init__()\n", + " #initialize orthogonal matrix as identity\n", + " self.C = nn.Parameter(torch.eye(n).float())\n", + " self.orthog = orthog\n", + "\n", + " def forward(self, B):\n", + " if self.orthog:\n", + " return self.C @ B @ self.C.transpose(-1, -2)\n", + " else:\n", + " return self.C @ B @ torch.linalg.inv(self.C)\n", + "\n", + "class Skew(nn.Module):\n", + " def __init__(self,n,device):\n", + " \"\"\"\n", + " Computes a skew-symmetric matrix X from some parameters (also called X)\n", + "\n", + " \"\"\"\n", + " super().__init__()\n", + "\n", + " self.L1 = nn.Linear(n,n,bias = False, device = device)\n", + " self.L2 = nn.Linear(n,n,bias = False, device = device)\n", + " self.L3 = nn.Linear(n,n,bias = False, device = device)\n", + "\n", + " def forward(self, X):\n", + " X = torch.tanh(self.L1(X))\n", + " X = torch.tanh(self.L2(X))\n", + " X = self.L3(X)\n", + " return X - X.transpose(-1, -2)\n", + "\n", + "class Matrix(nn.Module):\n", + " def __init__(self,n,device):\n", + " \"\"\"\n", + " Computes a matrix X from some parameters (also called X)\n", + "\n", + " \"\"\"\n", + " super().__init__()\n", + "\n", + " self.L1 = nn.Linear(n,n,bias = False, device = device)\n", + " self.L2 = nn.Linear(n,n,bias = False, device = device)\n", + " self.L3 = nn.Linear(n,n,bias = False, device = device)\n", + "\n", + " def forward(self, X):\n", + " X = torch.tanh(self.L1(X))\n", + " X = torch.tanh(self.L2(X))\n", + " X = self.L3(X)\n", + " return X\n", + "\n", + "class CayleyMap(nn.Module):\n", + " \"\"\"\n", + " Maps a skew-symmetric matrix to an orthogonal matrix in O(n)\n", + " \"\"\"\n", + " def __init__(self, n, device):\n", + " \"\"\"\n", + " Parameters\n", + " __________\n", + "\n", + " n : int\n", + " dimension of the matrix we want to map\n", + "\n", + " device : {'cpu','cuda'} or int\n", + " hardware device on which to send the matrix\n", + " \"\"\"\n", + " super().__init__()\n", + " self.register_buffer(\"Id\", torch.eye(n,device = device))\n", + "\n", + " def forward(self, X):\n", + " # (I + X)(I - X)^{-1}\n", + " return torch.linalg.solve(self.Id + X, self.Id - X)\n", + "\n", + "class SimilarityTransformDist:\n", + " \"\"\"\n", + " Computes the Procrustes Analysis over Vector Fields\n", + " \"\"\"\n", + " def __init__(self,\n", + " iters = 200,\n", + " score_method: Literal[\"angular\", \"euclidean\",\"wasserstein\"] = \"angular\",\n", + " lr = 0.01,\n", + " device: Literal[\"cpu\",\"cuda\"] = 'cpu',\n", + " verbose = False,\n", + " group: Literal[\"O(n)\",\"SO(n)\",\"GL(n)\"] = \"O(n)\",\n", + " wasserstein_compare = None\n", + " ):\n", + " \"\"\"\n", + " Parameters\n", + " _________\n", + " iters : int\n", + " number of iterations to perform gradient descent\n", + "\n", + " score_method : {\"angular\",\"euclidean\",\"wasserstein\"}\n", + " specifies the type of metric to use\n", + " \"wasserstein\" will compare the singular values or eigenvalues\n", + " of the two matrices as in Redman et al., (2023)\n", + "\n", + " lr : float\n", + " learning rate\n", + "\n", + " device : {'cpu','cuda'} or int\n", + "\n", + " verbose : bool\n", + " prints when finished optimizing\n", + "\n", + " group : {'SO(n)','O(n)', 'GL(n)'}\n", + " specifies the group of matrices to optimize over\n", + "\n", + " wasserstein_compare : {'sv','eig',None}\n", + " specifies whether to compare the singular values or eigenvalues\n", + " if score_method is \"wasserstein\", or the shapes are different\n", + " \"\"\"\n", + "\n", + " self.iters = iters\n", + " self.score_method = score_method\n", + " self.lr = lr\n", + " self.verbose = verbose\n", + " self.device = device\n", + " self.C_star = None\n", + " self.A = None\n", + " self.B = None\n", + " self.group = group\n", + " self.wasserstein_compare = wasserstein_compare\n", + "\n", + " def fit(self,\n", + " A,\n", + " B,\n", + " iters = None,\n", + " lr = None,\n", + " group = None,\n", + " ):\n", + " \"\"\"\n", + " Computes the optimal matrix C over specified group\n", + "\n", + " Parameters\n", + " __________\n", + " A : np.array or torch.tensor\n", + " first data matrix\n", + " B : np.array or torch.tensor\n", + " second data matrix\n", + " iters : int or None\n", + " number of optimization steps, if None then resorts to saved self.iters\n", + " lr : float or None\n", + " learning rate, if None then resorts to saved self.lr\n", + " group : {'SO(n)','O(n)', 'GL(n)'}\n", + " specifies the group of matrices to optimize over\n", + "\n", + " Returns\n", + " _______\n", + " None\n", + " \"\"\"\n", + " assert A.shape[0] == A.shape[1]\n", + " assert B.shape[0] == B.shape[1]\n", + "\n", + " A = A.to(self.device)\n", + " B = B.to(self.device)\n", + " self.A,self.B = A,B\n", + " lr = self.lr if lr is None else lr\n", + " iters = self.iters if iters is None else iters\n", + " group = self.group if group is None else group\n", + "\n", + " if group in {\"SO(n)\", \"O(n)\"}:\n", + " self.losses, self.C_star, self.sim_net = self.optimize_C(A,\n", + " B,\n", + " lr,iters,\n", + " orthog=True,\n", + " verbose=self.verbose)\n", + " if group == \"O(n)\":\n", + " #permute the first row and column of B then rerun the optimization\n", + " P = torch.eye(B.shape[0],device=self.device)\n", + " if P.shape[0] > 1:\n", + " P[[0, 1], :] = P[[1, 0], :]\n", + " losses, C_star, sim_net = self.optimize_C(A,\n", + " P @ B @ P.T,\n", + " lr,iters,\n", + " orthog=True,\n", + " verbose=self.verbose)\n", + " if losses[-1] < self.losses[-1]:\n", + " self.losses = losses\n", + " self.C_star = C_star @ P\n", + " self.sim_net = sim_net\n", + " if group == \"GL(n)\":\n", + " self.losses, self.C_star, self.sim_net = self.optimize_C(A,\n", + " B,\n", + " lr,iters,\n", + " orthog=False,\n", + " verbose=self.verbose)\n", + "\n", + " def optimize_C(self,A,B,lr,iters,orthog,verbose):\n", + " #parameterize mapping to be orthogonal\n", + " n = A.shape[0]\n", + " sim_net = LearnableSimilarityTransform(n,orthog=orthog).to(self.device)\n", + " if orthog:\n", + " parametrize.register_parametrization(sim_net, \"C\", Skew(n,self.device))\n", + " parametrize.register_parametrization(sim_net, \"C\", CayleyMap(n,self.device))\n", + " else:\n", + " parametrize.register_parametrization(sim_net, \"C\", Matrix(n,self.device))\n", + "\n", + " simdist_loss = nn.MSELoss(reduction = 'sum')\n", + "\n", + " optimizer = optim.Adam(sim_net.parameters(), lr=lr)\n", + " # scheduler = optim.lr_scheduler.ExponentialLR(optimizer, gamma=0.999)\n", + "\n", + " losses = []\n", + " A /= torch.linalg.norm(A)\n", + " B /= torch.linalg.norm(B)\n", + " for _ in range(iters):\n", + " # Zero the gradients of the optimizer.\n", + " optimizer.zero_grad()\n", + " # Compute the Frobenius norm between A and the product.\n", + " loss = simdist_loss(A, sim_net(B))\n", + "\n", + " loss.backward()\n", + "\n", + " optimizer.step()\n", + " # if _ % 99:\n", + " # scheduler.step()\n", + " losses.append(loss.item())\n", + "\n", + " if verbose:\n", + " print(\"Finished optimizing C\")\n", + "\n", + " C_star = sim_net.C.detach()\n", + " return losses, C_star,sim_net\n", + "\n", + " def score(self,A=None,B=None,score_method=None,group=None):\n", + " \"\"\"\n", + " Given an optimal C already computed, calculate the metric\n", + "\n", + " Parameters\n", + " __________\n", + " A : np.array or torch.tensor or None\n", + " first data matrix, if None defaults to the saved matrix in fit\n", + " B : np.array or torch.tensor or None\n", + " second data matrix if None, defaults to the savec matrix in fit\n", + " score_method : None or {'angular','euclidean'}\n", + " overwrites the score method in the object for this application\n", + " Returns\n", + " _______\n", + "\n", + " score : float\n", + " similarity of the data under the similarity transform w.r.t C\n", + " \"\"\"\n", + " assert self.C_star is not None\n", + " A = self.A if A is None else A\n", + " B = self.B if B is None else B\n", + " assert A is not None\n", + " assert B is not None\n", + " assert A.shape == self.C_star.shape\n", + " assert B.shape == self.C_star.shape\n", + " score_method = self.score_method if score_method is None else score_method\n", + " group = self.group if group is None else group\n", + " with torch.no_grad():\n", + " if not isinstance(A,torch.Tensor):\n", + " A = torch.from_numpy(A).float().to(self.device)\n", + " if not isinstance(B,torch.Tensor):\n", + " B = torch.from_numpy(B).float().to(self.device)\n", + " C = self.C_star.to(self.device)\n", + "\n", + " if group in {\"SO(n)\", \"O(n)\"}:\n", + " Cinv = C.T\n", + " elif group in {\"GL(n)\"}:\n", + " Cinv = torch.linalg.inv(C)\n", + " else:\n", + " raise AssertionError(\"Need proper group name\")\n", + " if score_method == 'angular':\n", + " num = torch.trace(A.T @ C @ B @ Cinv)\n", + " den = torch.norm(A,p = 'fro')*torch.norm(B,p = 'fro')\n", + " score = torch.arccos(num/den).cpu().numpy()\n", + " if np.isnan(score): #around -1 and 1, we sometimes get NaNs due to arccos\n", + " if num/den < 0:\n", + " score = np.pi\n", + " else:\n", + " score = 0\n", + " else:\n", + " score = torch.norm(A - C @ B @ Cinv,p='fro').cpu().numpy().item() #/ A.numpy().size\n", + "\n", + " return score\n", + "\n", + " def fit_score(self,\n", + " A,\n", + " B,\n", + " iters = None,\n", + " lr = None,\n", + " score_method = None,\n", + " zero_pad = True,\n", + " group = None):\n", + " \"\"\"\n", + " for efficiency, computes the optimal matrix and returns the score\n", + "\n", + " Parameters\n", + " __________\n", + " A : np.array or torch.tensor\n", + " first data matrix\n", + " B : np.array or torch.tensor\n", + " second data matrix\n", + " iters : int or None\n", + " number of optimization steps, if None then resorts to saved self.iters\n", + " lr : float or None\n", + " learning rate, if None then resorts to saved self.lr\n", + " score_method : {'angular','euclidean'} or None\n", + " overwrites parameter in the class\n", + " zero_pad : bool\n", + " if True, then the smaller matrix will be zero padded so its the same size\n", + " Returns\n", + " _______\n", + "\n", + " score : float\n", + " similarity of the data under the similarity transform w.r.t C\n", + "\n", + " \"\"\"\n", + " score_method = self.score_method if score_method is None else score_method\n", + " group = self.group if group is None else group\n", + "\n", + " if isinstance(A,np.ndarray):\n", + " A = torch.from_numpy(A).float()\n", + " if isinstance(B,np.ndarray):\n", + " B = torch.from_numpy(B).float()\n", + "\n", + " assert A.shape[0] == B.shape[1] or self.wasserstein_compare is not None\n", + " if A.shape[0] != B.shape[0]:\n", + " if self.wasserstein_compare is None:\n", + " raise AssertionError(\"Matrices must be the same size unless using wasserstein distance\")\n", + " else: #otherwise resort to L2 Wasserstein over singular or eigenvalues\n", + " print(f\"resorting to wasserstein distance over {self.wasserstein_compare}\")\n", + "\n", + " if self.score_method == \"wasserstein\":\n", + " assert self.wasserstein_compare in {\"sv\",\"eig\"}\n", + " if self.wasserstein_compare == \"sv\":\n", + " a = torch.svd(A).S.view(-1,1)\n", + " b = torch.svd(B).S.view(-1,1)\n", + " elif self.wasserstein_compare == \"eig\":\n", + " a = torch.linalg.eig(A).eigenvalues\n", + " a = torch.vstack([a.real,a.imag]).T\n", + "\n", + " b = torch.linalg.eig(B).eigenvalues\n", + " b = torch.vstack([b.real,b.imag]).T\n", + " else:\n", + " raise AssertionError(\"wasserstein_compare must be 'sv' or 'eig'\")\n", + " device = a.device\n", + " a = a#.cpu()\n", + " b = b#.cpu()\n", + " M = ot.dist(a,b)#.numpy()\n", + " a,b = torch.ones(a.shape[0])/a.shape[0],torch.ones(b.shape[0])/b.shape[0]\n", + " a,b = a.to(device),b.to(device)\n", + "\n", + " score_star = ot.emd2(a,b,M)\n", + " #wasserstein_distance(A.cpu().numpy(),B.cpu().numpy())\n", + "\n", + " else:\n", + "\n", + " self.fit(A, B,iters,lr,group)\n", + " score_star = self.score(self.A,self.B,score_method=score_method,group=group)\n", + "\n", + " return score_star\n", + "\n", + "class DSA:\n", + " \"\"\"\n", + " Computes the Dynamical Similarity Analysis (DSA) for two data matrices\n", + " \"\"\"\n", + " def __init__(self,\n", + " X,\n", + " Y=None,\n", + " n_delays=1,\n", + " delay_interval=1,\n", + " rank=None,\n", + " rank_thresh=None,\n", + " rank_explained_variance = None,\n", + " lamb = 0.0,\n", + " send_to_cpu = True,\n", + " iters = 1500,\n", + " score_method: Literal[\"angular\", \"euclidean\",\"wasserstein\"] = \"angular\",\n", + " lr = 5e-3,\n", + " group: Literal[\"GL(n)\", \"O(n)\", \"SO(n)\"] = \"O(n)\",\n", + " zero_pad = False,\n", + " device = 'cpu',\n", + " verbose = False,\n", + " reduced_rank_reg = False,\n", + " kernel=None,\n", + " num_centers=0.1,\n", + " svd_solver='arnoldi',\n", + " wasserstein_compare: Literal['sv','eig',None] = None\n", + " ):\n", + " \"\"\"\n", + " Parameters\n", + " __________\n", + "\n", + " X : np.array or torch.tensor or list of np.arrays or torch.tensors\n", + " first data matrix/matrices\n", + "\n", + " Y : None or np.array or torch.tensor or list of np.arrays or torch.tensors\n", + " second data matrix/matrices.\n", + " * If Y is None, X is compared to itself pairwise\n", + " (must be a list)\n", + " * If Y is a single matrix, all matrices in X are compared to Y\n", + " * If Y is a list, all matrices in X are compared to all matrices in Y\n", + "\n", + " DMD parameters:\n", + "\n", + " n_delays : int or list or tuple/list: (int,int), (list,list),(list,int),(int,list)\n", + " number of delays to use in constructing the Hankel matrix\n", + "\n", + " delay_interval : int or list or tuple/list: (int,int), (list,list),(list,int),(int,list)\n", + " interval between samples taken in constructing Hankel matrix\n", + "\n", + " rank : int or list or tuple/list: (int,int), (list,list),(list,int),(int,list)\n", + " rank of DMD matrix fit in reduced-rank regression\n", + "\n", + " rank_thresh : float or list or tuple/list: (float,float), (list,list),(list,float),(float,list)\n", + " Parameter that controls the rank of V in fitting HAVOK DMD by dictating a threshold\n", + " of singular values to use. Explicitly, the rank of V will be the number of singular\n", + " values greater than rank_thresh. Defaults to None.\n", + "\n", + " rank_explained_variance : float or list or tuple: (float,float), (list,list),(list,float),(float,list)\n", + " Parameter that controls the rank of V in fitting HAVOK DMD by indicating the percentage of\n", + " cumulative explained variance that should be explained by the columns of V. Defaults to None.\n", + "\n", + " lamb : float\n", + " L-1 regularization parameter in DMD fit\n", + "\n", + " send_to_cpu: bool\n", + " If True, will send all tensors in the object back to the cpu after everything is computed.\n", + " This is implemented to prevent gpu memory overload when computing multiple DMDs.\n", + "\n", + " NOTE: for all of these above, they can be single values or lists or tuples,\n", + " depending on the corresponding dimensions of the data\n", + " If at least one of X and Y are lists, then if they are a single value\n", + " it will default to the rank of all DMD matrices.\n", + " If they are (int,int), then they will correspond to an individual dmd matrix\n", + " OR to X and Y respectively across all matrices\n", + " If it is (list,list), then each element will correspond to an individual\n", + " dmd matrix indexed at the same position\n", + "\n", + " SimDist parameters:\n", + "\n", + " iters : int\n", + " number of optimization iterations in Procrustes over vector fields\n", + "\n", + " score_method : {'angular','euclidean'}\n", + " type of metric to compute, angular vs euclidean distance\n", + "\n", + " lr : float\n", + " learning rate of the Procrustes over vector fields optimization\n", + "\n", + " group : {'SO(n)','O(n)', 'GL(n)'}\n", + " specifies the group of matrices to optimize over\n", + "\n", + " zero_pad : bool\n", + " whether or not to zero-pad if the dimensions are different\n", + "\n", + " device : 'cpu' or 'cuda' or int\n", + " hardware to use in both DMD and PoVF\n", + "\n", + " verbose : bool\n", + " whether or not print when sections of the analysis is completed\n", + "\n", + " wasserstein_compare : {'sv','eig',None}\n", + " specifies whether to compare the singular values or eigenvalues\n", + " if score_method is \"wasserstein\", or the shapes are different\n", + " \"\"\"\n", + " self.X = X\n", + " self.Y = Y\n", + " if self.X is None and isinstance(self.Y,list):\n", + " self.X, self.Y = self.Y, self.X #swap so code is easy\n", + "\n", + " self.check_method()\n", + " if self.method == 'self-pairwise':\n", + " self.data = [self.X]\n", + " else:\n", + " self.data = [self.X, self.Y]\n", + "\n", + " self.n_delays = self.broadcast_params(n_delays,cast=int)\n", + " self.delay_interval = self.broadcast_params(delay_interval,cast=int)\n", + " self.rank = self.broadcast_params(rank,cast=int)\n", + " self.rank_thresh = self.broadcast_params(rank_thresh)\n", + " self.rank_explained_variance = self.broadcast_params(rank_explained_variance)\n", + " self.lamb = self.broadcast_params(lamb)\n", + " self.send_to_cpu = send_to_cpu\n", + " self.iters = iters\n", + " self.score_method = score_method\n", + " self.lr = lr\n", + " self.device = device\n", + " self.verbose = verbose\n", + " self.zero_pad = zero_pad\n", + " self.group = group\n", + " self.reduced_rank_reg = reduced_rank_reg\n", + " self.kernel = kernel\n", + " self.wasserstein_compare = wasserstein_compare\n", + "\n", + " if kernel is None:\n", + " #get a list of all DMDs here\n", + " self.dmds = [[DMD(Xi,\n", + " self.n_delays[i][j],\n", + " delay_interval=self.delay_interval[i][j],\n", + " rank=self.rank[i][j],\n", + " rank_thresh=self.rank_thresh[i][j],\n", + " rank_explained_variance=self.rank_explained_variance[i][j],\n", + " reduced_rank_reg=self.reduced_rank_reg,\n", + " lamb=self.lamb[i][j],\n", + " device=self.device,\n", + " verbose=self.verbose,\n", + " send_to_cpu=self.send_to_cpu) for j,Xi in enumerate(dat)] for i,dat in enumerate(self.data)]\n", + " else:\n", + " #get a list of all DMDs here\n", + " self.dmds = [[KernelDMD(Xi,\n", + " self.n_delays[i][j],\n", + " kernel=self.kernel,\n", + " num_centers=num_centers,\n", + " delay_interval=self.delay_interval[i][j],\n", + " rank=self.rank[i][j],\n", + " reduced_rank_reg=self.reduced_rank_reg,\n", + " lamb=self.lamb[i][j],\n", + " verbose=self.verbose,\n", + " svd_solver=svd_solver,\n", + " ) for j,Xi in enumerate(dat)] for i,dat in enumerate(self.data)]\n", + "\n", + " self.simdist = SimilarityTransformDist(iters,score_method,lr,device,verbose,group,wasserstein_compare)\n", + "\n", + " def check_method(self):\n", + " '''\n", + " helper function to identify what type of dsa we're running\n", + " '''\n", + " tensor_or_np = lambda x: isinstance(x,(np.ndarray,torch.Tensor))\n", + "\n", + " if isinstance(self.X,list):\n", + " if self.Y is None:\n", + " self.method = 'self-pairwise'\n", + " elif isinstance(self.Y,list):\n", + " self.method = 'bipartite-pairwise'\n", + " elif tensor_or_np(self.Y):\n", + " self.method = 'list-to-one'\n", + " self.Y = [self.Y] #wrap in a list for iteration\n", + " else:\n", + " raise ValueError('unknown type of Y')\n", + " elif tensor_or_np(self.X):\n", + " self.X = [self.X]\n", + " if self.Y is None:\n", + " raise ValueError('only one element provided')\n", + " elif isinstance(self.Y,list):\n", + " self.method = 'one-to-list'\n", + " elif tensor_or_np(self.Y):\n", + " self.method = 'default'\n", + " self.Y = [self.Y]\n", + " else:\n", + " raise ValueError('unknown type of Y')\n", + " else:\n", + " raise ValueError('unknown type of X')\n", + "\n", + " def broadcast_params(self,param,cast=None):\n", + " '''\n", + " aligns the dimensionality of the parameters with the data so it's one-to-one\n", + " '''\n", + " out = []\n", + " if isinstance(param,(int,float,np.integer)) or param is None: #self.X has already been mapped to [self.X]\n", + " out.append([param] * len(self.X))\n", + " if self.Y is not None:\n", + " out.append([param] * len(self.Y))\n", + " elif isinstance(param,(tuple,list,np.ndarray)):\n", + " if self.method == 'self-pairwise' and len(param) >= len(self.X):\n", + " out = [param]\n", + " else:\n", + " assert len(param) <= 2 #only 2 elements max\n", + "\n", + " #if the inner terms are singly valued, we broadcast, otherwise needs to be the same dimensions\n", + " for i,data in enumerate([self.X,self.Y]):\n", + " if data is None:\n", + " continue\n", + " if isinstance(param[i],(int,float)):\n", + " out.append([param[i]] * len(data))\n", + " elif isinstance(param[i],(list,np.ndarray,tuple)):\n", + " assert len(param[i]) >= len(data)\n", + " out.append(param[i][:len(data)])\n", + " else:\n", + " raise ValueError(\"unknown type entered for parameter\")\n", + "\n", + " if cast is not None and param is not None:\n", + " out = [[cast(x) for x in dat] for dat in out]\n", + "\n", + " return out\n", + "\n", + " def fit_dmds(self,\n", + " X=None,\n", + " Y=None,\n", + " n_delays=None,\n", + " delay_interval=None,\n", + " rank=None,\n", + " rank_thresh = None,\n", + " rank_explained_variance=None,\n", + " reduced_rank_reg=None,\n", + " lamb = None,\n", + " device='cpu',\n", + " verbose=False,\n", + " send_to_cpu=True\n", + " ):\n", + " \"\"\"\n", + " Recomputes only the DMDs with a single set of hyperparameters. This will not compare, that will need to be done with the full procedure\n", + " \"\"\"\n", + " X = self.X if X is None else X\n", + " Y = self.Y if Y is None else Y\n", + " n_delays = self.n_delays if n_delays is None else n_delays\n", + " delay_interval = self.delay_interval if delay_interval is None else delay_interval\n", + " rank = self.rank if rank is None else rank\n", + " lamb = self.lamb if lamb is None else lamb\n", + " data = []\n", + " if isinstance(X,list):\n", + " data.append(X)\n", + " else:\n", + " data.append([X])\n", + " if Y is not None:\n", + " if isinstance(Y,list):\n", + " data.append(Y)\n", + " else:\n", + " data.append([Y])\n", + "\n", + " dmds = [[DMD(Xi,n_delays,delay_interval,\n", + " rank,rank_thresh,rank_explained_variance,reduced_rank_reg,\n", + " lamb,device,verbose,send_to_cpu) for Xi in dat] for dat in data]\n", + "\n", + " for dmd_sets in dmds:\n", + " for dmd in dmd_sets:\n", + " dmd.fit()\n", + "\n", + " return dmds\n", + "\n", + " def fit_score(self):\n", + " \"\"\"\n", + " Standard fitting function for both DMDs and PoVF\n", + "\n", + " Parameters\n", + " __________\n", + "\n", + " Returns\n", + " _______\n", + "\n", + " sims : np.array\n", + " data matrix of the similarity scores between the specific sets of data\n", + " \"\"\"\n", + " for dmd_sets in self.dmds:\n", + " for dmd in dmd_sets:\n", + " dmd.fit()\n", + "\n", + " return self.score()\n", + "\n", + " def score(self,iters=None,lr=None,score_method=None):\n", + " \"\"\"\n", + " Rescore DSA with precomputed dmds if you want to try again\n", + "\n", + " Parameters\n", + " __________\n", + " iters : int or None\n", + " number of optimization steps, if None then resorts to saved self.iters\n", + " lr : float or None\n", + " learning rate, if None then resorts to saved self.lr\n", + " score_method : None or {'angular','euclidean'}\n", + " overwrites the score method in the object for this application\n", + "\n", + " Returns\n", + " ________\n", + " score : float\n", + " similarity score of the two precomputed DMDs\n", + " \"\"\"\n", + "\n", + " iters = self.iters if iters is None else iters\n", + " lr = self.lr if lr is None else lr\n", + " score_method = self.score_method if score_method is None else score_method\n", + "\n", + " ind2 = 1 - int(self.method == 'self-pairwise')\n", + " # 0 if self.pairwise (want to compare the set to itself)\n", + "\n", + " self.sims = np.zeros((len(self.dmds[0]),len(self.dmds[ind2])))\n", + " for i,dmd1 in enumerate(self.dmds[0]):\n", + " for j,dmd2 in enumerate(self.dmds[ind2]):\n", + " if self.method == 'self-pairwise':\n", + " if j >= i:\n", + " continue\n", + " if self.verbose:\n", + " print(f'computing similarity between DMDs {i} and {j}')\n", + "\n", + " self.sims[i,j] = self.simdist.fit_score(dmd1.A_v,dmd2.A_v,iters,lr,score_method,zero_pad=self.zero_pad)\n", + "\n", + " if self.method == 'self-pairwise':\n", + " self.sims[j,i] = self.sims[i,j]\n", + "\n", + "\n", + " if self.method == 'default':\n", + " return self.sims[0,0]\n", + "\n", + " return self.sims" + ] + }, + { + "cell_type": "code", + "execution_count": null, + "id": "eced3162", + "metadata": { + "cellView": "form", + "execution": {} + }, + "outputs": [], + "source": [ + "# @title Helper functions (Bonus Section)\n", + "\n", + "import contextlib\n", + "import io\n", + "import argparse\n", + "# Standard library imports\n", + "from collections import OrderedDict\n", + "import logging\n", + "\n", + "# External libraries: General utilities\n", + "import argparse\n", + "import numpy as np\n", + "\n", + "# PyTorch related imports\n", + "import torch\n", + "import torch.nn as nn\n", + "import torch.nn.functional as F\n", + "import torch.optim as optim\n", + "from torch.optim.lr_scheduler import StepLR\n", + "from torchvision import datasets, transforms\n", + "from torchvision.models.feature_extraction import create_feature_extractor, get_graph_node_names\n", + "from torchvision.utils import make_grid\n", + "\n", + "# Matplotlib for plotting\n", + "import matplotlib as mpl\n", + "import matplotlib.pyplot as plt\n", + "%matplotlib inline\n", + "\n", + "# SciPy for statistical functions\n", + "from scipy import stats\n", + "\n", + "# Scikit-Learn for machine learning utilities\n", + "from sklearn.decomposition import PCA\n", + "from sklearn import manifold\n", + "\n", + "# RSA toolbox specific imports\n", + "import rsatoolbox\n", + "from rsatoolbox.data import Dataset\n", + "from rsatoolbox.rdm.calc import calc_rdm\n", + "\n", + "class Net(nn.Module):\n", + " \"\"\"\n", + " A neural network model for image classification, consisting of two convolutional layers,\n", + " followed by two fully connected layers with dropout regularization.\n", + "\n", + " Methods:\n", + " - forward(input): Defines the forward pass of the network.\n", + " \"\"\"\n", + "\n", + " def __init__(self):\n", + " \"\"\"\n", + " Initializes the network layers.\n", + "\n", + " Layers:\n", + " - conv1: First convolutional layer with 1 input channel, 32 output channels, and a 3x3 kernel.\n", + " - conv2: Second convolutional layer with 32 input channels, 64 output channels, and a 3x3 kernel.\n", + " - dropout1: Dropout layer with a dropout probability of 0.25.\n", + " - dropout2: Dropout layer with a dropout probability of 0.5.\n", + " - fc1: First fully connected layer with 9216 input features and 128 output features.\n", + " - fc2: Second fully connected layer with 128 input features and 10 output features.\n", + " \"\"\"\n", + " super(Net, self).__init__()\n", + " self.conv1 = nn.Conv2d(1, 32, 3, 1)\n", + " self.conv2 = nn.Conv2d(32, 64, 3, 1)\n", + " self.dropout1 = nn.Dropout(0.25)\n", + " self.dropout2 = nn.Dropout(0.5)\n", + " self.fc1 = nn.Linear(9216, 128)\n", + " self.fc2 = nn.Linear(128, 10)\n", + "\n", + " def forward(self, input):\n", + " \"\"\"\n", + " Defines the forward pass of the network.\n", + "\n", + " Inputs:\n", + " - input (torch.Tensor): Input tensor of shape (batch_size, 1, height, width).\n", + "\n", + " Outputs:\n", + " - output (torch.Tensor): Output tensor of shape (batch_size, 10) representing the class probabilities for each input sample.\n", + " \"\"\"\n", + " x = self.conv1(input)\n", + " x = F.relu(x)\n", + " x = self.conv2(x)\n", + " x = F.relu(x)\n", + " x = F.max_pool2d(x, 2)\n", + " x = self.dropout1(x)\n", + " x = torch.flatten(x, 1)\n", + " x = self.fc1(x)\n", + " x = F.relu(x)\n", + " x = self.dropout2(x)\n", + " x = self.fc2(x)\n", + " output = F.softmax(x, dim=1)\n", + " return output\n", + "\n", + "class recurrent_Net(nn.Module):\n", + " \"\"\"\n", + " A recurrent neural network model for image classification, consisting of two convolutional layers\n", + " with recurrent connections and a readout layer.\n", + "\n", + " Methods:\n", + " - __init__(time_steps=5): Initializes the network layers and sets the number of time steps for recurrence.\n", + " - forward(input): Defines the forward pass of the network.\n", + " \"\"\"\n", + "\n", + " def __init__(self, time_steps=5):\n", + " \"\"\"\n", + " Initializes the network layers and sets the number of time steps for recurrence.\n", + "\n", + " Layers:\n", + " - conv1: First convolutional layer with 1 input channel, 16 output channels, and a 3x3 kernel with a stride of 3.\n", + " - conv2: Second convolutional layer with 16 input channels, 16 output channels, and a 3x3 kernel with padding of 1.\n", + " - readout: A sequential layer containing:\n", + " - dropout: Dropout layer with a dropout probability of 0.25.\n", + " - avgpool: Adaptive average pooling layer to reduce spatial dimensions to 1x1.\n", + " - flatten: Flatten layer to convert the 2D pooled output to 1D.\n", + " - linear: Fully connected layer with 16 input features and 10 output features.\n", + " - time_steps (int): Number of time steps for the recurrent connection.\n", + " \"\"\"\n", + " super(recurrent_Net, self).__init__()\n", + " self.conv1 = nn.Conv2d(1, 16, 3, 3)\n", + " self.conv2 = nn.Conv2d(16, 16, 3, 1, padding=1)\n", + " self.readout = nn.Sequential(OrderedDict([\n", + " ('dropout', nn.Dropout(0.25)),\n", + " ('avgpool', nn.AdaptiveAvgPool2d(1)),\n", + " ('flatten', nn.Flatten()),\n", + " ('linear', nn.Linear(16, 10))\n", + " ]))\n", + " self.time_steps = time_steps\n", + "\n", + " def forward(self, input):\n", + " \"\"\"\n", + " Defines the forward pass of the network.\n", + "\n", + " Inputs:\n", + " - input (torch.Tensor): Input tensor of shape (batch_size, 1, height, width).\n", + "\n", + " Outputs:\n", + " - output (torch.Tensor): Output tensor of shape (batch_size, 10) representing the class probabilities for each input sample.\n", + " \"\"\"\n", + " input = self.conv1(input)\n", + " x = input\n", + " for t in range(0, self.time_steps):\n", + " x = input + self.conv2(x)\n", + " x = F.relu(x)\n", + "\n", + " x = self.readout(x)\n", + " output = F.softmax(x, dim=1)\n", + " return output\n", + "\n", + "\n", + "def train_one_epoch(args, model, device, train_loader, optimizer, epoch):\n", + " \"\"\"\n", + " Trains the model for one epoch.\n", + "\n", + " Inputs:\n", + " - args (Namespace): Arguments for training configuration.\n", + " - model (torch.nn.Module): The model to be trained.\n", + " - device (torch.device): The device to use for training (CPU/GPU).\n", + " - train_loader (torch.utils.data.DataLoader): DataLoader for the training data.\n", + " - optimizer (torch.optim.Optimizer): Optimizer for updating the model parameters.\n", + " - epoch (int): The current epoch number.\n", + " \"\"\"\n", + " model.train()\n", + " for batch_idx, (data, target) in enumerate(train_loader):\n", + " data, target = data.to(device), target.to(device)\n", + " optimizer.zero_grad()\n", + " output = model(data)\n", + " output = torch.log(output) # to make it a log_softmax\n", + " loss = F.nll_loss(output, target)\n", + " loss.backward()\n", + " optimizer.step()\n", + " if batch_idx % args.log_interval == 0:\n", + " print('Train Epoch: {} [{}/{} ({:.0f}%)]\\tLoss: {:.6f}'.format(\n", + " epoch, batch_idx * len(data), len(train_loader.dataset),\n", + " 100. * batch_idx / len(train_loader), loss.item()))\n", + " if args.dry_run:\n", + " break\n", + "\n", + "def test(model, device, test_loader, return_features=False):\n", + " \"\"\"\n", + " Evaluates the model on the test dataset.\n", + "\n", + " Inputs:\n", + " - model (torch.nn.Module): The model to be evaluated.\n", + " - device (torch.device): The device to use for evaluation (CPU/GPU).\n", + " - test_loader (torch.utils.data.DataLoader): DataLoader for the test data.\n", + " - return_features (bool): If True, returns the features from the model. Default is False.\n", + " \"\"\"\n", + " model.eval()\n", + " test_loss = 0\n", + " correct = 0\n", + " with torch.no_grad():\n", + " for data, target in test_loader:\n", + " data, target = data.to(device), target.to(device)\n", + " output = model(data)\n", + " output = torch.log(output)\n", + " test_loss += F.nll_loss(output, target, reduction='sum').item() # sum up batch loss\n", + " pred = output.argmax(dim=1, keepdim=True) # get the index of the max log-probability\n", + " correct += pred.eq(target.view_as(pred)).sum().item()\n", + "\n", + " test_loss /= len(test_loader.dataset)\n", + "\n", + " print('\\nTest set: Average loss: {:.4f}, Accuracy: {}/{} ({:.0f}%)\\n'.format(\n", + " test_loss, correct, len(test_loader.dataset),\n", + " 100. * correct / len(test_loader.dataset)))\n", + "\n", + "def build_args():\n", + " \"\"\"\n", + " Builds and parses command-line arguments for training.\n", + " \"\"\"\n", + " parser = argparse.ArgumentParser(description='PyTorch MNIST Example')\n", + " parser.add_argument('--batch-size', type=int, default=64, metavar='N',\n", + " help='input batch size for training (default: 64)')\n", + " parser.add_argument('--test-batch-size', type=int, default=1000, metavar='N',\n", + " help='input batch size for testing (default: 1000)')\n", + " parser.add_argument('--epochs', type=int, default=2, metavar='N',\n", + " help='number of epochs to train (default: 14)')\n", + " parser.add_argument('--lr', type=float, default=1.0, metavar='LR',\n", + " help='learning rate (default: 1.0)')\n", + " parser.add_argument('--gamma', type=float, default=0.7, metavar='M',\n", + " help='Learning rate step gamma (default: 0.7)')\n", + " parser.add_argument('--no-cuda', action='store_true', default=False,\n", + " help='disables CUDA training')\n", + " parser.add_argument('--no-mps', action='store_true', default=False,\n", + " help='disables macOS GPU training')\n", + " parser.add_argument('--dry-run', action='store_true', default=False,\n", + " help='quickly check a single pass')\n", + " parser.add_argument('--seed', type=int, default=1, metavar='S',\n", + " help='random seed (default: 1)')\n", + " parser.add_argument('--log-interval', type=int, default=50, metavar='N',\n", + " help='how many batches to wait before logging training status')\n", + " parser.add_argument('--save-model', action='store_true', default=False,\n", + " help='For Saving the current Model')\n", + " args = parser.parse_args('')\n", + "\n", + " use_cuda = torch.cuda.is_available() #not args.no_cuda and\n", + "\n", + " if use_cuda:\n", + " device = torch.device(\"cuda\")\n", + " else:\n", + " device = torch.device(\"cpu\")\n", + "\n", + " args.use_cuda = use_cuda\n", + " args.device = device\n", + " return args\n", + "\n", + "def fetch_dataloaders(args):\n", + " \"\"\"\n", + " Fetches the data loaders for training and testing datasets.\n", + "\n", + " Inputs:\n", + " - args (Namespace): Parsed arguments with training configuration.\n", + "\n", + " Outputs:\n", + " - train_loader (torch.utils.data.DataLoader): DataLoader for the training data.\n", + " - test_loader (torch.utils.data.DataLoader): DataLoader for the test data.\n", + " \"\"\"\n", + " train_kwargs = {'batch_size': args.batch_size}\n", + " test_kwargs = {'batch_size': args.test_batch_size}\n", + " if args.use_cuda:\n", + " cuda_kwargs = {'num_workers': 1,\n", + " 'pin_memory': True,\n", + " 'shuffle': True}\n", + " train_kwargs.update(cuda_kwargs)\n", + " test_kwargs.update(cuda_kwargs)\n", + "\n", + " transform=transforms.Compose([\n", + " transforms.ToTensor(),\n", + " transforms.Normalize((0.1307,), (0.3081,))\n", + " ])\n", + " with contextlib.redirect_stdout(io.StringIO()): #to suppress output\n", + " dataset1 = datasets.MNIST('../data', train=True, download=True,\n", + " transform=transform)\n", + " dataset2 = datasets.MNIST('../data', train=False,\n", + " transform=transform)\n", + " train_loader = torch.utils.data.DataLoader(dataset1, **train_kwargs)\n", + " test_loader = torch.utils.data.DataLoader(dataset2, **test_kwargs)\n", + " return train_loader, test_loader\n", + "\n", + "def train_model(args, model, optimizer):\n", + " \"\"\"\n", + " Trains the model using the specified arguments and optimizer.\n", + "\n", + " Inputs:\n", + " - args (Namespace): Parsed arguments with training configuration.\n", + " - model (torch.nn.Module): The model to be trained.\n", + " - optimizer (torch.optim.Optimizer): Optimizer for updating the model parameters.\n", + "\n", + " Outputs:\n", + " - None: The function trains the model and optionally saves it.\n", + " \"\"\"\n", + " scheduler = StepLR(optimizer, step_size=1, gamma=args.gamma)\n", + " for epoch in range(1, args.epochs + 1):\n", + " train_one_epoch(args, model, args.device, train_loader, optimizer, epoch)\n", + " test(model, args.device, test_loader)\n", + " scheduler.step()\n", + "\n", + " if args.save_model:\n", + " torch.save(model.state_dict(), \"mnist_cnn.pt\")\n", + "\n", + "\n", + "def calc_rdms(model_features, method='correlation'):\n", + " \"\"\"\n", + " Calculates representational dissimilarity matrices (RDMs) for model features.\n", + "\n", + " Inputs:\n", + " - model_features (dict): A dictionary where keys are layer names and values are features of the layers.\n", + " - method (str): The method to calculate RDMs, e.g., 'correlation'. Default is 'correlation'.\n", + "\n", + " Outputs:\n", + " - rdms (pyrsa.rdm.RDMs): RDMs object containing dissimilarity matrices.\n", + " - rdms_dict (dict): A dictionary with layer names as keys and their corresponding RDMs as values.\n", + " \"\"\"\n", + " ds_list = []\n", + " for l in range(len(model_features)):\n", + " layer = list(model_features.keys())[l]\n", + " feats = model_features[layer]\n", + "\n", + " if type(feats) is list:\n", + " feats = feats[-1]\n", + "\n", + " if args.use_cuda:\n", + " feats = feats.cpu()\n", + "\n", + " if len(feats.shape) > 2:\n", + " feats = feats.flatten(1)\n", + "\n", + " feats = feats.detach().numpy()\n", + " ds = Dataset(feats, descriptors=dict(layer=layer))\n", + " ds_list.append(ds)\n", + "\n", + " rdms = calc_rdm(ds_list, method=method)\n", + " rdms_dict = {list(model_features.keys())[i]: rdms.get_matrices()[i] for i in range(len(model_features))}\n", + "\n", + " return rdms, rdms_dict\n", + "\n", + "def fgsm_attack(image, epsilon, data_grad):\n", + " \"\"\"\n", + " Performs FGSM attack on an image.\n", + "\n", + " Inputs:\n", + " - image (torch.Tensor): Original image.\n", + " - epsilon (float): Perturbation magnitude.\n", + " - data_grad (torch.Tensor): Gradient of the data.\n", + "\n", + " Outputs:\n", + " - perturbed_image (torch.Tensor): Perturbed image after FGSM attack.\n", + " \"\"\"\n", + " sign_data_grad = data_grad.sign()\n", + " perturbed_image = image + epsilon * sign_data_grad\n", + " perturbed_image = torch.clamp(perturbed_image, 0, 1)\n", + " return perturbed_image\n", + "\n", + "def denorm(batch, mean=[0.1307], std=[0.3081]):\n", + " \"\"\"\n", + " Converts a batch of normalized tensors to their original scale.\n", + "\n", + " Inputs:\n", + " - batch (torch.Tensor): Batch of normalized tensors.\n", + " - mean (torch.Tensor or list): Mean used for normalization.\n", + " - std (torch.Tensor or list): Standard deviation used for normalization.\n", + "\n", + " Outputs:\n", + " - torch.Tensor: Batch of tensors without normalization applied to them.\n", + " \"\"\"\n", + " if isinstance(mean, list):\n", + " mean = torch.tensor(mean).to(batch.device)\n", + " if isinstance(std, list):\n", + " std = torch.tensor(std).to(batch.device)\n", + "\n", + " return batch * std.view(1, -1, 1, 1) + mean.view(1, -1, 1, 1)\n", + "\n", + "def generate_adversarial(model, imgs, targets, epsilon):\n", + " \"\"\"\n", + " Generates adversarial examples using FGSM attack.\n", + "\n", + " Inputs:\n", + " - model (torch.nn.Module): The model to attack.\n", + " - imgs (torch.Tensor): Batch of images.\n", + " - targets (torch.Tensor): Batch of target labels.\n", + " - epsilon (float): Perturbation magnitude.\n", + "\n", + " Outputs:\n", + " - adv_imgs (torch.Tensor): Batch of adversarial images.\n", + " \"\"\"\n", + " adv_imgs = []\n", + "\n", + " for img, target in zip(imgs, targets):\n", + " img = img.unsqueeze(0)\n", + " target = target.unsqueeze(0)\n", + " img.requires_grad = True\n", + "\n", + " output = model(img)\n", + " output = torch.log(output)\n", + " loss = F.nll_loss(output, target)\n", + "\n", + " model.zero_grad()\n", + " loss.backward()\n", + "\n", + " data_grad = img.grad.data\n", + " data_denorm = denorm(img)\n", + " perturbed_data = fgsm_attack(data_denorm, epsilon, data_grad)\n", + " perturbed_data_normalized = transforms.Normalize((0.1307,), (0.3081,))(perturbed_data)\n", + "\n", + " adv_imgs.append(perturbed_data_normalized.detach())\n", + "\n", + " return torch.cat(adv_imgs)\n", + "\n", + "def test_adversarial(model, imgs, targets):\n", + " \"\"\"\n", + " Tests the model on adversarial examples and prints the accuracy.\n", + "\n", + " Inputs:\n", + " - model (torch.nn.Module): The model to be tested.\n", + " - imgs (torch.Tensor): Batch of adversarial images.\n", + " - targets (torch.Tensor): Batch of target labels.\n", + " \"\"\"\n", + " correct = 0\n", + " output = model(imgs)\n", + " output = torch.log(output)\n", + " pred = output.argmax(dim=1, keepdim=True) # get the index of the max log-probability\n", + " correct += pred.eq(targets.view_as(pred)).sum().item()\n", + "\n", + " final_acc = correct / float(len(imgs))\n", + " print(f\"adversarial test accuracy = {correct} / {len(imgs)} = {final_acc}\")\n", + "\n", + "def extract_features(model, imgs, return_layers, plot='none'):\n", + " \"\"\"\n", + " Extracts features from specified layers of the model.\n", + "\n", + " Inputs:\n", + " - model (torch.nn.Module): The model from which to extract features.\n", + " - imgs (torch.Tensor): Batch of input images.\n", + " - return_layers (list): List of layer names from which to extract features.\n", + " - plot (str): Option to plot the features. Default is 'none'.\n", + "\n", + " Outputs:\n", + " - model_features (dict): A dictionary with layer names as keys and extracted features as values.\n", + " \"\"\"\n", + " if return_layers == 'all':\n", + " return_layers, _ = get_graph_node_names(model)\n", + " elif return_layers == 'layers':\n", + " layers, _ = get_graph_node_names(model)\n", + " return_layers = [l for l in layers if 'input' in l or 'conv' in l or 'fc' in l]\n", + "\n", + " feature_extractor = create_feature_extractor(model, return_nodes=return_layers)\n", + " model_features = feature_extractor(imgs)\n", + "\n", + " return model_features" + ] + }, + { + "cell_type": "code", + "execution_count": null, + "id": "be4a4946", + "metadata": { + "cellView": "form", + "execution": {} + }, + "outputs": [], + "source": [ + "# @title Plotting functions (Bonus)\n", + "\n", + "def sample_images(data_loader, n=5, plot=False):\n", + " \"\"\"\n", + " Samples a specified number of images from a data loader.\n", + "\n", + " Inputs:\n", + " - data_loader (torch.utils.data.DataLoader): Data loader containing images and labels.\n", + " - n (int): Number of images to sample per class.\n", + " - plot (bool): Whether to plot the sampled images using matplotlib.\n", + "\n", + " Outputs:\n", + " - imgs (torch.Tensor): Sampled images.\n", + " - labels (torch.Tensor): Corresponding labels for the sampled images.\n", + " \"\"\"\n", + "\n", + " with plt.xkcd():\n", + " imgs, targets = next(iter(data_loader))\n", + "\n", + " imgs_o = []\n", + " labels = []\n", + " for value in range(10):\n", + " cat_imgs = imgs[np.where(targets == value)][0:n]\n", + " imgs_o.append(cat_imgs)\n", + " labels.append([value]*len(cat_imgs))\n", + "\n", + " imgs = torch.cat(imgs_o, dim=0)\n", + " labels = torch.tensor(labels).flatten()\n", + "\n", + " if plot:\n", + " plt.imshow(torch.moveaxis(make_grid(imgs, nrow=5, padding=0, normalize=False, pad_value=0), 0,-1))\n", + " plt.axis('off')\n", + "\n", + " return imgs, labels\n", + "\n", + "\n", + "def plot_rdms(model_rdms):\n", + " \"\"\"\n", + " Plots the Representational Dissimilarity Matrices (RDMs) for each layer of a model.\n", + "\n", + " Inputs:\n", + " - model_rdms (dict): A dictionary where keys are layer names and values are the corresponding RDMs.\n", + " \"\"\"\n", + "\n", + " with plt.xkcd():\n", + " fig = plt.figure(figsize=(8, 4))\n", + " gs = fig.add_gridspec(1, len(model_rdms))\n", + " fig.subplots_adjust(wspace=0.2, hspace=0.2)\n", + "\n", + " for l in range(len(model_rdms)):\n", + "\n", + " layer = list(model_rdms.keys())[l]\n", + " rdm = np.squeeze(model_rdms[layer])\n", + "\n", + " if len(rdm.shape) < 2:\n", + " rdm = rdm.reshape( (int(np.sqrt(rdm.shape[0])), int(np.sqrt(rdm.shape[0]))) )\n", + "\n", + " rdm = rdm / np.max(rdm)\n", + "\n", + " ax = plt.subplot(gs[0,l])\n", + " ax_ = ax.imshow(rdm, cmap='magma_r')\n", + " ax.set_title(f'{layer}')\n", + "\n", + " fig.subplots_adjust(right=0.9)\n", + " cbar_ax = fig.add_axes([1.01, 0.18, 0.01, 0.53])\n", + " cbar_ax.text(-2.3, 0.05, 'Normalized euclidean distance', size=10, rotation=90)\n", + " fig.colorbar(ax_, cax=cbar_ax)\n", + "\n", + " plt.show()\n", + "\n", + "def rep_path(model_features, model_colors, labels=None, rdm_calc_method='euclidean', rdm_comp_method='cosine'):\n", + " \"\"\"\n", + " Represents paths of model features in a reduced-dimensional space.\n", + "\n", + " Inputs:\n", + " - model_features (dict): Dictionary containing model features for each model.\n", + " - model_colors (dict): Dictionary mapping model names to colors for visualization.\n", + " - labels (array-like, optional): Array of labels corresponding to the model features.\n", + " - rdm_calc_method (str, optional): Method for calculating RDMS ('euclidean' or 'correlation').\n", + " - rdm_comp_method (str, optional): Method for comparing RDMS ('cosine' or 'corr').\n", + " \"\"\"\n", + " with plt.xkcd():\n", + " path_len = []\n", + " path_colors = []\n", + " rdms_list = []\n", + " ax_ticks = []\n", + " tick_colors = []\n", + " model_names = list(model_features.keys())\n", + " for m in range(len(model_names)):\n", + " model_name = model_names[m]\n", + " features = model_features[model_name]\n", + " path_colors.append(model_colors[model_name])\n", + " path_len.append(len(features))\n", + " ax_ticks.append(list(features.keys()))\n", + " tick_colors.append([model_colors[model_name]]*len(features))\n", + " rdms, _ = calc_rdms(features, method=rdm_calc_method)\n", + " rdms_list.append(rdms)\n", + "\n", + " path_len = np.insert(np.cumsum(path_len),0,0)\n", + "\n", + " if labels is not None:\n", + " rdms, _ = calc_rdms({'labels' : F.one_hot(labels).float().to(device)}, method=rdm_calc_method)\n", + " rdms_list.append(rdms)\n", + " ax_ticks.append(['labels'])\n", + " tick_colors.append(['m'])\n", + " idx_labels = -1\n", + "\n", + " rdms = rsatoolbox.rdm.concat(rdms_list)\n", + "\n", + " #Flatten the list\n", + " ax_ticks = [l for model_layers in ax_ticks for l in model_layers]\n", + " tick_colors = [l for model_layers in tick_colors for l in model_layers]\n", + " tick_colors = ['k' if tick == 'input' else color for tick, color in zip(ax_ticks, tick_colors)]\n", + "\n", + " rdms_comp = rsatoolbox.rdm.compare(rdms, rdms, method=rdm_comp_method)\n", + " if rdm_comp_method == 'cosine':\n", + " rdms_comp = np.arccos(rdms_comp)\n", + " rdms_comp = np.nan_to_num(rdms_comp, nan=0.0)\n", + "\n", + " # Symmetrize\n", + " rdms_comp = (rdms_comp + rdms_comp.T) / 2.0\n", + "\n", + " # reduce dim to 2\n", + " transformer = manifold.MDS(n_components = 2, max_iter=1000, n_init=10, normalized_stress='auto', dissimilarity=\"precomputed\")\n", + " dims= transformer.fit_transform(rdms_comp)\n", + "\n", + " # remove duplicates of the input layer from multiple models\n", + " remove_duplicates = np.where(np.array(ax_ticks) == 'input')[0][1:]\n", + " for index in remove_duplicates:\n", + " del ax_ticks[index]\n", + " del tick_colors[index]\n", + " rdms_comp = np.delete(np.delete(rdms_comp, index, axis=0), index, axis=1)\n", + "\n", + " fig = plt.figure(figsize=(8, 4))\n", + " gs = fig.add_gridspec(1, 2)\n", + " fig.subplots_adjust(wspace=0.2, hspace=0.2)\n", + "\n", + " ax = plt.subplot(gs[0,0])\n", + " ax_ = ax.imshow(rdms_comp, cmap='viridis_r')\n", + " fig.subplots_adjust(left=0.2)\n", + " cbar_ax = fig.add_axes([-0.01, 0.2, 0.01, 0.5])\n", + " #cbar_ax.text(-7, 0.05, 'dissimilarity between rdms', size=10, rotation=90)\n", + " fig.colorbar(ax_, cax=cbar_ax,location='left')\n", + " ax.set_title('Dissimilarity between layer rdms', fontdict = {'fontsize': 14})\n", + " ax.set_xticks(np.arange(len(ax_ticks)), labels=ax_ticks, fontsize=7, rotation=83)\n", + " ax.set_yticks(np.arange(len(ax_ticks)), labels=ax_ticks, fontsize=7)\n", + " [t.set_color(i) for (i,t) in zip(tick_colors, ax.xaxis.get_ticklabels())]\n", + " [t.set_color(i) for (i,t) in zip(tick_colors, ax.yaxis.get_ticklabels())]\n", + "\n", + " ax = plt.subplot(gs[0,1])\n", + " amin, amax = dims.min(), dims.max()\n", + " amin, amax = (amin + amax) / 2 - (amax - amin) * 5/8, (amin + amax) / 2 + (amax - amin) * 5/8\n", + "\n", + " for i in range(len(rdms_list)-1):\n", + "\n", + " path_indices = np.arange(path_len[i], path_len[i+1])\n", + " ax.plot(dims[path_indices, 0], dims[path_indices, 1], color=path_colors[i], marker='.')\n", + " ax.set_title('Representational geometry path', fontdict = {'fontsize': 14})\n", + " ax.set_xlim([amin, amax])\n", + " ax.set_ylim([amin, amax])\n", + " ax.set_xlabel(f\"dim 1\")\n", + " ax.set_ylabel(f\"dim 2\")\n", + "\n", + " # if idx_input is not None:\n", + " idx_input = 0\n", + " ax.plot(dims[idx_input, 0], dims[idx_input, 1], color='k', marker='s')\n", + "\n", + " if labels is not None:\n", + " ax.plot(dims[idx_labels, 0], dims[idx_labels, 1], color='m', marker='*')\n", + "\n", + " ax.legend(model_names, fontsize=8)\n", + " fig.tight_layout()\n", + "\n", + "def plot_dim_reduction(model_features, labels, transformer_funcs):\n", + " \"\"\"\n", + " Plots the dimensionality reduction results for model features using various transformers.\n", + "\n", + " Inputs:\n", + " - model_features (dict): Dictionary containing model features for each layer.\n", + " - labels (array-like): Array of labels corresponding to the model features.\n", + " - transformer_funcs (list): List of dimensionality reduction techniques to apply ('PCA', 'MDS', 't-SNE').\n", + " \"\"\"\n", + " with plt.xkcd():\n", + "\n", + " transformers = []\n", + " for t in transformer_funcs:\n", + " if t == 'PCA': transformers.append(PCA(n_components=2))\n", + " if t == 'MDS': transformers.append(manifold.MDS(n_components = 2, normalized_stress='auto'))\n", + " if t == 't-SNE': transformers.append(manifold.TSNE(n_components = 2, perplexity=40, verbose=0))\n", + "\n", + " fig = plt.figure(figsize=(8, 2.5*len(transformers)))\n", + " # and we add one plot per reference point\n", + " gs = fig.add_gridspec(len(transformers), len(model_features))\n", + " fig.subplots_adjust(wspace=0.2, hspace=0.2)\n", + "\n", + " return_layers = list(model_features.keys())\n", + "\n", + " for f in range(len(transformer_funcs)):\n", + "\n", + " for l in range(len(return_layers)):\n", + " layer = return_layers[l]\n", + " feats = model_features[layer].detach().cpu().flatten(1)\n", + " feats_transformed= transformers[f].fit_transform(feats)\n", + "\n", + " amin, amax = feats_transformed.min(), feats_transformed.max()\n", + " amin, amax = (amin + amax) / 2 - (amax - amin) * 5/8, (amin + amax) / 2 + (amax - amin) * 5/8\n", + " ax = plt.subplot(gs[f,l])\n", + " ax.set_xlim([amin, amax])\n", + " ax.set_ylim([amin, amax])\n", + " ax.axis(\"off\")\n", + " #ax.set_title(f'{layer}')\n", + " if f == 0: ax.text(0.5, 1.12, f'{layer}', size=16, ha=\"center\", transform=ax.transAxes)\n", + " if l == 0: ax.text(-0.3, 0.5, transformer_funcs[f], size=16, ha=\"center\", transform=ax.transAxes)\n", + " # Create a discrete color map based on unique labels\n", + " num_colors = len(np.unique(labels))\n", + " cmap = plt.get_cmap('viridis_r', num_colors) # 10 discrete colors\n", + " norm = mpl.colors.BoundaryNorm(np.arange(-0.5,num_colors), cmap.N)\n", + " ax_ = ax.scatter(feats_transformed[:, 0], feats_transformed[:, 1], c=labels, cmap=cmap, norm=norm)\n", + "\n", + " fig.subplots_adjust(right=0.9)\n", + " cbar_ax = fig.add_axes([1.01, 0.18, 0.01, 0.53])\n", + " fig.colorbar(ax_, cax=cbar_ax, ticks=np.linspace(0,9,10))\n", + " plt.show()" + ] + }, + { + "cell_type": "code", + "execution_count": null, + "id": "21f68945", + "metadata": { + "cellView": "form", + "execution": {} + }, + "outputs": [], + "source": [ + "# @title Data retrieval\n", + "\n", + "import os\n", + "import requests\n", + "import hashlib\n", + "\n", + "# Variables for file and download URL\n", + "fnames = [\"standard_model.pth\", \"adversarial_model.pth\", \"recurrent_model.pth\"] # The names of the files to be downloaded\n", + "urls = [\"https://osf.io/s5rt6/download\", \"https://osf.io/qv5eb/download\", \"https://osf.io/6hnwk/download\"] # URLs from where the files will be downloaded\n", + "expected_md5s = [\"2e63c2cd77bc9f1fa67673d956ec910d\", \"25fb34497377921b54368317f68a7aa7\", \"ee5cea3baa264cb78300102fa6ed66e8\"] # MD5 hashes for verifying files integrity\n", + "\n", + "for fname, url, expected_md5 in zip(fnames, urls, expected_md5s):\n", + " if not os.path.isfile(fname):\n", + " try:\n", + " # Attempt to download the file\n", + " r = requests.get(url) # Make a GET request to the specified URL\n", + " except requests.ConnectionError:\n", + " # Handle connection errors during the download\n", + " print(\"!!! Failed to download data !!!\")\n", + " else:\n", + " # No connection errors, proceed to check the response\n", + " if r.status_code != requests.codes.ok:\n", + " # Check if the HTTP response status code indicates a successful download\n", + " print(\"!!! Failed to download data !!!\")\n", + " elif hashlib.md5(r.content).hexdigest() != expected_md5:\n", + " # Verify the integrity of the downloaded file using MD5 checksum\n", + " print(\"!!! Data download appears corrupted !!!\")\n", + " else:\n", + " # If download is successful and data is not corrupted, save the file\n", + " with open(fname, \"wb\") as fid:\n", + " fid.write(r.content) # Write the downloaded content to a file" + ] + }, + { + "cell_type": "code", + "execution_count": null, + "id": "93aeca0a", + "metadata": { + "cellView": "form", + "execution": {} + }, + "outputs": [], + "source": [ + "# @title Figure settings\n", + "\n", + "logging.getLogger('matplotlib.font_manager').disabled = True\n", + "\n", + "%matplotlib inline\n", + "%config InlineBackend.figure_format = 'retina' # perfrom high definition rendering for images and plots\n", + "plt.style.use(\"https://raw.githubusercontent.com/NeuromatchAcademy/course-content/main/nma.mplstyle\")" + ] + }, + { + "cell_type": "code", + "execution_count": null, + "id": "dd8052d5", + "metadata": { + "cellView": "form", + "execution": {} + }, + "outputs": [], + "source": [ + "# @title Set device (GPU or CPU)\n", + "\n", + "# inform the user if the notebook uses GPU or CPU.\n", + "\n", + "def set_device():\n", + " \"\"\"\n", + " Determines and sets the computational device for PyTorch operations based on the availability of a CUDA-capable GPU.\n", + "\n", + " Outputs:\n", + " - device (str): The device that PyTorch will use for computations ('cuda' or 'cpu'). This string can be directly used\n", + " in PyTorch operations to specify the device.\n", + " \"\"\"\n", + "\n", + " device = \"cuda\" if torch.cuda.is_available() else \"cpu\"\n", + " if device != \"cuda\":\n", + " print(\"GPU is not enabled in this notebook. \\n\"\n", + " \"If you want to enable it, in the menu under `Runtime` -> \\n\"\n", + " \"`Hardware accelerator.` and select `GPU` from the dropdown menu\")\n", + " else:\n", + " print(\"GPU is enabled in this notebook. \\n\"\n", + " \"If you want to disable it, in the menu under `Runtime` -> \\n\"\n", + " \"`Hardware accelerator.` and select `None` from the dropdown menu\")\n", + "\n", + " return device\n", + "\n", + "device = set_device()" ] }, { @@ -75,84 +2155,541 @@ "execution_count": null, "id": "c28a92e7-e76c-48de-b574-15a1272717cf", "metadata": { - "cellView": "form", + "cellView": "form", + "execution": {} + }, + "outputs": [], + "source": [ + "# @title Load Slides\n", + "\n", + "from IPython.display import IFrame\n", + "from ipywidgets import widgets\n", + "out = widgets.Output()\n", + "\n", + "link_id = \"8fx23\"\n", + "\n", + "with out:\n", + " print(f\"If you want to download the slides: https://osf.io/download/{link_id}/\")\n", + " display(IFrame(src=f\"https://mfr.ca-1.osf.io/render?url=https://osf.io/{link_id}/?direct%26mode=render%26action=download%26mode=render\", width=730, height=410))\n", + "display(out)" + ] + }, + { + "cell_type": "markdown", + "id": "407ace26", + "metadata": { + "execution": {} + }, + "source": [ + "---\n", + "\n", + "# Intro\n", + "\n", + "Welcome to Tutorial 5 of Day 3 (W1D3) of the NeuroAI course. In this tutorial we are going to look at an exciting method that measures similarity from a slightly different perspective, a temporal one. The prior methods we have looked at were centeed around geometry and spatial representations, where we looked at metrics such as the Euclidean and Mahalanobis distance metrics. However, one thing we often want to study in neuroscience and in AI separately - is the temporal domain. Even more so in our own field of NeuroAI, we often deal with time series of neuronal / biological recordings. One thing you should already have a broad level of awareness of is that end structures can end up looking the same even though the paths taken to arrive at those end structures were very different.\n", + "\n", + "In NeuroAI, we're often confronted with systems that seem to have some sort of overlap and we want to study whether this implies there is a shared computation pairs up with the shared task (we looked at this in detail yesterday in our *Comparing Tasks* day). Today, we will begin by watching a short intro video by Mitchell Ostrow, who will describe his method to compare representations over temporal sequences (the method is called Dynamic Similarity Analysis). Then we are going to introduce three simple dynamical systems and we will explore them from the perspective of Dynamic Similarity Analysis and also describe the conceptual relationship to Representational Similarity Analysis. You will have a short coding exercise on the topic of temporal similarity analysis on three different types of trajectories. \n", + "\n", + "At the end of the tutorial, we will finally look at a further aspect of temporal sequences using RNNs. This is an adaptation of the ideas introduced in Tutorial 2 but now based around recurrent representations from RNNs. We hope you enjoy this tutorial today and that it gets you thinking not just what similarity values mean, but which ones are appropriate (here, from a spatial or temporal perspective). We aim to continually expand the tools necessary in the NeuroAI researcher's toolkit. Complementary tools, when applicable, can often tell a far richer story than just using a single method." + ] + }, + { + "cell_type": "code", + "execution_count": null, + "id": "b5d6178f-ddf5-41ae-b676-15e452dc8b78", + "metadata": { + "cellView": "form", + "execution": {} + }, + "outputs": [], + "source": [ + "# @title Video 1: Dynamical Similarity Analysis\n", + "\n", + "from ipywidgets import widgets\n", + "from IPython.display import YouTubeVideo\n", + "from IPython.display import IFrame\n", + "from IPython.display import display\n", + "\n", + "class PlayVideo(IFrame):\n", + " def __init__(self, id, source, page=1, width=400, height=300, **kwargs):\n", + " self.id = id\n", + " if source == 'Bilibili':\n", + " src = f'https://player.bilibili.com/player.html?bvid={id}&page={page}'\n", + " elif source == 'Osf':\n", + " src = f'https://mfr.ca-1.osf.io/render?url=https://osf.io/download/{id}/?direct%26mode=render'\n", + " super(PlayVideo, self).__init__(src, width, height, **kwargs)\n", + "\n", + "def display_videos(video_ids, W=400, H=300, fs=1):\n", + " tab_contents = []\n", + " for i, video_id in enumerate(video_ids):\n", + " out = widgets.Output()\n", + " with out:\n", + " if video_ids[i][0] == 'Youtube':\n", + " video = YouTubeVideo(id=video_ids[i][1], width=W,\n", + " height=H, fs=fs, rel=0)\n", + " print(f'Video available at https://youtube.com/watch?v={video.id}')\n", + " else:\n", + " video = PlayVideo(id=video_ids[i][1], source=video_ids[i][0], width=W,\n", + " height=H, fs=fs, autoplay=False)\n", + " if video_ids[i][0] == 'Bilibili':\n", + " print(f'Video available at https://www.bilibili.com/video/{video.id}')\n", + " elif video_ids[i][0] == 'Osf':\n", + " print(f'Video available at https://osf.io/{video.id}')\n", + " display(video)\n", + " tab_contents.append(out)\n", + " return tab_contents\n", + "\n", + "video_ids = [('Youtube', 'FHikIsQFQvM'), ('Bilibili', 'BV1qm421g7hV')]\n", + "tab_contents = display_videos(video_ids, W=854, H=480)\n", + "tabs = widgets.Tab()\n", + "tabs.children = tab_contents\n", + "for i in range(len(tab_contents)):\n", + " tabs.set_title(i, video_ids[i][0])\n", + "display(tabs)" + ] + }, + { + "cell_type": "code", + "execution_count": null, + "id": "d2ce83bc-7e86-44d3-a40a-4ad46fd5a6df", + "metadata": { + "cellView": "form", + "execution": {} + }, + "outputs": [], + "source": [ + "# @title Submit your feedback\n", + "content_review(f\"{feedback_prefix}_DSA_video\")" + ] + }, + { + "cell_type": "markdown", + "id": "937041e9", + "metadata": { + "execution": {} + }, + "source": [ + "## Section 1: Visualization of Three Temporal Sequences\n", + "\n", + "We are going to be working with the analysis of three temporal sequences today:\n", + "\n", + "* The circular time series (`Circle`)\n", + "* The oval time series (`Oval`)\n", + "* The random walk (`R-Walk`)\n", + "\n", + "The random walk is going to be broadly *oval shaped*. Now, what do you think, from a geometric perspective, might result from a spatial analysis of these three different *representations*? You will probably assume because the random walk has an oval shape and there is also an oval time series (that's not a random walk) that these would result in a higher spatial similarity. You'd be right to assume this. However, what we're going to do with the `Circle` and `Oval` time series is to include an oscillator at a specific frequency, shared amongst these two time series. In effect, this means that although when plotted in totality the shapes are different, during the dynamic (temporal) evolution of these time series, a very similar shared pattern is emerging. We want methods that are sensitive to these changes to give higher scores for time series sharing similar temporal patterns (e.g. both containing oscillating patterns at similar frequences) rather than just a random walk that resembles (geometrically) one of the other shapes (`R-Walk`). Before we continue, we'll just define this random walk in a little more detail. A random walk at a specific location / timepoint takes a random step of fixed length in a specific direction, but this can be broadly controlled to resemble geometric shapes. We've taken a random walk and then reframed it to be similar in shape to `Oval`. \n", + "\n", + "Let's now visualize these three temporal sequences, to make the previous paragraph a little clearer." + ] + }, + { + "cell_type": "code", + "execution_count": null, + "id": "b57dfe1a", + "metadata": { + "execution": {} + }, + "outputs": [], + "source": [ + "# Circle\n", + "r = .1; # rotation\n", + "A = np.array([[np.cos(r), np.sin(r)], [-np.sin(r), np.cos(r)]])\n", + "B = np.array([[1, 0], [0, 1]])\n", + "\n", + "trajectory = generate_2d_random_process(A, B)\n", + "trajectory_circle = trajectory\n", + "\n", + "# Oval\n", + "r = .1; # rotation\n", + "s = 4; # scaling\n", + "S = np.array([[1, 0], [0, s]])\n", + "Si = np.array([[1, 0], [0, 1/s]])\n", + "V = np.array([[1, 1], [-1, 1]])/np.sqrt(2)\n", + "Vi = np.array([[1, -1], [1, 1]])/np.sqrt(2)\n", + "R = np.array([[np.cos(r), np.sin(r)], [-np.sin(r), np.cos(r)]])\n", + "A = np.linalg.multi_dot([V,Si,R,S,Vi])\n", + "B = np.array([[1, 0], [0, 1]])\n", + "\n", + "trajectory = generate_2d_random_process(A, B)\n", + "trajectory_oval = trajectory\n", + "\n", + "# R-Walk (random walk)\n", + "r = .1; # rotation\n", + "A = np.array([[.9, 0], [0, .9]])\n", + "c = -.95; # correlation coefficient\n", + "B = np.array([[1, c], [0, np.sqrt(1-c*c)]])\n", + "\n", + "trajectory = generate_2d_random_process(A, B)\n", + "trajectory_walk = trajectory" + ] + }, + { + "cell_type": "markdown", + "id": "113a0dee", + "metadata": { + "execution": {} + }, + "source": [ + "Can you see how the spatial / geometric similarity of `R-Walk` and `Oval` are more similar, but the oscillations during the temporal sequence are shared between `Circle` and `Oval`? Let's run Dynamic Similarity Analysis on these temporal sequences and see what scores are returned.\n", + "\n", + "We calcularted `trajectory_oval` and `trajectory_circle` above, so let's plug these into the `DSA` function imported earlier (in the helper function cell) and see what the similarity score is." + ] + }, + { + "cell_type": "code", + "execution_count": null, + "id": "c3e36d59", + "metadata": { + "execution": {} + }, + "outputs": [], + "source": [ + "# Define the DSA computation class\n", + "dsa = DSA(X=trajectory_oval, Y=trajectory_circle, n_delays=1)\n", + "\n", + "# Call the fit method and save the result\n", + "similarities_oval_circle = dsa.fit_score()\n", + "\n", + "print(f\"DSA similarity between Oval and Circle: {similarities_oval_circle:.4f}\")" + ] + }, + { + "cell_type": "markdown", + "id": "9f1fb622", + "metadata": { + "execution": {} + }, + "source": [ + "## Multi-way Comparison\n", + "\n", + "We're now going to run DSA on our three trajectories and fit the model, returning the scores which we can investigate by plotting a confusion matrix with a heatmap to show the DSA scores." + ] + }, + { + "cell_type": "code", + "execution_count": null, + "id": "6ee9e8e8", + "metadata": { "execution": {} }, "outputs": [], "source": [ - "# @title Bonus material slides\n", + "n_delays = 1\n", + "delay_interval = 1\n", "\n", - "from IPython.display import IFrame\n", - "from ipywidgets import widgets\n", - "out = widgets.Output()\n", + "models = [trajectory_circle, trajectory_oval, trajectory_walk]\n", + "dsa = DSA(models, n_delays=n_delays, delay_interval=delay_interval)\n", + "similarities = dsa.fit_score()" + ] + }, + { + "cell_type": "code", + "execution_count": null, + "id": "18318ddb", + "metadata": { + "execution": {} + }, + "outputs": [], + "source": [ + "labels = ['Circle', 'Oval', 'Walk']\n", + "data = np.random.rand(len(labels), len(labels))\n", + "ax = sns.heatmap(similarities, xticklabels=labels, yticklabels=labels)\n", + "cbar = ax.collections[0].colorbar\n", + "cbar.ax.set_ylabel('DSA Score');\n", + "plt.title(\"Dynamic Similarity Analysis Score among Trajectories\");" + ] + }, + { + "cell_type": "markdown", + "id": "ffd49b4b", + "metadata": { + "execution": {} + }, + "source": [ + "This heatmap across the three model comparisons shows that the DSA scores between (`Walk` and `Circle`) and (`Walk` and `Oval`) to be (relatively) high, while the comparison between (`Circle` and `Oval`) is very low. Please note that this confusion matrix is symmetrical, meaning that the analysis between `trajectory_A` and `trajectory_B` returns the same dynamic similarity score as `trajectory_B` and `trajectory_A`. This is a common feature we have also seen in comparison metrics in standard RSA. One thing to note in the calculation of DSA is that comparisons among identical trajectories is `0`. This is unlike in RSA where we expect the correlation among the same stimuli to be `1.0`. This is why we see black squares along the diagonal.\n", "\n", - "link_id = \"8fx23\"\n", + "Let's put our thinking caps on for a moment: This isn't really the result we would have expected, right? What do you think might be going on here? Have a look back at the *hyperparameters* and try to make an educated guess!" + ] + }, + { + "cell_type": "markdown", + "id": "d0ff5faa", + "metadata": { + "execution": {} + }, + "source": [ + "## DSA Hyperparameters (`n_delays` and `delay_interval`)\n", "\n", - "with out:\n", - " print(f\"If you want to download the slides: https://osf.io/download/{link_id}/\")\n", - " display(IFrame(src=f\"https://mfr.ca-1.osf.io/render?url=https://osf.io/{link_id}/?direct%26mode=render%26action=download%26mode=render\", width=730, height=410))\n", - "display(out)" + "We'll now give you a hint as to why the setting of these hyperparameters is important when considering dynamic similarity analysis. The oscillators we have placed in the trajectories of `Circle` and `Oval` are not immediately apparent if you study only the previous time step for each element. It's only when considering the recurring pattern across a few different temporal delays and at what delay interval you want those to be, that we would expect to be able to detect recurring oscillations that provide us with the information we need to conclude that `Oval` and `Circle` are actually *dynamically* similar.\n", + "\n", + "You should change the values below to be more sensible hyperparameter settings and re-run the model and plot the new confusion matrix. Try using `n_delays` equal to `20` and `delay_interval` equal to `10`. Don't forget to define `models` (see above example if you get stuck)." ] }, { "cell_type": "code", "execution_count": null, - "id": "b5d6178f-ddf5-41ae-b676-15e452dc8b78", + "id": "9d8d4c03", + "metadata": { + "execution": {} + }, + "outputs": [], + "source": [ + "#################################################\n", + "## TODO for students: fill in the missing parts ##\n", + "raise NotImplementedError(\"Student exercise\")\n", + "#################################################\n", + "\n", + "n_delays = ...\n", + "delay_interval = ...\n", + "\n", + "models = ...\n", + "dsa = DSA(...)\n", + "similarities = ...\n", + "\n", + "labels = ['Circle', 'Oval', 'Walk']\n", + "ax = sns.heatmap(similarities, xticklabels=labels, yticklabels=labels)\n", + "cbar = ax.collections[0].colorbar\n", + "cbar.ax.set_ylabel('DSA Score');\n", + "plt.title(\"Dynamic Similarity Analysis Score among Trajectories\");" + ] + }, + { + "cell_type": "code", + "execution_count": null, + "id": "a6377c65", + "metadata": { + "execution": {} + }, + "outputs": [], + "source": [ + "# to_remove solution\n", + "\n", + "n_delays = 20\n", + "delay_interval = 10\n", + "\n", + "models = [trajectory_circle, trajectory_oval, trajectory_walk]\n", + "dsa = DSA(models, n_delays=n_delays, delay_interval=delay_interval)\n", + "similarities = dsa.fit_score()\n", + "\n", + "labels = ['Circle', 'Oval', 'Walk']\n", + "ax = sns.heatmap(similarities, xticklabels=labels, yticklabels=labels)\n", + "cbar = ax.collections[0].colorbar\n", + "cbar.ax.set_ylabel('DSA Score');\n", + "plt.title(\"Dynamic Similarity Analysis Score among Trajectories\");" + ] + }, + { + "cell_type": "markdown", + "id": "04b0e32f", + "metadata": { + "execution": {} + }, + "source": [ + "What do you see now? We now see a much more sensible result. The DSA scores have now correctly identified that `Oval` and `Circle` are very dynamically similar! They have the highest color score according to the colorbar on the side. As is always good practice in science, let's have a look inside the `similarities` variable to look at the exact values and confirm what we see in the figure above." + ] + }, + { + "cell_type": "code", + "execution_count": null, + "id": "55fa4065", + "metadata": { + "execution": {} + }, + "outputs": [], + "source": [ + "similarities" + ] + }, + { + "cell_type": "markdown", + "id": "59cb799f", + "metadata": { + "execution": {} + }, + "source": [ + "## Comparison with RSA\n", + "\n", + "At the start of this exercise, we saw three different trajectories and pointed out that the random walk and oval shapes were most similar from a geometric perspective, both ellipse-like but not similar in their dynamic similarity. To better show the difference between DSA and RSA, we encourage you to run another comparison where we consider each time step to be a pair in the X,Y space and we will look at the the similarity between of `Oval` with both `Circle` and `Walk`. If our understanding is correct, then RSA should indicate a higher geometric similarity between (`Oval` and `Walk`) than with (`Oval` and `Circle`)." + ] + }, + { + "cell_type": "markdown", + "id": "87cf4e6e", + "metadata": { + "execution": {} + }, + "source": [ + "---\n", + "# (Bonus) Representational Geometry of Recurrent Models\n", + "\n", + "Transformations of representations can occur across space and time, e.g., layers of a neural network and steps of recurrent computation. We've looked at the temporal dimension today and earlier today in the other tutorials we looked mainly at spatial representations.\n", + "\n", + "Just as the layers in a feedforward DNN can change the representational geometry to perform a task, steps in a recurrent network can reuse the same layer to reach the same computational depth.\n", + "\n", + "In this section, we look at a very simple recurrent network with only 2650 trainable parameters." + ] + }, + { + "cell_type": "markdown", + "id": "3d613edd", + "metadata": { + "execution": {} + }, + "source": [ + "Here is a diagram of this network:\n", + "\n", + "![Recurrent convolutional neural network](https://github.com/neuromatch/NeuroAI_Course/blob/main/tutorials/W1D3_ComparingArtificialAndBiologicalNetworks/static/rcnn_tutorial.png?raw=true)" + ] + }, + { + "cell_type": "code", + "execution_count": null, + "id": "2f0443d3", "metadata": { "cellView": "form", "execution": {} }, "outputs": [], "source": [ - "# @title Video 1: Dynamical Similarity Analysis\n", + "# @title Grab a recurrent model\n", "\n", - "from ipywidgets import widgets\n", - "from IPython.display import YouTubeVideo\n", - "from IPython.display import IFrame\n", - "from IPython.display import display\n", + "args = build_args()\n", + "train_loader, test_loader = fetch_dataloaders(args)\n", + "path = \"recurrent_model.pth\"\n", + "model_recurrent = torch.load(path, map_location=args.device, weights_only=False)" + ] + }, + { + "cell_type": "markdown", + "id": "d463c3a9", + "metadata": { + "execution": {} + }, + "source": [ + "
We can first look at the computational steps in this network. As we see below, the `conv2` operation is repeated for 5 times." + ] + }, + { + "cell_type": "code", + "execution_count": null, + "id": "6bfabacd", + "metadata": { + "execution": {} + }, + "outputs": [], + "source": [ + "train_nodes, _ = get_graph_node_names(model_recurrent)\n", + "print('The computational steps in the network are: \\n', train_nodes)" + ] + }, + { + "cell_type": "markdown", + "id": "1d410c3a", + "metadata": { + "execution": {} + }, + "source": [ + "Plotting the RDMs after each application of the `conv2` operation shows the same progressive emergence of the blockwise structure around the diagonal, mediating the correct classification in this task." + ] + }, + { + "cell_type": "code", + "execution_count": null, + "id": "30249608", + "metadata": { + "execution": {} + }, + "outputs": [], + "source": [ + "imgs, labels = sample_images(test_loader, n=20)\n", + "return_layers = ['conv2', 'conv2_1', 'conv2_2', 'conv2_3', 'conv2_4']\n", + "model_features = extract_features(model_recurrent, imgs.to(device), return_layers)\n", "\n", - "class PlayVideo(IFrame):\n", - " def __init__(self, id, source, page=1, width=400, height=300, **kwargs):\n", - " self.id = id\n", - " if source == 'Bilibili':\n", - " src = f'https://player.bilibili.com/player.html?bvid={id}&page={page}'\n", - " elif source == 'Osf':\n", - " src = f'https://mfr.ca-1.osf.io/render?url=https://osf.io/download/{id}/?direct%26mode=render'\n", - " super(PlayVideo, self).__init__(src, width, height, **kwargs)\n", + "rdms, rdms_dict = calc_rdms(model_features)\n", + "plot_rdms(rdms_dict)" + ] + }, + { + "cell_type": "markdown", + "id": "248329c3", + "metadata": { + "execution": {} + }, + "source": [ + "We can also look at how the different dimensionality reduction techniques capture the dynamics of changing geometry." + ] + }, + { + "cell_type": "code", + "execution_count": null, + "id": "1b0e2cdf", + "metadata": { + "execution": {} + }, + "outputs": [], + "source": [ + "return_layers = ['conv2', 'conv2_1', 'conv2_2', 'conv2_3', 'conv2_4']\n", "\n", - "def display_videos(video_ids, W=400, H=300, fs=1):\n", - " tab_contents = []\n", - " for i, video_id in enumerate(video_ids):\n", - " out = widgets.Output()\n", - " with out:\n", - " if video_ids[i][0] == 'Youtube':\n", - " video = YouTubeVideo(id=video_ids[i][1], width=W,\n", - " height=H, fs=fs, rel=0)\n", - " print(f'Video available at https://youtube.com/watch?v={video.id}')\n", - " else:\n", - " video = PlayVideo(id=video_ids[i][1], source=video_ids[i][0], width=W,\n", - " height=H, fs=fs, autoplay=False)\n", - " if video_ids[i][0] == 'Bilibili':\n", - " print(f'Video available at https://www.bilibili.com/video/{video.id}')\n", - " elif video_ids[i][0] == 'Osf':\n", - " print(f'Video available at https://osf.io/{video.id}')\n", - " display(video)\n", - " tab_contents.append(out)\n", - " return tab_contents\n", + "imgs, labels = sample_images(test_loader, n=50) #grab 500 samples from the test set\n", + "model_features = extract_features(model_recurrent, imgs.to(device), return_layers)\n", "\n", - "video_ids = [('Youtube', 'FHikIsQFQvM'), ('Bilibili', 'BV1qm421g7hV')]\n", - "tab_contents = display_videos(video_ids, W=854, H=480)\n", - "tabs = widgets.Tab()\n", - "tabs.children = tab_contents\n", - "for i in range(len(tab_contents)):\n", - " tabs.set_title(i, video_ids[i][0])\n", - "display(tabs)" + "plot_dim_reduction(model_features, labels, transformer_funcs =['PCA', 'MDS', 't-SNE'])" + ] + }, + { + "cell_type": "markdown", + "id": "1aaf5f4a", + "metadata": { + "execution": {} + }, + "source": [ + "## Representational geometry paths for recurrent models\n", + "\n", + "We can look at the model's recurrent computational steps as a path in the representational geometry space." ] }, { "cell_type": "code", "execution_count": null, - "id": "d2ce83bc-7e86-44d3-a40a-4ad46fd5a6df", + "id": "7f88274a", + "metadata": { + "execution": {} + }, + "outputs": [], + "source": [ + "imgs, labels = sample_images(test_loader, n=50) #grab 500 samples from the test set\n", + "model_features_recurrent = extract_features(model_recurrent, imgs.to(device), return_layers='all')\n", + "\n", + "#rdms, rdms_dict = calc_rdms(model_features)\n", + "features = {'recurrent model': model_features_recurrent}\n", + "model_colors = {'recurrent model': 'y'}\n", + "\n", + "rep_path(features, model_colors, labels)" + ] + }, + { + "cell_type": "markdown", + "id": "5c3fbd44", + "metadata": { + "execution": {} + }, + "source": [ + "We can also look at the paths taken by the feedforward and the recurrent models and compare them." + ] + }, + { + "cell_type": "markdown", + "id": "b25a8cc6", + "metadata": { + "execution": {} + }, + "source": [ + "If you recall back to Tutorial 2, we compared a standard feedward model's representations. We can extend our analysis of the analysis of the recurrent model's representations by making a side-by-side comparison. We can also look at the paths taken by the feedforward and the recurrent models and compare them. What we see above in the case of the recurrent model is the fast-shifting path through the geometric space from inputs to labels. This illustration serves to show that models take many different paths and can have very diverse underlying mechanisms but still arrive at a superficially similar output at the end of training." + ] + }, + { + "cell_type": "code", + "execution_count": null, + "id": "c904e840", "metadata": { "cellView": "form", "execution": {} @@ -160,7 +2697,19 @@ "outputs": [], "source": [ "# @title Submit your feedback\n", - "content_review(f\"{feedback_prefix}_DSA_video\")" + "content_review(f\"{feedback_prefix}_recurrent_models\")" + ] + }, + { + "cell_type": "markdown", + "id": "3ed56061", + "metadata": { + "execution": {} + }, + "source": [ + "# The Big Picture\n", + "\n", + "Today, you've looked at what it means to measure representations from different systems. These systems can be of the same type (multiple brain systems, multiple artificial models) as well as with representations between these systems. In NeuroAI, we're especially interested in such comparisons, comparing representational systems in deep learning networks, for instance, to brain recordings recorded while those biological systems experienced / perceived the same set of stimuli. Comparisons can be geometric / spatial or they can be temporal. Today, we looked at Dynamic Similarity Analysis, a method used to be able to capture the dependencies among trajectories, not just capturing the similarity of the full temporal sequence upon completion of the temporal sequence. It's often important to take into account multiple dimensions of representational similarity. A combination of tools is definitely required in the NeuroAI researcher's toolkit. We hope you have many chances to use these tools in your future work as NeuroAI researchers." ] } ], @@ -191,7 +2740,7 @@ "name": "python", "nbconvert_exporter": "python", "pygments_lexer": "ipython3", - "version": "3.9.19" + "version": "3.9.22" } }, "nbformat": 4, diff --git a/tutorials/W1D3_ComparingArtificialAndBiologicalNetworks/instructor/W1D3_Tutorial4.ipynb b/tutorials/W1D3_ComparingArtificialAndBiologicalNetworks/instructor/W1D3_Tutorial4.ipynb index 139d831e9..3566eb171 100644 --- a/tutorials/W1D3_ComparingArtificialAndBiologicalNetworks/instructor/W1D3_Tutorial4.ipynb +++ b/tutorials/W1D3_ComparingArtificialAndBiologicalNetworks/instructor/W1D3_Tutorial4.ipynb @@ -17,7 +17,7 @@ "execution": {} }, "source": [ - "# Tutorial 4: Representational geometry & noise\n", + "# (Bonus) Tutorial 4: Representational geometry & noise\n", "\n", "**Week 1, Day 3: Comparing Artificial And Biological Networks**\n", "\n", @@ -25,9 +25,9 @@ "\n", "__Content creators:__ Wenxuan Guo, Heiko Schütt\n", "\n", - "__Content reviewers:__ Alish Dipani, Samuele Bolotta, Yizhou Chen, RyeongKyung Yoon, Ruiyi Zhang, Lily Chamakura, Hlib Solodzhuk\n", + "__Content reviewers:__ Alish Dipani, Samuele Bolotta, Yizhou Chen, RyeongKyung Yoon, Ruiyi Zhang, Lily Chamakura, Hlib Solodzhuk, Alex Murphy\n", "\n", - "__Production editors:__ Konstantine Tsafatinos, Ella Batty, Spiros Chavlis, Samuele Bolotta, Hlib Solodzhuk, Patrick Mineault\n", + "__Production editors:__ Konstantine Tsafatinos, Ella Batty, Spiros Chavlis, Samuele Bolotta, Hlib Solodzhuk, Patrick Mineault, Alex Murphy\n", "\n", "Acknowledgments: the tutorial outline was written by Heiko Schütt. The content was greatly improved by discussions with Heiko, Hlib, and Alish, and the insightful illustrations presented in the paper by Walther et al. (2016)\n" ] @@ -61,7 +61,7 @@ "\n", "5. Using random projections to estimate distances. This section introduces the Johnson–Lindenstrauss Lemma, which states that random projections maintain the integrity of distance estimates in a lower-dimensional space. This concept is crucial for reducing dimensionality while preserving the relational structure of the data.\n", "\n", - "We will adhere to the notational conventions established by [Walther et al. (2016)](https://pubmed.ncbi.nlm.nih.gov/26707889/) for all discussed distance measures. " + "We will adhere to the notational conventions established by [Walther et al. (2016)](https://pubmed.ncbi.nlm.nih.gov/26707889/) for all discussed distance measures." ] }, { @@ -644,6 +644,72 @@ "display(tabs)" ] }, + { + "cell_type": "markdown", + "id": "b64eaea5", + "metadata": { + "execution": {} + }, + "source": [ + "The video below is additional information in more detail which was previously part of the introductory video for this course day. It provides some useful further information on the technical details mentioned during these tutorials. Please feel free to check it out and use it as a resource if you want to learn more or if you want to get a deeper understanding on some of the important details." + ] + }, + { + "cell_type": "code", + "execution_count": null, + "id": "200235dc", + "metadata": { + "cellView": "form", + "execution": {} + }, + "outputs": [], + "source": [ + "# @title Video 2 (BONUS): Extended Intro Video\n", + "\n", + "from ipywidgets import widgets\n", + "from IPython.display import YouTubeVideo\n", + "from IPython.display import IFrame\n", + "from IPython.display import display\n", + "\n", + "class PlayVideo(IFrame):\n", + " def __init__(self, id, source, page=1, width=400, height=300, **kwargs):\n", + " self.id = id\n", + " if source == 'Bilibili':\n", + " src = f'https://player.bilibili.com/player.html?bvid={id}&page={page}'\n", + " elif source == 'Osf':\n", + " src = f'https://mfr.ca-1.osf.io/render?url=https://osf.io/download/{id}/?direct%26mode=render'\n", + " super(PlayVideo, self).__init__(src, width, height, **kwargs)\n", + "\n", + "def display_videos(video_ids, W=400, H=300, fs=1):\n", + " tab_contents = []\n", + " for i, video_id in enumerate(video_ids):\n", + " out = widgets.Output()\n", + " with out:\n", + " if video_ids[i][0] == 'Youtube':\n", + " video = YouTubeVideo(id=video_ids[i][1], width=W,\n", + " height=H, fs=fs, rel=0)\n", + " print(f'Video available at https://youtube.com/watch?v={video.id}')\n", + " else:\n", + " video = PlayVideo(id=video_ids[i][1], source=video_ids[i][0], width=W,\n", + " height=H, fs=fs, autoplay=False)\n", + " if video_ids[i][0] == 'Bilibili':\n", + " print(f'Video available at https://www.bilibili.com/video/{video.id}')\n", + " elif video_ids[i][0] == 'Osf':\n", + " print(f'Video available at https://osf.io/{video.id}')\n", + " display(video)\n", + " tab_contents.append(out)\n", + " return tab_contents\n", + "\n", + "#assert 1 == 0, \"Upload this video\"\n", + "video_ids = [('Youtube', 'm9srqTx5ci0'), ('Bilibili', 'BV1meVjz3Eeh')]\n", + "tab_contents = display_videos(video_ids, W=854, H=480)\n", + "tabs = widgets.Tab()\n", + "tabs.children = tab_contents\n", + "for i in range(len(tab_contents)):\n", + " tabs.set_title(i, video_ids[i][0])\n", + "display(tabs)" + ] + }, { "cell_type": "code", "execution_count": null, @@ -1509,6 +1575,20 @@ "3. Cross-validated distance estimators (cross-validated Euclidean or Mahalanobis distance) can remove the positive bias introduced by noise.\n", "4. The Johnson–Lindenstrauss Lemma shows that random projections preserve the Euclidean distance with some distortions. Crucially, the distortion does not depend on the dimensionality of the original space." ] + }, + { + "cell_type": "markdown", + "id": "40936ec4", + "metadata": { + "execution": {} + }, + "source": [ + "# The Big Picture\n", + "\n", + "The goal of this tutorial is to provide you with some mathematical tools for your NeuroAI researcher toolkit. What happens when you pull out the Euclidean metric from your toolkit and, while this has worked well in the past, suddenly in different scenarios it doesn't seem to perform so well. Aha, you spot the potential for correlated noise and you reach deeper into your toolkit and pull out the Mahalanobis metric, which implicitly undoes the correlated noise in the model. Perhaps you can't even tell if there is any correlated noise in your data and you try with both metrics, and Mahalanobis works well but Euclidean does not, that can be a hunch that leads you to confirm the presence of correlated noise. \n", + "\n", + "Sometimes you might be faced with dimensionalities that are just too high to practically deal with in your use case. Then, why not recall what you learned about how random projections can reduce the dimensionality of a feature space and be largely resistant to corrupting the applicability of distance metrics. These metrics also might work better in this lower dimensional space. If you apply this idea and need to justify it, just reach into your NeuroAI toolkit and pull out the Johnson-Lindenstrauss Lemma as your justification." + ] } ], "metadata": { @@ -1539,7 +1619,7 @@ "name": "python", "nbconvert_exporter": "python", "pygments_lexer": "ipython3", - "version": "3.9.19" + "version": "3.9.22" } }, "nbformat": 4, diff --git a/tutorials/W1D3_ComparingArtificialAndBiologicalNetworks/instructor/W1D3_Tutorial5.ipynb b/tutorials/W1D3_ComparingArtificialAndBiologicalNetworks/instructor/W1D3_Tutorial5.ipynb index 77480d80e..8c9385c4b 100644 --- a/tutorials/W1D3_ComparingArtificialAndBiologicalNetworks/instructor/W1D3_Tutorial5.ipynb +++ b/tutorials/W1D3_ComparingArtificialAndBiologicalNetworks/instructor/W1D3_Tutorial5.ipynb @@ -17,17 +17,17 @@ "execution": {} }, "source": [ - "# Bonus Material: Dynamical similarity analysis (DSA)\n", + "# Tutorial 5: Dynamical Similarity Analysis (DSA)\n", "\n", "**Week 1, Day 3: Comparing Artificial And Biological Networks**\n", "\n", "**By Neuromatch Academy**\n", "\n", - "__Content creators:__ Mitchell Ostrow\n", + "__Content creators:__ Mitchell Ostrow, Alex Murphy\n", "\n", - "__Content reviewers:__ Xaq Pitkow, Hlib Solodzhuk\n", + "__Content reviewers:__ Xaq Pitkow, Hlib Solodzhuk, Alex Murphy\n", "\n", - "__Production editors:__ Konstantine Tsafatinos, Ella Batty, Spiros Chavlis, Samuele Bolotta, Hlib Solodzhuk, Patrick Mineault\n" + "__Production editors:__ Konstantine Tsafatinos, Ella Batty, Spiros Chavlis, Samuele Bolotta, Hlib Solodzhuk, Patrick Mineault, Alex Murphy\n" ] }, { @@ -52,7 +52,7 @@ "source": [ "# @title Install and import feedback gadget\n", "\n", - "!pip install vibecheck --quiet\n", + "!pip install vibecheck rsatoolbox --quiet\n", "\n", "from vibecheck import DatatopsContentReviewContainer\n", "def content_review(notebook_section: str):\n", @@ -67,7 +67,2087 @@ " ).render()\n", "\n", "\n", - "feedback_prefix = \"W1D3_Bonus\"" + "feedback_prefix = \"W1D5_DSA\"" + ] + }, + { + "cell_type": "code", + "execution_count": null, + "id": "ef9abaa3", + "metadata": { + "cellView": "form", + "execution": {} + }, + "outputs": [], + "source": [ + "# @title Helper functions\n", + "\n", + "import numpy as np\n", + "import matplotlib.pyplot as plt\n", + "import seaborn as sns\n", + "\n", + "def generate_2d_random_process(A, B, T=1000):\n", + " \"\"\"\n", + " Generates a 2D random process with the equation x(t+1) = A.x(t) + B.noise.\n", + "\n", + " Args:\n", + " A: 2x2 transition matrix.\n", + " B: 2x2 noise scaling matrix.\n", + " T: Number of time steps.\n", + "\n", + " Returns:\n", + " A NumPy array of shape (T+1, 2) representing the trajectory.\n", + " \"\"\"\n", + " # Assuming equilibrium distribution is zero mean and identity covariance for simplicity.\n", + " # You may adjust this according to your actual equilibrium distribution\n", + " x = np.zeros(2)\n", + "\n", + " trajectory = [x.copy()] # Initialize with x(0)\n", + " for t in range(T):\n", + " noise = np.random.normal(size=2) # Standard normal noise\n", + " x = np.dot(A, x) + np.dot(B, noise)\n", + " trajectory.append(x.copy())\n", + " return np.array(trajectory)\n", + "\n", + "\"\"\"This module computes the Havok DMD model for a given dataset.\"\"\"\n", + "import torch\n", + "\n", + "def embed_signal_torch(data, n_delays, delay_interval=1):\n", + " \"\"\"\n", + " Create a delay embedding from the provided tensor data.\n", + "\n", + " Parameters\n", + " ----------\n", + " data : torch.tensor\n", + " The data from which to create the delay embedding. Must be either: (1) a\n", + " 2-dimensional array/tensor of shape T x N where T is the number\n", + " of time points and N is the number of observed dimensions\n", + " at each time point, or (2) a 3-dimensional array/tensor of shape\n", + " K x T x N where K is the number of \"trials\" and T and N are\n", + " as defined above.\n", + "\n", + " n_delays : int\n", + " Parameter that controls the size of the delay embedding. Explicitly,\n", + " the number of delays to include.\n", + "\n", + " delay_interval : int\n", + " The number of time steps between each delay in the delay embedding. Defaults\n", + " to 1 time step.\n", + " \"\"\"\n", + " if isinstance(data, np.ndarray):\n", + " data = torch.from_numpy(data)\n", + " device = data.device\n", + "\n", + " if data.shape[int(data.ndim==3)] - (n_delays - 1)*delay_interval < 1:\n", + " raise ValueError(\"The number of delays is too large for the number of time points in the data!\")\n", + "\n", + " # initialize the embedding\n", + " if data.ndim == 3:\n", + " embedding = torch.zeros((data.shape[0], data.shape[1] - (n_delays - 1)*delay_interval, data.shape[2]*n_delays)).to(device)\n", + " else:\n", + " embedding = torch.zeros((data.shape[0] - (n_delays - 1)*delay_interval, data.shape[1]*n_delays)).to(device)\n", + "\n", + " for d in range(n_delays):\n", + " index = (n_delays - 1 - d)*delay_interval\n", + " ddelay = d*delay_interval\n", + "\n", + " if data.ndim == 3:\n", + " ddata = d*data.shape[2]\n", + " embedding[:,:, ddata: ddata + data.shape[2]] = data[:,index:data.shape[1] - ddelay]\n", + " else:\n", + " ddata = d*data.shape[1]\n", + " embedding[:, ddata:ddata + data.shape[1]] = data[index:data.shape[0] - ddelay]\n", + "\n", + " return embedding\n", + "\n", + "class DMD:\n", + " \"\"\"DMD class for computing and predicting with DMD models.\n", + " \"\"\"\n", + " def __init__(\n", + " self,\n", + " data,\n", + " n_delays,\n", + " delay_interval=1,\n", + " rank=None,\n", + " rank_thresh=None,\n", + " rank_explained_variance=None,\n", + " reduced_rank_reg=False,\n", + " lamb=0,\n", + " device='cpu',\n", + " verbose=False,\n", + " send_to_cpu=False,\n", + " steps_ahead=1\n", + " ):\n", + " \"\"\"\n", + " Parameters\n", + " ----------\n", + " data : np.ndarray or torch.tensor\n", + " The data to fit the DMD model to. Must be either: (1) a\n", + " 2-dimensional array/tensor of shape T x N where T is the number\n", + " of time points and N is the number of observed dimensions\n", + " at each time point, or (2) a 3-dimensional array/tensor of shape\n", + " K x T x N where K is the number of \"trials\" and T and N are\n", + " as defined above.\n", + "\n", + " n_delays : int\n", + " Parameter that controls the size of the delay embedding. Explicitly,\n", + " the number of delays to include.\n", + "\n", + " delay_interval : int\n", + " The number of time steps between each delay in the delay embedding. Defaults\n", + " to 1 time step.\n", + "\n", + " rank : int\n", + " The rank of V in fitting HAVOK DMD - i.e., the number of columns of V to\n", + " use to fit the DMD model. Defaults to None, in which case all columns of V\n", + " will be used.\n", + "\n", + " rank_thresh : float\n", + " Parameter that controls the rank of V in fitting HAVOK DMD by dictating a threshold\n", + " of singular values to use. Explicitly, the rank of V will be the number of singular\n", + " values greater than rank_thresh. Defaults to None.\n", + "\n", + " rank_explained_variance : float\n", + " Parameter that controls the rank of V in fitting HAVOK DMD by indicating the percentage of\n", + " cumulative explained variance that should be explained by the columns of V. Defaults to None.\n", + "\n", + " reduced_rank_reg : bool\n", + " Determines whether to use reduced rank regression (True) or principal component regression (False)\n", + "\n", + " lamb : float\n", + " Regularization parameter for ridge regression. Defaults to 0.\n", + "\n", + " device: string, int, or torch.device\n", + " A string, int or torch.device object to indicate the device to torch.\n", + "\n", + " verbose: bool\n", + " If True, print statements will be provided about the progress of the fitting procedure.\n", + "\n", + " send_to_cpu: bool\n", + " If True, will send all tensors in the object back to the cpu after everything is computed.\n", + " This is implemented to prevent gpu memory overload when computing multiple DMDs.\n", + "\n", + " steps_ahead: int\n", + " The number of time steps ahead to predict. Defaults to 1.\n", + " \"\"\"\n", + "\n", + " self.device = device\n", + " self._init_data(data)\n", + "\n", + " self.n_delays = n_delays\n", + " self.delay_interval = delay_interval\n", + " self.rank = rank\n", + " self.rank_thresh = rank_thresh\n", + " self.rank_explained_variance = rank_explained_variance\n", + " self.reduced_rank_reg = reduced_rank_reg\n", + " self.lamb = lamb\n", + " self.verbose = verbose\n", + " self.send_to_cpu = send_to_cpu\n", + " self.steps_ahead = steps_ahead\n", + "\n", + " # Hankel matrix\n", + " self.H = None\n", + "\n", + " # SVD attributes\n", + " self.U = None\n", + " self.S = None\n", + " self.V = None\n", + " self.S_mat = None\n", + " self.S_mat_inv = None\n", + "\n", + " # DMD attributes\n", + " self.A_v = None\n", + " self.A_havok_dmd = None\n", + "\n", + " def _init_data(self, data):\n", + " # check if the data is an np.ndarry - if so, convert it to Torch\n", + " if isinstance(data, np.ndarray):\n", + " data = torch.from_numpy(data)\n", + " self.data = data\n", + " # create attributes for the data dimensions\n", + " if self.data.ndim == 3:\n", + " self.ntrials = self.data.shape[0]\n", + " self.window = self.data.shape[1]\n", + " self.n = self.data.shape[2]\n", + " else:\n", + " self.window = self.data.shape[0]\n", + " self.n = self.data.shape[1]\n", + " self.ntrials = 1\n", + "\n", + " def compute_hankel(\n", + " self,\n", + " data=None,\n", + " n_delays=None,\n", + " delay_interval=None,\n", + " ):\n", + " \"\"\"\n", + " Computes the Hankel matrix from the provided data.\n", + "\n", + " Parameters\n", + " ----------\n", + " data : np.ndarray or torch.tensor\n", + " The data to fit the DMD model to. Must be either: (1) a\n", + " 2-dimensional array/tensor of shape T x N where T is the number\n", + " of time points and N is the number of observed dimensions\n", + " at each time point, or (2) a 3-dimensional array/tensor of shape\n", + " K x T x N where K is the number of \"trials\" and T and N are\n", + " as defined above.\n", + "\n", + " n_delays : int\n", + " Parameter that controls the size of the delay embedding. Explicitly,\n", + " the number of delays to include. Defaults to None - provide only if you want\n", + " to override the value of n_delays from the init.\n", + "\n", + " delay_interval : int\n", + " The number of time steps between each delay in the delay embedding. Defaults\n", + " to 1 time step. Defaults to None - provide only if you want\n", + " to override the value of n_delays from the init.\n", + " \"\"\"\n", + " if self.verbose:\n", + " print(\"Computing Hankel matrix ...\")\n", + "\n", + " # if parameters are provided, overwrite them from the init\n", + " self.data = self.data if data is None else self._init_data(data)\n", + " self.n_delays = self.n_delays if n_delays is None else n_delays\n", + " self.delay_interval = self.delay_interval if delay_interval is None else delay_interval\n", + " self.data = self.data.to(self.device)\n", + "\n", + " self.H = embed_signal_torch(self.data, self.n_delays, self.delay_interval)\n", + "\n", + " if self.verbose:\n", + " print(\"Hankel matrix computed!\")\n", + "\n", + " def compute_svd(self):\n", + " \"\"\"\n", + " Computes the SVD of the Hankel matrix.\n", + " \"\"\"\n", + "\n", + " if self.verbose:\n", + " print(\"Computing SVD on Hankel matrix ...\")\n", + " if self.H.ndim == 3: #flatten across trials for 3d\n", + " H = self.H.reshape(self.H.shape[0] * self.H.shape[1], self.H.shape[2])\n", + " else:\n", + " H = self.H\n", + " # compute the SVD\n", + " U, S, Vh = torch.linalg.svd(H.T, full_matrices=False)\n", + "\n", + " # update attributes\n", + " V = Vh.T\n", + " self.U = U\n", + " self.S = S\n", + " self.V = V\n", + "\n", + " # construct the singuar value matrix and its inverse\n", + " # dim = self.n_delays * self.n\n", + " # s = len(S)\n", + " # self.S_mat = torch.zeros(dim, dim,dtype=torch.float32).to(self.device)\n", + " # self.S_mat_inv = torch.zeros(dim, dim,dtype=torch.float32).to(self.device)\n", + " self.S_mat = torch.diag(S).to(self.device)\n", + " self.S_mat_inv= torch.diag(1 / S).to(self.device)\n", + "\n", + " # compute explained variance\n", + " exp_variance_inds = self.S**2 / ((self.S**2).sum())\n", + " cumulative_explained = torch.cumsum(exp_variance_inds, 0)\n", + " self.cumulative_explained_variance = cumulative_explained\n", + "\n", + " #make the X and Y components of the regression by staggering the hankel eigen-time delay coordinates by time\n", + " if self.reduced_rank_reg:\n", + " V = self.V\n", + " else:\n", + " V = self.V\n", + "\n", + " if self.ntrials > 1:\n", + " if V.numel() < self.H.numel():\n", + " raise ValueError(\"The dimension of the SVD of the Hankel matrix is smaller than the dimension of the Hankel matrix itself. \\n \\\n", + " This is likely due to the number of time points being smaller than the number of dimensions. \\n \\\n", + " Please reduce the number of delays.\")\n", + "\n", + " V = V.reshape(self.H.shape)\n", + "\n", + " #first reshape back into Hankel shape, separated by trials\n", + " newshape = (self.H.shape[0]*(self.H.shape[1]-self.steps_ahead),self.H.shape[2])\n", + " self.Vt_minus = V[:,:-self.steps_ahead].reshape(newshape)\n", + " self.Vt_plus = V[:,self.steps_ahead:].reshape(newshape)\n", + " else:\n", + " self.Vt_minus = V[:-self.steps_ahead]\n", + " self.Vt_plus = V[self.steps_ahead:]\n", + "\n", + "\n", + " if self.verbose:\n", + " print(\"SVD complete!\")\n", + "\n", + " def recalc_rank(self,rank,rank_thresh,rank_explained_variance):\n", + " '''\n", + " Parameters\n", + " ----------\n", + " rank : int\n", + " The rank of V in fitting HAVOK DMD - i.e., the number of columns of V to\n", + " use to fit the DMD model. Defaults to None, in which case all columns of V\n", + " will be used. Provide only if you want to override the value from the init.\n", + "\n", + " rank_thresh : float\n", + " Parameter that controls the rank of V in fitting HAVOK DMD by dictating a threshold\n", + " of singular values to use. Explicitly, the rank of V will be the number of singular\n", + " values greater than rank_thresh. Defaults to None - provide only if you want\n", + " to override the value from the init.\n", + "\n", + " rank_explained_variance : float\n", + " Parameter that controls the rank of V in fitting HAVOK DMD by indicating the percentage of\n", + " cumulative explained variance that should be explained by the columns of V. Defaults to None -\n", + " provide only if you want to overried the value from the init.\n", + " '''\n", + " # if an argument was provided, overwrite the stored rank information\n", + " none_vars = (rank is None) + (rank_thresh is None) + (rank_explained_variance is None)\n", + " if none_vars != 3:\n", + " self.rank = None\n", + " self.rank_thresh = None\n", + " self.rank_explained_variance = None\n", + "\n", + " self.rank = self.rank if rank is None else rank\n", + " self.rank_thresh = self.rank_thresh if rank_thresh is None else rank_thresh\n", + " self.rank_explained_variance = self.rank_explained_variance if rank_explained_variance is None else rank_explained_variance\n", + "\n", + " none_vars = (self.rank is None) + (self.rank_thresh is None) + (self.rank_explained_variance is None)\n", + " if none_vars < 2:\n", + " raise ValueError(\"More than one value was provided between rank, rank_thresh, and rank_explained_variance. Please provide only one of these, and ensure the others are None!\")\n", + " elif none_vars == 3:\n", + " self.rank = len(self.S)\n", + "\n", + " if self.reduced_rank_reg:\n", + " S = self.proj_mat_S\n", + " else:\n", + " S = self.S\n", + "\n", + " if rank_thresh is not None:\n", + " if S[-1] > rank_thresh:\n", + " self.rank = len(S)\n", + " else:\n", + " self.rank = torch.argmax(torch.arange(len(S), 0, -1).to(self.device)*(S < rank_thresh))\n", + "\n", + " if rank_explained_variance is not None:\n", + " self.rank = int(torch.argmax((self.cumulative_explained_variance > rank_explained_variance).type(torch.int)).cpu().numpy())\n", + "\n", + " if self.rank > self.H.shape[-1]:\n", + " self.rank = self.H.shape[-1]\n", + "\n", + " if self.rank is None:\n", + " if S[-1] > self.rank_thresh:\n", + " self.rank = len(S)\n", + " else:\n", + " self.rank = torch.argmax(torch.arange(len(S), 0, -1).to(self.device)*(S < self.rank_thresh))\n", + "\n", + " def compute_havok_dmd(self,lamb=None):\n", + " \"\"\"\n", + " Computes the Havok DMD matrix (Principal Component Regression)\n", + "\n", + " Parameters\n", + " ----------\n", + " lamb : float\n", + " Regularization parameter for ridge regression. Defaults to 0 - provide only if you want\n", + " to override the value of n_delays from the init.\n", + "\n", + " \"\"\"\n", + " if self.verbose:\n", + " print(\"Computing least squares fits to HAVOK DMD ...\")\n", + "\n", + " self.lamb = self.lamb if lamb is None else lamb\n", + "\n", + " A_v = (torch.linalg.inv(self.Vt_minus[:, :self.rank].T @ self.Vt_minus[:, :self.rank] + self.lamb*torch.eye(self.rank).to(self.device)) \\\n", + " @ self.Vt_minus[:, :self.rank].T @ self.Vt_plus[:, :self.rank]).T\n", + " self.A_v = A_v\n", + " self.A_havok_dmd = self.U @ self.S_mat[:self.U.shape[1], :self.rank] @ self.A_v @ self.S_mat_inv[:self.rank, :self.U.shape[1]] @ self.U.T\n", + "\n", + " if self.verbose:\n", + " print(\"Least squares complete! \\n\")\n", + "\n", + " def compute_proj_mat(self,lamb=None):\n", + " if self.verbose:\n", + " print(\"Computing Projector Matrix for Reduced Rank Regression\")\n", + "\n", + " self.lamb = self.lamb if lamb is None else lamb\n", + "\n", + " self.proj_mat = self.Vt_plus.T @ self.Vt_minus @ torch.linalg.inv(self.Vt_minus.T @ self.Vt_minus +\n", + " self.lamb*torch.eye(self.Vt_minus.shape[1]).to(self.device)) @ \\\n", + " self.Vt_minus.T @ self.Vt_plus\n", + "\n", + " self.proj_mat_S, self.proj_mat_V = torch.linalg.eigh(self.proj_mat)\n", + " #todo: more efficient to flip ranks (negative index) in compute_reduced_rank_regression but also less interpretable\n", + " self.proj_mat_S = torch.flip(self.proj_mat_S, dims=(0,))\n", + " self.proj_mat_V = torch.flip(self.proj_mat_V, dims=(1,))\n", + "\n", + " if self.verbose:\n", + " print(\"Projector Matrix computed! \\n\")\n", + "\n", + " def compute_reduced_rank_regression(self,lamb=None):\n", + " if self.verbose:\n", + " print(\"Computing Reduced Rank Regression ...\")\n", + "\n", + " self.lamb = self.lamb if lamb is None else lamb\n", + " proj_mat = self.proj_mat_V[:,:self.rank] @ self.proj_mat_V[:,:self.rank].T\n", + " B_ols = torch.linalg.inv(self.Vt_minus.T @ self.Vt_minus + self.lamb*torch.eye(self.Vt_minus.shape[1]).to(self.device)) @ self.Vt_minus.T @ self.Vt_plus\n", + "\n", + " self.A_v = B_ols @ proj_mat\n", + " self.A_havok_dmd = self.U @ self.S_mat[:self.U.shape[1],:self.A_v.shape[1]] @ self.A_v.T @ self.S_mat_inv[:self.A_v.shape[0], :self.U.shape[1]] @ self.U.T\n", + "\n", + "\n", + " if self.verbose:\n", + " print(\"Reduced Rank Regression complete! \\n\")\n", + "\n", + " def fit(\n", + " self,\n", + " data=None,\n", + " n_delays=None,\n", + " delay_interval=None,\n", + " rank=None,\n", + " rank_thresh=None,\n", + " rank_explained_variance=None,\n", + " lamb=None,\n", + " device=None,\n", + " verbose=None,\n", + " steps_ahead=None\n", + " ):\n", + " \"\"\"\n", + " Parameters\n", + " ----------\n", + " data : np.ndarray or torch.tensor\n", + " The data to fit the DMD model to. Must be either: (1) a\n", + " 2-dimensional array/tensor of shape T x N where T is the number\n", + " of time points and N is the number of observed dimensions\n", + " at each time point, or (2) a 3-dimensional array/tensor of shape\n", + " K x T x N where K is the number of \"trials\" and T and N are\n", + " as defined above. Defaults to None - provide only if you want to\n", + " override the value from the init.\n", + "\n", + " n_delays : int\n", + " Parameter that controls the size of the delay embedding. Explicitly,\n", + " the number of delays to include. Defaults to None - provide only if you want to\n", + " override the value from the init.\n", + "\n", + " delay_interval : int\n", + " The number of time steps between each delay in the delay embedding. Defaults to None -\n", + " provide only if you want to override the value from the init.\n", + "\n", + " rank : int\n", + " The rank of V in fitting HAVOK DMD - i.e., the number of columns of V to\n", + " use to fit the DMD model. Defaults to None, in which case all columns of V\n", + " will be used - provide only if you want to\n", + " override the value from the init.\n", + "\n", + " rank_thresh : int\n", + " Parameter that controls the rank of V in fitting HAVOK DMD by dictating a threshold\n", + " of singular values to use. Explicitly, the rank of V will be the number of singular\n", + " values greater than rank_thresh. Defaults to None - provide only if you want to\n", + " override the value from the init.\n", + "\n", + " rank_explained_variance : float\n", + " Parameter that controls the rank of V in fitting HAVOK DMD by indicating the percentage of\n", + " cumulative explained variance that should be explained by the columns of V. Defaults to None -\n", + " provide only if you want to overried the value from the init.\n", + "\n", + " lamb : float\n", + " Regularization parameter for ridge regression. Defaults to None - provide only if you want to\n", + " override the value from the init.\n", + "\n", + " device: string or int\n", + " A string or int to indicate the device to torch. For example, can be 'cpu' or 'cuda',\n", + " or alternatively 0 if the intenion is to use GPU device 0. Defaults to None - provide only\n", + " if you want to override the value from the init.\n", + "\n", + " verbose: bool\n", + " If True, print statements will be provided about the progress of the fitting procedure.\n", + " Defaults to None - provide only if you want to override the value from the init.\n", + "\n", + " steps_ahead: int\n", + " The number of time steps ahead to predict. Defaults to 1.\n", + "\n", + " \"\"\"\n", + " # if parameters are provided, overwrite them from the init\n", + " self.steps_ahead = self.steps_ahead if steps_ahead is None else steps_ahead\n", + " self.device = self.device if device is None else device\n", + " self.verbose = self.verbose if verbose is None else verbose\n", + "\n", + " self.compute_hankel(data, n_delays, delay_interval)\n", + " self.compute_svd()\n", + "\n", + " if self.reduced_rank_reg:\n", + " self.compute_proj_mat(lamb)\n", + " self.recalc_rank(rank,rank_thresh,rank_explained_variance)\n", + " self.compute_reduced_rank_regression(lamb)\n", + " else:\n", + " self.recalc_rank(rank,rank_thresh,rank_explained_variance)\n", + " self.compute_havok_dmd(lamb)\n", + "\n", + " if self.send_to_cpu:\n", + " self.all_to_device('cpu') #send back to the cpu to save memory\n", + "\n", + " def predict(\n", + " self,\n", + " test_data=None,\n", + " reseed=None,\n", + " full_return=False\n", + " ):\n", + " \"\"\"\n", + " Returns\n", + " -------\n", + " pred_data : torch.tensor\n", + " The predictions generated by the HAVOK model. Of the same shape as test_data. Note that the first\n", + " (self.n_delays - 1)*self.delay_interval + 1 time steps of the generated predictions are by construction\n", + " identical to the test_data.\n", + "\n", + " H_test_havok_dmd : torch.tensor (Optional)\n", + " Returned if full_return=True. The predicted Hankel matrix generated by the HAVOK model.\n", + " H_test : torch.tensor (Optional)\n", + " Returned if full_return=True. The true Hankel matrix\n", + " \"\"\"\n", + " # initialize test_data\n", + " if test_data is None:\n", + " test_data = self.data\n", + " if isinstance(test_data, np.ndarray):\n", + " test_data = torch.from_numpy(test_data).to(self.device)\n", + " ndim = test_data.ndim\n", + " if ndim == 2:\n", + " test_data = test_data.unsqueeze(0)\n", + " H_test = embed_signal_torch(test_data, self.n_delays, self.delay_interval)\n", + " steps_ahead = self.steps_ahead if self.steps_ahead is not None else 1\n", + "\n", + " if reseed is None:\n", + " reseed = 1\n", + "\n", + " H_test_havok_dmd = torch.zeros(H_test.shape).to(self.device)\n", + " H_test_havok_dmd[:, :steps_ahead] = H_test[:, :steps_ahead]\n", + "\n", + " A = self.A_havok_dmd.unsqueeze(0)\n", + " for t in range(steps_ahead, H_test.shape[1]):\n", + " if t % reseed == 0:\n", + " H_test_havok_dmd[:, t] = (A @ H_test[:, t - steps_ahead].transpose(-2, -1)).transpose(-2, -1)\n", + " else:\n", + " H_test_havok_dmd[:, t] = (A @ H_test_havok_dmd[:, t - steps_ahead].transpose(-2, -1)).transpose(-2, -1)\n", + " pred_data = torch.hstack([test_data[:, :(self.n_delays - 1)*self.delay_interval + steps_ahead], H_test_havok_dmd[:, steps_ahead:, :self.n]])\n", + "\n", + " if ndim == 2:\n", + " pred_data = pred_data[0]\n", + "\n", + " if full_return:\n", + " return pred_data, H_test_havok_dmd, H_test\n", + " else:\n", + " return pred_data\n", + "\n", + " def all_to_device(self,device='cpu'):\n", + " for k,v in self.__dict__.items():\n", + " if isinstance(v, torch.Tensor):\n", + " self.__dict__[k] = v.to(device)\n", + "\n", + "from typing import Literal\n", + "import torch.nn as nn\n", + "import torch.optim as optim\n", + "from typing import Literal\n", + "import torch.nn.utils.parametrize as parametrize\n", + "from scipy.stats import wasserstein_distance\n", + "\n", + "def pad_zeros(A,B,device):\n", + "\n", + " with torch.no_grad():\n", + " dim = max(A.shape[0],B.shape[0])\n", + " A1 = torch.zeros((dim,dim)).float()\n", + " A1[:A.shape[0],:A.shape[1]] += A\n", + " A = A1.float().to(device)\n", + "\n", + " B1 = torch.zeros((dim,dim)).float()\n", + " B1[:B.shape[0],:B.shape[1]] += B\n", + " B = B1.float().to(device)\n", + "\n", + " return A,B\n", + "\n", + "class LearnableSimilarityTransform(nn.Module):\n", + " \"\"\"\n", + " Computes the similarity transform for a learnable orthonormal matrix C\n", + " \"\"\"\n", + " def __init__(self, n,orthog=True):\n", + " \"\"\"\n", + " Parameters\n", + " __________\n", + " n : int\n", + " dimension of the C matrix\n", + " \"\"\"\n", + " super(LearnableSimilarityTransform, self).__init__()\n", + " #initialize orthogonal matrix as identity\n", + " self.C = nn.Parameter(torch.eye(n).float())\n", + " self.orthog = orthog\n", + "\n", + " def forward(self, B):\n", + " if self.orthog:\n", + " return self.C @ B @ self.C.transpose(-1, -2)\n", + " else:\n", + " return self.C @ B @ torch.linalg.inv(self.C)\n", + "\n", + "class Skew(nn.Module):\n", + " def __init__(self,n,device):\n", + " \"\"\"\n", + " Computes a skew-symmetric matrix X from some parameters (also called X)\n", + "\n", + " \"\"\"\n", + " super().__init__()\n", + "\n", + " self.L1 = nn.Linear(n,n,bias = False, device = device)\n", + " self.L2 = nn.Linear(n,n,bias = False, device = device)\n", + " self.L3 = nn.Linear(n,n,bias = False, device = device)\n", + "\n", + " def forward(self, X):\n", + " X = torch.tanh(self.L1(X))\n", + " X = torch.tanh(self.L2(X))\n", + " X = self.L3(X)\n", + " return X - X.transpose(-1, -2)\n", + "\n", + "class Matrix(nn.Module):\n", + " def __init__(self,n,device):\n", + " \"\"\"\n", + " Computes a matrix X from some parameters (also called X)\n", + "\n", + " \"\"\"\n", + " super().__init__()\n", + "\n", + " self.L1 = nn.Linear(n,n,bias = False, device = device)\n", + " self.L2 = nn.Linear(n,n,bias = False, device = device)\n", + " self.L3 = nn.Linear(n,n,bias = False, device = device)\n", + "\n", + " def forward(self, X):\n", + " X = torch.tanh(self.L1(X))\n", + " X = torch.tanh(self.L2(X))\n", + " X = self.L3(X)\n", + " return X\n", + "\n", + "class CayleyMap(nn.Module):\n", + " \"\"\"\n", + " Maps a skew-symmetric matrix to an orthogonal matrix in O(n)\n", + " \"\"\"\n", + " def __init__(self, n, device):\n", + " \"\"\"\n", + " Parameters\n", + " __________\n", + "\n", + " n : int\n", + " dimension of the matrix we want to map\n", + "\n", + " device : {'cpu','cuda'} or int\n", + " hardware device on which to send the matrix\n", + " \"\"\"\n", + " super().__init__()\n", + " self.register_buffer(\"Id\", torch.eye(n,device = device))\n", + "\n", + " def forward(self, X):\n", + " # (I + X)(I - X)^{-1}\n", + " return torch.linalg.solve(self.Id + X, self.Id - X)\n", + "\n", + "class SimilarityTransformDist:\n", + " \"\"\"\n", + " Computes the Procrustes Analysis over Vector Fields\n", + " \"\"\"\n", + " def __init__(self,\n", + " iters = 200,\n", + " score_method: Literal[\"angular\", \"euclidean\",\"wasserstein\"] = \"angular\",\n", + " lr = 0.01,\n", + " device: Literal[\"cpu\",\"cuda\"] = 'cpu',\n", + " verbose = False,\n", + " group: Literal[\"O(n)\",\"SO(n)\",\"GL(n)\"] = \"O(n)\",\n", + " wasserstein_compare = None\n", + " ):\n", + " \"\"\"\n", + " Parameters\n", + " _________\n", + " iters : int\n", + " number of iterations to perform gradient descent\n", + "\n", + " score_method : {\"angular\",\"euclidean\",\"wasserstein\"}\n", + " specifies the type of metric to use\n", + " \"wasserstein\" will compare the singular values or eigenvalues\n", + " of the two matrices as in Redman et al., (2023)\n", + "\n", + " lr : float\n", + " learning rate\n", + "\n", + " device : {'cpu','cuda'} or int\n", + "\n", + " verbose : bool\n", + " prints when finished optimizing\n", + "\n", + " group : {'SO(n)','O(n)', 'GL(n)'}\n", + " specifies the group of matrices to optimize over\n", + "\n", + " wasserstein_compare : {'sv','eig',None}\n", + " specifies whether to compare the singular values or eigenvalues\n", + " if score_method is \"wasserstein\", or the shapes are different\n", + " \"\"\"\n", + "\n", + " self.iters = iters\n", + " self.score_method = score_method\n", + " self.lr = lr\n", + " self.verbose = verbose\n", + " self.device = device\n", + " self.C_star = None\n", + " self.A = None\n", + " self.B = None\n", + " self.group = group\n", + " self.wasserstein_compare = wasserstein_compare\n", + "\n", + " def fit(self,\n", + " A,\n", + " B,\n", + " iters = None,\n", + " lr = None,\n", + " group = None,\n", + " ):\n", + " \"\"\"\n", + " Computes the optimal matrix C over specified group\n", + "\n", + " Parameters\n", + " __________\n", + " A : np.array or torch.tensor\n", + " first data matrix\n", + " B : np.array or torch.tensor\n", + " second data matrix\n", + " iters : int or None\n", + " number of optimization steps, if None then resorts to saved self.iters\n", + " lr : float or None\n", + " learning rate, if None then resorts to saved self.lr\n", + " group : {'SO(n)','O(n)', 'GL(n)'}\n", + " specifies the group of matrices to optimize over\n", + "\n", + " Returns\n", + " _______\n", + " None\n", + " \"\"\"\n", + " assert A.shape[0] == A.shape[1]\n", + " assert B.shape[0] == B.shape[1]\n", + "\n", + " A = A.to(self.device)\n", + " B = B.to(self.device)\n", + " self.A,self.B = A,B\n", + " lr = self.lr if lr is None else lr\n", + " iters = self.iters if iters is None else iters\n", + " group = self.group if group is None else group\n", + "\n", + " if group in {\"SO(n)\", \"O(n)\"}:\n", + " self.losses, self.C_star, self.sim_net = self.optimize_C(A,\n", + " B,\n", + " lr,iters,\n", + " orthog=True,\n", + " verbose=self.verbose)\n", + " if group == \"O(n)\":\n", + " #permute the first row and column of B then rerun the optimization\n", + " P = torch.eye(B.shape[0],device=self.device)\n", + " if P.shape[0] > 1:\n", + " P[[0, 1], :] = P[[1, 0], :]\n", + " losses, C_star, sim_net = self.optimize_C(A,\n", + " P @ B @ P.T,\n", + " lr,iters,\n", + " orthog=True,\n", + " verbose=self.verbose)\n", + " if losses[-1] < self.losses[-1]:\n", + " self.losses = losses\n", + " self.C_star = C_star @ P\n", + " self.sim_net = sim_net\n", + " if group == \"GL(n)\":\n", + " self.losses, self.C_star, self.sim_net = self.optimize_C(A,\n", + " B,\n", + " lr,iters,\n", + " orthog=False,\n", + " verbose=self.verbose)\n", + "\n", + " def optimize_C(self,A,B,lr,iters,orthog,verbose):\n", + " #parameterize mapping to be orthogonal\n", + " n = A.shape[0]\n", + " sim_net = LearnableSimilarityTransform(n,orthog=orthog).to(self.device)\n", + " if orthog:\n", + " parametrize.register_parametrization(sim_net, \"C\", Skew(n,self.device))\n", + " parametrize.register_parametrization(sim_net, \"C\", CayleyMap(n,self.device))\n", + " else:\n", + " parametrize.register_parametrization(sim_net, \"C\", Matrix(n,self.device))\n", + "\n", + " simdist_loss = nn.MSELoss(reduction = 'sum')\n", + "\n", + " optimizer = optim.Adam(sim_net.parameters(), lr=lr)\n", + " # scheduler = optim.lr_scheduler.ExponentialLR(optimizer, gamma=0.999)\n", + "\n", + " losses = []\n", + " A /= torch.linalg.norm(A)\n", + " B /= torch.linalg.norm(B)\n", + " for _ in range(iters):\n", + " # Zero the gradients of the optimizer.\n", + " optimizer.zero_grad()\n", + " # Compute the Frobenius norm between A and the product.\n", + " loss = simdist_loss(A, sim_net(B))\n", + "\n", + " loss.backward()\n", + "\n", + " optimizer.step()\n", + " # if _ % 99:\n", + " # scheduler.step()\n", + " losses.append(loss.item())\n", + "\n", + " if verbose:\n", + " print(\"Finished optimizing C\")\n", + "\n", + " C_star = sim_net.C.detach()\n", + " return losses, C_star,sim_net\n", + "\n", + " def score(self,A=None,B=None,score_method=None,group=None):\n", + " \"\"\"\n", + " Given an optimal C already computed, calculate the metric\n", + "\n", + " Parameters\n", + " __________\n", + " A : np.array or torch.tensor or None\n", + " first data matrix, if None defaults to the saved matrix in fit\n", + " B : np.array or torch.tensor or None\n", + " second data matrix if None, defaults to the savec matrix in fit\n", + " score_method : None or {'angular','euclidean'}\n", + " overwrites the score method in the object for this application\n", + " Returns\n", + " _______\n", + "\n", + " score : float\n", + " similarity of the data under the similarity transform w.r.t C\n", + " \"\"\"\n", + " assert self.C_star is not None\n", + " A = self.A if A is None else A\n", + " B = self.B if B is None else B\n", + " assert A is not None\n", + " assert B is not None\n", + " assert A.shape == self.C_star.shape\n", + " assert B.shape == self.C_star.shape\n", + " score_method = self.score_method if score_method is None else score_method\n", + " group = self.group if group is None else group\n", + " with torch.no_grad():\n", + " if not isinstance(A,torch.Tensor):\n", + " A = torch.from_numpy(A).float().to(self.device)\n", + " if not isinstance(B,torch.Tensor):\n", + " B = torch.from_numpy(B).float().to(self.device)\n", + " C = self.C_star.to(self.device)\n", + "\n", + " if group in {\"SO(n)\", \"O(n)\"}:\n", + " Cinv = C.T\n", + " elif group in {\"GL(n)\"}:\n", + " Cinv = torch.linalg.inv(C)\n", + " else:\n", + " raise AssertionError(\"Need proper group name\")\n", + " if score_method == 'angular':\n", + " num = torch.trace(A.T @ C @ B @ Cinv)\n", + " den = torch.norm(A,p = 'fro')*torch.norm(B,p = 'fro')\n", + " score = torch.arccos(num/den).cpu().numpy()\n", + " if np.isnan(score): #around -1 and 1, we sometimes get NaNs due to arccos\n", + " if num/den < 0:\n", + " score = np.pi\n", + " else:\n", + " score = 0\n", + " else:\n", + " score = torch.norm(A - C @ B @ Cinv,p='fro').cpu().numpy().item() #/ A.numpy().size\n", + "\n", + " return score\n", + "\n", + " def fit_score(self,\n", + " A,\n", + " B,\n", + " iters = None,\n", + " lr = None,\n", + " score_method = None,\n", + " zero_pad = True,\n", + " group = None):\n", + " \"\"\"\n", + " for efficiency, computes the optimal matrix and returns the score\n", + "\n", + " Parameters\n", + " __________\n", + " A : np.array or torch.tensor\n", + " first data matrix\n", + " B : np.array or torch.tensor\n", + " second data matrix\n", + " iters : int or None\n", + " number of optimization steps, if None then resorts to saved self.iters\n", + " lr : float or None\n", + " learning rate, if None then resorts to saved self.lr\n", + " score_method : {'angular','euclidean'} or None\n", + " overwrites parameter in the class\n", + " zero_pad : bool\n", + " if True, then the smaller matrix will be zero padded so its the same size\n", + " Returns\n", + " _______\n", + "\n", + " score : float\n", + " similarity of the data under the similarity transform w.r.t C\n", + "\n", + " \"\"\"\n", + " score_method = self.score_method if score_method is None else score_method\n", + " group = self.group if group is None else group\n", + "\n", + " if isinstance(A,np.ndarray):\n", + " A = torch.from_numpy(A).float()\n", + " if isinstance(B,np.ndarray):\n", + " B = torch.from_numpy(B).float()\n", + "\n", + " assert A.shape[0] == B.shape[1] or self.wasserstein_compare is not None\n", + " if A.shape[0] != B.shape[0]:\n", + " if self.wasserstein_compare is None:\n", + " raise AssertionError(\"Matrices must be the same size unless using wasserstein distance\")\n", + " else: #otherwise resort to L2 Wasserstein over singular or eigenvalues\n", + " print(f\"resorting to wasserstein distance over {self.wasserstein_compare}\")\n", + "\n", + " if self.score_method == \"wasserstein\":\n", + " assert self.wasserstein_compare in {\"sv\",\"eig\"}\n", + " if self.wasserstein_compare == \"sv\":\n", + " a = torch.svd(A).S.view(-1,1)\n", + " b = torch.svd(B).S.view(-1,1)\n", + " elif self.wasserstein_compare == \"eig\":\n", + " a = torch.linalg.eig(A).eigenvalues\n", + " a = torch.vstack([a.real,a.imag]).T\n", + "\n", + " b = torch.linalg.eig(B).eigenvalues\n", + " b = torch.vstack([b.real,b.imag]).T\n", + " else:\n", + " raise AssertionError(\"wasserstein_compare must be 'sv' or 'eig'\")\n", + " device = a.device\n", + " a = a#.cpu()\n", + " b = b#.cpu()\n", + " M = ot.dist(a,b)#.numpy()\n", + " a,b = torch.ones(a.shape[0])/a.shape[0],torch.ones(b.shape[0])/b.shape[0]\n", + " a,b = a.to(device),b.to(device)\n", + "\n", + " score_star = ot.emd2(a,b,M)\n", + " #wasserstein_distance(A.cpu().numpy(),B.cpu().numpy())\n", + "\n", + " else:\n", + "\n", + " self.fit(A, B,iters,lr,group)\n", + " score_star = self.score(self.A,self.B,score_method=score_method,group=group)\n", + "\n", + " return score_star\n", + "\n", + "class DSA:\n", + " \"\"\"\n", + " Computes the Dynamical Similarity Analysis (DSA) for two data matrices\n", + " \"\"\"\n", + " def __init__(self,\n", + " X,\n", + " Y=None,\n", + " n_delays=1,\n", + " delay_interval=1,\n", + " rank=None,\n", + " rank_thresh=None,\n", + " rank_explained_variance = None,\n", + " lamb = 0.0,\n", + " send_to_cpu = True,\n", + " iters = 1500,\n", + " score_method: Literal[\"angular\", \"euclidean\",\"wasserstein\"] = \"angular\",\n", + " lr = 5e-3,\n", + " group: Literal[\"GL(n)\", \"O(n)\", \"SO(n)\"] = \"O(n)\",\n", + " zero_pad = False,\n", + " device = 'cpu',\n", + " verbose = False,\n", + " reduced_rank_reg = False,\n", + " kernel=None,\n", + " num_centers=0.1,\n", + " svd_solver='arnoldi',\n", + " wasserstein_compare: Literal['sv','eig',None] = None\n", + " ):\n", + " \"\"\"\n", + " Parameters\n", + " __________\n", + "\n", + " X : np.array or torch.tensor or list of np.arrays or torch.tensors\n", + " first data matrix/matrices\n", + "\n", + " Y : None or np.array or torch.tensor or list of np.arrays or torch.tensors\n", + " second data matrix/matrices.\n", + " * If Y is None, X is compared to itself pairwise\n", + " (must be a list)\n", + " * If Y is a single matrix, all matrices in X are compared to Y\n", + " * If Y is a list, all matrices in X are compared to all matrices in Y\n", + "\n", + " DMD parameters:\n", + "\n", + " n_delays : int or list or tuple/list: (int,int), (list,list),(list,int),(int,list)\n", + " number of delays to use in constructing the Hankel matrix\n", + "\n", + " delay_interval : int or list or tuple/list: (int,int), (list,list),(list,int),(int,list)\n", + " interval between samples taken in constructing Hankel matrix\n", + "\n", + " rank : int or list or tuple/list: (int,int), (list,list),(list,int),(int,list)\n", + " rank of DMD matrix fit in reduced-rank regression\n", + "\n", + " rank_thresh : float or list or tuple/list: (float,float), (list,list),(list,float),(float,list)\n", + " Parameter that controls the rank of V in fitting HAVOK DMD by dictating a threshold\n", + " of singular values to use. Explicitly, the rank of V will be the number of singular\n", + " values greater than rank_thresh. Defaults to None.\n", + "\n", + " rank_explained_variance : float or list or tuple: (float,float), (list,list),(list,float),(float,list)\n", + " Parameter that controls the rank of V in fitting HAVOK DMD by indicating the percentage of\n", + " cumulative explained variance that should be explained by the columns of V. Defaults to None.\n", + "\n", + " lamb : float\n", + " L-1 regularization parameter in DMD fit\n", + "\n", + " send_to_cpu: bool\n", + " If True, will send all tensors in the object back to the cpu after everything is computed.\n", + " This is implemented to prevent gpu memory overload when computing multiple DMDs.\n", + "\n", + " NOTE: for all of these above, they can be single values or lists or tuples,\n", + " depending on the corresponding dimensions of the data\n", + " If at least one of X and Y are lists, then if they are a single value\n", + " it will default to the rank of all DMD matrices.\n", + " If they are (int,int), then they will correspond to an individual dmd matrix\n", + " OR to X and Y respectively across all matrices\n", + " If it is (list,list), then each element will correspond to an individual\n", + " dmd matrix indexed at the same position\n", + "\n", + " SimDist parameters:\n", + "\n", + " iters : int\n", + " number of optimization iterations in Procrustes over vector fields\n", + "\n", + " score_method : {'angular','euclidean'}\n", + " type of metric to compute, angular vs euclidean distance\n", + "\n", + " lr : float\n", + " learning rate of the Procrustes over vector fields optimization\n", + "\n", + " group : {'SO(n)','O(n)', 'GL(n)'}\n", + " specifies the group of matrices to optimize over\n", + "\n", + " zero_pad : bool\n", + " whether or not to zero-pad if the dimensions are different\n", + "\n", + " device : 'cpu' or 'cuda' or int\n", + " hardware to use in both DMD and PoVF\n", + "\n", + " verbose : bool\n", + " whether or not print when sections of the analysis is completed\n", + "\n", + " wasserstein_compare : {'sv','eig',None}\n", + " specifies whether to compare the singular values or eigenvalues\n", + " if score_method is \"wasserstein\", or the shapes are different\n", + " \"\"\"\n", + " self.X = X\n", + " self.Y = Y\n", + " if self.X is None and isinstance(self.Y,list):\n", + " self.X, self.Y = self.Y, self.X #swap so code is easy\n", + "\n", + " self.check_method()\n", + " if self.method == 'self-pairwise':\n", + " self.data = [self.X]\n", + " else:\n", + " self.data = [self.X, self.Y]\n", + "\n", + " self.n_delays = self.broadcast_params(n_delays,cast=int)\n", + " self.delay_interval = self.broadcast_params(delay_interval,cast=int)\n", + " self.rank = self.broadcast_params(rank,cast=int)\n", + " self.rank_thresh = self.broadcast_params(rank_thresh)\n", + " self.rank_explained_variance = self.broadcast_params(rank_explained_variance)\n", + " self.lamb = self.broadcast_params(lamb)\n", + " self.send_to_cpu = send_to_cpu\n", + " self.iters = iters\n", + " self.score_method = score_method\n", + " self.lr = lr\n", + " self.device = device\n", + " self.verbose = verbose\n", + " self.zero_pad = zero_pad\n", + " self.group = group\n", + " self.reduced_rank_reg = reduced_rank_reg\n", + " self.kernel = kernel\n", + " self.wasserstein_compare = wasserstein_compare\n", + "\n", + " if kernel is None:\n", + " #get a list of all DMDs here\n", + " self.dmds = [[DMD(Xi,\n", + " self.n_delays[i][j],\n", + " delay_interval=self.delay_interval[i][j],\n", + " rank=self.rank[i][j],\n", + " rank_thresh=self.rank_thresh[i][j],\n", + " rank_explained_variance=self.rank_explained_variance[i][j],\n", + " reduced_rank_reg=self.reduced_rank_reg,\n", + " lamb=self.lamb[i][j],\n", + " device=self.device,\n", + " verbose=self.verbose,\n", + " send_to_cpu=self.send_to_cpu) for j,Xi in enumerate(dat)] for i,dat in enumerate(self.data)]\n", + " else:\n", + " #get a list of all DMDs here\n", + " self.dmds = [[KernelDMD(Xi,\n", + " self.n_delays[i][j],\n", + " kernel=self.kernel,\n", + " num_centers=num_centers,\n", + " delay_interval=self.delay_interval[i][j],\n", + " rank=self.rank[i][j],\n", + " reduced_rank_reg=self.reduced_rank_reg,\n", + " lamb=self.lamb[i][j],\n", + " verbose=self.verbose,\n", + " svd_solver=svd_solver,\n", + " ) for j,Xi in enumerate(dat)] for i,dat in enumerate(self.data)]\n", + "\n", + " self.simdist = SimilarityTransformDist(iters,score_method,lr,device,verbose,group,wasserstein_compare)\n", + "\n", + " def check_method(self):\n", + " '''\n", + " helper function to identify what type of dsa we're running\n", + " '''\n", + " tensor_or_np = lambda x: isinstance(x,(np.ndarray,torch.Tensor))\n", + "\n", + " if isinstance(self.X,list):\n", + " if self.Y is None:\n", + " self.method = 'self-pairwise'\n", + " elif isinstance(self.Y,list):\n", + " self.method = 'bipartite-pairwise'\n", + " elif tensor_or_np(self.Y):\n", + " self.method = 'list-to-one'\n", + " self.Y = [self.Y] #wrap in a list for iteration\n", + " else:\n", + " raise ValueError('unknown type of Y')\n", + " elif tensor_or_np(self.X):\n", + " self.X = [self.X]\n", + " if self.Y is None:\n", + " raise ValueError('only one element provided')\n", + " elif isinstance(self.Y,list):\n", + " self.method = 'one-to-list'\n", + " elif tensor_or_np(self.Y):\n", + " self.method = 'default'\n", + " self.Y = [self.Y]\n", + " else:\n", + " raise ValueError('unknown type of Y')\n", + " else:\n", + " raise ValueError('unknown type of X')\n", + "\n", + " def broadcast_params(self,param,cast=None):\n", + " '''\n", + " aligns the dimensionality of the parameters with the data so it's one-to-one\n", + " '''\n", + " out = []\n", + " if isinstance(param,(int,float,np.integer)) or param is None: #self.X has already been mapped to [self.X]\n", + " out.append([param] * len(self.X))\n", + " if self.Y is not None:\n", + " out.append([param] * len(self.Y))\n", + " elif isinstance(param,(tuple,list,np.ndarray)):\n", + " if self.method == 'self-pairwise' and len(param) >= len(self.X):\n", + " out = [param]\n", + " else:\n", + " assert len(param) <= 2 #only 2 elements max\n", + "\n", + " #if the inner terms are singly valued, we broadcast, otherwise needs to be the same dimensions\n", + " for i,data in enumerate([self.X,self.Y]):\n", + " if data is None:\n", + " continue\n", + " if isinstance(param[i],(int,float)):\n", + " out.append([param[i]] * len(data))\n", + " elif isinstance(param[i],(list,np.ndarray,tuple)):\n", + " assert len(param[i]) >= len(data)\n", + " out.append(param[i][:len(data)])\n", + " else:\n", + " raise ValueError(\"unknown type entered for parameter\")\n", + "\n", + " if cast is not None and param is not None:\n", + " out = [[cast(x) for x in dat] for dat in out]\n", + "\n", + " return out\n", + "\n", + " def fit_dmds(self,\n", + " X=None,\n", + " Y=None,\n", + " n_delays=None,\n", + " delay_interval=None,\n", + " rank=None,\n", + " rank_thresh = None,\n", + " rank_explained_variance=None,\n", + " reduced_rank_reg=None,\n", + " lamb = None,\n", + " device='cpu',\n", + " verbose=False,\n", + " send_to_cpu=True\n", + " ):\n", + " \"\"\"\n", + " Recomputes only the DMDs with a single set of hyperparameters. This will not compare, that will need to be done with the full procedure\n", + " \"\"\"\n", + " X = self.X if X is None else X\n", + " Y = self.Y if Y is None else Y\n", + " n_delays = self.n_delays if n_delays is None else n_delays\n", + " delay_interval = self.delay_interval if delay_interval is None else delay_interval\n", + " rank = self.rank if rank is None else rank\n", + " lamb = self.lamb if lamb is None else lamb\n", + " data = []\n", + " if isinstance(X,list):\n", + " data.append(X)\n", + " else:\n", + " data.append([X])\n", + " if Y is not None:\n", + " if isinstance(Y,list):\n", + " data.append(Y)\n", + " else:\n", + " data.append([Y])\n", + "\n", + " dmds = [[DMD(Xi,n_delays,delay_interval,\n", + " rank,rank_thresh,rank_explained_variance,reduced_rank_reg,\n", + " lamb,device,verbose,send_to_cpu) for Xi in dat] for dat in data]\n", + "\n", + " for dmd_sets in dmds:\n", + " for dmd in dmd_sets:\n", + " dmd.fit()\n", + "\n", + " return dmds\n", + "\n", + " def fit_score(self):\n", + " \"\"\"\n", + " Standard fitting function for both DMDs and PoVF\n", + "\n", + " Parameters\n", + " __________\n", + "\n", + " Returns\n", + " _______\n", + "\n", + " sims : np.array\n", + " data matrix of the similarity scores between the specific sets of data\n", + " \"\"\"\n", + " for dmd_sets in self.dmds:\n", + " for dmd in dmd_sets:\n", + " dmd.fit()\n", + "\n", + " return self.score()\n", + "\n", + " def score(self,iters=None,lr=None,score_method=None):\n", + " \"\"\"\n", + " Rescore DSA with precomputed dmds if you want to try again\n", + "\n", + " Parameters\n", + " __________\n", + " iters : int or None\n", + " number of optimization steps, if None then resorts to saved self.iters\n", + " lr : float or None\n", + " learning rate, if None then resorts to saved self.lr\n", + " score_method : None or {'angular','euclidean'}\n", + " overwrites the score method in the object for this application\n", + "\n", + " Returns\n", + " ________\n", + " score : float\n", + " similarity score of the two precomputed DMDs\n", + " \"\"\"\n", + "\n", + " iters = self.iters if iters is None else iters\n", + " lr = self.lr if lr is None else lr\n", + " score_method = self.score_method if score_method is None else score_method\n", + "\n", + " ind2 = 1 - int(self.method == 'self-pairwise')\n", + " # 0 if self.pairwise (want to compare the set to itself)\n", + "\n", + " self.sims = np.zeros((len(self.dmds[0]),len(self.dmds[ind2])))\n", + " for i,dmd1 in enumerate(self.dmds[0]):\n", + " for j,dmd2 in enumerate(self.dmds[ind2]):\n", + " if self.method == 'self-pairwise':\n", + " if j >= i:\n", + " continue\n", + " if self.verbose:\n", + " print(f'computing similarity between DMDs {i} and {j}')\n", + "\n", + " self.sims[i,j] = self.simdist.fit_score(dmd1.A_v,dmd2.A_v,iters,lr,score_method,zero_pad=self.zero_pad)\n", + "\n", + " if self.method == 'self-pairwise':\n", + " self.sims[j,i] = self.sims[i,j]\n", + "\n", + "\n", + " if self.method == 'default':\n", + " return self.sims[0,0]\n", + "\n", + " return self.sims" + ] + }, + { + "cell_type": "code", + "execution_count": null, + "id": "eced3162", + "metadata": { + "cellView": "form", + "execution": {} + }, + "outputs": [], + "source": [ + "# @title Helper functions (Bonus Section)\n", + "\n", + "import contextlib\n", + "import io\n", + "import argparse\n", + "# Standard library imports\n", + "from collections import OrderedDict\n", + "import logging\n", + "\n", + "# External libraries: General utilities\n", + "import argparse\n", + "import numpy as np\n", + "\n", + "# PyTorch related imports\n", + "import torch\n", + "import torch.nn as nn\n", + "import torch.nn.functional as F\n", + "import torch.optim as optim\n", + "from torch.optim.lr_scheduler import StepLR\n", + "from torchvision import datasets, transforms\n", + "from torchvision.models.feature_extraction import create_feature_extractor, get_graph_node_names\n", + "from torchvision.utils import make_grid\n", + "\n", + "# Matplotlib for plotting\n", + "import matplotlib as mpl\n", + "import matplotlib.pyplot as plt\n", + "%matplotlib inline\n", + "\n", + "# SciPy for statistical functions\n", + "from scipy import stats\n", + "\n", + "# Scikit-Learn for machine learning utilities\n", + "from sklearn.decomposition import PCA\n", + "from sklearn import manifold\n", + "\n", + "# RSA toolbox specific imports\n", + "import rsatoolbox\n", + "from rsatoolbox.data import Dataset\n", + "from rsatoolbox.rdm.calc import calc_rdm\n", + "\n", + "class Net(nn.Module):\n", + " \"\"\"\n", + " A neural network model for image classification, consisting of two convolutional layers,\n", + " followed by two fully connected layers with dropout regularization.\n", + "\n", + " Methods:\n", + " - forward(input): Defines the forward pass of the network.\n", + " \"\"\"\n", + "\n", + " def __init__(self):\n", + " \"\"\"\n", + " Initializes the network layers.\n", + "\n", + " Layers:\n", + " - conv1: First convolutional layer with 1 input channel, 32 output channels, and a 3x3 kernel.\n", + " - conv2: Second convolutional layer with 32 input channels, 64 output channels, and a 3x3 kernel.\n", + " - dropout1: Dropout layer with a dropout probability of 0.25.\n", + " - dropout2: Dropout layer with a dropout probability of 0.5.\n", + " - fc1: First fully connected layer with 9216 input features and 128 output features.\n", + " - fc2: Second fully connected layer with 128 input features and 10 output features.\n", + " \"\"\"\n", + " super(Net, self).__init__()\n", + " self.conv1 = nn.Conv2d(1, 32, 3, 1)\n", + " self.conv2 = nn.Conv2d(32, 64, 3, 1)\n", + " self.dropout1 = nn.Dropout(0.25)\n", + " self.dropout2 = nn.Dropout(0.5)\n", + " self.fc1 = nn.Linear(9216, 128)\n", + " self.fc2 = nn.Linear(128, 10)\n", + "\n", + " def forward(self, input):\n", + " \"\"\"\n", + " Defines the forward pass of the network.\n", + "\n", + " Inputs:\n", + " - input (torch.Tensor): Input tensor of shape (batch_size, 1, height, width).\n", + "\n", + " Outputs:\n", + " - output (torch.Tensor): Output tensor of shape (batch_size, 10) representing the class probabilities for each input sample.\n", + " \"\"\"\n", + " x = self.conv1(input)\n", + " x = F.relu(x)\n", + " x = self.conv2(x)\n", + " x = F.relu(x)\n", + " x = F.max_pool2d(x, 2)\n", + " x = self.dropout1(x)\n", + " x = torch.flatten(x, 1)\n", + " x = self.fc1(x)\n", + " x = F.relu(x)\n", + " x = self.dropout2(x)\n", + " x = self.fc2(x)\n", + " output = F.softmax(x, dim=1)\n", + " return output\n", + "\n", + "class recurrent_Net(nn.Module):\n", + " \"\"\"\n", + " A recurrent neural network model for image classification, consisting of two convolutional layers\n", + " with recurrent connections and a readout layer.\n", + "\n", + " Methods:\n", + " - __init__(time_steps=5): Initializes the network layers and sets the number of time steps for recurrence.\n", + " - forward(input): Defines the forward pass of the network.\n", + " \"\"\"\n", + "\n", + " def __init__(self, time_steps=5):\n", + " \"\"\"\n", + " Initializes the network layers and sets the number of time steps for recurrence.\n", + "\n", + " Layers:\n", + " - conv1: First convolutional layer with 1 input channel, 16 output channels, and a 3x3 kernel with a stride of 3.\n", + " - conv2: Second convolutional layer with 16 input channels, 16 output channels, and a 3x3 kernel with padding of 1.\n", + " - readout: A sequential layer containing:\n", + " - dropout: Dropout layer with a dropout probability of 0.25.\n", + " - avgpool: Adaptive average pooling layer to reduce spatial dimensions to 1x1.\n", + " - flatten: Flatten layer to convert the 2D pooled output to 1D.\n", + " - linear: Fully connected layer with 16 input features and 10 output features.\n", + " - time_steps (int): Number of time steps for the recurrent connection.\n", + " \"\"\"\n", + " super(recurrent_Net, self).__init__()\n", + " self.conv1 = nn.Conv2d(1, 16, 3, 3)\n", + " self.conv2 = nn.Conv2d(16, 16, 3, 1, padding=1)\n", + " self.readout = nn.Sequential(OrderedDict([\n", + " ('dropout', nn.Dropout(0.25)),\n", + " ('avgpool', nn.AdaptiveAvgPool2d(1)),\n", + " ('flatten', nn.Flatten()),\n", + " ('linear', nn.Linear(16, 10))\n", + " ]))\n", + " self.time_steps = time_steps\n", + "\n", + " def forward(self, input):\n", + " \"\"\"\n", + " Defines the forward pass of the network.\n", + "\n", + " Inputs:\n", + " - input (torch.Tensor): Input tensor of shape (batch_size, 1, height, width).\n", + "\n", + " Outputs:\n", + " - output (torch.Tensor): Output tensor of shape (batch_size, 10) representing the class probabilities for each input sample.\n", + " \"\"\"\n", + " input = self.conv1(input)\n", + " x = input\n", + " for t in range(0, self.time_steps):\n", + " x = input + self.conv2(x)\n", + " x = F.relu(x)\n", + "\n", + " x = self.readout(x)\n", + " output = F.softmax(x, dim=1)\n", + " return output\n", + "\n", + "\n", + "def train_one_epoch(args, model, device, train_loader, optimizer, epoch):\n", + " \"\"\"\n", + " Trains the model for one epoch.\n", + "\n", + " Inputs:\n", + " - args (Namespace): Arguments for training configuration.\n", + " - model (torch.nn.Module): The model to be trained.\n", + " - device (torch.device): The device to use for training (CPU/GPU).\n", + " - train_loader (torch.utils.data.DataLoader): DataLoader for the training data.\n", + " - optimizer (torch.optim.Optimizer): Optimizer for updating the model parameters.\n", + " - epoch (int): The current epoch number.\n", + " \"\"\"\n", + " model.train()\n", + " for batch_idx, (data, target) in enumerate(train_loader):\n", + " data, target = data.to(device), target.to(device)\n", + " optimizer.zero_grad()\n", + " output = model(data)\n", + " output = torch.log(output) # to make it a log_softmax\n", + " loss = F.nll_loss(output, target)\n", + " loss.backward()\n", + " optimizer.step()\n", + " if batch_idx % args.log_interval == 0:\n", + " print('Train Epoch: {} [{}/{} ({:.0f}%)]\\tLoss: {:.6f}'.format(\n", + " epoch, batch_idx * len(data), len(train_loader.dataset),\n", + " 100. * batch_idx / len(train_loader), loss.item()))\n", + " if args.dry_run:\n", + " break\n", + "\n", + "def test(model, device, test_loader, return_features=False):\n", + " \"\"\"\n", + " Evaluates the model on the test dataset.\n", + "\n", + " Inputs:\n", + " - model (torch.nn.Module): The model to be evaluated.\n", + " - device (torch.device): The device to use for evaluation (CPU/GPU).\n", + " - test_loader (torch.utils.data.DataLoader): DataLoader for the test data.\n", + " - return_features (bool): If True, returns the features from the model. Default is False.\n", + " \"\"\"\n", + " model.eval()\n", + " test_loss = 0\n", + " correct = 0\n", + " with torch.no_grad():\n", + " for data, target in test_loader:\n", + " data, target = data.to(device), target.to(device)\n", + " output = model(data)\n", + " output = torch.log(output)\n", + " test_loss += F.nll_loss(output, target, reduction='sum').item() # sum up batch loss\n", + " pred = output.argmax(dim=1, keepdim=True) # get the index of the max log-probability\n", + " correct += pred.eq(target.view_as(pred)).sum().item()\n", + "\n", + " test_loss /= len(test_loader.dataset)\n", + "\n", + " print('\\nTest set: Average loss: {:.4f}, Accuracy: {}/{} ({:.0f}%)\\n'.format(\n", + " test_loss, correct, len(test_loader.dataset),\n", + " 100. * correct / len(test_loader.dataset)))\n", + "\n", + "def build_args():\n", + " \"\"\"\n", + " Builds and parses command-line arguments for training.\n", + " \"\"\"\n", + " parser = argparse.ArgumentParser(description='PyTorch MNIST Example')\n", + " parser.add_argument('--batch-size', type=int, default=64, metavar='N',\n", + " help='input batch size for training (default: 64)')\n", + " parser.add_argument('--test-batch-size', type=int, default=1000, metavar='N',\n", + " help='input batch size for testing (default: 1000)')\n", + " parser.add_argument('--epochs', type=int, default=2, metavar='N',\n", + " help='number of epochs to train (default: 14)')\n", + " parser.add_argument('--lr', type=float, default=1.0, metavar='LR',\n", + " help='learning rate (default: 1.0)')\n", + " parser.add_argument('--gamma', type=float, default=0.7, metavar='M',\n", + " help='Learning rate step gamma (default: 0.7)')\n", + " parser.add_argument('--no-cuda', action='store_true', default=False,\n", + " help='disables CUDA training')\n", + " parser.add_argument('--no-mps', action='store_true', default=False,\n", + " help='disables macOS GPU training')\n", + " parser.add_argument('--dry-run', action='store_true', default=False,\n", + " help='quickly check a single pass')\n", + " parser.add_argument('--seed', type=int, default=1, metavar='S',\n", + " help='random seed (default: 1)')\n", + " parser.add_argument('--log-interval', type=int, default=50, metavar='N',\n", + " help='how many batches to wait before logging training status')\n", + " parser.add_argument('--save-model', action='store_true', default=False,\n", + " help='For Saving the current Model')\n", + " args = parser.parse_args('')\n", + "\n", + " use_cuda = torch.cuda.is_available() #not args.no_cuda and\n", + "\n", + " if use_cuda:\n", + " device = torch.device(\"cuda\")\n", + " else:\n", + " device = torch.device(\"cpu\")\n", + "\n", + " args.use_cuda = use_cuda\n", + " args.device = device\n", + " return args\n", + "\n", + "def fetch_dataloaders(args):\n", + " \"\"\"\n", + " Fetches the data loaders for training and testing datasets.\n", + "\n", + " Inputs:\n", + " - args (Namespace): Parsed arguments with training configuration.\n", + "\n", + " Outputs:\n", + " - train_loader (torch.utils.data.DataLoader): DataLoader for the training data.\n", + " - test_loader (torch.utils.data.DataLoader): DataLoader for the test data.\n", + " \"\"\"\n", + " train_kwargs = {'batch_size': args.batch_size}\n", + " test_kwargs = {'batch_size': args.test_batch_size}\n", + " if args.use_cuda:\n", + " cuda_kwargs = {'num_workers': 1,\n", + " 'pin_memory': True,\n", + " 'shuffle': True}\n", + " train_kwargs.update(cuda_kwargs)\n", + " test_kwargs.update(cuda_kwargs)\n", + "\n", + " transform=transforms.Compose([\n", + " transforms.ToTensor(),\n", + " transforms.Normalize((0.1307,), (0.3081,))\n", + " ])\n", + " with contextlib.redirect_stdout(io.StringIO()): #to suppress output\n", + " dataset1 = datasets.MNIST('../data', train=True, download=True,\n", + " transform=transform)\n", + " dataset2 = datasets.MNIST('../data', train=False,\n", + " transform=transform)\n", + " train_loader = torch.utils.data.DataLoader(dataset1, **train_kwargs)\n", + " test_loader = torch.utils.data.DataLoader(dataset2, **test_kwargs)\n", + " return train_loader, test_loader\n", + "\n", + "def train_model(args, model, optimizer):\n", + " \"\"\"\n", + " Trains the model using the specified arguments and optimizer.\n", + "\n", + " Inputs:\n", + " - args (Namespace): Parsed arguments with training configuration.\n", + " - model (torch.nn.Module): The model to be trained.\n", + " - optimizer (torch.optim.Optimizer): Optimizer for updating the model parameters.\n", + "\n", + " Outputs:\n", + " - None: The function trains the model and optionally saves it.\n", + " \"\"\"\n", + " scheduler = StepLR(optimizer, step_size=1, gamma=args.gamma)\n", + " for epoch in range(1, args.epochs + 1):\n", + " train_one_epoch(args, model, args.device, train_loader, optimizer, epoch)\n", + " test(model, args.device, test_loader)\n", + " scheduler.step()\n", + "\n", + " if args.save_model:\n", + " torch.save(model.state_dict(), \"mnist_cnn.pt\")\n", + "\n", + "\n", + "def calc_rdms(model_features, method='correlation'):\n", + " \"\"\"\n", + " Calculates representational dissimilarity matrices (RDMs) for model features.\n", + "\n", + " Inputs:\n", + " - model_features (dict): A dictionary where keys are layer names and values are features of the layers.\n", + " - method (str): The method to calculate RDMs, e.g., 'correlation'. Default is 'correlation'.\n", + "\n", + " Outputs:\n", + " - rdms (pyrsa.rdm.RDMs): RDMs object containing dissimilarity matrices.\n", + " - rdms_dict (dict): A dictionary with layer names as keys and their corresponding RDMs as values.\n", + " \"\"\"\n", + " ds_list = []\n", + " for l in range(len(model_features)):\n", + " layer = list(model_features.keys())[l]\n", + " feats = model_features[layer]\n", + "\n", + " if type(feats) is list:\n", + " feats = feats[-1]\n", + "\n", + " if args.use_cuda:\n", + " feats = feats.cpu()\n", + "\n", + " if len(feats.shape) > 2:\n", + " feats = feats.flatten(1)\n", + "\n", + " feats = feats.detach().numpy()\n", + " ds = Dataset(feats, descriptors=dict(layer=layer))\n", + " ds_list.append(ds)\n", + "\n", + " rdms = calc_rdm(ds_list, method=method)\n", + " rdms_dict = {list(model_features.keys())[i]: rdms.get_matrices()[i] for i in range(len(model_features))}\n", + "\n", + " return rdms, rdms_dict\n", + "\n", + "def fgsm_attack(image, epsilon, data_grad):\n", + " \"\"\"\n", + " Performs FGSM attack on an image.\n", + "\n", + " Inputs:\n", + " - image (torch.Tensor): Original image.\n", + " - epsilon (float): Perturbation magnitude.\n", + " - data_grad (torch.Tensor): Gradient of the data.\n", + "\n", + " Outputs:\n", + " - perturbed_image (torch.Tensor): Perturbed image after FGSM attack.\n", + " \"\"\"\n", + " sign_data_grad = data_grad.sign()\n", + " perturbed_image = image + epsilon * sign_data_grad\n", + " perturbed_image = torch.clamp(perturbed_image, 0, 1)\n", + " return perturbed_image\n", + "\n", + "def denorm(batch, mean=[0.1307], std=[0.3081]):\n", + " \"\"\"\n", + " Converts a batch of normalized tensors to their original scale.\n", + "\n", + " Inputs:\n", + " - batch (torch.Tensor): Batch of normalized tensors.\n", + " - mean (torch.Tensor or list): Mean used for normalization.\n", + " - std (torch.Tensor or list): Standard deviation used for normalization.\n", + "\n", + " Outputs:\n", + " - torch.Tensor: Batch of tensors without normalization applied to them.\n", + " \"\"\"\n", + " if isinstance(mean, list):\n", + " mean = torch.tensor(mean).to(batch.device)\n", + " if isinstance(std, list):\n", + " std = torch.tensor(std).to(batch.device)\n", + "\n", + " return batch * std.view(1, -1, 1, 1) + mean.view(1, -1, 1, 1)\n", + "\n", + "def generate_adversarial(model, imgs, targets, epsilon):\n", + " \"\"\"\n", + " Generates adversarial examples using FGSM attack.\n", + "\n", + " Inputs:\n", + " - model (torch.nn.Module): The model to attack.\n", + " - imgs (torch.Tensor): Batch of images.\n", + " - targets (torch.Tensor): Batch of target labels.\n", + " - epsilon (float): Perturbation magnitude.\n", + "\n", + " Outputs:\n", + " - adv_imgs (torch.Tensor): Batch of adversarial images.\n", + " \"\"\"\n", + " adv_imgs = []\n", + "\n", + " for img, target in zip(imgs, targets):\n", + " img = img.unsqueeze(0)\n", + " target = target.unsqueeze(0)\n", + " img.requires_grad = True\n", + "\n", + " output = model(img)\n", + " output = torch.log(output)\n", + " loss = F.nll_loss(output, target)\n", + "\n", + " model.zero_grad()\n", + " loss.backward()\n", + "\n", + " data_grad = img.grad.data\n", + " data_denorm = denorm(img)\n", + " perturbed_data = fgsm_attack(data_denorm, epsilon, data_grad)\n", + " perturbed_data_normalized = transforms.Normalize((0.1307,), (0.3081,))(perturbed_data)\n", + "\n", + " adv_imgs.append(perturbed_data_normalized.detach())\n", + "\n", + " return torch.cat(adv_imgs)\n", + "\n", + "def test_adversarial(model, imgs, targets):\n", + " \"\"\"\n", + " Tests the model on adversarial examples and prints the accuracy.\n", + "\n", + " Inputs:\n", + " - model (torch.nn.Module): The model to be tested.\n", + " - imgs (torch.Tensor): Batch of adversarial images.\n", + " - targets (torch.Tensor): Batch of target labels.\n", + " \"\"\"\n", + " correct = 0\n", + " output = model(imgs)\n", + " output = torch.log(output)\n", + " pred = output.argmax(dim=1, keepdim=True) # get the index of the max log-probability\n", + " correct += pred.eq(targets.view_as(pred)).sum().item()\n", + "\n", + " final_acc = correct / float(len(imgs))\n", + " print(f\"adversarial test accuracy = {correct} / {len(imgs)} = {final_acc}\")\n", + "\n", + "def extract_features(model, imgs, return_layers, plot='none'):\n", + " \"\"\"\n", + " Extracts features from specified layers of the model.\n", + "\n", + " Inputs:\n", + " - model (torch.nn.Module): The model from which to extract features.\n", + " - imgs (torch.Tensor): Batch of input images.\n", + " - return_layers (list): List of layer names from which to extract features.\n", + " - plot (str): Option to plot the features. Default is 'none'.\n", + "\n", + " Outputs:\n", + " - model_features (dict): A dictionary with layer names as keys and extracted features as values.\n", + " \"\"\"\n", + " if return_layers == 'all':\n", + " return_layers, _ = get_graph_node_names(model)\n", + " elif return_layers == 'layers':\n", + " layers, _ = get_graph_node_names(model)\n", + " return_layers = [l for l in layers if 'input' in l or 'conv' in l or 'fc' in l]\n", + "\n", + " feature_extractor = create_feature_extractor(model, return_nodes=return_layers)\n", + " model_features = feature_extractor(imgs)\n", + "\n", + " return model_features" + ] + }, + { + "cell_type": "code", + "execution_count": null, + "id": "be4a4946", + "metadata": { + "cellView": "form", + "execution": {} + }, + "outputs": [], + "source": [ + "# @title Plotting functions (Bonus)\n", + "\n", + "def sample_images(data_loader, n=5, plot=False):\n", + " \"\"\"\n", + " Samples a specified number of images from a data loader.\n", + "\n", + " Inputs:\n", + " - data_loader (torch.utils.data.DataLoader): Data loader containing images and labels.\n", + " - n (int): Number of images to sample per class.\n", + " - plot (bool): Whether to plot the sampled images using matplotlib.\n", + "\n", + " Outputs:\n", + " - imgs (torch.Tensor): Sampled images.\n", + " - labels (torch.Tensor): Corresponding labels for the sampled images.\n", + " \"\"\"\n", + "\n", + " with plt.xkcd():\n", + " imgs, targets = next(iter(data_loader))\n", + "\n", + " imgs_o = []\n", + " labels = []\n", + " for value in range(10):\n", + " cat_imgs = imgs[np.where(targets == value)][0:n]\n", + " imgs_o.append(cat_imgs)\n", + " labels.append([value]*len(cat_imgs))\n", + "\n", + " imgs = torch.cat(imgs_o, dim=0)\n", + " labels = torch.tensor(labels).flatten()\n", + "\n", + " if plot:\n", + " plt.imshow(torch.moveaxis(make_grid(imgs, nrow=5, padding=0, normalize=False, pad_value=0), 0,-1))\n", + " plt.axis('off')\n", + "\n", + " return imgs, labels\n", + "\n", + "\n", + "def plot_rdms(model_rdms):\n", + " \"\"\"\n", + " Plots the Representational Dissimilarity Matrices (RDMs) for each layer of a model.\n", + "\n", + " Inputs:\n", + " - model_rdms (dict): A dictionary where keys are layer names and values are the corresponding RDMs.\n", + " \"\"\"\n", + "\n", + " with plt.xkcd():\n", + " fig = plt.figure(figsize=(8, 4))\n", + " gs = fig.add_gridspec(1, len(model_rdms))\n", + " fig.subplots_adjust(wspace=0.2, hspace=0.2)\n", + "\n", + " for l in range(len(model_rdms)):\n", + "\n", + " layer = list(model_rdms.keys())[l]\n", + " rdm = np.squeeze(model_rdms[layer])\n", + "\n", + " if len(rdm.shape) < 2:\n", + " rdm = rdm.reshape( (int(np.sqrt(rdm.shape[0])), int(np.sqrt(rdm.shape[0]))) )\n", + "\n", + " rdm = rdm / np.max(rdm)\n", + "\n", + " ax = plt.subplot(gs[0,l])\n", + " ax_ = ax.imshow(rdm, cmap='magma_r')\n", + " ax.set_title(f'{layer}')\n", + "\n", + " fig.subplots_adjust(right=0.9)\n", + " cbar_ax = fig.add_axes([1.01, 0.18, 0.01, 0.53])\n", + " cbar_ax.text(-2.3, 0.05, 'Normalized euclidean distance', size=10, rotation=90)\n", + " fig.colorbar(ax_, cax=cbar_ax)\n", + "\n", + " plt.show()\n", + "\n", + "def rep_path(model_features, model_colors, labels=None, rdm_calc_method='euclidean', rdm_comp_method='cosine'):\n", + " \"\"\"\n", + " Represents paths of model features in a reduced-dimensional space.\n", + "\n", + " Inputs:\n", + " - model_features (dict): Dictionary containing model features for each model.\n", + " - model_colors (dict): Dictionary mapping model names to colors for visualization.\n", + " - labels (array-like, optional): Array of labels corresponding to the model features.\n", + " - rdm_calc_method (str, optional): Method for calculating RDMS ('euclidean' or 'correlation').\n", + " - rdm_comp_method (str, optional): Method for comparing RDMS ('cosine' or 'corr').\n", + " \"\"\"\n", + " with plt.xkcd():\n", + " path_len = []\n", + " path_colors = []\n", + " rdms_list = []\n", + " ax_ticks = []\n", + " tick_colors = []\n", + " model_names = list(model_features.keys())\n", + " for m in range(len(model_names)):\n", + " model_name = model_names[m]\n", + " features = model_features[model_name]\n", + " path_colors.append(model_colors[model_name])\n", + " path_len.append(len(features))\n", + " ax_ticks.append(list(features.keys()))\n", + " tick_colors.append([model_colors[model_name]]*len(features))\n", + " rdms, _ = calc_rdms(features, method=rdm_calc_method)\n", + " rdms_list.append(rdms)\n", + "\n", + " path_len = np.insert(np.cumsum(path_len),0,0)\n", + "\n", + " if labels is not None:\n", + " rdms, _ = calc_rdms({'labels' : F.one_hot(labels).float().to(device)}, method=rdm_calc_method)\n", + " rdms_list.append(rdms)\n", + " ax_ticks.append(['labels'])\n", + " tick_colors.append(['m'])\n", + " idx_labels = -1\n", + "\n", + " rdms = rsatoolbox.rdm.concat(rdms_list)\n", + "\n", + " #Flatten the list\n", + " ax_ticks = [l for model_layers in ax_ticks for l in model_layers]\n", + " tick_colors = [l for model_layers in tick_colors for l in model_layers]\n", + " tick_colors = ['k' if tick == 'input' else color for tick, color in zip(ax_ticks, tick_colors)]\n", + "\n", + " rdms_comp = rsatoolbox.rdm.compare(rdms, rdms, method=rdm_comp_method)\n", + " if rdm_comp_method == 'cosine':\n", + " rdms_comp = np.arccos(rdms_comp)\n", + " rdms_comp = np.nan_to_num(rdms_comp, nan=0.0)\n", + "\n", + " # Symmetrize\n", + " rdms_comp = (rdms_comp + rdms_comp.T) / 2.0\n", + "\n", + " # reduce dim to 2\n", + " transformer = manifold.MDS(n_components = 2, max_iter=1000, n_init=10, normalized_stress='auto', dissimilarity=\"precomputed\")\n", + " dims= transformer.fit_transform(rdms_comp)\n", + "\n", + " # remove duplicates of the input layer from multiple models\n", + " remove_duplicates = np.where(np.array(ax_ticks) == 'input')[0][1:]\n", + " for index in remove_duplicates:\n", + " del ax_ticks[index]\n", + " del tick_colors[index]\n", + " rdms_comp = np.delete(np.delete(rdms_comp, index, axis=0), index, axis=1)\n", + "\n", + " fig = plt.figure(figsize=(8, 4))\n", + " gs = fig.add_gridspec(1, 2)\n", + " fig.subplots_adjust(wspace=0.2, hspace=0.2)\n", + "\n", + " ax = plt.subplot(gs[0,0])\n", + " ax_ = ax.imshow(rdms_comp, cmap='viridis_r')\n", + " fig.subplots_adjust(left=0.2)\n", + " cbar_ax = fig.add_axes([-0.01, 0.2, 0.01, 0.5])\n", + " #cbar_ax.text(-7, 0.05, 'dissimilarity between rdms', size=10, rotation=90)\n", + " fig.colorbar(ax_, cax=cbar_ax,location='left')\n", + " ax.set_title('Dissimilarity between layer rdms', fontdict = {'fontsize': 14})\n", + " ax.set_xticks(np.arange(len(ax_ticks)), labels=ax_ticks, fontsize=7, rotation=83)\n", + " ax.set_yticks(np.arange(len(ax_ticks)), labels=ax_ticks, fontsize=7)\n", + " [t.set_color(i) for (i,t) in zip(tick_colors, ax.xaxis.get_ticklabels())]\n", + " [t.set_color(i) for (i,t) in zip(tick_colors, ax.yaxis.get_ticklabels())]\n", + "\n", + " ax = plt.subplot(gs[0,1])\n", + " amin, amax = dims.min(), dims.max()\n", + " amin, amax = (amin + amax) / 2 - (amax - amin) * 5/8, (amin + amax) / 2 + (amax - amin) * 5/8\n", + "\n", + " for i in range(len(rdms_list)-1):\n", + "\n", + " path_indices = np.arange(path_len[i], path_len[i+1])\n", + " ax.plot(dims[path_indices, 0], dims[path_indices, 1], color=path_colors[i], marker='.')\n", + " ax.set_title('Representational geometry path', fontdict = {'fontsize': 14})\n", + " ax.set_xlim([amin, amax])\n", + " ax.set_ylim([amin, amax])\n", + " ax.set_xlabel(f\"dim 1\")\n", + " ax.set_ylabel(f\"dim 2\")\n", + "\n", + " # if idx_input is not None:\n", + " idx_input = 0\n", + " ax.plot(dims[idx_input, 0], dims[idx_input, 1], color='k', marker='s')\n", + "\n", + " if labels is not None:\n", + " ax.plot(dims[idx_labels, 0], dims[idx_labels, 1], color='m', marker='*')\n", + "\n", + " ax.legend(model_names, fontsize=8)\n", + " fig.tight_layout()\n", + "\n", + "def plot_dim_reduction(model_features, labels, transformer_funcs):\n", + " \"\"\"\n", + " Plots the dimensionality reduction results for model features using various transformers.\n", + "\n", + " Inputs:\n", + " - model_features (dict): Dictionary containing model features for each layer.\n", + " - labels (array-like): Array of labels corresponding to the model features.\n", + " - transformer_funcs (list): List of dimensionality reduction techniques to apply ('PCA', 'MDS', 't-SNE').\n", + " \"\"\"\n", + " with plt.xkcd():\n", + "\n", + " transformers = []\n", + " for t in transformer_funcs:\n", + " if t == 'PCA': transformers.append(PCA(n_components=2))\n", + " if t == 'MDS': transformers.append(manifold.MDS(n_components = 2, normalized_stress='auto'))\n", + " if t == 't-SNE': transformers.append(manifold.TSNE(n_components = 2, perplexity=40, verbose=0))\n", + "\n", + " fig = plt.figure(figsize=(8, 2.5*len(transformers)))\n", + " # and we add one plot per reference point\n", + " gs = fig.add_gridspec(len(transformers), len(model_features))\n", + " fig.subplots_adjust(wspace=0.2, hspace=0.2)\n", + "\n", + " return_layers = list(model_features.keys())\n", + "\n", + " for f in range(len(transformer_funcs)):\n", + "\n", + " for l in range(len(return_layers)):\n", + " layer = return_layers[l]\n", + " feats = model_features[layer].detach().cpu().flatten(1)\n", + " feats_transformed= transformers[f].fit_transform(feats)\n", + "\n", + " amin, amax = feats_transformed.min(), feats_transformed.max()\n", + " amin, amax = (amin + amax) / 2 - (amax - amin) * 5/8, (amin + amax) / 2 + (amax - amin) * 5/8\n", + " ax = plt.subplot(gs[f,l])\n", + " ax.set_xlim([amin, amax])\n", + " ax.set_ylim([amin, amax])\n", + " ax.axis(\"off\")\n", + " #ax.set_title(f'{layer}')\n", + " if f == 0: ax.text(0.5, 1.12, f'{layer}', size=16, ha=\"center\", transform=ax.transAxes)\n", + " if l == 0: ax.text(-0.3, 0.5, transformer_funcs[f], size=16, ha=\"center\", transform=ax.transAxes)\n", + " # Create a discrete color map based on unique labels\n", + " num_colors = len(np.unique(labels))\n", + " cmap = plt.get_cmap('viridis_r', num_colors) # 10 discrete colors\n", + " norm = mpl.colors.BoundaryNorm(np.arange(-0.5,num_colors), cmap.N)\n", + " ax_ = ax.scatter(feats_transformed[:, 0], feats_transformed[:, 1], c=labels, cmap=cmap, norm=norm)\n", + "\n", + " fig.subplots_adjust(right=0.9)\n", + " cbar_ax = fig.add_axes([1.01, 0.18, 0.01, 0.53])\n", + " fig.colorbar(ax_, cax=cbar_ax, ticks=np.linspace(0,9,10))\n", + " plt.show()" + ] + }, + { + "cell_type": "code", + "execution_count": null, + "id": "21f68945", + "metadata": { + "cellView": "form", + "execution": {} + }, + "outputs": [], + "source": [ + "# @title Data retrieval\n", + "\n", + "import os\n", + "import requests\n", + "import hashlib\n", + "\n", + "# Variables for file and download URL\n", + "fnames = [\"standard_model.pth\", \"adversarial_model.pth\", \"recurrent_model.pth\"] # The names of the files to be downloaded\n", + "urls = [\"https://osf.io/s5rt6/download\", \"https://osf.io/qv5eb/download\", \"https://osf.io/6hnwk/download\"] # URLs from where the files will be downloaded\n", + "expected_md5s = [\"2e63c2cd77bc9f1fa67673d956ec910d\", \"25fb34497377921b54368317f68a7aa7\", \"ee5cea3baa264cb78300102fa6ed66e8\"] # MD5 hashes for verifying files integrity\n", + "\n", + "for fname, url, expected_md5 in zip(fnames, urls, expected_md5s):\n", + " if not os.path.isfile(fname):\n", + " try:\n", + " # Attempt to download the file\n", + " r = requests.get(url) # Make a GET request to the specified URL\n", + " except requests.ConnectionError:\n", + " # Handle connection errors during the download\n", + " print(\"!!! Failed to download data !!!\")\n", + " else:\n", + " # No connection errors, proceed to check the response\n", + " if r.status_code != requests.codes.ok:\n", + " # Check if the HTTP response status code indicates a successful download\n", + " print(\"!!! Failed to download data !!!\")\n", + " elif hashlib.md5(r.content).hexdigest() != expected_md5:\n", + " # Verify the integrity of the downloaded file using MD5 checksum\n", + " print(\"!!! Data download appears corrupted !!!\")\n", + " else:\n", + " # If download is successful and data is not corrupted, save the file\n", + " with open(fname, \"wb\") as fid:\n", + " fid.write(r.content) # Write the downloaded content to a file" + ] + }, + { + "cell_type": "code", + "execution_count": null, + "id": "93aeca0a", + "metadata": { + "cellView": "form", + "execution": {} + }, + "outputs": [], + "source": [ + "# @title Figure settings\n", + "\n", + "logging.getLogger('matplotlib.font_manager').disabled = True\n", + "\n", + "%matplotlib inline\n", + "%config InlineBackend.figure_format = 'retina' # perfrom high definition rendering for images and plots\n", + "plt.style.use(\"https://raw.githubusercontent.com/NeuromatchAcademy/course-content/main/nma.mplstyle\")" + ] + }, + { + "cell_type": "code", + "execution_count": null, + "id": "dd8052d5", + "metadata": { + "cellView": "form", + "execution": {} + }, + "outputs": [], + "source": [ + "# @title Set device (GPU or CPU)\n", + "\n", + "# inform the user if the notebook uses GPU or CPU.\n", + "\n", + "def set_device():\n", + " \"\"\"\n", + " Determines and sets the computational device for PyTorch operations based on the availability of a CUDA-capable GPU.\n", + "\n", + " Outputs:\n", + " - device (str): The device that PyTorch will use for computations ('cuda' or 'cpu'). This string can be directly used\n", + " in PyTorch operations to specify the device.\n", + " \"\"\"\n", + "\n", + " device = \"cuda\" if torch.cuda.is_available() else \"cpu\"\n", + " if device != \"cuda\":\n", + " print(\"GPU is not enabled in this notebook. \\n\"\n", + " \"If you want to enable it, in the menu under `Runtime` -> \\n\"\n", + " \"`Hardware accelerator.` and select `GPU` from the dropdown menu\")\n", + " else:\n", + " print(\"GPU is enabled in this notebook. \\n\"\n", + " \"If you want to disable it, in the menu under `Runtime` -> \\n\"\n", + " \"`Hardware accelerator.` and select `None` from the dropdown menu\")\n", + "\n", + " return device\n", + "\n", + "device = set_device()" ] }, { @@ -75,84 +2155,543 @@ "execution_count": null, "id": "c28a92e7-e76c-48de-b574-15a1272717cf", "metadata": { - "cellView": "form", + "cellView": "form", + "execution": {} + }, + "outputs": [], + "source": [ + "# @title Load Slides\n", + "\n", + "from IPython.display import IFrame\n", + "from ipywidgets import widgets\n", + "out = widgets.Output()\n", + "\n", + "link_id = \"8fx23\"\n", + "\n", + "with out:\n", + " print(f\"If you want to download the slides: https://osf.io/download/{link_id}/\")\n", + " display(IFrame(src=f\"https://mfr.ca-1.osf.io/render?url=https://osf.io/{link_id}/?direct%26mode=render%26action=download%26mode=render\", width=730, height=410))\n", + "display(out)" + ] + }, + { + "cell_type": "markdown", + "id": "407ace26", + "metadata": { + "execution": {} + }, + "source": [ + "---\n", + "\n", + "# Intro\n", + "\n", + "Welcome to Tutorial 5 of Day 3 (W1D3) of the NeuroAI course. In this tutorial we are going to look at an exciting method that measures similarity from a slightly different perspective, a temporal one. The prior methods we have looked at were centeed around geometry and spatial representations, where we looked at metrics such as the Euclidean and Mahalanobis distance metrics. However, one thing we often want to study in neuroscience and in AI separately - is the temporal domain. Even more so in our own field of NeuroAI, we often deal with time series of neuronal / biological recordings. One thing you should already have a broad level of awareness of is that end structures can end up looking the same even though the paths taken to arrive at those end structures were very different.\n", + "\n", + "In NeuroAI, we're often confronted with systems that seem to have some sort of overlap and we want to study whether this implies there is a shared computation pairs up with the shared task (we looked at this in detail yesterday in our *Comparing Tasks* day). Today, we will begin by watching a short intro video by Mitchell Ostrow, who will describe his method to compare representations over temporal sequences (the method is called Dynamic Similarity Analysis). Then we are going to introduce three simple dynamical systems and we will explore them from the perspective of Dynamic Similarity Analysis and also describe the conceptual relationship to Representational Similarity Analysis. You will have a short coding exercise on the topic of temporal similarity analysis on three different types of trajectories. \n", + "\n", + "At the end of the tutorial, we will finally look at a further aspect of temporal sequences using RNNs. This is an adaptation of the ideas introduced in Tutorial 2 but now based around recurrent representations from RNNs. We hope you enjoy this tutorial today and that it gets you thinking not just what similarity values mean, but which ones are appropriate (here, from a spatial or temporal perspective). We aim to continually expand the tools necessary in the NeuroAI researcher's toolkit. Complementary tools, when applicable, can often tell a far richer story than just using a single method." + ] + }, + { + "cell_type": "code", + "execution_count": null, + "id": "b5d6178f-ddf5-41ae-b676-15e452dc8b78", + "metadata": { + "cellView": "form", + "execution": {} + }, + "outputs": [], + "source": [ + "# @title Video 1: Dynamical Similarity Analysis\n", + "\n", + "from ipywidgets import widgets\n", + "from IPython.display import YouTubeVideo\n", + "from IPython.display import IFrame\n", + "from IPython.display import display\n", + "\n", + "class PlayVideo(IFrame):\n", + " def __init__(self, id, source, page=1, width=400, height=300, **kwargs):\n", + " self.id = id\n", + " if source == 'Bilibili':\n", + " src = f'https://player.bilibili.com/player.html?bvid={id}&page={page}'\n", + " elif source == 'Osf':\n", + " src = f'https://mfr.ca-1.osf.io/render?url=https://osf.io/download/{id}/?direct%26mode=render'\n", + " super(PlayVideo, self).__init__(src, width, height, **kwargs)\n", + "\n", + "def display_videos(video_ids, W=400, H=300, fs=1):\n", + " tab_contents = []\n", + " for i, video_id in enumerate(video_ids):\n", + " out = widgets.Output()\n", + " with out:\n", + " if video_ids[i][0] == 'Youtube':\n", + " video = YouTubeVideo(id=video_ids[i][1], width=W,\n", + " height=H, fs=fs, rel=0)\n", + " print(f'Video available at https://youtube.com/watch?v={video.id}')\n", + " else:\n", + " video = PlayVideo(id=video_ids[i][1], source=video_ids[i][0], width=W,\n", + " height=H, fs=fs, autoplay=False)\n", + " if video_ids[i][0] == 'Bilibili':\n", + " print(f'Video available at https://www.bilibili.com/video/{video.id}')\n", + " elif video_ids[i][0] == 'Osf':\n", + " print(f'Video available at https://osf.io/{video.id}')\n", + " display(video)\n", + " tab_contents.append(out)\n", + " return tab_contents\n", + "\n", + "video_ids = [('Youtube', 'FHikIsQFQvM'), ('Bilibili', 'BV1qm421g7hV')]\n", + "tab_contents = display_videos(video_ids, W=854, H=480)\n", + "tabs = widgets.Tab()\n", + "tabs.children = tab_contents\n", + "for i in range(len(tab_contents)):\n", + " tabs.set_title(i, video_ids[i][0])\n", + "display(tabs)" + ] + }, + { + "cell_type": "code", + "execution_count": null, + "id": "d2ce83bc-7e86-44d3-a40a-4ad46fd5a6df", + "metadata": { + "cellView": "form", + "execution": {} + }, + "outputs": [], + "source": [ + "# @title Submit your feedback\n", + "content_review(f\"{feedback_prefix}_DSA_video\")" + ] + }, + { + "cell_type": "markdown", + "id": "937041e9", + "metadata": { + "execution": {} + }, + "source": [ + "## Section 1: Visualization of Three Temporal Sequences\n", + "\n", + "We are going to be working with the analysis of three temporal sequences today:\n", + "\n", + "* The circular time series (`Circle`)\n", + "* The oval time series (`Oval`)\n", + "* The random walk (`R-Walk`)\n", + "\n", + "The random walk is going to be broadly *oval shaped*. Now, what do you think, from a geometric perspective, might result from a spatial analysis of these three different *representations*? You will probably assume because the random walk has an oval shape and there is also an oval time series (that's not a random walk) that these would result in a higher spatial similarity. You'd be right to assume this. However, what we're going to do with the `Circle` and `Oval` time series is to include an oscillator at a specific frequency, shared amongst these two time series. In effect, this means that although when plotted in totality the shapes are different, during the dynamic (temporal) evolution of these time series, a very similar shared pattern is emerging. We want methods that are sensitive to these changes to give higher scores for time series sharing similar temporal patterns (e.g. both containing oscillating patterns at similar frequences) rather than just a random walk that resembles (geometrically) one of the other shapes (`R-Walk`). Before we continue, we'll just define this random walk in a little more detail. A random walk at a specific location / timepoint takes a random step of fixed length in a specific direction, but this can be broadly controlled to resemble geometric shapes. We've taken a random walk and then reframed it to be similar in shape to `Oval`. \n", + "\n", + "Let's now visualize these three temporal sequences, to make the previous paragraph a little clearer." + ] + }, + { + "cell_type": "code", + "execution_count": null, + "id": "b57dfe1a", + "metadata": { + "execution": {} + }, + "outputs": [], + "source": [ + "# Circle\n", + "r = .1; # rotation\n", + "A = np.array([[np.cos(r), np.sin(r)], [-np.sin(r), np.cos(r)]])\n", + "B = np.array([[1, 0], [0, 1]])\n", + "\n", + "trajectory = generate_2d_random_process(A, B)\n", + "trajectory_circle = trajectory\n", + "\n", + "# Oval\n", + "r = .1; # rotation\n", + "s = 4; # scaling\n", + "S = np.array([[1, 0], [0, s]])\n", + "Si = np.array([[1, 0], [0, 1/s]])\n", + "V = np.array([[1, 1], [-1, 1]])/np.sqrt(2)\n", + "Vi = np.array([[1, -1], [1, 1]])/np.sqrt(2)\n", + "R = np.array([[np.cos(r), np.sin(r)], [-np.sin(r), np.cos(r)]])\n", + "A = np.linalg.multi_dot([V,Si,R,S,Vi])\n", + "B = np.array([[1, 0], [0, 1]])\n", + "\n", + "trajectory = generate_2d_random_process(A, B)\n", + "trajectory_oval = trajectory\n", + "\n", + "# R-Walk (random walk)\n", + "r = .1; # rotation\n", + "A = np.array([[.9, 0], [0, .9]])\n", + "c = -.95; # correlation coefficient\n", + "B = np.array([[1, c], [0, np.sqrt(1-c*c)]])\n", + "\n", + "trajectory = generate_2d_random_process(A, B)\n", + "trajectory_walk = trajectory" + ] + }, + { + "cell_type": "markdown", + "id": "113a0dee", + "metadata": { + "execution": {} + }, + "source": [ + "Can you see how the spatial / geometric similarity of `R-Walk` and `Oval` are more similar, but the oscillations during the temporal sequence are shared between `Circle` and `Oval`? Let's run Dynamic Similarity Analysis on these temporal sequences and see what scores are returned.\n", + "\n", + "We calcularted `trajectory_oval` and `trajectory_circle` above, so let's plug these into the `DSA` function imported earlier (in the helper function cell) and see what the similarity score is." + ] + }, + { + "cell_type": "code", + "execution_count": null, + "id": "c3e36d59", + "metadata": { + "execution": {} + }, + "outputs": [], + "source": [ + "# Define the DSA computation class\n", + "dsa = DSA(X=trajectory_oval, Y=trajectory_circle, n_delays=1)\n", + "\n", + "# Call the fit method and save the result\n", + "similarities_oval_circle = dsa.fit_score()\n", + "\n", + "print(f\"DSA similarity between Oval and Circle: {similarities_oval_circle:.4f}\")" + ] + }, + { + "cell_type": "markdown", + "id": "9f1fb622", + "metadata": { + "execution": {} + }, + "source": [ + "## Multi-way Comparison\n", + "\n", + "We're now going to run DSA on our three trajectories and fit the model, returning the scores which we can investigate by plotting a confusion matrix with a heatmap to show the DSA scores." + ] + }, + { + "cell_type": "code", + "execution_count": null, + "id": "6ee9e8e8", + "metadata": { "execution": {} }, "outputs": [], "source": [ - "# @title Bonus material slides\n", + "n_delays = 1\n", + "delay_interval = 1\n", "\n", - "from IPython.display import IFrame\n", - "from ipywidgets import widgets\n", - "out = widgets.Output()\n", + "models = [trajectory_circle, trajectory_oval, trajectory_walk]\n", + "dsa = DSA(models, n_delays=n_delays, delay_interval=delay_interval)\n", + "similarities = dsa.fit_score()" + ] + }, + { + "cell_type": "code", + "execution_count": null, + "id": "18318ddb", + "metadata": { + "execution": {} + }, + "outputs": [], + "source": [ + "labels = ['Circle', 'Oval', 'Walk']\n", + "data = np.random.rand(len(labels), len(labels))\n", + "ax = sns.heatmap(similarities, xticklabels=labels, yticklabels=labels)\n", + "cbar = ax.collections[0].colorbar\n", + "cbar.ax.set_ylabel('DSA Score');\n", + "plt.title(\"Dynamic Similarity Analysis Score among Trajectories\");" + ] + }, + { + "cell_type": "markdown", + "id": "ffd49b4b", + "metadata": { + "execution": {} + }, + "source": [ + "This heatmap across the three model comparisons shows that the DSA scores between (`Walk` and `Circle`) and (`Walk` and `Oval`) to be (relatively) high, while the comparison between (`Circle` and `Oval`) is very low. Please note that this confusion matrix is symmetrical, meaning that the analysis between `trajectory_A` and `trajectory_B` returns the same dynamic similarity score as `trajectory_B` and `trajectory_A`. This is a common feature we have also seen in comparison metrics in standard RSA. One thing to note in the calculation of DSA is that comparisons among identical trajectories is `0`. This is unlike in RSA where we expect the correlation among the same stimuli to be `1.0`. This is why we see black squares along the diagonal.\n", "\n", - "link_id = \"8fx23\"\n", + "Let's put our thinking caps on for a moment: This isn't really the result we would have expected, right? What do you think might be going on here? Have a look back at the *hyperparameters* and try to make an educated guess!" + ] + }, + { + "cell_type": "markdown", + "id": "d0ff5faa", + "metadata": { + "execution": {} + }, + "source": [ + "## DSA Hyperparameters (`n_delays` and `delay_interval`)\n", "\n", - "with out:\n", - " print(f\"If you want to download the slides: https://osf.io/download/{link_id}/\")\n", - " display(IFrame(src=f\"https://mfr.ca-1.osf.io/render?url=https://osf.io/{link_id}/?direct%26mode=render%26action=download%26mode=render\", width=730, height=410))\n", - "display(out)" + "We'll now give you a hint as to why the setting of these hyperparameters is important when considering dynamic similarity analysis. The oscillators we have placed in the trajectories of `Circle` and `Oval` are not immediately apparent if you study only the previous time step for each element. It's only when considering the recurring pattern across a few different temporal delays and at what delay interval you want those to be, that we would expect to be able to detect recurring oscillations that provide us with the information we need to conclude that `Oval` and `Circle` are actually *dynamically* similar.\n", + "\n", + "You should change the values below to be more sensible hyperparameter settings and re-run the model and plot the new confusion matrix. Try using `n_delays` equal to `20` and `delay_interval` equal to `10`. Don't forget to define `models` (see above example if you get stuck)." + ] + }, + { + "cell_type": "markdown", + "id": "9d8d4c03", + "metadata": { + "colab_type": "text", + "execution": {} + }, + "source": [ + "```python\n", + "#################################################\n", + "## TODO for students: fill in the missing parts ##\n", + "raise NotImplementedError(\"Student exercise\")\n", + "#################################################\n", + "\n", + "n_delays = ...\n", + "delay_interval = ...\n", + "\n", + "models = ...\n", + "dsa = DSA(...)\n", + "similarities = ...\n", + "\n", + "labels = ['Circle', 'Oval', 'Walk']\n", + "ax = sns.heatmap(similarities, xticklabels=labels, yticklabels=labels)\n", + "cbar = ax.collections[0].colorbar\n", + "cbar.ax.set_ylabel('DSA Score');\n", + "plt.title(\"Dynamic Similarity Analysis Score among Trajectories\");\n", + "\n", + "```" ] }, { "cell_type": "code", "execution_count": null, - "id": "b5d6178f-ddf5-41ae-b676-15e452dc8b78", + "id": "a6377c65", + "metadata": { + "execution": {} + }, + "outputs": [], + "source": [ + "# to_remove solution\n", + "\n", + "n_delays = 20\n", + "delay_interval = 10\n", + "\n", + "models = [trajectory_circle, trajectory_oval, trajectory_walk]\n", + "dsa = DSA(models, n_delays=n_delays, delay_interval=delay_interval)\n", + "similarities = dsa.fit_score()\n", + "\n", + "labels = ['Circle', 'Oval', 'Walk']\n", + "ax = sns.heatmap(similarities, xticklabels=labels, yticklabels=labels)\n", + "cbar = ax.collections[0].colorbar\n", + "cbar.ax.set_ylabel('DSA Score');\n", + "plt.title(\"Dynamic Similarity Analysis Score among Trajectories\");" + ] + }, + { + "cell_type": "markdown", + "id": "04b0e32f", + "metadata": { + "execution": {} + }, + "source": [ + "What do you see now? We now see a much more sensible result. The DSA scores have now correctly identified that `Oval` and `Circle` are very dynamically similar! They have the highest color score according to the colorbar on the side. As is always good practice in science, let's have a look inside the `similarities` variable to look at the exact values and confirm what we see in the figure above." + ] + }, + { + "cell_type": "code", + "execution_count": null, + "id": "55fa4065", + "metadata": { + "execution": {} + }, + "outputs": [], + "source": [ + "similarities" + ] + }, + { + "cell_type": "markdown", + "id": "59cb799f", + "metadata": { + "execution": {} + }, + "source": [ + "## Comparison with RSA\n", + "\n", + "At the start of this exercise, we saw three different trajectories and pointed out that the random walk and oval shapes were most similar from a geometric perspective, both ellipse-like but not similar in their dynamic similarity. To better show the difference between DSA and RSA, we encourage you to run another comparison where we consider each time step to be a pair in the X,Y space and we will look at the the similarity between of `Oval` with both `Circle` and `Walk`. If our understanding is correct, then RSA should indicate a higher geometric similarity between (`Oval` and `Walk`) than with (`Oval` and `Circle`)." + ] + }, + { + "cell_type": "markdown", + "id": "87cf4e6e", + "metadata": { + "execution": {} + }, + "source": [ + "---\n", + "# (Bonus) Representational Geometry of Recurrent Models\n", + "\n", + "Transformations of representations can occur across space and time, e.g., layers of a neural network and steps of recurrent computation. We've looked at the temporal dimension today and earlier today in the other tutorials we looked mainly at spatial representations.\n", + "\n", + "Just as the layers in a feedforward DNN can change the representational geometry to perform a task, steps in a recurrent network can reuse the same layer to reach the same computational depth.\n", + "\n", + "In this section, we look at a very simple recurrent network with only 2650 trainable parameters." + ] + }, + { + "cell_type": "markdown", + "id": "3d613edd", + "metadata": { + "execution": {} + }, + "source": [ + "Here is a diagram of this network:\n", + "\n", + "![Recurrent convolutional neural network](https://github.com/neuromatch/NeuroAI_Course/blob/main/tutorials/W1D3_ComparingArtificialAndBiologicalNetworks/static/rcnn_tutorial.png?raw=true)" + ] + }, + { + "cell_type": "code", + "execution_count": null, + "id": "2f0443d3", "metadata": { "cellView": "form", "execution": {} }, "outputs": [], "source": [ - "# @title Video 1: Dynamical Similarity Analysis\n", + "# @title Grab a recurrent model\n", "\n", - "from ipywidgets import widgets\n", - "from IPython.display import YouTubeVideo\n", - "from IPython.display import IFrame\n", - "from IPython.display import display\n", + "args = build_args()\n", + "train_loader, test_loader = fetch_dataloaders(args)\n", + "path = \"recurrent_model.pth\"\n", + "model_recurrent = torch.load(path, map_location=args.device, weights_only=False)" + ] + }, + { + "cell_type": "markdown", + "id": "d463c3a9", + "metadata": { + "execution": {} + }, + "source": [ + "
We can first look at the computational steps in this network. As we see below, the `conv2` operation is repeated for 5 times." + ] + }, + { + "cell_type": "code", + "execution_count": null, + "id": "6bfabacd", + "metadata": { + "execution": {} + }, + "outputs": [], + "source": [ + "train_nodes, _ = get_graph_node_names(model_recurrent)\n", + "print('The computational steps in the network are: \\n', train_nodes)" + ] + }, + { + "cell_type": "markdown", + "id": "1d410c3a", + "metadata": { + "execution": {} + }, + "source": [ + "Plotting the RDMs after each application of the `conv2` operation shows the same progressive emergence of the blockwise structure around the diagonal, mediating the correct classification in this task." + ] + }, + { + "cell_type": "code", + "execution_count": null, + "id": "30249608", + "metadata": { + "execution": {} + }, + "outputs": [], + "source": [ + "imgs, labels = sample_images(test_loader, n=20)\n", + "return_layers = ['conv2', 'conv2_1', 'conv2_2', 'conv2_3', 'conv2_4']\n", + "model_features = extract_features(model_recurrent, imgs.to(device), return_layers)\n", "\n", - "class PlayVideo(IFrame):\n", - " def __init__(self, id, source, page=1, width=400, height=300, **kwargs):\n", - " self.id = id\n", - " if source == 'Bilibili':\n", - " src = f'https://player.bilibili.com/player.html?bvid={id}&page={page}'\n", - " elif source == 'Osf':\n", - " src = f'https://mfr.ca-1.osf.io/render?url=https://osf.io/download/{id}/?direct%26mode=render'\n", - " super(PlayVideo, self).__init__(src, width, height, **kwargs)\n", + "rdms, rdms_dict = calc_rdms(model_features)\n", + "plot_rdms(rdms_dict)" + ] + }, + { + "cell_type": "markdown", + "id": "248329c3", + "metadata": { + "execution": {} + }, + "source": [ + "We can also look at how the different dimensionality reduction techniques capture the dynamics of changing geometry." + ] + }, + { + "cell_type": "code", + "execution_count": null, + "id": "1b0e2cdf", + "metadata": { + "execution": {} + }, + "outputs": [], + "source": [ + "return_layers = ['conv2', 'conv2_1', 'conv2_2', 'conv2_3', 'conv2_4']\n", "\n", - "def display_videos(video_ids, W=400, H=300, fs=1):\n", - " tab_contents = []\n", - " for i, video_id in enumerate(video_ids):\n", - " out = widgets.Output()\n", - " with out:\n", - " if video_ids[i][0] == 'Youtube':\n", - " video = YouTubeVideo(id=video_ids[i][1], width=W,\n", - " height=H, fs=fs, rel=0)\n", - " print(f'Video available at https://youtube.com/watch?v={video.id}')\n", - " else:\n", - " video = PlayVideo(id=video_ids[i][1], source=video_ids[i][0], width=W,\n", - " height=H, fs=fs, autoplay=False)\n", - " if video_ids[i][0] == 'Bilibili':\n", - " print(f'Video available at https://www.bilibili.com/video/{video.id}')\n", - " elif video_ids[i][0] == 'Osf':\n", - " print(f'Video available at https://osf.io/{video.id}')\n", - " display(video)\n", - " tab_contents.append(out)\n", - " return tab_contents\n", + "imgs, labels = sample_images(test_loader, n=50) #grab 500 samples from the test set\n", + "model_features = extract_features(model_recurrent, imgs.to(device), return_layers)\n", "\n", - "video_ids = [('Youtube', 'FHikIsQFQvM'), ('Bilibili', 'BV1qm421g7hV')]\n", - "tab_contents = display_videos(video_ids, W=854, H=480)\n", - "tabs = widgets.Tab()\n", - "tabs.children = tab_contents\n", - "for i in range(len(tab_contents)):\n", - " tabs.set_title(i, video_ids[i][0])\n", - "display(tabs)" + "plot_dim_reduction(model_features, labels, transformer_funcs =['PCA', 'MDS', 't-SNE'])" + ] + }, + { + "cell_type": "markdown", + "id": "1aaf5f4a", + "metadata": { + "execution": {} + }, + "source": [ + "## Representational geometry paths for recurrent models\n", + "\n", + "We can look at the model's recurrent computational steps as a path in the representational geometry space." ] }, { "cell_type": "code", "execution_count": null, - "id": "d2ce83bc-7e86-44d3-a40a-4ad46fd5a6df", + "id": "7f88274a", + "metadata": { + "execution": {} + }, + "outputs": [], + "source": [ + "imgs, labels = sample_images(test_loader, n=50) #grab 500 samples from the test set\n", + "model_features_recurrent = extract_features(model_recurrent, imgs.to(device), return_layers='all')\n", + "\n", + "#rdms, rdms_dict = calc_rdms(model_features)\n", + "features = {'recurrent model': model_features_recurrent}\n", + "model_colors = {'recurrent model': 'y'}\n", + "\n", + "rep_path(features, model_colors, labels)" + ] + }, + { + "cell_type": "markdown", + "id": "5c3fbd44", + "metadata": { + "execution": {} + }, + "source": [ + "We can also look at the paths taken by the feedforward and the recurrent models and compare them." + ] + }, + { + "cell_type": "markdown", + "id": "b25a8cc6", + "metadata": { + "execution": {} + }, + "source": [ + "If you recall back to Tutorial 2, we compared a standard feedward model's representations. We can extend our analysis of the analysis of the recurrent model's representations by making a side-by-side comparison. We can also look at the paths taken by the feedforward and the recurrent models and compare them. What we see above in the case of the recurrent model is the fast-shifting path through the geometric space from inputs to labels. This illustration serves to show that models take many different paths and can have very diverse underlying mechanisms but still arrive at a superficially similar output at the end of training." + ] + }, + { + "cell_type": "code", + "execution_count": null, + "id": "c904e840", "metadata": { "cellView": "form", "execution": {} @@ -160,7 +2699,19 @@ "outputs": [], "source": [ "# @title Submit your feedback\n", - "content_review(f\"{feedback_prefix}_DSA_video\")" + "content_review(f\"{feedback_prefix}_recurrent_models\")" + ] + }, + { + "cell_type": "markdown", + "id": "3ed56061", + "metadata": { + "execution": {} + }, + "source": [ + "# The Big Picture\n", + "\n", + "Today, you've looked at what it means to measure representations from different systems. These systems can be of the same type (multiple brain systems, multiple artificial models) as well as with representations between these systems. In NeuroAI, we're especially interested in such comparisons, comparing representational systems in deep learning networks, for instance, to brain recordings recorded while those biological systems experienced / perceived the same set of stimuli. Comparisons can be geometric / spatial or they can be temporal. Today, we looked at Dynamic Similarity Analysis, a method used to be able to capture the dependencies among trajectories, not just capturing the similarity of the full temporal sequence upon completion of the temporal sequence. It's often important to take into account multiple dimensions of representational similarity. A combination of tools is definitely required in the NeuroAI researcher's toolkit. We hope you have many chances to use these tools in your future work as NeuroAI researchers." ] } ], @@ -191,7 +2742,7 @@ "name": "python", "nbconvert_exporter": "python", "pygments_lexer": "ipython3", - "version": "3.9.19" + "version": "3.9.22" } }, "nbformat": 4, diff --git a/tutorials/W1D3_ComparingArtificialAndBiologicalNetworks/solutions/W1D3_Tutorial5_Solution_0467919d.py b/tutorials/W1D3_ComparingArtificialAndBiologicalNetworks/solutions/W1D3_Tutorial5_Solution_0467919d.py new file mode 100644 index 000000000..88c68e281 --- /dev/null +++ b/tutorials/W1D3_ComparingArtificialAndBiologicalNetworks/solutions/W1D3_Tutorial5_Solution_0467919d.py @@ -0,0 +1,13 @@ + +n_delays = 20 +delay_interval = 10 + +models = [trajectory_circle, trajectory_oval, trajectory_walk] +dsa = DSA(models, n_delays=n_delays, delay_interval=delay_interval) +similarities = dsa.fit_score() + +labels = ['Circle', 'Oval', 'Walk'] +ax = sns.heatmap(similarities, xticklabels=labels, yticklabels=labels) +cbar = ax.collections[0].colorbar +cbar.ax.set_ylabel('DSA Score'); +plt.title("Dynamic Similarity Analysis Score among Trajectories"); \ No newline at end of file diff --git a/tutorials/W1D3_ComparingArtificialAndBiologicalNetworks/static/W1D3_Tutorial4_Solution_1ac2083f_0.png b/tutorials/W1D3_ComparingArtificialAndBiologicalNetworks/static/W1D3_Tutorial4_Solution_1ac2083f_0.png index b3f67e4b8..8cc3e6f97 100644 Binary files a/tutorials/W1D3_ComparingArtificialAndBiologicalNetworks/static/W1D3_Tutorial4_Solution_1ac2083f_0.png and b/tutorials/W1D3_ComparingArtificialAndBiologicalNetworks/static/W1D3_Tutorial4_Solution_1ac2083f_0.png differ diff --git a/tutorials/W1D3_ComparingArtificialAndBiologicalNetworks/static/W1D3_Tutorial5_Solution_0467919d_0.png b/tutorials/W1D3_ComparingArtificialAndBiologicalNetworks/static/W1D3_Tutorial5_Solution_0467919d_0.png new file mode 100644 index 000000000..bf42ab982 Binary files /dev/null and b/tutorials/W1D3_ComparingArtificialAndBiologicalNetworks/static/W1D3_Tutorial5_Solution_0467919d_0.png differ diff --git a/tutorials/W1D3_ComparingArtificialAndBiologicalNetworks/student/W1D3_Tutorial4.ipynb b/tutorials/W1D3_ComparingArtificialAndBiologicalNetworks/student/W1D3_Tutorial4.ipynb index 43a8c3119..5f8510a44 100644 --- a/tutorials/W1D3_ComparingArtificialAndBiologicalNetworks/student/W1D3_Tutorial4.ipynb +++ b/tutorials/W1D3_ComparingArtificialAndBiologicalNetworks/student/W1D3_Tutorial4.ipynb @@ -17,7 +17,7 @@ "execution": {} }, "source": [ - "# Tutorial 4: Representational geometry & noise\n", + "# (Bonus) Tutorial 4: Representational geometry & noise\n", "\n", "**Week 1, Day 3: Comparing Artificial And Biological Networks**\n", "\n", @@ -25,9 +25,9 @@ "\n", "__Content creators:__ Wenxuan Guo, Heiko Schütt\n", "\n", - "__Content reviewers:__ Alish Dipani, Samuele Bolotta, Yizhou Chen, RyeongKyung Yoon, Ruiyi Zhang, Lily Chamakura, Hlib Solodzhuk\n", + "__Content reviewers:__ Alish Dipani, Samuele Bolotta, Yizhou Chen, RyeongKyung Yoon, Ruiyi Zhang, Lily Chamakura, Hlib Solodzhuk, Alex Murphy\n", "\n", - "__Production editors:__ Konstantine Tsafatinos, Ella Batty, Spiros Chavlis, Samuele Bolotta, Hlib Solodzhuk, Patrick Mineault\n", + "__Production editors:__ Konstantine Tsafatinos, Ella Batty, Spiros Chavlis, Samuele Bolotta, Hlib Solodzhuk, Patrick Mineault, Alex Murphy\n", "\n", "Acknowledgments: the tutorial outline was written by Heiko Schütt. The content was greatly improved by discussions with Heiko, Hlib, and Alish, and the insightful illustrations presented in the paper by Walther et al. (2016)\n" ] @@ -61,7 +61,7 @@ "\n", "5. Using random projections to estimate distances. This section introduces the Johnson–Lindenstrauss Lemma, which states that random projections maintain the integrity of distance estimates in a lower-dimensional space. This concept is crucial for reducing dimensionality while preserving the relational structure of the data.\n", "\n", - "We will adhere to the notational conventions established by [Walther et al. (2016)](https://pubmed.ncbi.nlm.nih.gov/26707889/) for all discussed distance measures. " + "We will adhere to the notational conventions established by [Walther et al. (2016)](https://pubmed.ncbi.nlm.nih.gov/26707889/) for all discussed distance measures." ] }, { @@ -644,6 +644,72 @@ "display(tabs)" ] }, + { + "cell_type": "markdown", + "id": "b64eaea5", + "metadata": { + "execution": {} + }, + "source": [ + "The video below is additional information in more detail which was previously part of the introductory video for this course day. It provides some useful further information on the technical details mentioned during these tutorials. Please feel free to check it out and use it as a resource if you want to learn more or if you want to get a deeper understanding on some of the important details." + ] + }, + { + "cell_type": "code", + "execution_count": null, + "id": "200235dc", + "metadata": { + "cellView": "form", + "execution": {} + }, + "outputs": [], + "source": [ + "# @title Video 2 (BONUS): Extended Intro Video\n", + "\n", + "from ipywidgets import widgets\n", + "from IPython.display import YouTubeVideo\n", + "from IPython.display import IFrame\n", + "from IPython.display import display\n", + "\n", + "class PlayVideo(IFrame):\n", + " def __init__(self, id, source, page=1, width=400, height=300, **kwargs):\n", + " self.id = id\n", + " if source == 'Bilibili':\n", + " src = f'https://player.bilibili.com/player.html?bvid={id}&page={page}'\n", + " elif source == 'Osf':\n", + " src = f'https://mfr.ca-1.osf.io/render?url=https://osf.io/download/{id}/?direct%26mode=render'\n", + " super(PlayVideo, self).__init__(src, width, height, **kwargs)\n", + "\n", + "def display_videos(video_ids, W=400, H=300, fs=1):\n", + " tab_contents = []\n", + " for i, video_id in enumerate(video_ids):\n", + " out = widgets.Output()\n", + " with out:\n", + " if video_ids[i][0] == 'Youtube':\n", + " video = YouTubeVideo(id=video_ids[i][1], width=W,\n", + " height=H, fs=fs, rel=0)\n", + " print(f'Video available at https://youtube.com/watch?v={video.id}')\n", + " else:\n", + " video = PlayVideo(id=video_ids[i][1], source=video_ids[i][0], width=W,\n", + " height=H, fs=fs, autoplay=False)\n", + " if video_ids[i][0] == 'Bilibili':\n", + " print(f'Video available at https://www.bilibili.com/video/{video.id}')\n", + " elif video_ids[i][0] == 'Osf':\n", + " print(f'Video available at https://osf.io/{video.id}')\n", + " display(video)\n", + " tab_contents.append(out)\n", + " return tab_contents\n", + "\n", + "#assert 1 == 0, \"Upload this video\"\n", + "video_ids = [('Youtube', 'm9srqTx5ci0'), ('Bilibili', 'BV1meVjz3Eeh')]\n", + "tab_contents = display_videos(video_ids, W=854, H=480)\n", + "tabs = widgets.Tab()\n", + "tabs.children = tab_contents\n", + "for i in range(len(tab_contents)):\n", + " tabs.set_title(i, video_ids[i][0])\n", + "display(tabs)" + ] + }, { "cell_type": "code", "execution_count": null, @@ -1461,6 +1527,20 @@ "3. Cross-validated distance estimators (cross-validated Euclidean or Mahalanobis distance) can remove the positive bias introduced by noise.\n", "4. The Johnson–Lindenstrauss Lemma shows that random projections preserve the Euclidean distance with some distortions. Crucially, the distortion does not depend on the dimensionality of the original space." ] + }, + { + "cell_type": "markdown", + "id": "40936ec4", + "metadata": { + "execution": {} + }, + "source": [ + "# The Big Picture\n", + "\n", + "The goal of this tutorial is to provide you with some mathematical tools for your NeuroAI researcher toolkit. What happens when you pull out the Euclidean metric from your toolkit and, while this has worked well in the past, suddenly in different scenarios it doesn't seem to perform so well. Aha, you spot the potential for correlated noise and you reach deeper into your toolkit and pull out the Mahalanobis metric, which implicitly undoes the correlated noise in the model. Perhaps you can't even tell if there is any correlated noise in your data and you try with both metrics, and Mahalanobis works well but Euclidean does not, that can be a hunch that leads you to confirm the presence of correlated noise. \n", + "\n", + "Sometimes you might be faced with dimensionalities that are just too high to practically deal with in your use case. Then, why not recall what you learned about how random projections can reduce the dimensionality of a feature space and be largely resistant to corrupting the applicability of distance metrics. These metrics also might work better in this lower dimensional space. If you apply this idea and need to justify it, just reach into your NeuroAI toolkit and pull out the Johnson-Lindenstrauss Lemma as your justification." + ] } ], "metadata": { @@ -1491,7 +1571,7 @@ "name": "python", "nbconvert_exporter": "python", "pygments_lexer": "ipython3", - "version": "3.9.19" + "version": "3.9.22" } }, "nbformat": 4, diff --git a/tutorials/W1D3_ComparingArtificialAndBiologicalNetworks/student/W1D3_Tutorial5.ipynb b/tutorials/W1D3_ComparingArtificialAndBiologicalNetworks/student/W1D3_Tutorial5.ipynb index 77480d80e..191d2f283 100644 --- a/tutorials/W1D3_ComparingArtificialAndBiologicalNetworks/student/W1D3_Tutorial5.ipynb +++ b/tutorials/W1D3_ComparingArtificialAndBiologicalNetworks/student/W1D3_Tutorial5.ipynb @@ -17,17 +17,17 @@ "execution": {} }, "source": [ - "# Bonus Material: Dynamical similarity analysis (DSA)\n", + "# Tutorial 5: Dynamical Similarity Analysis (DSA)\n", "\n", "**Week 1, Day 3: Comparing Artificial And Biological Networks**\n", "\n", "**By Neuromatch Academy**\n", "\n", - "__Content creators:__ Mitchell Ostrow\n", + "__Content creators:__ Mitchell Ostrow, Alex Murphy\n", "\n", - "__Content reviewers:__ Xaq Pitkow, Hlib Solodzhuk\n", + "__Content reviewers:__ Xaq Pitkow, Hlib Solodzhuk, Alex Murphy\n", "\n", - "__Production editors:__ Konstantine Tsafatinos, Ella Batty, Spiros Chavlis, Samuele Bolotta, Hlib Solodzhuk, Patrick Mineault\n" + "__Production editors:__ Konstantine Tsafatinos, Ella Batty, Spiros Chavlis, Samuele Bolotta, Hlib Solodzhuk, Patrick Mineault, Alex Murphy\n" ] }, { @@ -52,7 +52,7 @@ "source": [ "# @title Install and import feedback gadget\n", "\n", - "!pip install vibecheck --quiet\n", + "!pip install vibecheck rsatoolbox --quiet\n", "\n", "from vibecheck import DatatopsContentReviewContainer\n", "def content_review(notebook_section: str):\n", @@ -67,7 +67,2087 @@ " ).render()\n", "\n", "\n", - "feedback_prefix = \"W1D3_Bonus\"" + "feedback_prefix = \"W1D5_DSA\"" + ] + }, + { + "cell_type": "code", + "execution_count": null, + "id": "ef9abaa3", + "metadata": { + "cellView": "form", + "execution": {} + }, + "outputs": [], + "source": [ + "# @title Helper functions\n", + "\n", + "import numpy as np\n", + "import matplotlib.pyplot as plt\n", + "import seaborn as sns\n", + "\n", + "def generate_2d_random_process(A, B, T=1000):\n", + " \"\"\"\n", + " Generates a 2D random process with the equation x(t+1) = A.x(t) + B.noise.\n", + "\n", + " Args:\n", + " A: 2x2 transition matrix.\n", + " B: 2x2 noise scaling matrix.\n", + " T: Number of time steps.\n", + "\n", + " Returns:\n", + " A NumPy array of shape (T+1, 2) representing the trajectory.\n", + " \"\"\"\n", + " # Assuming equilibrium distribution is zero mean and identity covariance for simplicity.\n", + " # You may adjust this according to your actual equilibrium distribution\n", + " x = np.zeros(2)\n", + "\n", + " trajectory = [x.copy()] # Initialize with x(0)\n", + " for t in range(T):\n", + " noise = np.random.normal(size=2) # Standard normal noise\n", + " x = np.dot(A, x) + np.dot(B, noise)\n", + " trajectory.append(x.copy())\n", + " return np.array(trajectory)\n", + "\n", + "\"\"\"This module computes the Havok DMD model for a given dataset.\"\"\"\n", + "import torch\n", + "\n", + "def embed_signal_torch(data, n_delays, delay_interval=1):\n", + " \"\"\"\n", + " Create a delay embedding from the provided tensor data.\n", + "\n", + " Parameters\n", + " ----------\n", + " data : torch.tensor\n", + " The data from which to create the delay embedding. Must be either: (1) a\n", + " 2-dimensional array/tensor of shape T x N where T is the number\n", + " of time points and N is the number of observed dimensions\n", + " at each time point, or (2) a 3-dimensional array/tensor of shape\n", + " K x T x N where K is the number of \"trials\" and T and N are\n", + " as defined above.\n", + "\n", + " n_delays : int\n", + " Parameter that controls the size of the delay embedding. Explicitly,\n", + " the number of delays to include.\n", + "\n", + " delay_interval : int\n", + " The number of time steps between each delay in the delay embedding. Defaults\n", + " to 1 time step.\n", + " \"\"\"\n", + " if isinstance(data, np.ndarray):\n", + " data = torch.from_numpy(data)\n", + " device = data.device\n", + "\n", + " if data.shape[int(data.ndim==3)] - (n_delays - 1)*delay_interval < 1:\n", + " raise ValueError(\"The number of delays is too large for the number of time points in the data!\")\n", + "\n", + " # initialize the embedding\n", + " if data.ndim == 3:\n", + " embedding = torch.zeros((data.shape[0], data.shape[1] - (n_delays - 1)*delay_interval, data.shape[2]*n_delays)).to(device)\n", + " else:\n", + " embedding = torch.zeros((data.shape[0] - (n_delays - 1)*delay_interval, data.shape[1]*n_delays)).to(device)\n", + "\n", + " for d in range(n_delays):\n", + " index = (n_delays - 1 - d)*delay_interval\n", + " ddelay = d*delay_interval\n", + "\n", + " if data.ndim == 3:\n", + " ddata = d*data.shape[2]\n", + " embedding[:,:, ddata: ddata + data.shape[2]] = data[:,index:data.shape[1] - ddelay]\n", + " else:\n", + " ddata = d*data.shape[1]\n", + " embedding[:, ddata:ddata + data.shape[1]] = data[index:data.shape[0] - ddelay]\n", + "\n", + " return embedding\n", + "\n", + "class DMD:\n", + " \"\"\"DMD class for computing and predicting with DMD models.\n", + " \"\"\"\n", + " def __init__(\n", + " self,\n", + " data,\n", + " n_delays,\n", + " delay_interval=1,\n", + " rank=None,\n", + " rank_thresh=None,\n", + " rank_explained_variance=None,\n", + " reduced_rank_reg=False,\n", + " lamb=0,\n", + " device='cpu',\n", + " verbose=False,\n", + " send_to_cpu=False,\n", + " steps_ahead=1\n", + " ):\n", + " \"\"\"\n", + " Parameters\n", + " ----------\n", + " data : np.ndarray or torch.tensor\n", + " The data to fit the DMD model to. Must be either: (1) a\n", + " 2-dimensional array/tensor of shape T x N where T is the number\n", + " of time points and N is the number of observed dimensions\n", + " at each time point, or (2) a 3-dimensional array/tensor of shape\n", + " K x T x N where K is the number of \"trials\" and T and N are\n", + " as defined above.\n", + "\n", + " n_delays : int\n", + " Parameter that controls the size of the delay embedding. Explicitly,\n", + " the number of delays to include.\n", + "\n", + " delay_interval : int\n", + " The number of time steps between each delay in the delay embedding. Defaults\n", + " to 1 time step.\n", + "\n", + " rank : int\n", + " The rank of V in fitting HAVOK DMD - i.e., the number of columns of V to\n", + " use to fit the DMD model. Defaults to None, in which case all columns of V\n", + " will be used.\n", + "\n", + " rank_thresh : float\n", + " Parameter that controls the rank of V in fitting HAVOK DMD by dictating a threshold\n", + " of singular values to use. Explicitly, the rank of V will be the number of singular\n", + " values greater than rank_thresh. Defaults to None.\n", + "\n", + " rank_explained_variance : float\n", + " Parameter that controls the rank of V in fitting HAVOK DMD by indicating the percentage of\n", + " cumulative explained variance that should be explained by the columns of V. Defaults to None.\n", + "\n", + " reduced_rank_reg : bool\n", + " Determines whether to use reduced rank regression (True) or principal component regression (False)\n", + "\n", + " lamb : float\n", + " Regularization parameter for ridge regression. Defaults to 0.\n", + "\n", + " device: string, int, or torch.device\n", + " A string, int or torch.device object to indicate the device to torch.\n", + "\n", + " verbose: bool\n", + " If True, print statements will be provided about the progress of the fitting procedure.\n", + "\n", + " send_to_cpu: bool\n", + " If True, will send all tensors in the object back to the cpu after everything is computed.\n", + " This is implemented to prevent gpu memory overload when computing multiple DMDs.\n", + "\n", + " steps_ahead: int\n", + " The number of time steps ahead to predict. Defaults to 1.\n", + " \"\"\"\n", + "\n", + " self.device = device\n", + " self._init_data(data)\n", + "\n", + " self.n_delays = n_delays\n", + " self.delay_interval = delay_interval\n", + " self.rank = rank\n", + " self.rank_thresh = rank_thresh\n", + " self.rank_explained_variance = rank_explained_variance\n", + " self.reduced_rank_reg = reduced_rank_reg\n", + " self.lamb = lamb\n", + " self.verbose = verbose\n", + " self.send_to_cpu = send_to_cpu\n", + " self.steps_ahead = steps_ahead\n", + "\n", + " # Hankel matrix\n", + " self.H = None\n", + "\n", + " # SVD attributes\n", + " self.U = None\n", + " self.S = None\n", + " self.V = None\n", + " self.S_mat = None\n", + " self.S_mat_inv = None\n", + "\n", + " # DMD attributes\n", + " self.A_v = None\n", + " self.A_havok_dmd = None\n", + "\n", + " def _init_data(self, data):\n", + " # check if the data is an np.ndarry - if so, convert it to Torch\n", + " if isinstance(data, np.ndarray):\n", + " data = torch.from_numpy(data)\n", + " self.data = data\n", + " # create attributes for the data dimensions\n", + " if self.data.ndim == 3:\n", + " self.ntrials = self.data.shape[0]\n", + " self.window = self.data.shape[1]\n", + " self.n = self.data.shape[2]\n", + " else:\n", + " self.window = self.data.shape[0]\n", + " self.n = self.data.shape[1]\n", + " self.ntrials = 1\n", + "\n", + " def compute_hankel(\n", + " self,\n", + " data=None,\n", + " n_delays=None,\n", + " delay_interval=None,\n", + " ):\n", + " \"\"\"\n", + " Computes the Hankel matrix from the provided data.\n", + "\n", + " Parameters\n", + " ----------\n", + " data : np.ndarray or torch.tensor\n", + " The data to fit the DMD model to. Must be either: (1) a\n", + " 2-dimensional array/tensor of shape T x N where T is the number\n", + " of time points and N is the number of observed dimensions\n", + " at each time point, or (2) a 3-dimensional array/tensor of shape\n", + " K x T x N where K is the number of \"trials\" and T and N are\n", + " as defined above.\n", + "\n", + " n_delays : int\n", + " Parameter that controls the size of the delay embedding. Explicitly,\n", + " the number of delays to include. Defaults to None - provide only if you want\n", + " to override the value of n_delays from the init.\n", + "\n", + " delay_interval : int\n", + " The number of time steps between each delay in the delay embedding. Defaults\n", + " to 1 time step. Defaults to None - provide only if you want\n", + " to override the value of n_delays from the init.\n", + " \"\"\"\n", + " if self.verbose:\n", + " print(\"Computing Hankel matrix ...\")\n", + "\n", + " # if parameters are provided, overwrite them from the init\n", + " self.data = self.data if data is None else self._init_data(data)\n", + " self.n_delays = self.n_delays if n_delays is None else n_delays\n", + " self.delay_interval = self.delay_interval if delay_interval is None else delay_interval\n", + " self.data = self.data.to(self.device)\n", + "\n", + " self.H = embed_signal_torch(self.data, self.n_delays, self.delay_interval)\n", + "\n", + " if self.verbose:\n", + " print(\"Hankel matrix computed!\")\n", + "\n", + " def compute_svd(self):\n", + " \"\"\"\n", + " Computes the SVD of the Hankel matrix.\n", + " \"\"\"\n", + "\n", + " if self.verbose:\n", + " print(\"Computing SVD on Hankel matrix ...\")\n", + " if self.H.ndim == 3: #flatten across trials for 3d\n", + " H = self.H.reshape(self.H.shape[0] * self.H.shape[1], self.H.shape[2])\n", + " else:\n", + " H = self.H\n", + " # compute the SVD\n", + " U, S, Vh = torch.linalg.svd(H.T, full_matrices=False)\n", + "\n", + " # update attributes\n", + " V = Vh.T\n", + " self.U = U\n", + " self.S = S\n", + " self.V = V\n", + "\n", + " # construct the singuar value matrix and its inverse\n", + " # dim = self.n_delays * self.n\n", + " # s = len(S)\n", + " # self.S_mat = torch.zeros(dim, dim,dtype=torch.float32).to(self.device)\n", + " # self.S_mat_inv = torch.zeros(dim, dim,dtype=torch.float32).to(self.device)\n", + " self.S_mat = torch.diag(S).to(self.device)\n", + " self.S_mat_inv= torch.diag(1 / S).to(self.device)\n", + "\n", + " # compute explained variance\n", + " exp_variance_inds = self.S**2 / ((self.S**2).sum())\n", + " cumulative_explained = torch.cumsum(exp_variance_inds, 0)\n", + " self.cumulative_explained_variance = cumulative_explained\n", + "\n", + " #make the X and Y components of the regression by staggering the hankel eigen-time delay coordinates by time\n", + " if self.reduced_rank_reg:\n", + " V = self.V\n", + " else:\n", + " V = self.V\n", + "\n", + " if self.ntrials > 1:\n", + " if V.numel() < self.H.numel():\n", + " raise ValueError(\"The dimension of the SVD of the Hankel matrix is smaller than the dimension of the Hankel matrix itself. \\n \\\n", + " This is likely due to the number of time points being smaller than the number of dimensions. \\n \\\n", + " Please reduce the number of delays.\")\n", + "\n", + " V = V.reshape(self.H.shape)\n", + "\n", + " #first reshape back into Hankel shape, separated by trials\n", + " newshape = (self.H.shape[0]*(self.H.shape[1]-self.steps_ahead),self.H.shape[2])\n", + " self.Vt_minus = V[:,:-self.steps_ahead].reshape(newshape)\n", + " self.Vt_plus = V[:,self.steps_ahead:].reshape(newshape)\n", + " else:\n", + " self.Vt_minus = V[:-self.steps_ahead]\n", + " self.Vt_plus = V[self.steps_ahead:]\n", + "\n", + "\n", + " if self.verbose:\n", + " print(\"SVD complete!\")\n", + "\n", + " def recalc_rank(self,rank,rank_thresh,rank_explained_variance):\n", + " '''\n", + " Parameters\n", + " ----------\n", + " rank : int\n", + " The rank of V in fitting HAVOK DMD - i.e., the number of columns of V to\n", + " use to fit the DMD model. Defaults to None, in which case all columns of V\n", + " will be used. Provide only if you want to override the value from the init.\n", + "\n", + " rank_thresh : float\n", + " Parameter that controls the rank of V in fitting HAVOK DMD by dictating a threshold\n", + " of singular values to use. Explicitly, the rank of V will be the number of singular\n", + " values greater than rank_thresh. Defaults to None - provide only if you want\n", + " to override the value from the init.\n", + "\n", + " rank_explained_variance : float\n", + " Parameter that controls the rank of V in fitting HAVOK DMD by indicating the percentage of\n", + " cumulative explained variance that should be explained by the columns of V. Defaults to None -\n", + " provide only if you want to overried the value from the init.\n", + " '''\n", + " # if an argument was provided, overwrite the stored rank information\n", + " none_vars = (rank is None) + (rank_thresh is None) + (rank_explained_variance is None)\n", + " if none_vars != 3:\n", + " self.rank = None\n", + " self.rank_thresh = None\n", + " self.rank_explained_variance = None\n", + "\n", + " self.rank = self.rank if rank is None else rank\n", + " self.rank_thresh = self.rank_thresh if rank_thresh is None else rank_thresh\n", + " self.rank_explained_variance = self.rank_explained_variance if rank_explained_variance is None else rank_explained_variance\n", + "\n", + " none_vars = (self.rank is None) + (self.rank_thresh is None) + (self.rank_explained_variance is None)\n", + " if none_vars < 2:\n", + " raise ValueError(\"More than one value was provided between rank, rank_thresh, and rank_explained_variance. Please provide only one of these, and ensure the others are None!\")\n", + " elif none_vars == 3:\n", + " self.rank = len(self.S)\n", + "\n", + " if self.reduced_rank_reg:\n", + " S = self.proj_mat_S\n", + " else:\n", + " S = self.S\n", + "\n", + " if rank_thresh is not None:\n", + " if S[-1] > rank_thresh:\n", + " self.rank = len(S)\n", + " else:\n", + " self.rank = torch.argmax(torch.arange(len(S), 0, -1).to(self.device)*(S < rank_thresh))\n", + "\n", + " if rank_explained_variance is not None:\n", + " self.rank = int(torch.argmax((self.cumulative_explained_variance > rank_explained_variance).type(torch.int)).cpu().numpy())\n", + "\n", + " if self.rank > self.H.shape[-1]:\n", + " self.rank = self.H.shape[-1]\n", + "\n", + " if self.rank is None:\n", + " if S[-1] > self.rank_thresh:\n", + " self.rank = len(S)\n", + " else:\n", + " self.rank = torch.argmax(torch.arange(len(S), 0, -1).to(self.device)*(S < self.rank_thresh))\n", + "\n", + " def compute_havok_dmd(self,lamb=None):\n", + " \"\"\"\n", + " Computes the Havok DMD matrix (Principal Component Regression)\n", + "\n", + " Parameters\n", + " ----------\n", + " lamb : float\n", + " Regularization parameter for ridge regression. Defaults to 0 - provide only if you want\n", + " to override the value of n_delays from the init.\n", + "\n", + " \"\"\"\n", + " if self.verbose:\n", + " print(\"Computing least squares fits to HAVOK DMD ...\")\n", + "\n", + " self.lamb = self.lamb if lamb is None else lamb\n", + "\n", + " A_v = (torch.linalg.inv(self.Vt_minus[:, :self.rank].T @ self.Vt_minus[:, :self.rank] + self.lamb*torch.eye(self.rank).to(self.device)) \\\n", + " @ self.Vt_minus[:, :self.rank].T @ self.Vt_plus[:, :self.rank]).T\n", + " self.A_v = A_v\n", + " self.A_havok_dmd = self.U @ self.S_mat[:self.U.shape[1], :self.rank] @ self.A_v @ self.S_mat_inv[:self.rank, :self.U.shape[1]] @ self.U.T\n", + "\n", + " if self.verbose:\n", + " print(\"Least squares complete! \\n\")\n", + "\n", + " def compute_proj_mat(self,lamb=None):\n", + " if self.verbose:\n", + " print(\"Computing Projector Matrix for Reduced Rank Regression\")\n", + "\n", + " self.lamb = self.lamb if lamb is None else lamb\n", + "\n", + " self.proj_mat = self.Vt_plus.T @ self.Vt_minus @ torch.linalg.inv(self.Vt_minus.T @ self.Vt_minus +\n", + " self.lamb*torch.eye(self.Vt_minus.shape[1]).to(self.device)) @ \\\n", + " self.Vt_minus.T @ self.Vt_plus\n", + "\n", + " self.proj_mat_S, self.proj_mat_V = torch.linalg.eigh(self.proj_mat)\n", + " #todo: more efficient to flip ranks (negative index) in compute_reduced_rank_regression but also less interpretable\n", + " self.proj_mat_S = torch.flip(self.proj_mat_S, dims=(0,))\n", + " self.proj_mat_V = torch.flip(self.proj_mat_V, dims=(1,))\n", + "\n", + " if self.verbose:\n", + " print(\"Projector Matrix computed! \\n\")\n", + "\n", + " def compute_reduced_rank_regression(self,lamb=None):\n", + " if self.verbose:\n", + " print(\"Computing Reduced Rank Regression ...\")\n", + "\n", + " self.lamb = self.lamb if lamb is None else lamb\n", + " proj_mat = self.proj_mat_V[:,:self.rank] @ self.proj_mat_V[:,:self.rank].T\n", + " B_ols = torch.linalg.inv(self.Vt_minus.T @ self.Vt_minus + self.lamb*torch.eye(self.Vt_minus.shape[1]).to(self.device)) @ self.Vt_minus.T @ self.Vt_plus\n", + "\n", + " self.A_v = B_ols @ proj_mat\n", + " self.A_havok_dmd = self.U @ self.S_mat[:self.U.shape[1],:self.A_v.shape[1]] @ self.A_v.T @ self.S_mat_inv[:self.A_v.shape[0], :self.U.shape[1]] @ self.U.T\n", + "\n", + "\n", + " if self.verbose:\n", + " print(\"Reduced Rank Regression complete! \\n\")\n", + "\n", + " def fit(\n", + " self,\n", + " data=None,\n", + " n_delays=None,\n", + " delay_interval=None,\n", + " rank=None,\n", + " rank_thresh=None,\n", + " rank_explained_variance=None,\n", + " lamb=None,\n", + " device=None,\n", + " verbose=None,\n", + " steps_ahead=None\n", + " ):\n", + " \"\"\"\n", + " Parameters\n", + " ----------\n", + " data : np.ndarray or torch.tensor\n", + " The data to fit the DMD model to. Must be either: (1) a\n", + " 2-dimensional array/tensor of shape T x N where T is the number\n", + " of time points and N is the number of observed dimensions\n", + " at each time point, or (2) a 3-dimensional array/tensor of shape\n", + " K x T x N where K is the number of \"trials\" and T and N are\n", + " as defined above. Defaults to None - provide only if you want to\n", + " override the value from the init.\n", + "\n", + " n_delays : int\n", + " Parameter that controls the size of the delay embedding. Explicitly,\n", + " the number of delays to include. Defaults to None - provide only if you want to\n", + " override the value from the init.\n", + "\n", + " delay_interval : int\n", + " The number of time steps between each delay in the delay embedding. Defaults to None -\n", + " provide only if you want to override the value from the init.\n", + "\n", + " rank : int\n", + " The rank of V in fitting HAVOK DMD - i.e., the number of columns of V to\n", + " use to fit the DMD model. Defaults to None, in which case all columns of V\n", + " will be used - provide only if you want to\n", + " override the value from the init.\n", + "\n", + " rank_thresh : int\n", + " Parameter that controls the rank of V in fitting HAVOK DMD by dictating a threshold\n", + " of singular values to use. Explicitly, the rank of V will be the number of singular\n", + " values greater than rank_thresh. Defaults to None - provide only if you want to\n", + " override the value from the init.\n", + "\n", + " rank_explained_variance : float\n", + " Parameter that controls the rank of V in fitting HAVOK DMD by indicating the percentage of\n", + " cumulative explained variance that should be explained by the columns of V. Defaults to None -\n", + " provide only if you want to overried the value from the init.\n", + "\n", + " lamb : float\n", + " Regularization parameter for ridge regression. Defaults to None - provide only if you want to\n", + " override the value from the init.\n", + "\n", + " device: string or int\n", + " A string or int to indicate the device to torch. For example, can be 'cpu' or 'cuda',\n", + " or alternatively 0 if the intenion is to use GPU device 0. Defaults to None - provide only\n", + " if you want to override the value from the init.\n", + "\n", + " verbose: bool\n", + " If True, print statements will be provided about the progress of the fitting procedure.\n", + " Defaults to None - provide only if you want to override the value from the init.\n", + "\n", + " steps_ahead: int\n", + " The number of time steps ahead to predict. Defaults to 1.\n", + "\n", + " \"\"\"\n", + " # if parameters are provided, overwrite them from the init\n", + " self.steps_ahead = self.steps_ahead if steps_ahead is None else steps_ahead\n", + " self.device = self.device if device is None else device\n", + " self.verbose = self.verbose if verbose is None else verbose\n", + "\n", + " self.compute_hankel(data, n_delays, delay_interval)\n", + " self.compute_svd()\n", + "\n", + " if self.reduced_rank_reg:\n", + " self.compute_proj_mat(lamb)\n", + " self.recalc_rank(rank,rank_thresh,rank_explained_variance)\n", + " self.compute_reduced_rank_regression(lamb)\n", + " else:\n", + " self.recalc_rank(rank,rank_thresh,rank_explained_variance)\n", + " self.compute_havok_dmd(lamb)\n", + "\n", + " if self.send_to_cpu:\n", + " self.all_to_device('cpu') #send back to the cpu to save memory\n", + "\n", + " def predict(\n", + " self,\n", + " test_data=None,\n", + " reseed=None,\n", + " full_return=False\n", + " ):\n", + " \"\"\"\n", + " Returns\n", + " -------\n", + " pred_data : torch.tensor\n", + " The predictions generated by the HAVOK model. Of the same shape as test_data. Note that the first\n", + " (self.n_delays - 1)*self.delay_interval + 1 time steps of the generated predictions are by construction\n", + " identical to the test_data.\n", + "\n", + " H_test_havok_dmd : torch.tensor (Optional)\n", + " Returned if full_return=True. The predicted Hankel matrix generated by the HAVOK model.\n", + " H_test : torch.tensor (Optional)\n", + " Returned if full_return=True. The true Hankel matrix\n", + " \"\"\"\n", + " # initialize test_data\n", + " if test_data is None:\n", + " test_data = self.data\n", + " if isinstance(test_data, np.ndarray):\n", + " test_data = torch.from_numpy(test_data).to(self.device)\n", + " ndim = test_data.ndim\n", + " if ndim == 2:\n", + " test_data = test_data.unsqueeze(0)\n", + " H_test = embed_signal_torch(test_data, self.n_delays, self.delay_interval)\n", + " steps_ahead = self.steps_ahead if self.steps_ahead is not None else 1\n", + "\n", + " if reseed is None:\n", + " reseed = 1\n", + "\n", + " H_test_havok_dmd = torch.zeros(H_test.shape).to(self.device)\n", + " H_test_havok_dmd[:, :steps_ahead] = H_test[:, :steps_ahead]\n", + "\n", + " A = self.A_havok_dmd.unsqueeze(0)\n", + " for t in range(steps_ahead, H_test.shape[1]):\n", + " if t % reseed == 0:\n", + " H_test_havok_dmd[:, t] = (A @ H_test[:, t - steps_ahead].transpose(-2, -1)).transpose(-2, -1)\n", + " else:\n", + " H_test_havok_dmd[:, t] = (A @ H_test_havok_dmd[:, t - steps_ahead].transpose(-2, -1)).transpose(-2, -1)\n", + " pred_data = torch.hstack([test_data[:, :(self.n_delays - 1)*self.delay_interval + steps_ahead], H_test_havok_dmd[:, steps_ahead:, :self.n]])\n", + "\n", + " if ndim == 2:\n", + " pred_data = pred_data[0]\n", + "\n", + " if full_return:\n", + " return pred_data, H_test_havok_dmd, H_test\n", + " else:\n", + " return pred_data\n", + "\n", + " def all_to_device(self,device='cpu'):\n", + " for k,v in self.__dict__.items():\n", + " if isinstance(v, torch.Tensor):\n", + " self.__dict__[k] = v.to(device)\n", + "\n", + "from typing import Literal\n", + "import torch.nn as nn\n", + "import torch.optim as optim\n", + "from typing import Literal\n", + "import torch.nn.utils.parametrize as parametrize\n", + "from scipy.stats import wasserstein_distance\n", + "\n", + "def pad_zeros(A,B,device):\n", + "\n", + " with torch.no_grad():\n", + " dim = max(A.shape[0],B.shape[0])\n", + " A1 = torch.zeros((dim,dim)).float()\n", + " A1[:A.shape[0],:A.shape[1]] += A\n", + " A = A1.float().to(device)\n", + "\n", + " B1 = torch.zeros((dim,dim)).float()\n", + " B1[:B.shape[0],:B.shape[1]] += B\n", + " B = B1.float().to(device)\n", + "\n", + " return A,B\n", + "\n", + "class LearnableSimilarityTransform(nn.Module):\n", + " \"\"\"\n", + " Computes the similarity transform for a learnable orthonormal matrix C\n", + " \"\"\"\n", + " def __init__(self, n,orthog=True):\n", + " \"\"\"\n", + " Parameters\n", + " __________\n", + " n : int\n", + " dimension of the C matrix\n", + " \"\"\"\n", + " super(LearnableSimilarityTransform, self).__init__()\n", + " #initialize orthogonal matrix as identity\n", + " self.C = nn.Parameter(torch.eye(n).float())\n", + " self.orthog = orthog\n", + "\n", + " def forward(self, B):\n", + " if self.orthog:\n", + " return self.C @ B @ self.C.transpose(-1, -2)\n", + " else:\n", + " return self.C @ B @ torch.linalg.inv(self.C)\n", + "\n", + "class Skew(nn.Module):\n", + " def __init__(self,n,device):\n", + " \"\"\"\n", + " Computes a skew-symmetric matrix X from some parameters (also called X)\n", + "\n", + " \"\"\"\n", + " super().__init__()\n", + "\n", + " self.L1 = nn.Linear(n,n,bias = False, device = device)\n", + " self.L2 = nn.Linear(n,n,bias = False, device = device)\n", + " self.L3 = nn.Linear(n,n,bias = False, device = device)\n", + "\n", + " def forward(self, X):\n", + " X = torch.tanh(self.L1(X))\n", + " X = torch.tanh(self.L2(X))\n", + " X = self.L3(X)\n", + " return X - X.transpose(-1, -2)\n", + "\n", + "class Matrix(nn.Module):\n", + " def __init__(self,n,device):\n", + " \"\"\"\n", + " Computes a matrix X from some parameters (also called X)\n", + "\n", + " \"\"\"\n", + " super().__init__()\n", + "\n", + " self.L1 = nn.Linear(n,n,bias = False, device = device)\n", + " self.L2 = nn.Linear(n,n,bias = False, device = device)\n", + " self.L3 = nn.Linear(n,n,bias = False, device = device)\n", + "\n", + " def forward(self, X):\n", + " X = torch.tanh(self.L1(X))\n", + " X = torch.tanh(self.L2(X))\n", + " X = self.L3(X)\n", + " return X\n", + "\n", + "class CayleyMap(nn.Module):\n", + " \"\"\"\n", + " Maps a skew-symmetric matrix to an orthogonal matrix in O(n)\n", + " \"\"\"\n", + " def __init__(self, n, device):\n", + " \"\"\"\n", + " Parameters\n", + " __________\n", + "\n", + " n : int\n", + " dimension of the matrix we want to map\n", + "\n", + " device : {'cpu','cuda'} or int\n", + " hardware device on which to send the matrix\n", + " \"\"\"\n", + " super().__init__()\n", + " self.register_buffer(\"Id\", torch.eye(n,device = device))\n", + "\n", + " def forward(self, X):\n", + " # (I + X)(I - X)^{-1}\n", + " return torch.linalg.solve(self.Id + X, self.Id - X)\n", + "\n", + "class SimilarityTransformDist:\n", + " \"\"\"\n", + " Computes the Procrustes Analysis over Vector Fields\n", + " \"\"\"\n", + " def __init__(self,\n", + " iters = 200,\n", + " score_method: Literal[\"angular\", \"euclidean\",\"wasserstein\"] = \"angular\",\n", + " lr = 0.01,\n", + " device: Literal[\"cpu\",\"cuda\"] = 'cpu',\n", + " verbose = False,\n", + " group: Literal[\"O(n)\",\"SO(n)\",\"GL(n)\"] = \"O(n)\",\n", + " wasserstein_compare = None\n", + " ):\n", + " \"\"\"\n", + " Parameters\n", + " _________\n", + " iters : int\n", + " number of iterations to perform gradient descent\n", + "\n", + " score_method : {\"angular\",\"euclidean\",\"wasserstein\"}\n", + " specifies the type of metric to use\n", + " \"wasserstein\" will compare the singular values or eigenvalues\n", + " of the two matrices as in Redman et al., (2023)\n", + "\n", + " lr : float\n", + " learning rate\n", + "\n", + " device : {'cpu','cuda'} or int\n", + "\n", + " verbose : bool\n", + " prints when finished optimizing\n", + "\n", + " group : {'SO(n)','O(n)', 'GL(n)'}\n", + " specifies the group of matrices to optimize over\n", + "\n", + " wasserstein_compare : {'sv','eig',None}\n", + " specifies whether to compare the singular values or eigenvalues\n", + " if score_method is \"wasserstein\", or the shapes are different\n", + " \"\"\"\n", + "\n", + " self.iters = iters\n", + " self.score_method = score_method\n", + " self.lr = lr\n", + " self.verbose = verbose\n", + " self.device = device\n", + " self.C_star = None\n", + " self.A = None\n", + " self.B = None\n", + " self.group = group\n", + " self.wasserstein_compare = wasserstein_compare\n", + "\n", + " def fit(self,\n", + " A,\n", + " B,\n", + " iters = None,\n", + " lr = None,\n", + " group = None,\n", + " ):\n", + " \"\"\"\n", + " Computes the optimal matrix C over specified group\n", + "\n", + " Parameters\n", + " __________\n", + " A : np.array or torch.tensor\n", + " first data matrix\n", + " B : np.array or torch.tensor\n", + " second data matrix\n", + " iters : int or None\n", + " number of optimization steps, if None then resorts to saved self.iters\n", + " lr : float or None\n", + " learning rate, if None then resorts to saved self.lr\n", + " group : {'SO(n)','O(n)', 'GL(n)'}\n", + " specifies the group of matrices to optimize over\n", + "\n", + " Returns\n", + " _______\n", + " None\n", + " \"\"\"\n", + " assert A.shape[0] == A.shape[1]\n", + " assert B.shape[0] == B.shape[1]\n", + "\n", + " A = A.to(self.device)\n", + " B = B.to(self.device)\n", + " self.A,self.B = A,B\n", + " lr = self.lr if lr is None else lr\n", + " iters = self.iters if iters is None else iters\n", + " group = self.group if group is None else group\n", + "\n", + " if group in {\"SO(n)\", \"O(n)\"}:\n", + " self.losses, self.C_star, self.sim_net = self.optimize_C(A,\n", + " B,\n", + " lr,iters,\n", + " orthog=True,\n", + " verbose=self.verbose)\n", + " if group == \"O(n)\":\n", + " #permute the first row and column of B then rerun the optimization\n", + " P = torch.eye(B.shape[0],device=self.device)\n", + " if P.shape[0] > 1:\n", + " P[[0, 1], :] = P[[1, 0], :]\n", + " losses, C_star, sim_net = self.optimize_C(A,\n", + " P @ B @ P.T,\n", + " lr,iters,\n", + " orthog=True,\n", + " verbose=self.verbose)\n", + " if losses[-1] < self.losses[-1]:\n", + " self.losses = losses\n", + " self.C_star = C_star @ P\n", + " self.sim_net = sim_net\n", + " if group == \"GL(n)\":\n", + " self.losses, self.C_star, self.sim_net = self.optimize_C(A,\n", + " B,\n", + " lr,iters,\n", + " orthog=False,\n", + " verbose=self.verbose)\n", + "\n", + " def optimize_C(self,A,B,lr,iters,orthog,verbose):\n", + " #parameterize mapping to be orthogonal\n", + " n = A.shape[0]\n", + " sim_net = LearnableSimilarityTransform(n,orthog=orthog).to(self.device)\n", + " if orthog:\n", + " parametrize.register_parametrization(sim_net, \"C\", Skew(n,self.device))\n", + " parametrize.register_parametrization(sim_net, \"C\", CayleyMap(n,self.device))\n", + " else:\n", + " parametrize.register_parametrization(sim_net, \"C\", Matrix(n,self.device))\n", + "\n", + " simdist_loss = nn.MSELoss(reduction = 'sum')\n", + "\n", + " optimizer = optim.Adam(sim_net.parameters(), lr=lr)\n", + " # scheduler = optim.lr_scheduler.ExponentialLR(optimizer, gamma=0.999)\n", + "\n", + " losses = []\n", + " A /= torch.linalg.norm(A)\n", + " B /= torch.linalg.norm(B)\n", + " for _ in range(iters):\n", + " # Zero the gradients of the optimizer.\n", + " optimizer.zero_grad()\n", + " # Compute the Frobenius norm between A and the product.\n", + " loss = simdist_loss(A, sim_net(B))\n", + "\n", + " loss.backward()\n", + "\n", + " optimizer.step()\n", + " # if _ % 99:\n", + " # scheduler.step()\n", + " losses.append(loss.item())\n", + "\n", + " if verbose:\n", + " print(\"Finished optimizing C\")\n", + "\n", + " C_star = sim_net.C.detach()\n", + " return losses, C_star,sim_net\n", + "\n", + " def score(self,A=None,B=None,score_method=None,group=None):\n", + " \"\"\"\n", + " Given an optimal C already computed, calculate the metric\n", + "\n", + " Parameters\n", + " __________\n", + " A : np.array or torch.tensor or None\n", + " first data matrix, if None defaults to the saved matrix in fit\n", + " B : np.array or torch.tensor or None\n", + " second data matrix if None, defaults to the savec matrix in fit\n", + " score_method : None or {'angular','euclidean'}\n", + " overwrites the score method in the object for this application\n", + " Returns\n", + " _______\n", + "\n", + " score : float\n", + " similarity of the data under the similarity transform w.r.t C\n", + " \"\"\"\n", + " assert self.C_star is not None\n", + " A = self.A if A is None else A\n", + " B = self.B if B is None else B\n", + " assert A is not None\n", + " assert B is not None\n", + " assert A.shape == self.C_star.shape\n", + " assert B.shape == self.C_star.shape\n", + " score_method = self.score_method if score_method is None else score_method\n", + " group = self.group if group is None else group\n", + " with torch.no_grad():\n", + " if not isinstance(A,torch.Tensor):\n", + " A = torch.from_numpy(A).float().to(self.device)\n", + " if not isinstance(B,torch.Tensor):\n", + " B = torch.from_numpy(B).float().to(self.device)\n", + " C = self.C_star.to(self.device)\n", + "\n", + " if group in {\"SO(n)\", \"O(n)\"}:\n", + " Cinv = C.T\n", + " elif group in {\"GL(n)\"}:\n", + " Cinv = torch.linalg.inv(C)\n", + " else:\n", + " raise AssertionError(\"Need proper group name\")\n", + " if score_method == 'angular':\n", + " num = torch.trace(A.T @ C @ B @ Cinv)\n", + " den = torch.norm(A,p = 'fro')*torch.norm(B,p = 'fro')\n", + " score = torch.arccos(num/den).cpu().numpy()\n", + " if np.isnan(score): #around -1 and 1, we sometimes get NaNs due to arccos\n", + " if num/den < 0:\n", + " score = np.pi\n", + " else:\n", + " score = 0\n", + " else:\n", + " score = torch.norm(A - C @ B @ Cinv,p='fro').cpu().numpy().item() #/ A.numpy().size\n", + "\n", + " return score\n", + "\n", + " def fit_score(self,\n", + " A,\n", + " B,\n", + " iters = None,\n", + " lr = None,\n", + " score_method = None,\n", + " zero_pad = True,\n", + " group = None):\n", + " \"\"\"\n", + " for efficiency, computes the optimal matrix and returns the score\n", + "\n", + " Parameters\n", + " __________\n", + " A : np.array or torch.tensor\n", + " first data matrix\n", + " B : np.array or torch.tensor\n", + " second data matrix\n", + " iters : int or None\n", + " number of optimization steps, if None then resorts to saved self.iters\n", + " lr : float or None\n", + " learning rate, if None then resorts to saved self.lr\n", + " score_method : {'angular','euclidean'} or None\n", + " overwrites parameter in the class\n", + " zero_pad : bool\n", + " if True, then the smaller matrix will be zero padded so its the same size\n", + " Returns\n", + " _______\n", + "\n", + " score : float\n", + " similarity of the data under the similarity transform w.r.t C\n", + "\n", + " \"\"\"\n", + " score_method = self.score_method if score_method is None else score_method\n", + " group = self.group if group is None else group\n", + "\n", + " if isinstance(A,np.ndarray):\n", + " A = torch.from_numpy(A).float()\n", + " if isinstance(B,np.ndarray):\n", + " B = torch.from_numpy(B).float()\n", + "\n", + " assert A.shape[0] == B.shape[1] or self.wasserstein_compare is not None\n", + " if A.shape[0] != B.shape[0]:\n", + " if self.wasserstein_compare is None:\n", + " raise AssertionError(\"Matrices must be the same size unless using wasserstein distance\")\n", + " else: #otherwise resort to L2 Wasserstein over singular or eigenvalues\n", + " print(f\"resorting to wasserstein distance over {self.wasserstein_compare}\")\n", + "\n", + " if self.score_method == \"wasserstein\":\n", + " assert self.wasserstein_compare in {\"sv\",\"eig\"}\n", + " if self.wasserstein_compare == \"sv\":\n", + " a = torch.svd(A).S.view(-1,1)\n", + " b = torch.svd(B).S.view(-1,1)\n", + " elif self.wasserstein_compare == \"eig\":\n", + " a = torch.linalg.eig(A).eigenvalues\n", + " a = torch.vstack([a.real,a.imag]).T\n", + "\n", + " b = torch.linalg.eig(B).eigenvalues\n", + " b = torch.vstack([b.real,b.imag]).T\n", + " else:\n", + " raise AssertionError(\"wasserstein_compare must be 'sv' or 'eig'\")\n", + " device = a.device\n", + " a = a#.cpu()\n", + " b = b#.cpu()\n", + " M = ot.dist(a,b)#.numpy()\n", + " a,b = torch.ones(a.shape[0])/a.shape[0],torch.ones(b.shape[0])/b.shape[0]\n", + " a,b = a.to(device),b.to(device)\n", + "\n", + " score_star = ot.emd2(a,b,M)\n", + " #wasserstein_distance(A.cpu().numpy(),B.cpu().numpy())\n", + "\n", + " else:\n", + "\n", + " self.fit(A, B,iters,lr,group)\n", + " score_star = self.score(self.A,self.B,score_method=score_method,group=group)\n", + "\n", + " return score_star\n", + "\n", + "class DSA:\n", + " \"\"\"\n", + " Computes the Dynamical Similarity Analysis (DSA) for two data matrices\n", + " \"\"\"\n", + " def __init__(self,\n", + " X,\n", + " Y=None,\n", + " n_delays=1,\n", + " delay_interval=1,\n", + " rank=None,\n", + " rank_thresh=None,\n", + " rank_explained_variance = None,\n", + " lamb = 0.0,\n", + " send_to_cpu = True,\n", + " iters = 1500,\n", + " score_method: Literal[\"angular\", \"euclidean\",\"wasserstein\"] = \"angular\",\n", + " lr = 5e-3,\n", + " group: Literal[\"GL(n)\", \"O(n)\", \"SO(n)\"] = \"O(n)\",\n", + " zero_pad = False,\n", + " device = 'cpu',\n", + " verbose = False,\n", + " reduced_rank_reg = False,\n", + " kernel=None,\n", + " num_centers=0.1,\n", + " svd_solver='arnoldi',\n", + " wasserstein_compare: Literal['sv','eig',None] = None\n", + " ):\n", + " \"\"\"\n", + " Parameters\n", + " __________\n", + "\n", + " X : np.array or torch.tensor or list of np.arrays or torch.tensors\n", + " first data matrix/matrices\n", + "\n", + " Y : None or np.array or torch.tensor or list of np.arrays or torch.tensors\n", + " second data matrix/matrices.\n", + " * If Y is None, X is compared to itself pairwise\n", + " (must be a list)\n", + " * If Y is a single matrix, all matrices in X are compared to Y\n", + " * If Y is a list, all matrices in X are compared to all matrices in Y\n", + "\n", + " DMD parameters:\n", + "\n", + " n_delays : int or list or tuple/list: (int,int), (list,list),(list,int),(int,list)\n", + " number of delays to use in constructing the Hankel matrix\n", + "\n", + " delay_interval : int or list or tuple/list: (int,int), (list,list),(list,int),(int,list)\n", + " interval between samples taken in constructing Hankel matrix\n", + "\n", + " rank : int or list or tuple/list: (int,int), (list,list),(list,int),(int,list)\n", + " rank of DMD matrix fit in reduced-rank regression\n", + "\n", + " rank_thresh : float or list or tuple/list: (float,float), (list,list),(list,float),(float,list)\n", + " Parameter that controls the rank of V in fitting HAVOK DMD by dictating a threshold\n", + " of singular values to use. Explicitly, the rank of V will be the number of singular\n", + " values greater than rank_thresh. Defaults to None.\n", + "\n", + " rank_explained_variance : float or list or tuple: (float,float), (list,list),(list,float),(float,list)\n", + " Parameter that controls the rank of V in fitting HAVOK DMD by indicating the percentage of\n", + " cumulative explained variance that should be explained by the columns of V. Defaults to None.\n", + "\n", + " lamb : float\n", + " L-1 regularization parameter in DMD fit\n", + "\n", + " send_to_cpu: bool\n", + " If True, will send all tensors in the object back to the cpu after everything is computed.\n", + " This is implemented to prevent gpu memory overload when computing multiple DMDs.\n", + "\n", + " NOTE: for all of these above, they can be single values or lists or tuples,\n", + " depending on the corresponding dimensions of the data\n", + " If at least one of X and Y are lists, then if they are a single value\n", + " it will default to the rank of all DMD matrices.\n", + " If they are (int,int), then they will correspond to an individual dmd matrix\n", + " OR to X and Y respectively across all matrices\n", + " If it is (list,list), then each element will correspond to an individual\n", + " dmd matrix indexed at the same position\n", + "\n", + " SimDist parameters:\n", + "\n", + " iters : int\n", + " number of optimization iterations in Procrustes over vector fields\n", + "\n", + " score_method : {'angular','euclidean'}\n", + " type of metric to compute, angular vs euclidean distance\n", + "\n", + " lr : float\n", + " learning rate of the Procrustes over vector fields optimization\n", + "\n", + " group : {'SO(n)','O(n)', 'GL(n)'}\n", + " specifies the group of matrices to optimize over\n", + "\n", + " zero_pad : bool\n", + " whether or not to zero-pad if the dimensions are different\n", + "\n", + " device : 'cpu' or 'cuda' or int\n", + " hardware to use in both DMD and PoVF\n", + "\n", + " verbose : bool\n", + " whether or not print when sections of the analysis is completed\n", + "\n", + " wasserstein_compare : {'sv','eig',None}\n", + " specifies whether to compare the singular values or eigenvalues\n", + " if score_method is \"wasserstein\", or the shapes are different\n", + " \"\"\"\n", + " self.X = X\n", + " self.Y = Y\n", + " if self.X is None and isinstance(self.Y,list):\n", + " self.X, self.Y = self.Y, self.X #swap so code is easy\n", + "\n", + " self.check_method()\n", + " if self.method == 'self-pairwise':\n", + " self.data = [self.X]\n", + " else:\n", + " self.data = [self.X, self.Y]\n", + "\n", + " self.n_delays = self.broadcast_params(n_delays,cast=int)\n", + " self.delay_interval = self.broadcast_params(delay_interval,cast=int)\n", + " self.rank = self.broadcast_params(rank,cast=int)\n", + " self.rank_thresh = self.broadcast_params(rank_thresh)\n", + " self.rank_explained_variance = self.broadcast_params(rank_explained_variance)\n", + " self.lamb = self.broadcast_params(lamb)\n", + " self.send_to_cpu = send_to_cpu\n", + " self.iters = iters\n", + " self.score_method = score_method\n", + " self.lr = lr\n", + " self.device = device\n", + " self.verbose = verbose\n", + " self.zero_pad = zero_pad\n", + " self.group = group\n", + " self.reduced_rank_reg = reduced_rank_reg\n", + " self.kernel = kernel\n", + " self.wasserstein_compare = wasserstein_compare\n", + "\n", + " if kernel is None:\n", + " #get a list of all DMDs here\n", + " self.dmds = [[DMD(Xi,\n", + " self.n_delays[i][j],\n", + " delay_interval=self.delay_interval[i][j],\n", + " rank=self.rank[i][j],\n", + " rank_thresh=self.rank_thresh[i][j],\n", + " rank_explained_variance=self.rank_explained_variance[i][j],\n", + " reduced_rank_reg=self.reduced_rank_reg,\n", + " lamb=self.lamb[i][j],\n", + " device=self.device,\n", + " verbose=self.verbose,\n", + " send_to_cpu=self.send_to_cpu) for j,Xi in enumerate(dat)] for i,dat in enumerate(self.data)]\n", + " else:\n", + " #get a list of all DMDs here\n", + " self.dmds = [[KernelDMD(Xi,\n", + " self.n_delays[i][j],\n", + " kernel=self.kernel,\n", + " num_centers=num_centers,\n", + " delay_interval=self.delay_interval[i][j],\n", + " rank=self.rank[i][j],\n", + " reduced_rank_reg=self.reduced_rank_reg,\n", + " lamb=self.lamb[i][j],\n", + " verbose=self.verbose,\n", + " svd_solver=svd_solver,\n", + " ) for j,Xi in enumerate(dat)] for i,dat in enumerate(self.data)]\n", + "\n", + " self.simdist = SimilarityTransformDist(iters,score_method,lr,device,verbose,group,wasserstein_compare)\n", + "\n", + " def check_method(self):\n", + " '''\n", + " helper function to identify what type of dsa we're running\n", + " '''\n", + " tensor_or_np = lambda x: isinstance(x,(np.ndarray,torch.Tensor))\n", + "\n", + " if isinstance(self.X,list):\n", + " if self.Y is None:\n", + " self.method = 'self-pairwise'\n", + " elif isinstance(self.Y,list):\n", + " self.method = 'bipartite-pairwise'\n", + " elif tensor_or_np(self.Y):\n", + " self.method = 'list-to-one'\n", + " self.Y = [self.Y] #wrap in a list for iteration\n", + " else:\n", + " raise ValueError('unknown type of Y')\n", + " elif tensor_or_np(self.X):\n", + " self.X = [self.X]\n", + " if self.Y is None:\n", + " raise ValueError('only one element provided')\n", + " elif isinstance(self.Y,list):\n", + " self.method = 'one-to-list'\n", + " elif tensor_or_np(self.Y):\n", + " self.method = 'default'\n", + " self.Y = [self.Y]\n", + " else:\n", + " raise ValueError('unknown type of Y')\n", + " else:\n", + " raise ValueError('unknown type of X')\n", + "\n", + " def broadcast_params(self,param,cast=None):\n", + " '''\n", + " aligns the dimensionality of the parameters with the data so it's one-to-one\n", + " '''\n", + " out = []\n", + " if isinstance(param,(int,float,np.integer)) or param is None: #self.X has already been mapped to [self.X]\n", + " out.append([param] * len(self.X))\n", + " if self.Y is not None:\n", + " out.append([param] * len(self.Y))\n", + " elif isinstance(param,(tuple,list,np.ndarray)):\n", + " if self.method == 'self-pairwise' and len(param) >= len(self.X):\n", + " out = [param]\n", + " else:\n", + " assert len(param) <= 2 #only 2 elements max\n", + "\n", + " #if the inner terms are singly valued, we broadcast, otherwise needs to be the same dimensions\n", + " for i,data in enumerate([self.X,self.Y]):\n", + " if data is None:\n", + " continue\n", + " if isinstance(param[i],(int,float)):\n", + " out.append([param[i]] * len(data))\n", + " elif isinstance(param[i],(list,np.ndarray,tuple)):\n", + " assert len(param[i]) >= len(data)\n", + " out.append(param[i][:len(data)])\n", + " else:\n", + " raise ValueError(\"unknown type entered for parameter\")\n", + "\n", + " if cast is not None and param is not None:\n", + " out = [[cast(x) for x in dat] for dat in out]\n", + "\n", + " return out\n", + "\n", + " def fit_dmds(self,\n", + " X=None,\n", + " Y=None,\n", + " n_delays=None,\n", + " delay_interval=None,\n", + " rank=None,\n", + " rank_thresh = None,\n", + " rank_explained_variance=None,\n", + " reduced_rank_reg=None,\n", + " lamb = None,\n", + " device='cpu',\n", + " verbose=False,\n", + " send_to_cpu=True\n", + " ):\n", + " \"\"\"\n", + " Recomputes only the DMDs with a single set of hyperparameters. This will not compare, that will need to be done with the full procedure\n", + " \"\"\"\n", + " X = self.X if X is None else X\n", + " Y = self.Y if Y is None else Y\n", + " n_delays = self.n_delays if n_delays is None else n_delays\n", + " delay_interval = self.delay_interval if delay_interval is None else delay_interval\n", + " rank = self.rank if rank is None else rank\n", + " lamb = self.lamb if lamb is None else lamb\n", + " data = []\n", + " if isinstance(X,list):\n", + " data.append(X)\n", + " else:\n", + " data.append([X])\n", + " if Y is not None:\n", + " if isinstance(Y,list):\n", + " data.append(Y)\n", + " else:\n", + " data.append([Y])\n", + "\n", + " dmds = [[DMD(Xi,n_delays,delay_interval,\n", + " rank,rank_thresh,rank_explained_variance,reduced_rank_reg,\n", + " lamb,device,verbose,send_to_cpu) for Xi in dat] for dat in data]\n", + "\n", + " for dmd_sets in dmds:\n", + " for dmd in dmd_sets:\n", + " dmd.fit()\n", + "\n", + " return dmds\n", + "\n", + " def fit_score(self):\n", + " \"\"\"\n", + " Standard fitting function for both DMDs and PoVF\n", + "\n", + " Parameters\n", + " __________\n", + "\n", + " Returns\n", + " _______\n", + "\n", + " sims : np.array\n", + " data matrix of the similarity scores between the specific sets of data\n", + " \"\"\"\n", + " for dmd_sets in self.dmds:\n", + " for dmd in dmd_sets:\n", + " dmd.fit()\n", + "\n", + " return self.score()\n", + "\n", + " def score(self,iters=None,lr=None,score_method=None):\n", + " \"\"\"\n", + " Rescore DSA with precomputed dmds if you want to try again\n", + "\n", + " Parameters\n", + " __________\n", + " iters : int or None\n", + " number of optimization steps, if None then resorts to saved self.iters\n", + " lr : float or None\n", + " learning rate, if None then resorts to saved self.lr\n", + " score_method : None or {'angular','euclidean'}\n", + " overwrites the score method in the object for this application\n", + "\n", + " Returns\n", + " ________\n", + " score : float\n", + " similarity score of the two precomputed DMDs\n", + " \"\"\"\n", + "\n", + " iters = self.iters if iters is None else iters\n", + " lr = self.lr if lr is None else lr\n", + " score_method = self.score_method if score_method is None else score_method\n", + "\n", + " ind2 = 1 - int(self.method == 'self-pairwise')\n", + " # 0 if self.pairwise (want to compare the set to itself)\n", + "\n", + " self.sims = np.zeros((len(self.dmds[0]),len(self.dmds[ind2])))\n", + " for i,dmd1 in enumerate(self.dmds[0]):\n", + " for j,dmd2 in enumerate(self.dmds[ind2]):\n", + " if self.method == 'self-pairwise':\n", + " if j >= i:\n", + " continue\n", + " if self.verbose:\n", + " print(f'computing similarity between DMDs {i} and {j}')\n", + "\n", + " self.sims[i,j] = self.simdist.fit_score(dmd1.A_v,dmd2.A_v,iters,lr,score_method,zero_pad=self.zero_pad)\n", + "\n", + " if self.method == 'self-pairwise':\n", + " self.sims[j,i] = self.sims[i,j]\n", + "\n", + "\n", + " if self.method == 'default':\n", + " return self.sims[0,0]\n", + "\n", + " return self.sims" + ] + }, + { + "cell_type": "code", + "execution_count": null, + "id": "eced3162", + "metadata": { + "cellView": "form", + "execution": {} + }, + "outputs": [], + "source": [ + "# @title Helper functions (Bonus Section)\n", + "\n", + "import contextlib\n", + "import io\n", + "import argparse\n", + "# Standard library imports\n", + "from collections import OrderedDict\n", + "import logging\n", + "\n", + "# External libraries: General utilities\n", + "import argparse\n", + "import numpy as np\n", + "\n", + "# PyTorch related imports\n", + "import torch\n", + "import torch.nn as nn\n", + "import torch.nn.functional as F\n", + "import torch.optim as optim\n", + "from torch.optim.lr_scheduler import StepLR\n", + "from torchvision import datasets, transforms\n", + "from torchvision.models.feature_extraction import create_feature_extractor, get_graph_node_names\n", + "from torchvision.utils import make_grid\n", + "\n", + "# Matplotlib for plotting\n", + "import matplotlib as mpl\n", + "import matplotlib.pyplot as plt\n", + "%matplotlib inline\n", + "\n", + "# SciPy for statistical functions\n", + "from scipy import stats\n", + "\n", + "# Scikit-Learn for machine learning utilities\n", + "from sklearn.decomposition import PCA\n", + "from sklearn import manifold\n", + "\n", + "# RSA toolbox specific imports\n", + "import rsatoolbox\n", + "from rsatoolbox.data import Dataset\n", + "from rsatoolbox.rdm.calc import calc_rdm\n", + "\n", + "class Net(nn.Module):\n", + " \"\"\"\n", + " A neural network model for image classification, consisting of two convolutional layers,\n", + " followed by two fully connected layers with dropout regularization.\n", + "\n", + " Methods:\n", + " - forward(input): Defines the forward pass of the network.\n", + " \"\"\"\n", + "\n", + " def __init__(self):\n", + " \"\"\"\n", + " Initializes the network layers.\n", + "\n", + " Layers:\n", + " - conv1: First convolutional layer with 1 input channel, 32 output channels, and a 3x3 kernel.\n", + " - conv2: Second convolutional layer with 32 input channels, 64 output channels, and a 3x3 kernel.\n", + " - dropout1: Dropout layer with a dropout probability of 0.25.\n", + " - dropout2: Dropout layer with a dropout probability of 0.5.\n", + " - fc1: First fully connected layer with 9216 input features and 128 output features.\n", + " - fc2: Second fully connected layer with 128 input features and 10 output features.\n", + " \"\"\"\n", + " super(Net, self).__init__()\n", + " self.conv1 = nn.Conv2d(1, 32, 3, 1)\n", + " self.conv2 = nn.Conv2d(32, 64, 3, 1)\n", + " self.dropout1 = nn.Dropout(0.25)\n", + " self.dropout2 = nn.Dropout(0.5)\n", + " self.fc1 = nn.Linear(9216, 128)\n", + " self.fc2 = nn.Linear(128, 10)\n", + "\n", + " def forward(self, input):\n", + " \"\"\"\n", + " Defines the forward pass of the network.\n", + "\n", + " Inputs:\n", + " - input (torch.Tensor): Input tensor of shape (batch_size, 1, height, width).\n", + "\n", + " Outputs:\n", + " - output (torch.Tensor): Output tensor of shape (batch_size, 10) representing the class probabilities for each input sample.\n", + " \"\"\"\n", + " x = self.conv1(input)\n", + " x = F.relu(x)\n", + " x = self.conv2(x)\n", + " x = F.relu(x)\n", + " x = F.max_pool2d(x, 2)\n", + " x = self.dropout1(x)\n", + " x = torch.flatten(x, 1)\n", + " x = self.fc1(x)\n", + " x = F.relu(x)\n", + " x = self.dropout2(x)\n", + " x = self.fc2(x)\n", + " output = F.softmax(x, dim=1)\n", + " return output\n", + "\n", + "class recurrent_Net(nn.Module):\n", + " \"\"\"\n", + " A recurrent neural network model for image classification, consisting of two convolutional layers\n", + " with recurrent connections and a readout layer.\n", + "\n", + " Methods:\n", + " - __init__(time_steps=5): Initializes the network layers and sets the number of time steps for recurrence.\n", + " - forward(input): Defines the forward pass of the network.\n", + " \"\"\"\n", + "\n", + " def __init__(self, time_steps=5):\n", + " \"\"\"\n", + " Initializes the network layers and sets the number of time steps for recurrence.\n", + "\n", + " Layers:\n", + " - conv1: First convolutional layer with 1 input channel, 16 output channels, and a 3x3 kernel with a stride of 3.\n", + " - conv2: Second convolutional layer with 16 input channels, 16 output channels, and a 3x3 kernel with padding of 1.\n", + " - readout: A sequential layer containing:\n", + " - dropout: Dropout layer with a dropout probability of 0.25.\n", + " - avgpool: Adaptive average pooling layer to reduce spatial dimensions to 1x1.\n", + " - flatten: Flatten layer to convert the 2D pooled output to 1D.\n", + " - linear: Fully connected layer with 16 input features and 10 output features.\n", + " - time_steps (int): Number of time steps for the recurrent connection.\n", + " \"\"\"\n", + " super(recurrent_Net, self).__init__()\n", + " self.conv1 = nn.Conv2d(1, 16, 3, 3)\n", + " self.conv2 = nn.Conv2d(16, 16, 3, 1, padding=1)\n", + " self.readout = nn.Sequential(OrderedDict([\n", + " ('dropout', nn.Dropout(0.25)),\n", + " ('avgpool', nn.AdaptiveAvgPool2d(1)),\n", + " ('flatten', nn.Flatten()),\n", + " ('linear', nn.Linear(16, 10))\n", + " ]))\n", + " self.time_steps = time_steps\n", + "\n", + " def forward(self, input):\n", + " \"\"\"\n", + " Defines the forward pass of the network.\n", + "\n", + " Inputs:\n", + " - input (torch.Tensor): Input tensor of shape (batch_size, 1, height, width).\n", + "\n", + " Outputs:\n", + " - output (torch.Tensor): Output tensor of shape (batch_size, 10) representing the class probabilities for each input sample.\n", + " \"\"\"\n", + " input = self.conv1(input)\n", + " x = input\n", + " for t in range(0, self.time_steps):\n", + " x = input + self.conv2(x)\n", + " x = F.relu(x)\n", + "\n", + " x = self.readout(x)\n", + " output = F.softmax(x, dim=1)\n", + " return output\n", + "\n", + "\n", + "def train_one_epoch(args, model, device, train_loader, optimizer, epoch):\n", + " \"\"\"\n", + " Trains the model for one epoch.\n", + "\n", + " Inputs:\n", + " - args (Namespace): Arguments for training configuration.\n", + " - model (torch.nn.Module): The model to be trained.\n", + " - device (torch.device): The device to use for training (CPU/GPU).\n", + " - train_loader (torch.utils.data.DataLoader): DataLoader for the training data.\n", + " - optimizer (torch.optim.Optimizer): Optimizer for updating the model parameters.\n", + " - epoch (int): The current epoch number.\n", + " \"\"\"\n", + " model.train()\n", + " for batch_idx, (data, target) in enumerate(train_loader):\n", + " data, target = data.to(device), target.to(device)\n", + " optimizer.zero_grad()\n", + " output = model(data)\n", + " output = torch.log(output) # to make it a log_softmax\n", + " loss = F.nll_loss(output, target)\n", + " loss.backward()\n", + " optimizer.step()\n", + " if batch_idx % args.log_interval == 0:\n", + " print('Train Epoch: {} [{}/{} ({:.0f}%)]\\tLoss: {:.6f}'.format(\n", + " epoch, batch_idx * len(data), len(train_loader.dataset),\n", + " 100. * batch_idx / len(train_loader), loss.item()))\n", + " if args.dry_run:\n", + " break\n", + "\n", + "def test(model, device, test_loader, return_features=False):\n", + " \"\"\"\n", + " Evaluates the model on the test dataset.\n", + "\n", + " Inputs:\n", + " - model (torch.nn.Module): The model to be evaluated.\n", + " - device (torch.device): The device to use for evaluation (CPU/GPU).\n", + " - test_loader (torch.utils.data.DataLoader): DataLoader for the test data.\n", + " - return_features (bool): If True, returns the features from the model. Default is False.\n", + " \"\"\"\n", + " model.eval()\n", + " test_loss = 0\n", + " correct = 0\n", + " with torch.no_grad():\n", + " for data, target in test_loader:\n", + " data, target = data.to(device), target.to(device)\n", + " output = model(data)\n", + " output = torch.log(output)\n", + " test_loss += F.nll_loss(output, target, reduction='sum').item() # sum up batch loss\n", + " pred = output.argmax(dim=1, keepdim=True) # get the index of the max log-probability\n", + " correct += pred.eq(target.view_as(pred)).sum().item()\n", + "\n", + " test_loss /= len(test_loader.dataset)\n", + "\n", + " print('\\nTest set: Average loss: {:.4f}, Accuracy: {}/{} ({:.0f}%)\\n'.format(\n", + " test_loss, correct, len(test_loader.dataset),\n", + " 100. * correct / len(test_loader.dataset)))\n", + "\n", + "def build_args():\n", + " \"\"\"\n", + " Builds and parses command-line arguments for training.\n", + " \"\"\"\n", + " parser = argparse.ArgumentParser(description='PyTorch MNIST Example')\n", + " parser.add_argument('--batch-size', type=int, default=64, metavar='N',\n", + " help='input batch size for training (default: 64)')\n", + " parser.add_argument('--test-batch-size', type=int, default=1000, metavar='N',\n", + " help='input batch size for testing (default: 1000)')\n", + " parser.add_argument('--epochs', type=int, default=2, metavar='N',\n", + " help='number of epochs to train (default: 14)')\n", + " parser.add_argument('--lr', type=float, default=1.0, metavar='LR',\n", + " help='learning rate (default: 1.0)')\n", + " parser.add_argument('--gamma', type=float, default=0.7, metavar='M',\n", + " help='Learning rate step gamma (default: 0.7)')\n", + " parser.add_argument('--no-cuda', action='store_true', default=False,\n", + " help='disables CUDA training')\n", + " parser.add_argument('--no-mps', action='store_true', default=False,\n", + " help='disables macOS GPU training')\n", + " parser.add_argument('--dry-run', action='store_true', default=False,\n", + " help='quickly check a single pass')\n", + " parser.add_argument('--seed', type=int, default=1, metavar='S',\n", + " help='random seed (default: 1)')\n", + " parser.add_argument('--log-interval', type=int, default=50, metavar='N',\n", + " help='how many batches to wait before logging training status')\n", + " parser.add_argument('--save-model', action='store_true', default=False,\n", + " help='For Saving the current Model')\n", + " args = parser.parse_args('')\n", + "\n", + " use_cuda = torch.cuda.is_available() #not args.no_cuda and\n", + "\n", + " if use_cuda:\n", + " device = torch.device(\"cuda\")\n", + " else:\n", + " device = torch.device(\"cpu\")\n", + "\n", + " args.use_cuda = use_cuda\n", + " args.device = device\n", + " return args\n", + "\n", + "def fetch_dataloaders(args):\n", + " \"\"\"\n", + " Fetches the data loaders for training and testing datasets.\n", + "\n", + " Inputs:\n", + " - args (Namespace): Parsed arguments with training configuration.\n", + "\n", + " Outputs:\n", + " - train_loader (torch.utils.data.DataLoader): DataLoader for the training data.\n", + " - test_loader (torch.utils.data.DataLoader): DataLoader for the test data.\n", + " \"\"\"\n", + " train_kwargs = {'batch_size': args.batch_size}\n", + " test_kwargs = {'batch_size': args.test_batch_size}\n", + " if args.use_cuda:\n", + " cuda_kwargs = {'num_workers': 1,\n", + " 'pin_memory': True,\n", + " 'shuffle': True}\n", + " train_kwargs.update(cuda_kwargs)\n", + " test_kwargs.update(cuda_kwargs)\n", + "\n", + " transform=transforms.Compose([\n", + " transforms.ToTensor(),\n", + " transforms.Normalize((0.1307,), (0.3081,))\n", + " ])\n", + " with contextlib.redirect_stdout(io.StringIO()): #to suppress output\n", + " dataset1 = datasets.MNIST('../data', train=True, download=True,\n", + " transform=transform)\n", + " dataset2 = datasets.MNIST('../data', train=False,\n", + " transform=transform)\n", + " train_loader = torch.utils.data.DataLoader(dataset1, **train_kwargs)\n", + " test_loader = torch.utils.data.DataLoader(dataset2, **test_kwargs)\n", + " return train_loader, test_loader\n", + "\n", + "def train_model(args, model, optimizer):\n", + " \"\"\"\n", + " Trains the model using the specified arguments and optimizer.\n", + "\n", + " Inputs:\n", + " - args (Namespace): Parsed arguments with training configuration.\n", + " - model (torch.nn.Module): The model to be trained.\n", + " - optimizer (torch.optim.Optimizer): Optimizer for updating the model parameters.\n", + "\n", + " Outputs:\n", + " - None: The function trains the model and optionally saves it.\n", + " \"\"\"\n", + " scheduler = StepLR(optimizer, step_size=1, gamma=args.gamma)\n", + " for epoch in range(1, args.epochs + 1):\n", + " train_one_epoch(args, model, args.device, train_loader, optimizer, epoch)\n", + " test(model, args.device, test_loader)\n", + " scheduler.step()\n", + "\n", + " if args.save_model:\n", + " torch.save(model.state_dict(), \"mnist_cnn.pt\")\n", + "\n", + "\n", + "def calc_rdms(model_features, method='correlation'):\n", + " \"\"\"\n", + " Calculates representational dissimilarity matrices (RDMs) for model features.\n", + "\n", + " Inputs:\n", + " - model_features (dict): A dictionary where keys are layer names and values are features of the layers.\n", + " - method (str): The method to calculate RDMs, e.g., 'correlation'. Default is 'correlation'.\n", + "\n", + " Outputs:\n", + " - rdms (pyrsa.rdm.RDMs): RDMs object containing dissimilarity matrices.\n", + " - rdms_dict (dict): A dictionary with layer names as keys and their corresponding RDMs as values.\n", + " \"\"\"\n", + " ds_list = []\n", + " for l in range(len(model_features)):\n", + " layer = list(model_features.keys())[l]\n", + " feats = model_features[layer]\n", + "\n", + " if type(feats) is list:\n", + " feats = feats[-1]\n", + "\n", + " if args.use_cuda:\n", + " feats = feats.cpu()\n", + "\n", + " if len(feats.shape) > 2:\n", + " feats = feats.flatten(1)\n", + "\n", + " feats = feats.detach().numpy()\n", + " ds = Dataset(feats, descriptors=dict(layer=layer))\n", + " ds_list.append(ds)\n", + "\n", + " rdms = calc_rdm(ds_list, method=method)\n", + " rdms_dict = {list(model_features.keys())[i]: rdms.get_matrices()[i] for i in range(len(model_features))}\n", + "\n", + " return rdms, rdms_dict\n", + "\n", + "def fgsm_attack(image, epsilon, data_grad):\n", + " \"\"\"\n", + " Performs FGSM attack on an image.\n", + "\n", + " Inputs:\n", + " - image (torch.Tensor): Original image.\n", + " - epsilon (float): Perturbation magnitude.\n", + " - data_grad (torch.Tensor): Gradient of the data.\n", + "\n", + " Outputs:\n", + " - perturbed_image (torch.Tensor): Perturbed image after FGSM attack.\n", + " \"\"\"\n", + " sign_data_grad = data_grad.sign()\n", + " perturbed_image = image + epsilon * sign_data_grad\n", + " perturbed_image = torch.clamp(perturbed_image, 0, 1)\n", + " return perturbed_image\n", + "\n", + "def denorm(batch, mean=[0.1307], std=[0.3081]):\n", + " \"\"\"\n", + " Converts a batch of normalized tensors to their original scale.\n", + "\n", + " Inputs:\n", + " - batch (torch.Tensor): Batch of normalized tensors.\n", + " - mean (torch.Tensor or list): Mean used for normalization.\n", + " - std (torch.Tensor or list): Standard deviation used for normalization.\n", + "\n", + " Outputs:\n", + " - torch.Tensor: Batch of tensors without normalization applied to them.\n", + " \"\"\"\n", + " if isinstance(mean, list):\n", + " mean = torch.tensor(mean).to(batch.device)\n", + " if isinstance(std, list):\n", + " std = torch.tensor(std).to(batch.device)\n", + "\n", + " return batch * std.view(1, -1, 1, 1) + mean.view(1, -1, 1, 1)\n", + "\n", + "def generate_adversarial(model, imgs, targets, epsilon):\n", + " \"\"\"\n", + " Generates adversarial examples using FGSM attack.\n", + "\n", + " Inputs:\n", + " - model (torch.nn.Module): The model to attack.\n", + " - imgs (torch.Tensor): Batch of images.\n", + " - targets (torch.Tensor): Batch of target labels.\n", + " - epsilon (float): Perturbation magnitude.\n", + "\n", + " Outputs:\n", + " - adv_imgs (torch.Tensor): Batch of adversarial images.\n", + " \"\"\"\n", + " adv_imgs = []\n", + "\n", + " for img, target in zip(imgs, targets):\n", + " img = img.unsqueeze(0)\n", + " target = target.unsqueeze(0)\n", + " img.requires_grad = True\n", + "\n", + " output = model(img)\n", + " output = torch.log(output)\n", + " loss = F.nll_loss(output, target)\n", + "\n", + " model.zero_grad()\n", + " loss.backward()\n", + "\n", + " data_grad = img.grad.data\n", + " data_denorm = denorm(img)\n", + " perturbed_data = fgsm_attack(data_denorm, epsilon, data_grad)\n", + " perturbed_data_normalized = transforms.Normalize((0.1307,), (0.3081,))(perturbed_data)\n", + "\n", + " adv_imgs.append(perturbed_data_normalized.detach())\n", + "\n", + " return torch.cat(adv_imgs)\n", + "\n", + "def test_adversarial(model, imgs, targets):\n", + " \"\"\"\n", + " Tests the model on adversarial examples and prints the accuracy.\n", + "\n", + " Inputs:\n", + " - model (torch.nn.Module): The model to be tested.\n", + " - imgs (torch.Tensor): Batch of adversarial images.\n", + " - targets (torch.Tensor): Batch of target labels.\n", + " \"\"\"\n", + " correct = 0\n", + " output = model(imgs)\n", + " output = torch.log(output)\n", + " pred = output.argmax(dim=1, keepdim=True) # get the index of the max log-probability\n", + " correct += pred.eq(targets.view_as(pred)).sum().item()\n", + "\n", + " final_acc = correct / float(len(imgs))\n", + " print(f\"adversarial test accuracy = {correct} / {len(imgs)} = {final_acc}\")\n", + "\n", + "def extract_features(model, imgs, return_layers, plot='none'):\n", + " \"\"\"\n", + " Extracts features from specified layers of the model.\n", + "\n", + " Inputs:\n", + " - model (torch.nn.Module): The model from which to extract features.\n", + " - imgs (torch.Tensor): Batch of input images.\n", + " - return_layers (list): List of layer names from which to extract features.\n", + " - plot (str): Option to plot the features. Default is 'none'.\n", + "\n", + " Outputs:\n", + " - model_features (dict): A dictionary with layer names as keys and extracted features as values.\n", + " \"\"\"\n", + " if return_layers == 'all':\n", + " return_layers, _ = get_graph_node_names(model)\n", + " elif return_layers == 'layers':\n", + " layers, _ = get_graph_node_names(model)\n", + " return_layers = [l for l in layers if 'input' in l or 'conv' in l or 'fc' in l]\n", + "\n", + " feature_extractor = create_feature_extractor(model, return_nodes=return_layers)\n", + " model_features = feature_extractor(imgs)\n", + "\n", + " return model_features" + ] + }, + { + "cell_type": "code", + "execution_count": null, + "id": "be4a4946", + "metadata": { + "cellView": "form", + "execution": {} + }, + "outputs": [], + "source": [ + "# @title Plotting functions (Bonus)\n", + "\n", + "def sample_images(data_loader, n=5, plot=False):\n", + " \"\"\"\n", + " Samples a specified number of images from a data loader.\n", + "\n", + " Inputs:\n", + " - data_loader (torch.utils.data.DataLoader): Data loader containing images and labels.\n", + " - n (int): Number of images to sample per class.\n", + " - plot (bool): Whether to plot the sampled images using matplotlib.\n", + "\n", + " Outputs:\n", + " - imgs (torch.Tensor): Sampled images.\n", + " - labels (torch.Tensor): Corresponding labels for the sampled images.\n", + " \"\"\"\n", + "\n", + " with plt.xkcd():\n", + " imgs, targets = next(iter(data_loader))\n", + "\n", + " imgs_o = []\n", + " labels = []\n", + " for value in range(10):\n", + " cat_imgs = imgs[np.where(targets == value)][0:n]\n", + " imgs_o.append(cat_imgs)\n", + " labels.append([value]*len(cat_imgs))\n", + "\n", + " imgs = torch.cat(imgs_o, dim=0)\n", + " labels = torch.tensor(labels).flatten()\n", + "\n", + " if plot:\n", + " plt.imshow(torch.moveaxis(make_grid(imgs, nrow=5, padding=0, normalize=False, pad_value=0), 0,-1))\n", + " plt.axis('off')\n", + "\n", + " return imgs, labels\n", + "\n", + "\n", + "def plot_rdms(model_rdms):\n", + " \"\"\"\n", + " Plots the Representational Dissimilarity Matrices (RDMs) for each layer of a model.\n", + "\n", + " Inputs:\n", + " - model_rdms (dict): A dictionary where keys are layer names and values are the corresponding RDMs.\n", + " \"\"\"\n", + "\n", + " with plt.xkcd():\n", + " fig = plt.figure(figsize=(8, 4))\n", + " gs = fig.add_gridspec(1, len(model_rdms))\n", + " fig.subplots_adjust(wspace=0.2, hspace=0.2)\n", + "\n", + " for l in range(len(model_rdms)):\n", + "\n", + " layer = list(model_rdms.keys())[l]\n", + " rdm = np.squeeze(model_rdms[layer])\n", + "\n", + " if len(rdm.shape) < 2:\n", + " rdm = rdm.reshape( (int(np.sqrt(rdm.shape[0])), int(np.sqrt(rdm.shape[0]))) )\n", + "\n", + " rdm = rdm / np.max(rdm)\n", + "\n", + " ax = plt.subplot(gs[0,l])\n", + " ax_ = ax.imshow(rdm, cmap='magma_r')\n", + " ax.set_title(f'{layer}')\n", + "\n", + " fig.subplots_adjust(right=0.9)\n", + " cbar_ax = fig.add_axes([1.01, 0.18, 0.01, 0.53])\n", + " cbar_ax.text(-2.3, 0.05, 'Normalized euclidean distance', size=10, rotation=90)\n", + " fig.colorbar(ax_, cax=cbar_ax)\n", + "\n", + " plt.show()\n", + "\n", + "def rep_path(model_features, model_colors, labels=None, rdm_calc_method='euclidean', rdm_comp_method='cosine'):\n", + " \"\"\"\n", + " Represents paths of model features in a reduced-dimensional space.\n", + "\n", + " Inputs:\n", + " - model_features (dict): Dictionary containing model features for each model.\n", + " - model_colors (dict): Dictionary mapping model names to colors for visualization.\n", + " - labels (array-like, optional): Array of labels corresponding to the model features.\n", + " - rdm_calc_method (str, optional): Method for calculating RDMS ('euclidean' or 'correlation').\n", + " - rdm_comp_method (str, optional): Method for comparing RDMS ('cosine' or 'corr').\n", + " \"\"\"\n", + " with plt.xkcd():\n", + " path_len = []\n", + " path_colors = []\n", + " rdms_list = []\n", + " ax_ticks = []\n", + " tick_colors = []\n", + " model_names = list(model_features.keys())\n", + " for m in range(len(model_names)):\n", + " model_name = model_names[m]\n", + " features = model_features[model_name]\n", + " path_colors.append(model_colors[model_name])\n", + " path_len.append(len(features))\n", + " ax_ticks.append(list(features.keys()))\n", + " tick_colors.append([model_colors[model_name]]*len(features))\n", + " rdms, _ = calc_rdms(features, method=rdm_calc_method)\n", + " rdms_list.append(rdms)\n", + "\n", + " path_len = np.insert(np.cumsum(path_len),0,0)\n", + "\n", + " if labels is not None:\n", + " rdms, _ = calc_rdms({'labels' : F.one_hot(labels).float().to(device)}, method=rdm_calc_method)\n", + " rdms_list.append(rdms)\n", + " ax_ticks.append(['labels'])\n", + " tick_colors.append(['m'])\n", + " idx_labels = -1\n", + "\n", + " rdms = rsatoolbox.rdm.concat(rdms_list)\n", + "\n", + " #Flatten the list\n", + " ax_ticks = [l for model_layers in ax_ticks for l in model_layers]\n", + " tick_colors = [l for model_layers in tick_colors for l in model_layers]\n", + " tick_colors = ['k' if tick == 'input' else color for tick, color in zip(ax_ticks, tick_colors)]\n", + "\n", + " rdms_comp = rsatoolbox.rdm.compare(rdms, rdms, method=rdm_comp_method)\n", + " if rdm_comp_method == 'cosine':\n", + " rdms_comp = np.arccos(rdms_comp)\n", + " rdms_comp = np.nan_to_num(rdms_comp, nan=0.0)\n", + "\n", + " # Symmetrize\n", + " rdms_comp = (rdms_comp + rdms_comp.T) / 2.0\n", + "\n", + " # reduce dim to 2\n", + " transformer = manifold.MDS(n_components = 2, max_iter=1000, n_init=10, normalized_stress='auto', dissimilarity=\"precomputed\")\n", + " dims= transformer.fit_transform(rdms_comp)\n", + "\n", + " # remove duplicates of the input layer from multiple models\n", + " remove_duplicates = np.where(np.array(ax_ticks) == 'input')[0][1:]\n", + " for index in remove_duplicates:\n", + " del ax_ticks[index]\n", + " del tick_colors[index]\n", + " rdms_comp = np.delete(np.delete(rdms_comp, index, axis=0), index, axis=1)\n", + "\n", + " fig = plt.figure(figsize=(8, 4))\n", + " gs = fig.add_gridspec(1, 2)\n", + " fig.subplots_adjust(wspace=0.2, hspace=0.2)\n", + "\n", + " ax = plt.subplot(gs[0,0])\n", + " ax_ = ax.imshow(rdms_comp, cmap='viridis_r')\n", + " fig.subplots_adjust(left=0.2)\n", + " cbar_ax = fig.add_axes([-0.01, 0.2, 0.01, 0.5])\n", + " #cbar_ax.text(-7, 0.05, 'dissimilarity between rdms', size=10, rotation=90)\n", + " fig.colorbar(ax_, cax=cbar_ax,location='left')\n", + " ax.set_title('Dissimilarity between layer rdms', fontdict = {'fontsize': 14})\n", + " ax.set_xticks(np.arange(len(ax_ticks)), labels=ax_ticks, fontsize=7, rotation=83)\n", + " ax.set_yticks(np.arange(len(ax_ticks)), labels=ax_ticks, fontsize=7)\n", + " [t.set_color(i) for (i,t) in zip(tick_colors, ax.xaxis.get_ticklabels())]\n", + " [t.set_color(i) for (i,t) in zip(tick_colors, ax.yaxis.get_ticklabels())]\n", + "\n", + " ax = plt.subplot(gs[0,1])\n", + " amin, amax = dims.min(), dims.max()\n", + " amin, amax = (amin + amax) / 2 - (amax - amin) * 5/8, (amin + amax) / 2 + (amax - amin) * 5/8\n", + "\n", + " for i in range(len(rdms_list)-1):\n", + "\n", + " path_indices = np.arange(path_len[i], path_len[i+1])\n", + " ax.plot(dims[path_indices, 0], dims[path_indices, 1], color=path_colors[i], marker='.')\n", + " ax.set_title('Representational geometry path', fontdict = {'fontsize': 14})\n", + " ax.set_xlim([amin, amax])\n", + " ax.set_ylim([amin, amax])\n", + " ax.set_xlabel(f\"dim 1\")\n", + " ax.set_ylabel(f\"dim 2\")\n", + "\n", + " # if idx_input is not None:\n", + " idx_input = 0\n", + " ax.plot(dims[idx_input, 0], dims[idx_input, 1], color='k', marker='s')\n", + "\n", + " if labels is not None:\n", + " ax.plot(dims[idx_labels, 0], dims[idx_labels, 1], color='m', marker='*')\n", + "\n", + " ax.legend(model_names, fontsize=8)\n", + " fig.tight_layout()\n", + "\n", + "def plot_dim_reduction(model_features, labels, transformer_funcs):\n", + " \"\"\"\n", + " Plots the dimensionality reduction results for model features using various transformers.\n", + "\n", + " Inputs:\n", + " - model_features (dict): Dictionary containing model features for each layer.\n", + " - labels (array-like): Array of labels corresponding to the model features.\n", + " - transformer_funcs (list): List of dimensionality reduction techniques to apply ('PCA', 'MDS', 't-SNE').\n", + " \"\"\"\n", + " with plt.xkcd():\n", + "\n", + " transformers = []\n", + " for t in transformer_funcs:\n", + " if t == 'PCA': transformers.append(PCA(n_components=2))\n", + " if t == 'MDS': transformers.append(manifold.MDS(n_components = 2, normalized_stress='auto'))\n", + " if t == 't-SNE': transformers.append(manifold.TSNE(n_components = 2, perplexity=40, verbose=0))\n", + "\n", + " fig = plt.figure(figsize=(8, 2.5*len(transformers)))\n", + " # and we add one plot per reference point\n", + " gs = fig.add_gridspec(len(transformers), len(model_features))\n", + " fig.subplots_adjust(wspace=0.2, hspace=0.2)\n", + "\n", + " return_layers = list(model_features.keys())\n", + "\n", + " for f in range(len(transformer_funcs)):\n", + "\n", + " for l in range(len(return_layers)):\n", + " layer = return_layers[l]\n", + " feats = model_features[layer].detach().cpu().flatten(1)\n", + " feats_transformed= transformers[f].fit_transform(feats)\n", + "\n", + " amin, amax = feats_transformed.min(), feats_transformed.max()\n", + " amin, amax = (amin + amax) / 2 - (amax - amin) * 5/8, (amin + amax) / 2 + (amax - amin) * 5/8\n", + " ax = plt.subplot(gs[f,l])\n", + " ax.set_xlim([amin, amax])\n", + " ax.set_ylim([amin, amax])\n", + " ax.axis(\"off\")\n", + " #ax.set_title(f'{layer}')\n", + " if f == 0: ax.text(0.5, 1.12, f'{layer}', size=16, ha=\"center\", transform=ax.transAxes)\n", + " if l == 0: ax.text(-0.3, 0.5, transformer_funcs[f], size=16, ha=\"center\", transform=ax.transAxes)\n", + " # Create a discrete color map based on unique labels\n", + " num_colors = len(np.unique(labels))\n", + " cmap = plt.get_cmap('viridis_r', num_colors) # 10 discrete colors\n", + " norm = mpl.colors.BoundaryNorm(np.arange(-0.5,num_colors), cmap.N)\n", + " ax_ = ax.scatter(feats_transformed[:, 0], feats_transformed[:, 1], c=labels, cmap=cmap, norm=norm)\n", + "\n", + " fig.subplots_adjust(right=0.9)\n", + " cbar_ax = fig.add_axes([1.01, 0.18, 0.01, 0.53])\n", + " fig.colorbar(ax_, cax=cbar_ax, ticks=np.linspace(0,9,10))\n", + " plt.show()" + ] + }, + { + "cell_type": "code", + "execution_count": null, + "id": "21f68945", + "metadata": { + "cellView": "form", + "execution": {} + }, + "outputs": [], + "source": [ + "# @title Data retrieval\n", + "\n", + "import os\n", + "import requests\n", + "import hashlib\n", + "\n", + "# Variables for file and download URL\n", + "fnames = [\"standard_model.pth\", \"adversarial_model.pth\", \"recurrent_model.pth\"] # The names of the files to be downloaded\n", + "urls = [\"https://osf.io/s5rt6/download\", \"https://osf.io/qv5eb/download\", \"https://osf.io/6hnwk/download\"] # URLs from where the files will be downloaded\n", + "expected_md5s = [\"2e63c2cd77bc9f1fa67673d956ec910d\", \"25fb34497377921b54368317f68a7aa7\", \"ee5cea3baa264cb78300102fa6ed66e8\"] # MD5 hashes for verifying files integrity\n", + "\n", + "for fname, url, expected_md5 in zip(fnames, urls, expected_md5s):\n", + " if not os.path.isfile(fname):\n", + " try:\n", + " # Attempt to download the file\n", + " r = requests.get(url) # Make a GET request to the specified URL\n", + " except requests.ConnectionError:\n", + " # Handle connection errors during the download\n", + " print(\"!!! Failed to download data !!!\")\n", + " else:\n", + " # No connection errors, proceed to check the response\n", + " if r.status_code != requests.codes.ok:\n", + " # Check if the HTTP response status code indicates a successful download\n", + " print(\"!!! Failed to download data !!!\")\n", + " elif hashlib.md5(r.content).hexdigest() != expected_md5:\n", + " # Verify the integrity of the downloaded file using MD5 checksum\n", + " print(\"!!! Data download appears corrupted !!!\")\n", + " else:\n", + " # If download is successful and data is not corrupted, save the file\n", + " with open(fname, \"wb\") as fid:\n", + " fid.write(r.content) # Write the downloaded content to a file" + ] + }, + { + "cell_type": "code", + "execution_count": null, + "id": "93aeca0a", + "metadata": { + "cellView": "form", + "execution": {} + }, + "outputs": [], + "source": [ + "# @title Figure settings\n", + "\n", + "logging.getLogger('matplotlib.font_manager').disabled = True\n", + "\n", + "%matplotlib inline\n", + "%config InlineBackend.figure_format = 'retina' # perfrom high definition rendering for images and plots\n", + "plt.style.use(\"https://raw.githubusercontent.com/NeuromatchAcademy/course-content/main/nma.mplstyle\")" + ] + }, + { + "cell_type": "code", + "execution_count": null, + "id": "dd8052d5", + "metadata": { + "cellView": "form", + "execution": {} + }, + "outputs": [], + "source": [ + "# @title Set device (GPU or CPU)\n", + "\n", + "# inform the user if the notebook uses GPU or CPU.\n", + "\n", + "def set_device():\n", + " \"\"\"\n", + " Determines and sets the computational device for PyTorch operations based on the availability of a CUDA-capable GPU.\n", + "\n", + " Outputs:\n", + " - device (str): The device that PyTorch will use for computations ('cuda' or 'cpu'). This string can be directly used\n", + " in PyTorch operations to specify the device.\n", + " \"\"\"\n", + "\n", + " device = \"cuda\" if torch.cuda.is_available() else \"cpu\"\n", + " if device != \"cuda\":\n", + " print(\"GPU is not enabled in this notebook. \\n\"\n", + " \"If you want to enable it, in the menu under `Runtime` -> \\n\"\n", + " \"`Hardware accelerator.` and select `GPU` from the dropdown menu\")\n", + " else:\n", + " print(\"GPU is enabled in this notebook. \\n\"\n", + " \"If you want to disable it, in the menu under `Runtime` -> \\n\"\n", + " \"`Hardware accelerator.` and select `None` from the dropdown menu\")\n", + "\n", + " return device\n", + "\n", + "device = set_device()" ] }, { @@ -75,84 +2155,532 @@ "execution_count": null, "id": "c28a92e7-e76c-48de-b574-15a1272717cf", "metadata": { - "cellView": "form", + "cellView": "form", + "execution": {} + }, + "outputs": [], + "source": [ + "# @title Load Slides\n", + "\n", + "from IPython.display import IFrame\n", + "from ipywidgets import widgets\n", + "out = widgets.Output()\n", + "\n", + "link_id = \"8fx23\"\n", + "\n", + "with out:\n", + " print(f\"If you want to download the slides: https://osf.io/download/{link_id}/\")\n", + " display(IFrame(src=f\"https://mfr.ca-1.osf.io/render?url=https://osf.io/{link_id}/?direct%26mode=render%26action=download%26mode=render\", width=730, height=410))\n", + "display(out)" + ] + }, + { + "cell_type": "markdown", + "id": "407ace26", + "metadata": { + "execution": {} + }, + "source": [ + "---\n", + "\n", + "# Intro\n", + "\n", + "Welcome to Tutorial 5 of Day 3 (W1D3) of the NeuroAI course. In this tutorial we are going to look at an exciting method that measures similarity from a slightly different perspective, a temporal one. The prior methods we have looked at were centeed around geometry and spatial representations, where we looked at metrics such as the Euclidean and Mahalanobis distance metrics. However, one thing we often want to study in neuroscience and in AI separately - is the temporal domain. Even more so in our own field of NeuroAI, we often deal with time series of neuronal / biological recordings. One thing you should already have a broad level of awareness of is that end structures can end up looking the same even though the paths taken to arrive at those end structures were very different.\n", + "\n", + "In NeuroAI, we're often confronted with systems that seem to have some sort of overlap and we want to study whether this implies there is a shared computation pairs up with the shared task (we looked at this in detail yesterday in our *Comparing Tasks* day). Today, we will begin by watching a short intro video by Mitchell Ostrow, who will describe his method to compare representations over temporal sequences (the method is called Dynamic Similarity Analysis). Then we are going to introduce three simple dynamical systems and we will explore them from the perspective of Dynamic Similarity Analysis and also describe the conceptual relationship to Representational Similarity Analysis. You will have a short coding exercise on the topic of temporal similarity analysis on three different types of trajectories. \n", + "\n", + "At the end of the tutorial, we will finally look at a further aspect of temporal sequences using RNNs. This is an adaptation of the ideas introduced in Tutorial 2 but now based around recurrent representations from RNNs. We hope you enjoy this tutorial today and that it gets you thinking not just what similarity values mean, but which ones are appropriate (here, from a spatial or temporal perspective). We aim to continually expand the tools necessary in the NeuroAI researcher's toolkit. Complementary tools, when applicable, can often tell a far richer story than just using a single method." + ] + }, + { + "cell_type": "code", + "execution_count": null, + "id": "b5d6178f-ddf5-41ae-b676-15e452dc8b78", + "metadata": { + "cellView": "form", + "execution": {} + }, + "outputs": [], + "source": [ + "# @title Video 1: Dynamical Similarity Analysis\n", + "\n", + "from ipywidgets import widgets\n", + "from IPython.display import YouTubeVideo\n", + "from IPython.display import IFrame\n", + "from IPython.display import display\n", + "\n", + "class PlayVideo(IFrame):\n", + " def __init__(self, id, source, page=1, width=400, height=300, **kwargs):\n", + " self.id = id\n", + " if source == 'Bilibili':\n", + " src = f'https://player.bilibili.com/player.html?bvid={id}&page={page}'\n", + " elif source == 'Osf':\n", + " src = f'https://mfr.ca-1.osf.io/render?url=https://osf.io/download/{id}/?direct%26mode=render'\n", + " super(PlayVideo, self).__init__(src, width, height, **kwargs)\n", + "\n", + "def display_videos(video_ids, W=400, H=300, fs=1):\n", + " tab_contents = []\n", + " for i, video_id in enumerate(video_ids):\n", + " out = widgets.Output()\n", + " with out:\n", + " if video_ids[i][0] == 'Youtube':\n", + " video = YouTubeVideo(id=video_ids[i][1], width=W,\n", + " height=H, fs=fs, rel=0)\n", + " print(f'Video available at https://youtube.com/watch?v={video.id}')\n", + " else:\n", + " video = PlayVideo(id=video_ids[i][1], source=video_ids[i][0], width=W,\n", + " height=H, fs=fs, autoplay=False)\n", + " if video_ids[i][0] == 'Bilibili':\n", + " print(f'Video available at https://www.bilibili.com/video/{video.id}')\n", + " elif video_ids[i][0] == 'Osf':\n", + " print(f'Video available at https://osf.io/{video.id}')\n", + " display(video)\n", + " tab_contents.append(out)\n", + " return tab_contents\n", + "\n", + "video_ids = [('Youtube', 'FHikIsQFQvM'), ('Bilibili', 'BV1qm421g7hV')]\n", + "tab_contents = display_videos(video_ids, W=854, H=480)\n", + "tabs = widgets.Tab()\n", + "tabs.children = tab_contents\n", + "for i in range(len(tab_contents)):\n", + " tabs.set_title(i, video_ids[i][0])\n", + "display(tabs)" + ] + }, + { + "cell_type": "code", + "execution_count": null, + "id": "d2ce83bc-7e86-44d3-a40a-4ad46fd5a6df", + "metadata": { + "cellView": "form", + "execution": {} + }, + "outputs": [], + "source": [ + "# @title Submit your feedback\n", + "content_review(f\"{feedback_prefix}_DSA_video\")" + ] + }, + { + "cell_type": "markdown", + "id": "937041e9", + "metadata": { + "execution": {} + }, + "source": [ + "## Section 1: Visualization of Three Temporal Sequences\n", + "\n", + "We are going to be working with the analysis of three temporal sequences today:\n", + "\n", + "* The circular time series (`Circle`)\n", + "* The oval time series (`Oval`)\n", + "* The random walk (`R-Walk`)\n", + "\n", + "The random walk is going to be broadly *oval shaped*. Now, what do you think, from a geometric perspective, might result from a spatial analysis of these three different *representations*? You will probably assume because the random walk has an oval shape and there is also an oval time series (that's not a random walk) that these would result in a higher spatial similarity. You'd be right to assume this. However, what we're going to do with the `Circle` and `Oval` time series is to include an oscillator at a specific frequency, shared amongst these two time series. In effect, this means that although when plotted in totality the shapes are different, during the dynamic (temporal) evolution of these time series, a very similar shared pattern is emerging. We want methods that are sensitive to these changes to give higher scores for time series sharing similar temporal patterns (e.g. both containing oscillating patterns at similar frequences) rather than just a random walk that resembles (geometrically) one of the other shapes (`R-Walk`). Before we continue, we'll just define this random walk in a little more detail. A random walk at a specific location / timepoint takes a random step of fixed length in a specific direction, but this can be broadly controlled to resemble geometric shapes. We've taken a random walk and then reframed it to be similar in shape to `Oval`. \n", + "\n", + "Let's now visualize these three temporal sequences, to make the previous paragraph a little clearer." + ] + }, + { + "cell_type": "code", + "execution_count": null, + "id": "b57dfe1a", + "metadata": { + "execution": {} + }, + "outputs": [], + "source": [ + "# Circle\n", + "r = .1; # rotation\n", + "A = np.array([[np.cos(r), np.sin(r)], [-np.sin(r), np.cos(r)]])\n", + "B = np.array([[1, 0], [0, 1]])\n", + "\n", + "trajectory = generate_2d_random_process(A, B)\n", + "trajectory_circle = trajectory\n", + "\n", + "# Oval\n", + "r = .1; # rotation\n", + "s = 4; # scaling\n", + "S = np.array([[1, 0], [0, s]])\n", + "Si = np.array([[1, 0], [0, 1/s]])\n", + "V = np.array([[1, 1], [-1, 1]])/np.sqrt(2)\n", + "Vi = np.array([[1, -1], [1, 1]])/np.sqrt(2)\n", + "R = np.array([[np.cos(r), np.sin(r)], [-np.sin(r), np.cos(r)]])\n", + "A = np.linalg.multi_dot([V,Si,R,S,Vi])\n", + "B = np.array([[1, 0], [0, 1]])\n", + "\n", + "trajectory = generate_2d_random_process(A, B)\n", + "trajectory_oval = trajectory\n", + "\n", + "# R-Walk (random walk)\n", + "r = .1; # rotation\n", + "A = np.array([[.9, 0], [0, .9]])\n", + "c = -.95; # correlation coefficient\n", + "B = np.array([[1, c], [0, np.sqrt(1-c*c)]])\n", + "\n", + "trajectory = generate_2d_random_process(A, B)\n", + "trajectory_walk = trajectory" + ] + }, + { + "cell_type": "markdown", + "id": "113a0dee", + "metadata": { + "execution": {} + }, + "source": [ + "Can you see how the spatial / geometric similarity of `R-Walk` and `Oval` are more similar, but the oscillations during the temporal sequence are shared between `Circle` and `Oval`? Let's run Dynamic Similarity Analysis on these temporal sequences and see what scores are returned.\n", + "\n", + "We calcularted `trajectory_oval` and `trajectory_circle` above, so let's plug these into the `DSA` function imported earlier (in the helper function cell) and see what the similarity score is." + ] + }, + { + "cell_type": "code", + "execution_count": null, + "id": "c3e36d59", + "metadata": { + "execution": {} + }, + "outputs": [], + "source": [ + "# Define the DSA computation class\n", + "dsa = DSA(X=trajectory_oval, Y=trajectory_circle, n_delays=1)\n", + "\n", + "# Call the fit method and save the result\n", + "similarities_oval_circle = dsa.fit_score()\n", + "\n", + "print(f\"DSA similarity between Oval and Circle: {similarities_oval_circle:.4f}\")" + ] + }, + { + "cell_type": "markdown", + "id": "9f1fb622", + "metadata": { + "execution": {} + }, + "source": [ + "## Multi-way Comparison\n", + "\n", + "We're now going to run DSA on our three trajectories and fit the model, returning the scores which we can investigate by plotting a confusion matrix with a heatmap to show the DSA scores." + ] + }, + { + "cell_type": "code", + "execution_count": null, + "id": "6ee9e8e8", + "metadata": { "execution": {} }, "outputs": [], "source": [ - "# @title Bonus material slides\n", + "n_delays = 1\n", + "delay_interval = 1\n", "\n", - "from IPython.display import IFrame\n", - "from ipywidgets import widgets\n", - "out = widgets.Output()\n", + "models = [trajectory_circle, trajectory_oval, trajectory_walk]\n", + "dsa = DSA(models, n_delays=n_delays, delay_interval=delay_interval)\n", + "similarities = dsa.fit_score()" + ] + }, + { + "cell_type": "code", + "execution_count": null, + "id": "18318ddb", + "metadata": { + "execution": {} + }, + "outputs": [], + "source": [ + "labels = ['Circle', 'Oval', 'Walk']\n", + "data = np.random.rand(len(labels), len(labels))\n", + "ax = sns.heatmap(similarities, xticklabels=labels, yticklabels=labels)\n", + "cbar = ax.collections[0].colorbar\n", + "cbar.ax.set_ylabel('DSA Score');\n", + "plt.title(\"Dynamic Similarity Analysis Score among Trajectories\");" + ] + }, + { + "cell_type": "markdown", + "id": "ffd49b4b", + "metadata": { + "execution": {} + }, + "source": [ + "This heatmap across the three model comparisons shows that the DSA scores between (`Walk` and `Circle`) and (`Walk` and `Oval`) to be (relatively) high, while the comparison between (`Circle` and `Oval`) is very low. Please note that this confusion matrix is symmetrical, meaning that the analysis between `trajectory_A` and `trajectory_B` returns the same dynamic similarity score as `trajectory_B` and `trajectory_A`. This is a common feature we have also seen in comparison metrics in standard RSA. One thing to note in the calculation of DSA is that comparisons among identical trajectories is `0`. This is unlike in RSA where we expect the correlation among the same stimuli to be `1.0`. This is why we see black squares along the diagonal.\n", "\n", - "link_id = \"8fx23\"\n", + "Let's put our thinking caps on for a moment: This isn't really the result we would have expected, right? What do you think might be going on here? Have a look back at the *hyperparameters* and try to make an educated guess!" + ] + }, + { + "cell_type": "markdown", + "id": "d0ff5faa", + "metadata": { + "execution": {} + }, + "source": [ + "## DSA Hyperparameters (`n_delays` and `delay_interval`)\n", "\n", - "with out:\n", - " print(f\"If you want to download the slides: https://osf.io/download/{link_id}/\")\n", - " display(IFrame(src=f\"https://mfr.ca-1.osf.io/render?url=https://osf.io/{link_id}/?direct%26mode=render%26action=download%26mode=render\", width=730, height=410))\n", - "display(out)" + "We'll now give you a hint as to why the setting of these hyperparameters is important when considering dynamic similarity analysis. The oscillators we have placed in the trajectories of `Circle` and `Oval` are not immediately apparent if you study only the previous time step for each element. It's only when considering the recurring pattern across a few different temporal delays and at what delay interval you want those to be, that we would expect to be able to detect recurring oscillations that provide us with the information we need to conclude that `Oval` and `Circle` are actually *dynamically* similar.\n", + "\n", + "You should change the values below to be more sensible hyperparameter settings and re-run the model and plot the new confusion matrix. Try using `n_delays` equal to `20` and `delay_interval` equal to `10`. Don't forget to define `models` (see above example if you get stuck)." ] }, { "cell_type": "code", "execution_count": null, - "id": "b5d6178f-ddf5-41ae-b676-15e452dc8b78", + "id": "9d8d4c03", + "metadata": { + "execution": {} + }, + "outputs": [], + "source": [ + "#################################################\n", + "## TODO for students: fill in the missing parts ##\n", + "raise NotImplementedError(\"Student exercise\")\n", + "#################################################\n", + "\n", + "n_delays = ...\n", + "delay_interval = ...\n", + "\n", + "models = ...\n", + "dsa = DSA(...)\n", + "similarities = ...\n", + "\n", + "labels = ['Circle', 'Oval', 'Walk']\n", + "ax = sns.heatmap(similarities, xticklabels=labels, yticklabels=labels)\n", + "cbar = ax.collections[0].colorbar\n", + "cbar.ax.set_ylabel('DSA Score');\n", + "plt.title(\"Dynamic Similarity Analysis Score among Trajectories\");" + ] + }, + { + "cell_type": "markdown", + "id": "a6377c65", + "metadata": { + "colab_type": "text", + "execution": {} + }, + "source": [ + "[*Click for solution*](https://github.com/neuromatch/NeuroAI_Course/tree/main/tutorials/W1D3_ComparingArtificialAndBiologicalNetworks/solutions/W1D3_Tutorial5_Solution_0467919d.py)\n", + "\n", + "*Example output:*\n", + "\n", + "Solution hint\n", + "\n" + ] + }, + { + "cell_type": "markdown", + "id": "04b0e32f", + "metadata": { + "execution": {} + }, + "source": [ + "What do you see now? We now see a much more sensible result. The DSA scores have now correctly identified that `Oval` and `Circle` are very dynamically similar! They have the highest color score according to the colorbar on the side. As is always good practice in science, let's have a look inside the `similarities` variable to look at the exact values and confirm what we see in the figure above." + ] + }, + { + "cell_type": "code", + "execution_count": null, + "id": "55fa4065", + "metadata": { + "execution": {} + }, + "outputs": [], + "source": [ + "similarities" + ] + }, + { + "cell_type": "markdown", + "id": "59cb799f", + "metadata": { + "execution": {} + }, + "source": [ + "## Comparison with RSA\n", + "\n", + "At the start of this exercise, we saw three different trajectories and pointed out that the random walk and oval shapes were most similar from a geometric perspective, both ellipse-like but not similar in their dynamic similarity. To better show the difference between DSA and RSA, we encourage you to run another comparison where we consider each time step to be a pair in the X,Y space and we will look at the the similarity between of `Oval` with both `Circle` and `Walk`. If our understanding is correct, then RSA should indicate a higher geometric similarity between (`Oval` and `Walk`) than with (`Oval` and `Circle`)." + ] + }, + { + "cell_type": "markdown", + "id": "87cf4e6e", + "metadata": { + "execution": {} + }, + "source": [ + "---\n", + "# (Bonus) Representational Geometry of Recurrent Models\n", + "\n", + "Transformations of representations can occur across space and time, e.g., layers of a neural network and steps of recurrent computation. We've looked at the temporal dimension today and earlier today in the other tutorials we looked mainly at spatial representations.\n", + "\n", + "Just as the layers in a feedforward DNN can change the representational geometry to perform a task, steps in a recurrent network can reuse the same layer to reach the same computational depth.\n", + "\n", + "In this section, we look at a very simple recurrent network with only 2650 trainable parameters." + ] + }, + { + "cell_type": "markdown", + "id": "3d613edd", + "metadata": { + "execution": {} + }, + "source": [ + "Here is a diagram of this network:\n", + "\n", + "![Recurrent convolutional neural network](https://github.com/neuromatch/NeuroAI_Course/blob/main/tutorials/W1D3_ComparingArtificialAndBiologicalNetworks/static/rcnn_tutorial.png?raw=true)" + ] + }, + { + "cell_type": "code", + "execution_count": null, + "id": "2f0443d3", "metadata": { "cellView": "form", "execution": {} }, "outputs": [], "source": [ - "# @title Video 1: Dynamical Similarity Analysis\n", + "# @title Grab a recurrent model\n", "\n", - "from ipywidgets import widgets\n", - "from IPython.display import YouTubeVideo\n", - "from IPython.display import IFrame\n", - "from IPython.display import display\n", + "args = build_args()\n", + "train_loader, test_loader = fetch_dataloaders(args)\n", + "path = \"recurrent_model.pth\"\n", + "model_recurrent = torch.load(path, map_location=args.device, weights_only=False)" + ] + }, + { + "cell_type": "markdown", + "id": "d463c3a9", + "metadata": { + "execution": {} + }, + "source": [ + "
We can first look at the computational steps in this network. As we see below, the `conv2` operation is repeated for 5 times." + ] + }, + { + "cell_type": "code", + "execution_count": null, + "id": "6bfabacd", + "metadata": { + "execution": {} + }, + "outputs": [], + "source": [ + "train_nodes, _ = get_graph_node_names(model_recurrent)\n", + "print('The computational steps in the network are: \\n', train_nodes)" + ] + }, + { + "cell_type": "markdown", + "id": "1d410c3a", + "metadata": { + "execution": {} + }, + "source": [ + "Plotting the RDMs after each application of the `conv2` operation shows the same progressive emergence of the blockwise structure around the diagonal, mediating the correct classification in this task." + ] + }, + { + "cell_type": "code", + "execution_count": null, + "id": "30249608", + "metadata": { + "execution": {} + }, + "outputs": [], + "source": [ + "imgs, labels = sample_images(test_loader, n=20)\n", + "return_layers = ['conv2', 'conv2_1', 'conv2_2', 'conv2_3', 'conv2_4']\n", + "model_features = extract_features(model_recurrent, imgs.to(device), return_layers)\n", "\n", - "class PlayVideo(IFrame):\n", - " def __init__(self, id, source, page=1, width=400, height=300, **kwargs):\n", - " self.id = id\n", - " if source == 'Bilibili':\n", - " src = f'https://player.bilibili.com/player.html?bvid={id}&page={page}'\n", - " elif source == 'Osf':\n", - " src = f'https://mfr.ca-1.osf.io/render?url=https://osf.io/download/{id}/?direct%26mode=render'\n", - " super(PlayVideo, self).__init__(src, width, height, **kwargs)\n", + "rdms, rdms_dict = calc_rdms(model_features)\n", + "plot_rdms(rdms_dict)" + ] + }, + { + "cell_type": "markdown", + "id": "248329c3", + "metadata": { + "execution": {} + }, + "source": [ + "We can also look at how the different dimensionality reduction techniques capture the dynamics of changing geometry." + ] + }, + { + "cell_type": "code", + "execution_count": null, + "id": "1b0e2cdf", + "metadata": { + "execution": {} + }, + "outputs": [], + "source": [ + "return_layers = ['conv2', 'conv2_1', 'conv2_2', 'conv2_3', 'conv2_4']\n", "\n", - "def display_videos(video_ids, W=400, H=300, fs=1):\n", - " tab_contents = []\n", - " for i, video_id in enumerate(video_ids):\n", - " out = widgets.Output()\n", - " with out:\n", - " if video_ids[i][0] == 'Youtube':\n", - " video = YouTubeVideo(id=video_ids[i][1], width=W,\n", - " height=H, fs=fs, rel=0)\n", - " print(f'Video available at https://youtube.com/watch?v={video.id}')\n", - " else:\n", - " video = PlayVideo(id=video_ids[i][1], source=video_ids[i][0], width=W,\n", - " height=H, fs=fs, autoplay=False)\n", - " if video_ids[i][0] == 'Bilibili':\n", - " print(f'Video available at https://www.bilibili.com/video/{video.id}')\n", - " elif video_ids[i][0] == 'Osf':\n", - " print(f'Video available at https://osf.io/{video.id}')\n", - " display(video)\n", - " tab_contents.append(out)\n", - " return tab_contents\n", + "imgs, labels = sample_images(test_loader, n=50) #grab 500 samples from the test set\n", + "model_features = extract_features(model_recurrent, imgs.to(device), return_layers)\n", "\n", - "video_ids = [('Youtube', 'FHikIsQFQvM'), ('Bilibili', 'BV1qm421g7hV')]\n", - "tab_contents = display_videos(video_ids, W=854, H=480)\n", - "tabs = widgets.Tab()\n", - "tabs.children = tab_contents\n", - "for i in range(len(tab_contents)):\n", - " tabs.set_title(i, video_ids[i][0])\n", - "display(tabs)" + "plot_dim_reduction(model_features, labels, transformer_funcs =['PCA', 'MDS', 't-SNE'])" + ] + }, + { + "cell_type": "markdown", + "id": "1aaf5f4a", + "metadata": { + "execution": {} + }, + "source": [ + "## Representational geometry paths for recurrent models\n", + "\n", + "We can look at the model's recurrent computational steps as a path in the representational geometry space." ] }, { "cell_type": "code", "execution_count": null, - "id": "d2ce83bc-7e86-44d3-a40a-4ad46fd5a6df", + "id": "7f88274a", + "metadata": { + "execution": {} + }, + "outputs": [], + "source": [ + "imgs, labels = sample_images(test_loader, n=50) #grab 500 samples from the test set\n", + "model_features_recurrent = extract_features(model_recurrent, imgs.to(device), return_layers='all')\n", + "\n", + "#rdms, rdms_dict = calc_rdms(model_features)\n", + "features = {'recurrent model': model_features_recurrent}\n", + "model_colors = {'recurrent model': 'y'}\n", + "\n", + "rep_path(features, model_colors, labels)" + ] + }, + { + "cell_type": "markdown", + "id": "5c3fbd44", + "metadata": { + "execution": {} + }, + "source": [ + "We can also look at the paths taken by the feedforward and the recurrent models and compare them." + ] + }, + { + "cell_type": "markdown", + "id": "b25a8cc6", + "metadata": { + "execution": {} + }, + "source": [ + "If you recall back to Tutorial 2, we compared a standard feedward model's representations. We can extend our analysis of the analysis of the recurrent model's representations by making a side-by-side comparison. We can also look at the paths taken by the feedforward and the recurrent models and compare them. What we see above in the case of the recurrent model is the fast-shifting path through the geometric space from inputs to labels. This illustration serves to show that models take many different paths and can have very diverse underlying mechanisms but still arrive at a superficially similar output at the end of training." + ] + }, + { + "cell_type": "code", + "execution_count": null, + "id": "c904e840", "metadata": { "cellView": "form", "execution": {} @@ -160,7 +2688,19 @@ "outputs": [], "source": [ "# @title Submit your feedback\n", - "content_review(f\"{feedback_prefix}_DSA_video\")" + "content_review(f\"{feedback_prefix}_recurrent_models\")" + ] + }, + { + "cell_type": "markdown", + "id": "3ed56061", + "metadata": { + "execution": {} + }, + "source": [ + "# The Big Picture\n", + "\n", + "Today, you've looked at what it means to measure representations from different systems. These systems can be of the same type (multiple brain systems, multiple artificial models) as well as with representations between these systems. In NeuroAI, we're especially interested in such comparisons, comparing representational systems in deep learning networks, for instance, to brain recordings recorded while those biological systems experienced / perceived the same set of stimuli. Comparisons can be geometric / spatial or they can be temporal. Today, we looked at Dynamic Similarity Analysis, a method used to be able to capture the dependencies among trajectories, not just capturing the similarity of the full temporal sequence upon completion of the temporal sequence. It's often important to take into account multiple dimensions of representational similarity. A combination of tools is definitely required in the NeuroAI researcher's toolkit. We hope you have many chances to use these tools in your future work as NeuroAI researchers." ] } ], @@ -191,7 +2731,7 @@ "name": "python", "nbconvert_exporter": "python", "pygments_lexer": "ipython3", - "version": "3.9.19" + "version": "3.9.22" } }, "nbformat": 4,