1212rng = np .random .default_rng (0 )
1313
1414
15- colors = [' xkcd:sky' , ' xkcd:grass' ]
15+ colors = [" xkcd:sky" , " xkcd:grass" ]
1616cmap = ListedColormap (colors )
1717
18+
1819def create_discrete_colormap (n_classes ):
1920 if n_classes == 2 :
2021 return cmap .copy ()
21- return ListedColormap ([f' C{ i } ' for i in range (n_classes )])
22+ return ListedColormap ([f" C{ i } " for i in range (n_classes )])
2223
2324
2425def set_plot_style ():
@@ -34,11 +35,11 @@ def set_plot_style():
3435
3536def twospirals (n_samples , noise = 0.5 , rng = rng ):
3637 """
37- Returns the two spirals dataset.
38+ Returns the two spirals dataset.
3839 """
3940 n = np .sqrt (rng .uniform (size = (n_samples , 1 ))) * 360 * (2 * np .pi ) / 360
40- d1x = - np .cos (n ) * n + rng .uniform ((n_samples , 1 )) * noise
41- d1y = np .sin (n ) * n + rng .uniform ((n_samples , 1 )) * noise
41+ d1x = - np .cos (n ) * n + rng .uniform (size = (n_samples , 1 )) * noise
42+ d1y = np .sin (n ) * n + rng .uniform (size = (n_samples , 1 )) * noise
4243 return (
4344 np .vstack ((np .hstack ((d1x , d1y )), np .hstack ((- d1x , - d1y )))),
4445 np .hstack ((np .zeros (n_samples ), np .ones (n_samples ))),
@@ -67,7 +68,9 @@ def draw_linear_regression_function(reg, ax=None, **kwargs):
6768def plot_3d_views (X , y , cmap = cmap ):
6869 from mpl_toolkits .mplot3d import Axes3D # noqa
6970
70- fig , axs = plt .subplots (2 , 2 , subplot_kw = {'projection' : '3d' }, constrained_layout = False )
71+ fig , axs = plt .subplots (
72+ 2 , 2 , subplot_kw = {"projection" : "3d" }, constrained_layout = False
73+ )
7174
7275 for ax in axs .ravel ():
7376 ax .scatter (X [:, 0 ], X [:, 1 ], X [:, 2 ], c = y , cmap = cmap , lw = 0 )
@@ -83,6 +86,7 @@ def plot_3d_views(X, y, cmap=cmap):
8386 axs [1 , 1 ].view_init (90 , 0 )
8487 fig .subplots_adjust (wspace = 0.005 , hspace = 0.005 )
8588
89+
8690def draw_tree (clf ):
8791 import pydotplus
8892
@@ -176,7 +180,7 @@ def plot_bars_and_confusion(
176180 axes = None ,
177181 vmin = None ,
178182 vmax = None ,
179- cmap = ' inferno' ,
183+ cmap = " inferno" ,
180184 title = None ,
181185 bar_color = None ,
182186):
@@ -189,7 +193,7 @@ def plot_bars_and_confusion(
189193 if not isinstance (prediction , pd .Series ):
190194 prediction = pd .Series (prediction )
191195
192- correct = pd .Series (np .where (truth .values == prediction .values , ' Correct' , ' Wrong' ))
196+ correct = pd .Series (np .where (truth .values == prediction .values , " Correct" , " Wrong" ))
193197
194198 truth .sort_index (inplace = True )
195199 prediction .sort_index (inplace = True )
0 commit comments