7
7
@app .cell
8
8
def _ ():
9
9
import marimo as mo
10
+
10
11
return (mo ,)
11
12
12
13
@@ -63,7 +64,7 @@ def ridge_regression_1d(
63
64
w , b = ridge_regression_1d (x , y , alpha = hparams ["alpha" ].value )
64
65
y_prediction = w * x_test + b
65
66
66
- mse = np .mean ((y_prediction - y_test )** 2 )
67
+ mse = np .mean ((y_prediction - y_test ) ** 2 )
67
68
return NDArray , b , mse , y_prediction
68
69
69
70
@@ -72,7 +73,11 @@ def _(hparams):
72
73
import numpy as np
73
74
74
75
x = np .random .rand (hparams ["N" ].value )
75
- y = - 2.0 * x ** 2 + 1.0 + hparams ["epsilon" ].value * np .random .randn (hparams ["N" ].value )
76
+ y = (
77
+ - 2.0 * x ** 2
78
+ + 1.0
79
+ + hparams ["epsilon" ].value * np .random .randn (hparams ["N" ].value )
80
+ )
76
81
77
82
x_test = np .linspace (0 , 1 , 1000 )
78
83
y_test = - 2.0 * x_test ** 2 + 1.0 + hparams ["epsilon" ].value * np .random .randn (1000 )
@@ -105,12 +110,20 @@ def _(mo, mse, x, x_test, y, y_prediction, y_test):
105
110
fig = go .Figure ()
106
111
fig .add_trace (
107
112
go .Scatter (
108
- x = x , y = y , mode = "markers" , name = "Training data" , marker = dict (color = "blue" , opacity = 1.0 )
113
+ x = x ,
114
+ y = y ,
115
+ mode = "markers" ,
116
+ name = "Training data" ,
117
+ marker = dict (color = "blue" , opacity = 1.0 ),
109
118
)
110
119
)
111
120
fig .add_trace (
112
121
go .Scatter (
113
- x = x_test , y = y_test , mode = "markers" , name = "Test data" , marker = dict (color = "green" , opacity = 0.1 )
122
+ x = x_test ,
123
+ y = y_test ,
124
+ mode = "markers" ,
125
+ name = "Test data" ,
126
+ marker = dict (color = "green" , opacity = 0.1 ),
114
127
)
115
128
)
116
129
fig .add_trace (
@@ -122,9 +135,7 @@ def _(mo, mse, x, x_test, y, y_prediction, y_test):
122
135
line = dict (color = "red" ),
123
136
)
124
137
)
125
- fig .update_layout (
126
- title = f"Test error: { mse :2.4f} "
127
- )
138
+ fig .update_layout (title = f"Test error: { mse :2.4f} " )
128
139
plot = mo .ui .plotly (fig )
129
140
plot
130
141
return (go ,)
@@ -146,7 +157,9 @@ def _(mo):
146
157
147
158
@app .cell (hide_code = True )
148
159
def _ (mo ):
149
- slider_lambda = mo .ui .slider (start = 0 , stop = 20 , step = 1.0 , label = r"Ridge strength $\alpha$" , show_value = True )
160
+ slider_lambda = mo .ui .slider (
161
+ start = 0 , stop = 20 , step = 1.0 , label = r"Ridge strength $\alpha$" , show_value = True
162
+ )
150
163
slider_lambda
151
164
return (slider_lambda ,)
152
165
@@ -156,32 +169,40 @@ def _(b, go, mo, np, slider_lambda, x, y):
156
169
w_grid = np .linspace (- 10.0 , 10.0 , 100 )
157
170
158
171
def loss_and_reg_fn (param , lmbda ):
159
- return np .sum ((y - param * x - b )** 2 ), lmbda * param ** 2
172
+ return np .sum ((y - param * x - b ) ** 2 ), lmbda * param ** 2
160
173
161
174
values = np .array ([loss_and_reg_fn (w_ , slider_lambda .value ) for w_ in w_grid ])
162
175
163
176
fig_w = go .Figure ()
164
177
fig_w .add_trace (
165
178
go .Scatter (
166
- x = w_grid , y = np .sum (values , axis = 1 ), mode = "lines" , name = r"$L(w) + R(w)$" , line = dict (color = "blue" )
179
+ x = w_grid ,
180
+ y = np .sum (values , axis = 1 ),
181
+ mode = "lines" ,
182
+ name = r"$L(w) + R(w)$" ,
183
+ line = dict (color = "blue" ),
167
184
)
168
185
)
169
186
fig_w .add_trace (
170
187
go .Scatter (
171
- x = w_grid , y = values [:, 0 ], mode = "lines" , name = r"$L(w)$" , line = dict (color = "black" )
188
+ x = w_grid ,
189
+ y = values [:, 0 ],
190
+ mode = "lines" ,
191
+ name = r"$L(w)$" ,
192
+ line = dict (color = "black" ),
172
193
)
173
194
)
174
195
fig_w .add_trace (
175
196
go .Scatter (
176
- x = w_grid , y = values [:, 1 ], mode = "lines" , name = r"$R(w)$" , line = dict (color = "red" )
197
+ x = w_grid ,
198
+ y = values [:, 1 ],
199
+ mode = "lines" ,
200
+ name = r"$R(w)$" ,
201
+ line = dict (color = "red" ),
177
202
)
178
203
)
179
204
fig_w .update_layout (
180
- xaxis = dict (
181
- title = dict (
182
- text = r"$w$"
183
- )
184
- ),
205
+ xaxis = dict (title = dict (text = r"$w$" )),
185
206
title = f"Ridge strength: { slider_lambda .value } " ,
186
207
width = 800 ,
187
208
height = 500 ,
0 commit comments