Skip to content

Commit 69569ce

Browse files
committed
basic truncated normal dist
1 parent 9f3d9aa commit 69569ce

File tree

1 file changed

+34
-0
lines changed

1 file changed

+34
-0
lines changed

paramnormal/dist.py

Lines changed: 34 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -904,6 +904,39 @@ def fit(cls, data, **guesses):
904904
return cls.param_template(R=b*sigma, loc=loc, sigma=sigma)
905905

906906

907+
class truncated_normal(BaseDist_Mixin):
908+
dist = stats.truncnorm
909+
param_template = namedtuple('params', ['lower', 'upper', 'mu', 'sigma'])
910+
name = 'truncated normal'
911+
912+
@staticmethod
913+
@utils.greco_deco
914+
def _process_args(lower=None, upper=None, mu=None, sigma=None, fit=False):
915+
a = None
916+
b = None
917+
if lower is not None and mu is not None and sigma is not None:
918+
a = (lower - mu) / sigma
919+
920+
if upper is not None and mu is not None and sigma is not None:
921+
b = (upper - mu) / sigma
922+
923+
loc_key, scale_key = utils._get_loc_scale_keys(fit=fit)
924+
if fit:
925+
akey = 'f0'
926+
bkey = 'f1'
927+
else:
928+
akey = 'a'
929+
bkey = 'b'
930+
return {akey: a, bkey: b, loc_key: mu, scale_key: sigma}
931+
932+
@classmethod
933+
def fit(cls, data, **guesses):
934+
a, b, mu, sigma = cls._fit(data, **guesses)
935+
lower = a * sigma + mu
936+
upper = b * sigma + mu
937+
return cls.param_template(lower=lower, upper=upper, mu=mu, sigma=sigma)
938+
939+
907940
__all__ = [
908941
'normal',
909942
'lognormal',
@@ -915,4 +948,5 @@ def fit(cls, data, **guesses):
915948
'pareto',
916949
'exponential',
917950
'rice',
951+
'truncated_normal',
918952
]

0 commit comments

Comments
 (0)