Skip to content

Commit 6dbfd3d

Browse files
committed
Small fixes
1 parent af3ce37 commit 6dbfd3d

File tree

2 files changed

+40
-21
lines changed

2 files changed

+40
-21
lines changed

apps/optimizers.py

Lines changed: 2 additions & 4 deletions
Original file line numberDiff line numberDiff line change
@@ -42,11 +42,10 @@ def _(mo):
4242

4343
@app.cell(hide_code=True)
4444
def _(f_iterations, mo, np, x_grid, x_iterations, y_grid):
45+
import matplotlib
46+
import matplotlib.colors as mcolors
4547
import plotly.graph_objects as go
4648
from plotly.subplots import make_subplots
47-
import matplotlib.cm as cm
48-
import matplotlib.colors as mcolors
49-
import matplotlib
5049
norm = mcolors.Normalize(vmin=0, vmax=len(x_iterations) - 1)
5150
cmap = matplotlib.colormaps["jet"]
5251
colors = [mcolors.to_hex(cmap(norm(i))) for i in range(len(x_iterations))]
@@ -130,7 +129,6 @@ def grad_fn3(x: NDArray[np.float64]) -> NDArray[np.float64]:
130129

131130
@app.cell(hide_code=True)
132131
def _(NDArray, dropdown_dict, hparams, np):
133-
from typing import Callable
134132

135133
objective_fn = dropdown_dict.value[0]
136134
grad_fn = dropdown_dict.value[1]

apps/ridge_regression_tutorial.py

Lines changed: 38 additions & 17 deletions
Original file line numberDiff line numberDiff line change
@@ -7,6 +7,7 @@
77
@app.cell
88
def _():
99
import marimo as mo
10+
1011
return (mo,)
1112

1213

@@ -63,7 +64,7 @@ def ridge_regression_1d(
6364
w, b = ridge_regression_1d(x, y, alpha=hparams["alpha"].value)
6465
y_prediction = w * x_test + b
6566

66-
mse = np.mean((y_prediction - y_test)**2)
67+
mse = np.mean((y_prediction - y_test) ** 2)
6768
return NDArray, b, mse, y_prediction
6869

6970

@@ -72,7 +73,11 @@ def _(hparams):
7273
import numpy as np
7374

7475
x = np.random.rand(hparams["N"].value)
75-
y = - 2.0 * x**2 + 1.0 + hparams["epsilon"].value * np.random.randn(hparams["N"].value)
76+
y = (
77+
-2.0 * x**2
78+
+ 1.0
79+
+ hparams["epsilon"].value * np.random.randn(hparams["N"].value)
80+
)
7681

7782
x_test = np.linspace(0, 1, 1000)
7883
y_test = -2.0 * x_test**2 + 1.0 + hparams["epsilon"].value * np.random.randn(1000)
@@ -105,12 +110,20 @@ def _(mo, mse, x, x_test, y, y_prediction, y_test):
105110
fig = go.Figure()
106111
fig.add_trace(
107112
go.Scatter(
108-
x=x, y=y, mode="markers", name="Training data", marker=dict(color="blue", opacity=1.0)
113+
x=x,
114+
y=y,
115+
mode="markers",
116+
name="Training data",
117+
marker=dict(color="blue", opacity=1.0),
109118
)
110119
)
111120
fig.add_trace(
112121
go.Scatter(
113-
x=x_test, y=y_test, mode="markers", name="Test data", marker=dict(color="green", opacity=0.1)
122+
x=x_test,
123+
y=y_test,
124+
mode="markers",
125+
name="Test data",
126+
marker=dict(color="green", opacity=0.1),
114127
)
115128
)
116129
fig.add_trace(
@@ -122,9 +135,7 @@ def _(mo, mse, x, x_test, y, y_prediction, y_test):
122135
line=dict(color="red"),
123136
)
124137
)
125-
fig.update_layout(
126-
title=f"Test error: {mse:2.4f}"
127-
)
138+
fig.update_layout(title=f"Test error: {mse:2.4f}")
128139
plot = mo.ui.plotly(fig)
129140
plot
130141
return (go,)
@@ -146,7 +157,9 @@ def _(mo):
146157

147158
@app.cell(hide_code=True)
148159
def _(mo):
149-
slider_lambda = mo.ui.slider(start=0, stop=20, step=1.0, label=r"Ridge strength $\alpha$", show_value=True)
160+
slider_lambda = mo.ui.slider(
161+
start=0, stop=20, step=1.0, label=r"Ridge strength $\alpha$", show_value=True
162+
)
150163
slider_lambda
151164
return (slider_lambda,)
152165

@@ -156,32 +169,40 @@ def _(b, go, mo, np, slider_lambda, x, y):
156169
w_grid = np.linspace(-10.0, 10.0, 100)
157170

158171
def loss_and_reg_fn(param, lmbda):
159-
return np.sum((y - param*x - b)**2), lmbda*param**2
172+
return np.sum((y - param * x - b) ** 2), lmbda * param**2
160173

161174
values = np.array([loss_and_reg_fn(w_, slider_lambda.value) for w_ in w_grid])
162175

163176
fig_w = go.Figure()
164177
fig_w.add_trace(
165178
go.Scatter(
166-
x=w_grid, y=np.sum(values, axis=1), mode="lines", name=r"$L(w) + R(w)$", line=dict(color="blue")
179+
x=w_grid,
180+
y=np.sum(values, axis=1),
181+
mode="lines",
182+
name=r"$L(w) + R(w)$",
183+
line=dict(color="blue"),
167184
)
168185
)
169186
fig_w.add_trace(
170187
go.Scatter(
171-
x=w_grid, y=values[:, 0], mode="lines", name=r"$L(w)$", line=dict(color="black")
188+
x=w_grid,
189+
y=values[:, 0],
190+
mode="lines",
191+
name=r"$L(w)$",
192+
line=dict(color="black"),
172193
)
173194
)
174195
fig_w.add_trace(
175196
go.Scatter(
176-
x=w_grid, y=values[:, 1], mode="lines", name=r"$R(w)$", line=dict(color="red")
197+
x=w_grid,
198+
y=values[:, 1],
199+
mode="lines",
200+
name=r"$R(w)$",
201+
line=dict(color="red"),
177202
)
178203
)
179204
fig_w.update_layout(
180-
xaxis=dict(
181-
title=dict(
182-
text=r"$w$"
183-
)
184-
),
205+
xaxis=dict(title=dict(text=r"$w$")),
185206
title=f"Ridge strength: {slider_lambda.value}",
186207
width=800,
187208
height=500,

0 commit comments

Comments
 (0)