|
1628 | 1628 | "\n", |
1629 | 1629 | " def forward(self, x):\n", |
1630 | 1630 | " b, num_tokens, d_in = x.shape # New batch dimension b\n", |
| 1631 | + " # For inputs where `num_tokens` exceeds `context_length`, this will result in errors\n", |
| 1632 | + " # in the mask creation further below.\n", |
| 1633 | + " # In practice, this is not a problem since the LLM (chapters 4-7) ensures that inputs \n", |
| 1634 | + " # do not exceed `context_length` before reaching this forward method. \n", |
1631 | 1635 | " keys = self.W_key(x)\n", |
1632 | 1636 | " queries = self.W_query(x)\n", |
1633 | 1637 | " values = self.W_value(x)\n", |
|
1837 | 1841 | "\n", |
1838 | 1842 | " def forward(self, x):\n", |
1839 | 1843 | " b, num_tokens, d_in = x.shape\n", |
| 1844 | + " # As in `CausalAttention`, for inputs where `num_tokens` exceeds `context_length`, \n", |
| 1845 | + " # this will result in errors in the mask creation further below. \n", |
| 1846 | + " # In practice, this is not a problem since the LLM (chapters 4-7) ensures that inputs \n", |
| 1847 | + " # do not exceed `context_length` before reaching this forwar\n", |
1840 | 1848 | "\n", |
1841 | 1849 | " keys = self.W_key(x) # Shape: (b, num_tokens, d_out)\n", |
1842 | 1850 | " queries = self.W_query(x)\n", |
|
2029 | 2037 | "name": "python", |
2030 | 2038 | "nbconvert_exporter": "python", |
2031 | 2039 | "pygments_lexer": "ipython3", |
2032 | | - "version": "3.11.4" |
| 2040 | + "version": "3.10.16" |
2033 | 2041 | } |
2034 | 2042 | }, |
2035 | 2043 | "nbformat": 4, |
|
0 commit comments