Skip to content

Commit f3fdfe7

Browse files
committed
Some updates for NN notebooks
1 parent 38bec9a commit f3fdfe7

File tree

4 files changed

+293
-341
lines changed

4 files changed

+293
-341
lines changed

environment.yml

Lines changed: 2 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -30,8 +30,8 @@ dependencies:
3030
- pytables
3131
- pytest
3232
- python-graphviz
33-
- pytorch=2.1
34-
- pytorch-cpu=2.1
33+
- pytorch=2.6
34+
- pytorch-cpu=2.6
3535
- ruff
3636
- scikit-image
3737
- scikit-learn=1.5

ml/plots.py

Lines changed: 12 additions & 8 deletions
Original file line numberDiff line numberDiff line change
@@ -12,13 +12,14 @@
1212
rng = np.random.default_rng(0)
1313

1414

15-
colors = ['xkcd:sky', 'xkcd:grass']
15+
colors = ["xkcd:sky", "xkcd:grass"]
1616
cmap = ListedColormap(colors)
1717

18+
1819
def create_discrete_colormap(n_classes):
1920
if n_classes == 2:
2021
return cmap.copy()
21-
return ListedColormap([f'C{i}' for i in range(n_classes)])
22+
return ListedColormap([f"C{i}" for i in range(n_classes)])
2223

2324

2425
def set_plot_style():
@@ -34,11 +35,11 @@ def set_plot_style():
3435

3536
def twospirals(n_samples, noise=0.5, rng=rng):
3637
"""
37-
Returns the two spirals dataset.
38+
Returns the two spirals dataset.
3839
"""
3940
n = np.sqrt(rng.uniform(size=(n_samples, 1))) * 360 * (2 * np.pi) / 360
40-
d1x = -np.cos(n) * n + rng.uniform((n_samples, 1)) * noise
41-
d1y = np.sin(n) * n + rng.uniform((n_samples, 1)) * noise
41+
d1x = -np.cos(n) * n + rng.uniform(size=(n_samples, 1)) * noise
42+
d1y = np.sin(n) * n + rng.uniform(size=(n_samples, 1)) * noise
4243
return (
4344
np.vstack((np.hstack((d1x, d1y)), np.hstack((-d1x, -d1y)))),
4445
np.hstack((np.zeros(n_samples), np.ones(n_samples))),
@@ -67,7 +68,9 @@ def draw_linear_regression_function(reg, ax=None, **kwargs):
6768
def plot_3d_views(X, y, cmap=cmap):
6869
from mpl_toolkits.mplot3d import Axes3D # noqa
6970

70-
fig, axs = plt.subplots(2, 2, subplot_kw={'projection': '3d'}, constrained_layout=False)
71+
fig, axs = plt.subplots(
72+
2, 2, subplot_kw={"projection": "3d"}, constrained_layout=False
73+
)
7174

7275
for ax in axs.ravel():
7376
ax.scatter(X[:, 0], X[:, 1], X[:, 2], c=y, cmap=cmap, lw=0)
@@ -83,6 +86,7 @@ def plot_3d_views(X, y, cmap=cmap):
8386
axs[1, 1].view_init(90, 0)
8487
fig.subplots_adjust(wspace=0.005, hspace=0.005)
8588

89+
8690
def draw_tree(clf):
8791
import pydotplus
8892

@@ -176,7 +180,7 @@ def plot_bars_and_confusion(
176180
axes=None,
177181
vmin=None,
178182
vmax=None,
179-
cmap='inferno',
183+
cmap="inferno",
180184
title=None,
181185
bar_color=None,
182186
):
@@ -189,7 +193,7 @@ def plot_bars_and_confusion(
189193
if not isinstance(prediction, pd.Series):
190194
prediction = pd.Series(prediction)
191195

192-
correct = pd.Series(np.where(truth.values == prediction.values, 'Correct', 'Wrong'))
196+
correct = pd.Series(np.where(truth.values == prediction.values, "Correct", "Wrong"))
193197

194198
truth.sort_index(inplace=True)
195199
prediction.sort_index(inplace=True)

0 commit comments

Comments
 (0)