15
15
import numpy as np
16
16
17
17
from pandas import DataFrame , Series
18
+ from scipy .special import expit
18
19
19
20
from pymc3 .distributions .distribution import NoDistribution
20
21
from pymc3 .distributions .tree import LeafNode , SplitNode , Tree
@@ -30,7 +31,6 @@ def __init__(
30
31
m = 200 ,
31
32
alpha = 0.25 ,
32
33
split_prior = None ,
33
- scale = None ,
34
34
inv_link = None ,
35
35
jitter = False ,
36
36
* args ,
@@ -63,22 +63,32 @@ def __init__(
63
63
)
64
64
self .m = m
65
65
self .alpha = alpha
66
- self .y_std = Y .std ()
67
-
68
- if scale is None :
69
- self .leaf_scale = NormalSampler (sigma = None )
70
- elif isinstance (scale , (int , float )):
71
- self .leaf_scale = NormalSampler (sigma = Y .std () / self .m ** scale )
72
66
73
67
if inv_link is None :
74
- self .inv_link = lambda x : x
68
+ self .inv_link = self .link = lambda x : x
69
+ elif isinstance (inv_link , str ):
70
+ # The link function is just a rough approximation in order to allow the PGBART sampler
71
+ # to propose reasonable values for the leaf nodes.
72
+ if inv_link == "logistic" :
73
+ self .inv_link = expit
74
+ self .link = lambda x : (x - 0.5 ) * 10
75
+ elif inv_link == "exp" :
76
+ self .inv_link = np .exp
77
+ self .link = np .log
78
+ self .Y [self .Y == 0 ] += 0.0001
79
+ else :
80
+ raise ValueError ("Accepted strings are 'logistic' or 'exp'" )
75
81
else :
76
- self .inv_link = inv_link
82
+ self .inv_link , self .link = inv_link
83
+
84
+ self .init_mean = self .link (self .Y .mean ())
85
+ self .Y_un = self .link (self .Y )
77
86
78
87
self .num_observations = X .shape [0 ]
79
88
self .num_variates = X .shape [1 ]
80
89
self .available_predictors = list (range (self .num_variates ))
81
90
self .ssv = SampleSplittingVariable (split_prior , self .num_variates )
91
+ self .initial_value_leaf_nodes = self .init_mean / self .m
82
92
self .trees = self .init_list_of_trees ()
83
93
self .all_trees = []
84
94
self .mean = fast_mean ()
@@ -96,7 +106,7 @@ def preprocess_XY(self, X, Y):
96
106
return X , Y , missing_data
97
107
98
108
def init_list_of_trees (self ):
99
- initial_value_leaf_nodes = self .Y . mean () / self . m
109
+ initial_value_leaf_nodes = self .initial_value_leaf_nodes
100
110
initial_idx_data_points_leaf_nodes = np .array (range (self .num_observations ), dtype = "int32" )
101
111
list_of_trees = []
102
112
for i in range (self .m ):
@@ -110,7 +120,7 @@ def init_list_of_trees(self):
110
120
# bartMachine: A Powerful Tool for Machine Learning in R. ArXiv e-prints, 2013
111
121
# The sum_trees_output will contain the sum of the predicted output for all trees.
112
122
# When R_j is needed we subtract the current predicted output for tree T_j.
113
- self .sum_trees_output = np .full_like (self .Y , self .Y . mean () )
123
+ self .sum_trees_output = np .full_like (self .Y , self .init_mean )
114
124
115
125
return list_of_trees
116
126
@@ -181,14 +191,13 @@ def get_new_idx_data_points(self, current_split_node, idx_data_points):
181
191
182
192
def get_residuals (self ):
183
193
"""Compute the residuals."""
184
- R_j = self .Y - self .inv_link (self .sum_trees_output )
185
-
194
+ R_j = self .Y_un - self .sum_trees_output
186
195
return R_j
187
196
188
197
def draw_leaf_value (self , idx_data_points ):
189
198
"""Draw the residual mean."""
190
199
R_j = self .get_residuals ()[idx_data_points ]
191
- draw = self .mean (R_j ) + self . leaf_scale . random ()
200
+ draw = self .mean (R_j )
192
201
return draw
193
202
194
203
def predict (self , X_new ):
@@ -278,24 +287,6 @@ def rvs(self):
278
287
return i
279
288
280
289
281
- class NormalSampler :
282
- def __init__ (self , sigma ):
283
- self .size = 5000
284
- self .cache = []
285
- self .sigma = sigma
286
-
287
- def random (self ):
288
- if self .sigma is None :
289
- return 0
290
- else :
291
- if not self .cache :
292
- self .update ()
293
- return self .cache .pop ()
294
-
295
- def update (self ):
296
- self .cache = np .random .normal (loc = 0.0 , scale = self .sigma , size = self .size ).tolist ()
297
-
298
-
299
290
class BART (BaseBART ):
300
291
"""
301
292
BART distribution.
@@ -317,23 +308,17 @@ class BART(BaseBART):
317
308
Each element of split_prior should be in the [0, 1] interval and the elements should sum
318
309
to 1. Otherwise they will be normalized.
319
310
Defaults to None, all variable have the same a prior probability
320
- scale : float
321
- Controls the variance of the proposed leaf value. The leaf values are computed as a
322
- Gaussian with mean equal to the conditional residual mean and variance proportional to
323
- the variance of the response variable, and inversely proportional to the number of trees
324
- and the scale parameter. Defaults to None, i.e the variance is 0.
325
- inv_link : numpy function
326
- Inverse link function defaults to None, i.e. the identity function.
311
+ inv_link : str or tuple of functions
312
+ Inverse link function defaults to None, i.e. the identity function. Accepted strings are
313
+ ``logistic`` or ``exp``.
327
314
jitter : bool
328
315
Whether to jitter the X values or not. Defaults to False. When values of X are repeated,
329
316
jittering X has the effect of increasing the number of effective spliting variables,
330
317
otherwise it does not have any effect.
331
318
"""
332
319
333
- def __init__ (
334
- self , X , Y , m = 200 , alpha = 0.25 , split_prior = None , scale = None , inv_link = None , jitter = False
335
- ):
336
- super ().__init__ (X , Y , m , alpha , split_prior , scale , inv_link )
320
+ def __init__ (self , X , Y , m = 200 , alpha = 0.25 , split_prior = None , inv_link = None , jitter = False ):
321
+ super ().__init__ (X , Y , m , alpha , split_prior , inv_link )
337
322
338
323
def _str_repr (self , name = None , dist = None , formatting = "plain" ):
339
324
if dist is None :
0 commit comments