Skip to content

Commit 3f0ec0c

Browse files
authoredJul 27, 2024··
[ENH] Truncated Normal distribution (#421)
Interfaces truncated normal distribution from scipy.
1 parent 12fb8a4 commit 3f0ec0c

File tree

3 files changed

+107
-1
lines changed

3 files changed

+107
-1
lines changed
 

‎docs/source/api_reference/distributions.rst

+1-1
Original file line numberDiff line numberDiff line change
@@ -39,6 +39,7 @@ Continuous support - full reals
3939
Logistic
4040
Normal
4141
TDistribution
42+
TruncatedNormal
4243

4344

4445
Continuous support - non-negative reals
@@ -77,7 +78,6 @@ Integer support
7778
Binomial
7879
Poisson
7980

80-
8181
Non-parametric and empirical distributions
8282
------------------------------------------
8383

‎skpro/distributions/__init__.py

+2
Original file line numberDiff line numberDiff line change
@@ -34,6 +34,7 @@
3434
"QPD_U",
3535
"QPD_Johnson",
3636
"TDistribution",
37+
"TruncatedNormal",
3738
"Uniform",
3839
"Weibull",
3940
]
@@ -65,5 +66,6 @@
6566
from skpro.distributions.qpd import QPD_B, QPD_S, QPD_U, QPD_Johnson
6667
from skpro.distributions.qpd_empirical import QPD_Empirical
6768
from skpro.distributions.t import TDistribution
69+
from skpro.distributions.truncated_normal import TruncatedNormal
6870
from skpro.distributions.uniform import Uniform
6971
from skpro.distributions.weibull import Weibull
+104
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,104 @@
1+
# copyright: skpro developers, BSD-3-Clause License (see LICENSE file)
2+
"""Truncated Normal probability distribution."""
3+
4+
__author__ = ["ShreeshaM07"]
5+
6+
import pandas as pd
7+
from scipy.stats import rv_continuous, truncnorm
8+
9+
from skpro.distributions.adapters.scipy import _ScipyAdapter
10+
11+
12+
class TruncatedNormal(_ScipyAdapter):
13+
"""A truncated normal probability distribution.
14+
15+
Most methods wrap ``scipy.stats.truncnorm``.
16+
It truncates the normal distribution at
17+
the abscissa ``l_trunc`` and ``r_trunc``.
18+
19+
Note: The truncation parameters passed
20+
is internally shifted to be centred at
21+
mean and scaled by sigma.
22+
23+
Parameters
24+
----------
25+
mu : float or array of float (1D or 2D)
26+
mean of the normal distribution
27+
sigma : float or array of float (1D or 2D), must be positive
28+
standard deviation of the normal distribution
29+
l_trunc : float or array of float (1D or 2D)
30+
Left truncation abscissa.
31+
r_trunc : float or array of float (1D or 2D)
32+
Right truncation abscissa.
33+
index : pd.Index, optional, default = RangeIndex
34+
columns : pd.Index, optional, default = RangeIndex
35+
36+
Example
37+
-------
38+
>>> from skpro.distributions.truncated_normal import TruncatedNormal
39+
40+
>>> d = TruncatedNormal(\
41+
mu=[[0, 1], [2, 3], [4, 5]],\
42+
sigma= 1,\
43+
l_trunc= [[-0.1,0.5],[1.5,2.4],[4.1,5]],\
44+
r_trunc= [[0.8,2],[4,5],[5,7]]\
45+
)
46+
"""
47+
48+
_tags = {
49+
"capabilities:approx": ["energy", "pdfnorm"],
50+
"capabilities:exact": ["mean", "var", "pdf", "log_pdf", "cdf", "ppf"],
51+
"distr:measuretype": "continuous",
52+
"distr:paramtype": "parametric",
53+
"broadcast_init": "on",
54+
}
55+
56+
def __init__(self, mu, sigma, l_trunc, r_trunc, index=None, columns=None):
57+
self.mu = mu
58+
self.sigma = sigma
59+
self.l_trunc = l_trunc
60+
self.r_trunc = r_trunc
61+
62+
super().__init__(index=index, columns=columns)
63+
64+
def _get_scipy_object(self) -> rv_continuous:
65+
return truncnorm
66+
67+
def _get_scipy_param(self):
68+
mu = self._bc_params["mu"]
69+
sigma = self._bc_params["sigma"]
70+
l_trunc = self._bc_params["l_trunc"]
71+
r_trunc = self._bc_params["r_trunc"]
72+
73+
# shift it to be centred at mu and sigma
74+
a = (l_trunc - mu) / sigma
75+
b = (r_trunc - mu) / sigma
76+
77+
return [], {
78+
"loc": mu,
79+
"scale": sigma,
80+
"a": a,
81+
"b": b,
82+
}
83+
84+
@classmethod
85+
def get_test_params(cls, parameter_set="default"):
86+
"""Return testing parameter settings for the estimator."""
87+
# array case examples
88+
params1 = {
89+
"mu": [[0, 1], [2, 3], [4, 5]],
90+
"sigma": 1,
91+
"l_trunc": [[-0.1, 0.5], [1.5, 2.4], [4.1, 5]],
92+
"r_trunc": [[0.8, 2], [4, 5], [5, 7]],
93+
}
94+
params2 = {
95+
"mu": 0,
96+
"sigma": 1,
97+
"l_trunc": [-10, -5],
98+
"r_trunc": [5, 10],
99+
"index": pd.Index([1, 2, 5]),
100+
"columns": pd.Index(["a", "b"]),
101+
}
102+
# scalar case examples
103+
params3 = {"mu": 1, "sigma": 2, "l_trunc": -3, "r_trunc": 5}
104+
return [params1, params2, params3]

0 commit comments

Comments
 (0)
Please sign in to comment.