Skip to content

Commit 8cb87fe

Browse files
authored
Robust generalized bart inference (#4709)
* robust generalized bart inference * update docstring * update test * clarify role of link function * revert few changes * raise a ValueError if inv_link string is not valid * update release notes
1 parent fc31e00 commit 8cb87fe

File tree

4 files changed

+61
-74
lines changed

4 files changed

+61
-74
lines changed

RELEASE-NOTES.md

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -7,7 +7,7 @@
77
+ Fix bug in the computation of the log pseudolikelihood values (SMC-ABC). (see [#4672](https://github.com/pymc-devs/pymc3/pull/4672)).
88

99
### New Features
10-
+ BART with non-gaussian likelihoods (see [#4675](https://github.com/pymc-devs/pymc3/pull/4675)).
10+
+ BART with non-gaussian likelihoods (see [#4675](https://github.com/pymc-devs/pymc3/pull/4675) and [#4709](https://github.com/pymc-devs/pymc3/pull/4709)).
1111

1212
## PyMC3 3.11.2 (14 March 2021)
1313

pymc3/distributions/bart.py

Lines changed: 28 additions & 43 deletions
Original file line numberDiff line numberDiff line change
@@ -15,6 +15,7 @@
1515
import numpy as np
1616

1717
from pandas import DataFrame, Series
18+
from scipy.special import expit
1819

1920
from pymc3.distributions.distribution import NoDistribution
2021
from pymc3.distributions.tree import LeafNode, SplitNode, Tree
@@ -30,7 +31,6 @@ def __init__(
3031
m=200,
3132
alpha=0.25,
3233
split_prior=None,
33-
scale=None,
3434
inv_link=None,
3535
jitter=False,
3636
*args,
@@ -63,22 +63,32 @@ def __init__(
6363
)
6464
self.m = m
6565
self.alpha = alpha
66-
self.y_std = Y.std()
67-
68-
if scale is None:
69-
self.leaf_scale = NormalSampler(sigma=None)
70-
elif isinstance(scale, (int, float)):
71-
self.leaf_scale = NormalSampler(sigma=Y.std() / self.m ** scale)
7266

7367
if inv_link is None:
74-
self.inv_link = lambda x: x
68+
self.inv_link = self.link = lambda x: x
69+
elif isinstance(inv_link, str):
70+
# The link function is just a rough approximation in order to allow the PGBART sampler
71+
# to propose reasonable values for the leaf nodes.
72+
if inv_link == "logistic":
73+
self.inv_link = expit
74+
self.link = lambda x: (x - 0.5) * 10
75+
elif inv_link == "exp":
76+
self.inv_link = np.exp
77+
self.link = np.log
78+
self.Y[self.Y == 0] += 0.0001
79+
else:
80+
raise ValueError("Accepted strings are 'logistic' or 'exp'")
7581
else:
76-
self.inv_link = inv_link
82+
self.inv_link, self.link = inv_link
83+
84+
self.init_mean = self.link(self.Y.mean())
85+
self.Y_un = self.link(self.Y)
7786

7887
self.num_observations = X.shape[0]
7988
self.num_variates = X.shape[1]
8089
self.available_predictors = list(range(self.num_variates))
8190
self.ssv = SampleSplittingVariable(split_prior, self.num_variates)
91+
self.initial_value_leaf_nodes = self.init_mean / self.m
8292
self.trees = self.init_list_of_trees()
8393
self.all_trees = []
8494
self.mean = fast_mean()
@@ -96,7 +106,7 @@ def preprocess_XY(self, X, Y):
96106
return X, Y, missing_data
97107

98108
def init_list_of_trees(self):
99-
initial_value_leaf_nodes = self.Y.mean() / self.m
109+
initial_value_leaf_nodes = self.initial_value_leaf_nodes
100110
initial_idx_data_points_leaf_nodes = np.array(range(self.num_observations), dtype="int32")
101111
list_of_trees = []
102112
for i in range(self.m):
@@ -110,7 +120,7 @@ def init_list_of_trees(self):
110120
# bartMachine: A Powerful Tool for Machine Learning in R. ArXiv e-prints, 2013
111121
# The sum_trees_output will contain the sum of the predicted output for all trees.
112122
# When R_j is needed we subtract the current predicted output for tree T_j.
113-
self.sum_trees_output = np.full_like(self.Y, self.Y.mean())
123+
self.sum_trees_output = np.full_like(self.Y, self.init_mean)
114124

115125
return list_of_trees
116126

@@ -181,14 +191,13 @@ def get_new_idx_data_points(self, current_split_node, idx_data_points):
181191

182192
def get_residuals(self):
183193
"""Compute the residuals."""
184-
R_j = self.Y - self.inv_link(self.sum_trees_output)
185-
194+
R_j = self.Y_un - self.sum_trees_output
186195
return R_j
187196

188197
def draw_leaf_value(self, idx_data_points):
189198
"""Draw the residual mean."""
190199
R_j = self.get_residuals()[idx_data_points]
191-
draw = self.mean(R_j) + self.leaf_scale.random()
200+
draw = self.mean(R_j)
192201
return draw
193202

194203
def predict(self, X_new):
@@ -278,24 +287,6 @@ def rvs(self):
278287
return i
279288

280289

281-
class NormalSampler:
282-
def __init__(self, sigma):
283-
self.size = 5000
284-
self.cache = []
285-
self.sigma = sigma
286-
287-
def random(self):
288-
if self.sigma is None:
289-
return 0
290-
else:
291-
if not self.cache:
292-
self.update()
293-
return self.cache.pop()
294-
295-
def update(self):
296-
self.cache = np.random.normal(loc=0.0, scale=self.sigma, size=self.size).tolist()
297-
298-
299290
class BART(BaseBART):
300291
"""
301292
BART distribution.
@@ -317,23 +308,17 @@ class BART(BaseBART):
317308
Each element of split_prior should be in the [0, 1] interval and the elements should sum
318309
to 1. Otherwise they will be normalized.
319310
Defaults to None, all variable have the same a prior probability
320-
scale : float
321-
Controls the variance of the proposed leaf value. The leaf values are computed as a
322-
Gaussian with mean equal to the conditional residual mean and variance proportional to
323-
the variance of the response variable, and inversely proportional to the number of trees
324-
and the scale parameter. Defaults to None, i.e the variance is 0.
325-
inv_link : numpy function
326-
Inverse link function defaults to None, i.e. the identity function.
311+
inv_link : str or tuple of functions
312+
Inverse link function defaults to None, i.e. the identity function. Accepted strings are
313+
``logistic`` or ``exp``.
327314
jitter : bool
328315
Whether to jitter the X values or not. Defaults to False. When values of X are repeated,
329316
jittering X has the effect of increasing the number of effective spliting variables,
330317
otherwise it does not have any effect.
331318
"""
332319

333-
def __init__(
334-
self, X, Y, m=200, alpha=0.25, split_prior=None, scale=None, inv_link=None, jitter=False
335-
):
336-
super().__init__(X, Y, m, alpha, split_prior, scale, inv_link)
320+
def __init__(self, X, Y, m=200, alpha=0.25, split_prior=None, inv_link=None, jitter=False):
321+
super().__init__(X, Y, m, alpha, split_prior, inv_link)
337322

338323
def _str_repr(self, name=None, dist=None, formatting="plain"):
339324
if dist is None:

pymc3/step_methods/pgbart.py

Lines changed: 31 additions & 27 deletions
Original file line numberDiff line numberDiff line change
@@ -75,29 +75,29 @@ def __init__(self, vars=None, num_particles=10, max_stages=5000, chunk="auto", m
7575
self.log_num_particles = np.log(num_particles)
7676
self.indices = list(range(1, num_particles))
7777
self.max_stages = max_stages
78-
self.old_trees_particles_list = []
79-
for i in range(self.bart.m):
80-
p = ParticleTree(self.bart.trees[i], self.bart.prior_prob_leaf_node)
81-
self.old_trees_particles_list.append(p)
8278

8379
shared = make_shared_replacements(vars, model)
8480
self.likelihood_logp = logp([model.datalogpt], vars, shared)
81+
self.init_leaf_nodes = self.bart.initial_value_leaf_nodes
82+
self.init_likelihood = self.likelihood_logp(self.bart.inv_link(self.bart.sum_trees_output))
83+
self.init_log_weight = self.init_likelihood - self.log_num_particles
84+
self.old_trees_particles_list = []
85+
for i in range(self.bart.m):
86+
p = ParticleTree(
87+
self.bart.trees[i],
88+
self.bart.prior_prob_leaf_node,
89+
self.init_log_weight,
90+
self.init_likelihood,
91+
)
92+
self.old_trees_particles_list.append(p)
8593
super().__init__(vars, shared)
8694

8795
def astep(self, _):
8896
bart = self.bart
97+
8998
inv_link = bart.inv_link
90-
num_observations = bart.num_observations
9199
variable_inclusion = np.zeros(bart.num_variates, dtype="int")
92100

93-
# For the tunning phase we restrict max_stages to a low number, otherwise it is almost sure
94-
# we will reach max_stages given that our first set of m trees is not good at all.
95-
# Can set max_stages as a function of the number of variables/dimensions? XXX
96-
if self.tune:
97-
max_stages = 5
98-
else:
99-
max_stages = self.max_stages
100-
101101
if self.idx == bart.m:
102102
self.idx = 0
103103

@@ -110,25 +110,28 @@ def astep(self, _):
110110
bart.sum_trees_output -= old_prediction
111111
# Generate an initial set of SMC particles
112112
# at the end of the algorithm we return one of these particles as the new tree
113-
particles = self.init_particles(tree.tree_id, num_observations, inv_link)
113+
particles = self.init_particles(tree.tree_id)
114114

115-
for t in range(1, max_stages):
115+
for t in range(1, self.max_stages):
116116
# Get old particle at stage t
117117
particles[0] = self.get_old_tree_particle(tree.tree_id, t)
118118
# sample each particle (try to grow each tree)
119119
for c in range(1, self.num_particles):
120120
particles[c].sample_tree_sequential(bart)
121121
# Update weights. Since the prior is used as the proposal,the weights
122122
# are updated additively as the ratio of the new and old log_likelihoods
123-
for p_idx, p in enumerate(particles):
124-
new_likelihood = self.likelihood_logp(inv_link(p.tree.predict_output()))
123+
for p in particles:
124+
new_likelihood = self.likelihood_logp(
125+
inv_link(bart.sum_trees_output + p.tree.predict_output())
126+
)
125127
p.log_weight += new_likelihood - p.old_likelihood_logp
126128
p.old_likelihood_logp = new_likelihood
127129

128130
# Normalize weights
129131
W, normalized_weights = self.normalize(particles)
130132
# Resample all but first particle
131133
re_n_w = normalized_weights[1:] / normalized_weights[1:].sum()
134+
132135
new_indices = np.random.choice(self.indices, size=len(self.indices), p=re_n_w)
133136
particles[1:] = particles[new_indices]
134137

@@ -149,8 +152,7 @@ def astep(self, _):
149152
new_tree = np.random.choice(particles, p=normalized_weights)
150153
self.old_trees_particles_list[tree.tree_id] = new_tree
151154
bart.trees[idx] = new_tree.tree
152-
new_prediction = new_tree.tree.predict_output()
153-
bart.sum_trees_output += new_prediction
155+
bart.sum_trees_output += new_tree.tree.predict_output()
154156

155157
if not self.tune:
156158
self.iter += 1
@@ -194,26 +196,28 @@ def get_old_tree_particle(self, tree_id, t):
194196
old_tree_particle.set_particle_to_step(t)
195197
return old_tree_particle
196198

197-
def init_particles(self, tree_id, num_observations, inv_link):
199+
def init_particles(self, tree_id):
198200
"""
199201
Initialize particles
200202
"""
201-
# The first particle is from the tree we are trying to replace
202203
prev_tree = self.get_old_tree_particle(tree_id, 0)
203-
likelihood = self.likelihood_logp(inv_link(prev_tree.tree.predict_output()))
204+
likelihood = self.likelihood_logp(self.bart.inv_link(prev_tree.tree.predict_output()))
204205
prev_tree.old_likelihood_logp = likelihood
205206
prev_tree.log_weight = likelihood - self.log_num_particles
206207
particles = [prev_tree]
207208

208-
# The rest of the particles are identically initialized
209-
initial_idx_data_points_leaf_nodes = np.arange(num_observations, dtype="int32")
209+
initial_idx_data_points_leaf_nodes = np.arange(self.bart.num_observations, dtype="int32")
210210
new_tree = Tree.init_tree(
211211
tree_id=tree_id,
212-
leaf_node_value=0,
212+
leaf_node_value=self.init_leaf_nodes,
213213
idx_data_points=initial_idx_data_points_leaf_nodes,
214214
)
215-
for i in range(1, self.num_particles):
216-
particles.append(ParticleTree(new_tree, self.bart.prior_prob_leaf_node, 0, 0))
215+
216+
prior_prob = self.bart.prior_prob_leaf_node
217+
for _ in range(1, self.num_particles):
218+
particles.append(
219+
ParticleTree(new_tree, prior_prob, self.init_log_weight, self.init_likelihood)
220+
)
217221

218222
return np.array(particles)
219223

pymc3/tests/test_bart.py

Lines changed: 1 addition & 3 deletions
Original file line numberDiff line numberDiff line change
@@ -1,7 +1,5 @@
11
import numpy as np
22

3-
from scipy.special import expit
4-
53
import pymc3 as pm
64

75

@@ -106,7 +104,7 @@ def test_model():
106104

107105
Y = np.repeat([0, 1], 50)
108106
with pm.Model() as model:
109-
mu = pm.BART("mu", X, Y, m=50, inv_link=expit, scale=0.25)
107+
mu = pm.BART("mu", X, Y, m=50, inv_link="logistic")
110108
y = pm.Bernoulli("y", mu, observed=Y)
111109
trace = pm.sample(1000, random_seed=212480)
112110

0 commit comments

Comments
 (0)