11import numpy as np
22from scipy .stats import kde , mode
3- from numpy .linalg import LinAlgError
43import matplotlib .pyplot as plt
54import pymc3 as pm
65from .stats import quantiles , hpd
@@ -122,26 +121,18 @@ def histplot_op(ax, data, alpha=.35):
122121
123122
124123def kdeplot_op (ax , data , prior = None , prior_alpha = 1 , prior_style = '--' ):
125- errored = []
126124 for i in range (data .shape [1 ]):
127125 d = data [:, i ]
128- try :
129- density , l , u = fast_kde (d )
130- x = np .linspace (l , u , len (density ))
131-
132- if prior is not None :
133- p = prior .logp (x ).eval ()
134- ax .plot (x , np .exp (p ), alpha = prior_alpha , ls = prior_style )
126+ density , l , u = fast_kde (d )
127+ x = np .linspace (l , u , len (density ))
135128
136- ax .plot (x , density )
129+ if prior is not None :
130+ p = prior .logp (x ).eval ()
131+ ax .plot (x , np .exp (p ), alpha = prior_alpha , ls = prior_style )
137132
138- except LinAlgError :
139- errored .append (i )
133+ ax .plot (x , density )
140134
141135 ax .set_ylim (ymin = 0 )
142- if errored :
143- ax .text (.27 , .47 , 'WARNING: KDE plot failed for: ' + str (errored ), style = 'italic' ,
144- bbox = {'facecolor' : 'red' , 'alpha' : 0.5 , 'pad' : 10 })
145136
146137
147138def make_2d (a ):
@@ -793,6 +784,7 @@ def get_trace_dict(tr, varnames):
793784
794785 fig .tight_layout ()
795786 return ax
787+
796788
797789def fast_kde (x ):
798790 """
@@ -813,14 +805,16 @@ def fast_kde(x):
813805 xmax: maximum value of x
814806
815807 """
808+ # add small jitter in case input values are the same
809+ x = np .random .normal (x , 1e-12 )
816810
817811 xmin , xmax = x .min (), x .max ()
818812
819813 n = len (x )
820814 nx = 256
821815
822816 # compute histogram
823- bins = np .linspace (x . min (), x . max () , nx )
817+ bins = np .linspace (xmin , xmax , nx )
824818 xyi = np .digitize (x , bins )
825819 dx = (xmax - xmin ) / (nx - 1 )
826820 grid = np .histogram (x , bins = nx )[0 ]
0 commit comments