1414
1515import logging
1616
17- from copy import copy
17+ from copy import copy , deepcopy
1818
1919import aesara
2020import numpy as np
@@ -39,12 +39,12 @@ class PGBART(ArrayStepShared):
3939 vars: list
4040 List of value variables for sampler
4141 num_particles : int
42- Number of particles for the conditional SMC sampler . Defaults to 40
42+ Number of particles. Defaults to 40
4343 max_stages : int
44- Maximum number of iterations of the conditional SMC sampler . Defaults to 100.
44+ Maximum number of iterations. Defaults to 100.
4545 batch : int or tuple
4646 Number of trees fitted per step. Defaults to "auto", which is the 10% of the `m` trees
47- during tuning and after tuning. If a tuple is passed the first element is the batch size
47+ during and after tuning. If a tuple is passed the first element is the batch size
4848 during tuning and the second the batch size after tuning.
4949 model: PyMC Model
5050 Optional model for sampling step. Defaults to None (taken from context).
@@ -81,10 +81,10 @@ def __init__(self, vars=None, num_particles=40, max_stages=100, batch="auto", mo
8181 # if data is binary
8282 Y_unique = np .unique (self .Y )
8383 if Y_unique .size == 2 and np .all (Y_unique == [0 , 1 ]):
84- self .mu_std = 6 / (self .k * self .m ** 0.5 )
84+ self .mu_std = 3 / (self .k * self .m ** 0.5 )
8585 # maybe we need to check for count data
8686 else :
87- self .mu_std = ( 2 * self .Y .std () ) / (self .k * self .m ** 0.5 )
87+ self .mu_std = self .Y .std () / (self .k * self .m ** 0.5 )
8888
8989 self .num_observations = self .X .shape [0 ]
9090 self .num_variates = self .X .shape [1 ]
@@ -229,8 +229,7 @@ def init_particles(self, tree_id: int) -> np.ndarray:
229229 Initialize particles
230230 """
231231 p = self .all_particles [tree_id ]
232- particles = [p ]
233- particles .append (copy (p ))
232+ particles = [p , p .copy ()]
234233
235234 for _ in self .indices :
236235 particles .append (ParticleTree (self .a_tree ))
@@ -275,6 +274,9 @@ def __init__(self, tree):
275274 self .old_likelihood_logp = 0
276275 self .used_variates = []
277276
277+ def copy (self ):
278+ return deepcopy (self )
279+
278280 def sample_tree (
279281 self ,
280282 ssv ,
0 commit comments