@@ -904,6 +904,39 @@ def fit(cls, data, **guesses):
904
904
return cls .param_template (R = b * sigma , loc = loc , sigma = sigma )
905
905
906
906
907
+ class truncated_normal (BaseDist_Mixin ):
908
+ dist = stats .truncnorm
909
+ param_template = namedtuple ('params' , ['lower' , 'upper' , 'mu' , 'sigma' ])
910
+ name = 'truncated normal'
911
+
912
+ @staticmethod
913
+ @utils .greco_deco
914
+ def _process_args (lower = None , upper = None , mu = None , sigma = None , fit = False ):
915
+ a = None
916
+ b = None
917
+ if lower is not None and mu is not None and sigma is not None :
918
+ a = (lower - mu ) / sigma
919
+
920
+ if upper is not None and mu is not None and sigma is not None :
921
+ b = (upper - mu ) / sigma
922
+
923
+ loc_key , scale_key = utils ._get_loc_scale_keys (fit = fit )
924
+ if fit :
925
+ akey = 'f0'
926
+ bkey = 'f1'
927
+ else :
928
+ akey = 'a'
929
+ bkey = 'b'
930
+ return {akey : a , bkey : b , loc_key : mu , scale_key : sigma }
931
+
932
+ @classmethod
933
+ def fit (cls , data , ** guesses ):
934
+ a , b , mu , sigma = cls ._fit (data , ** guesses )
935
+ lower = a * sigma + mu
936
+ upper = b * sigma + mu
937
+ return cls .param_template (lower = lower , upper = upper , mu = mu , sigma = sigma )
938
+
939
+
907
940
__all__ = [
908
941
'normal' ,
909
942
'lognormal' ,
@@ -915,4 +948,5 @@ def fit(cls, data, **guesses):
915
948
'pareto' ,
916
949
'exponential' ,
917
950
'rice' ,
951
+ 'truncated_normal' ,
918
952
]
0 commit comments