Skip to content

bpo-37798: Add C fastpath for statistics.NormalDist.inv_cdf() #15266

New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Merged
merged 1 commit into from
Aug 23, 2019
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension


Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
155 changes: 82 additions & 73 deletions Lib/statistics.py
Original file line number Diff line number Diff line change
Expand Up @@ -824,6 +824,81 @@ def pstdev(data, mu=None):

## Normal Distribution #####################################################


def _normal_dist_inv_cdf(p, mu, sigma):
# There is no closed-form solution to the inverse CDF for the normal
# distribution, so we use a rational approximation instead:
# Wichura, M.J. (1988). "Algorithm AS241: The Percentage Points of the
# Normal Distribution". Applied Statistics. Blackwell Publishing. 37
# (3): 477–484. doi:10.2307/2347330. JSTOR 2347330.
q = p - 0.5
if fabs(q) <= 0.425:
r = 0.180625 - q * q
# Hash sum: 55.88319_28806_14901_4439
num = (((((((2.50908_09287_30122_6727e+3 * r +
3.34305_75583_58812_8105e+4) * r +
6.72657_70927_00870_0853e+4) * r +
4.59219_53931_54987_1457e+4) * r +
1.37316_93765_50946_1125e+4) * r +
1.97159_09503_06551_4427e+3) * r +
1.33141_66789_17843_7745e+2) * r +
3.38713_28727_96366_6080e+0) * q
den = (((((((5.22649_52788_52854_5610e+3 * r +
2.87290_85735_72194_2674e+4) * r +
3.93078_95800_09271_0610e+4) * r +
2.12137_94301_58659_5867e+4) * r +
5.39419_60214_24751_1077e+3) * r +
6.87187_00749_20579_0830e+2) * r +
4.23133_30701_60091_1252e+1) * r +
1.0)
x = num / den
return mu + (x * sigma)
r = p if q <= 0.0 else 1.0 - p
r = sqrt(-log(r))
if r <= 5.0:
r = r - 1.6
# Hash sum: 49.33206_50330_16102_89036
num = (((((((7.74545_01427_83414_07640e-4 * r +
2.27238_44989_26918_45833e-2) * r +
2.41780_72517_74506_11770e-1) * r +
1.27045_82524_52368_38258e+0) * r +
3.64784_83247_63204_60504e+0) * r +
5.76949_72214_60691_40550e+0) * r +
4.63033_78461_56545_29590e+0) * r +
1.42343_71107_49683_57734e+0)
den = (((((((1.05075_00716_44416_84324e-9 * r +
5.47593_80849_95344_94600e-4) * r +
1.51986_66563_61645_71966e-2) * r +
1.48103_97642_74800_74590e-1) * r +
6.89767_33498_51000_04550e-1) * r +
1.67638_48301_83803_84940e+0) * r +
2.05319_16266_37758_82187e+0) * r +
1.0)
else:
r = r - 5.0
# Hash sum: 47.52583_31754_92896_71629
num = (((((((2.01033_43992_92288_13265e-7 * r +
2.71155_55687_43487_57815e-5) * r +
1.24266_09473_88078_43860e-3) * r +
2.65321_89526_57612_30930e-2) * r +
2.96560_57182_85048_91230e-1) * r +
1.78482_65399_17291_33580e+0) * r +
5.46378_49111_64114_36990e+0) * r +
6.65790_46435_01103_77720e+0)
den = (((((((2.04426_31033_89939_78564e-15 * r +
1.42151_17583_16445_88870e-7) * r +
1.84631_83175_10054_68180e-5) * r +
7.86869_13114_56132_59100e-4) * r +
1.48753_61290_85061_48525e-2) * r +
1.36929_88092_27358_05310e-1) * r +
5.99832_20655_58879_37690e-1) * r +
1.0)
x = num / den
if q < 0.0:
x = -x
return mu + (x * sigma)


class NormalDist:
"Normal distribution of a random variable"
# https://en.wikipedia.org/wiki/Normal_distribution
Expand Down Expand Up @@ -882,79 +957,7 @@ def inv_cdf(self, p):
raise StatisticsError('p must be in the range 0.0 < p < 1.0')
if self._sigma <= 0.0:
raise StatisticsError('cdf() not defined when sigma at or below zero')

# There is no closed-form solution to the inverse CDF for the normal
# distribution, so we use a rational approximation instead:
# Wichura, M.J. (1988). "Algorithm AS241: The Percentage Points of the
# Normal Distribution". Applied Statistics. Blackwell Publishing. 37
# (3): 477–484. doi:10.2307/2347330. JSTOR 2347330.

q = p - 0.5
if fabs(q) <= 0.425:
r = 0.180625 - q * q
# Hash sum: 55.88319_28806_14901_4439
num = (((((((2.50908_09287_30122_6727e+3 * r +
3.34305_75583_58812_8105e+4) * r +
6.72657_70927_00870_0853e+4) * r +
4.59219_53931_54987_1457e+4) * r +
1.37316_93765_50946_1125e+4) * r +
1.97159_09503_06551_4427e+3) * r +
1.33141_66789_17843_7745e+2) * r +
3.38713_28727_96366_6080e+0) * q
den = (((((((5.22649_52788_52854_5610e+3 * r +
2.87290_85735_72194_2674e+4) * r +
3.93078_95800_09271_0610e+4) * r +
2.12137_94301_58659_5867e+4) * r +
5.39419_60214_24751_1077e+3) * r +
6.87187_00749_20579_0830e+2) * r +
4.23133_30701_60091_1252e+1) * r +
1.0)
x = num / den
return self._mu + (x * self._sigma)
r = p if q <= 0.0 else 1.0 - p
r = sqrt(-log(r))
if r <= 5.0:
r = r - 1.6
# Hash sum: 49.33206_50330_16102_89036
num = (((((((7.74545_01427_83414_07640e-4 * r +
2.27238_44989_26918_45833e-2) * r +
2.41780_72517_74506_11770e-1) * r +
1.27045_82524_52368_38258e+0) * r +
3.64784_83247_63204_60504e+0) * r +
5.76949_72214_60691_40550e+0) * r +
4.63033_78461_56545_29590e+0) * r +
1.42343_71107_49683_57734e+0)
den = (((((((1.05075_00716_44416_84324e-9 * r +
5.47593_80849_95344_94600e-4) * r +
1.51986_66563_61645_71966e-2) * r +
1.48103_97642_74800_74590e-1) * r +
6.89767_33498_51000_04550e-1) * r +
1.67638_48301_83803_84940e+0) * r +
2.05319_16266_37758_82187e+0) * r +
1.0)
else:
r = r - 5.0
# Hash sum: 47.52583_31754_92896_71629
num = (((((((2.01033_43992_92288_13265e-7 * r +
2.71155_55687_43487_57815e-5) * r +
1.24266_09473_88078_43860e-3) * r +
2.65321_89526_57612_30930e-2) * r +
2.96560_57182_85048_91230e-1) * r +
1.78482_65399_17291_33580e+0) * r +
5.46378_49111_64114_36990e+0) * r +
6.65790_46435_01103_77720e+0)
den = (((((((2.04426_31033_89939_78564e-15 * r +
1.42151_17583_16445_88870e-7) * r +
1.84631_83175_10054_68180e-5) * r +
7.86869_13114_56132_59100e-4) * r +
1.48753_61290_85061_48525e-2) * r +
1.36929_88092_27358_05310e-1) * r +
5.99832_20655_58879_37690e-1) * r +
1.0)
x = num / den
if q < 0.0:
x = -x
return self._mu + (x * self._sigma)
return _normal_dist_inv_cdf(p, self._mu, self._sigma)

def overlap(self, other):
"""Compute the overlapping coefficient (OVL) between two normal distributions.
Expand Down Expand Up @@ -1078,6 +1081,12 @@ def __hash__(self):
def __repr__(self):
return f'{type(self).__name__}(mu={self._mu!r}, sigma={self._sigma!r})'

# If available, use C implementation
try:
from _statistics import _normal_dist_inv_cdf
except ImportError:
pass


if __name__ == '__main__':

Expand Down
Original file line number Diff line number Diff line change
@@ -0,0 +1 @@
Add C fastpath for statistics.NormalDist.inv_cdf() Patch by Dong-hee Na
1 change: 1 addition & 0 deletions Modules/Setup
Original file line number Diff line number Diff line change
Expand Up @@ -182,6 +182,7 @@ _symtable symtablemodule.c
#_heapq _heapqmodule.c # Heap queue algorithm
#_asyncio _asynciomodule.c # Fast asyncio Future
#_json -I$(srcdir)/Include/internal -DPy_BUILD_CORE_BUILTIN _json.c # _json speedups
#_statistics _statisticsmodule.c # statistics accelerator

#unicodedata unicodedata.c # static Unicode character database

Expand Down
122 changes: 122 additions & 0 deletions Modules/_statisticsmodule.c
Original file line number Diff line number Diff line change
@@ -0,0 +1,122 @@
/* statistics accelerator C extensor: _statistics module. */

#include "Python.h"
#include "structmember.h"
#include "clinic/_statisticsmodule.c.h"

/*[clinic input]
module _statistics

[clinic start generated code]*/
/*[clinic end generated code: output=da39a3ee5e6b4b0d input=864a6f59b76123b2]*/


static PyMethodDef speedups_methods[] = {
_STATISTICS__NORMAL_DIST_INV_CDF_METHODDEF
{NULL, NULL, 0, NULL}
};

/*[clinic input]
_statistics._normal_dist_inv_cdf -> double
p: double
mu: double
sigma: double
/
[clinic start generated code]*/

static double
_statistics__normal_dist_inv_cdf_impl(PyObject *module, double p, double mu,
double sigma)
/*[clinic end generated code: output=02fd19ddaab36602 input=24715a74be15296a]*/
{
double q, num, den, r, x;
q = p - 0.5;
// Algorithm AS 241: The Percentage Points of the Normal Distribution
if(fabs(q) <= 0.425) {
r = 0.180625 - q * q;
// Hash sum AB: 55.88319 28806 14901 4439
num = (((((((2.5090809287301226727e+3 * r +
3.3430575583588128105e+4) * r +
6.7265770927008700853e+4) * r +
4.5921953931549871457e+4) * r +
1.3731693765509461125e+4) * r +
1.9715909503065514427e+3) * r +
1.3314166789178437745e+2) * r +
3.3871328727963666080e+0) * q;
den = (((((((5.2264952788528545610e+3 * r +
2.8729085735721942674e+4) * r +
3.9307895800092710610e+4) * r +
2.1213794301586595867e+4) * r +
5.3941960214247511077e+3) * r +
6.8718700749205790830e+2) * r +
4.2313330701600911252e+1) * r +
1.0);
x = num / den;
return mu + (x * sigma);
}
r = q <= 0.0? p : 1.0-p;
r = sqrt(-log(r));
if (r <= 5.0) {
r = r - 1.6;
// Hash sum CD: 49.33206 50330 16102 89036
num = (((((((7.74545014278341407640e-4 * r +
2.27238449892691845833e-2) * r +
2.41780725177450611770e-1) * r +
1.27045825245236838258e+0) * r +
3.64784832476320460504e+0) * r +
5.76949722146069140550e+0) * r +
4.63033784615654529590e+0) * r +
1.42343711074968357734e+0);
den = (((((((1.05075007164441684324e-9 * r +
5.47593808499534494600e-4) * r +
1.51986665636164571966e-2) * r +
1.48103976427480074590e-1) * r +
6.89767334985100004550e-1) * r +
1.67638483018380384940e+0) * r +
2.05319162663775882187e+0) * r +
1.0);
} else {
r -= 5.0;
// Hash sum EF: 47.52583 31754 92896 71629
num = (((((((2.01033439929228813265e-7 * r +
2.71155556874348757815e-5) * r +
1.24266094738807843860e-3) * r +
2.65321895265761230930e-2) * r +
2.96560571828504891230e-1) * r +
1.78482653991729133580e+0) * r +
5.46378491116411436990e+0) * r +
6.65790464350110377720e+0);
den = (((((((2.04426310338993978564e-15 * r +
1.42151175831644588870e-7) * r +
1.84631831751005468180e-5) * r +
7.86869131145613259100e-4) * r +
1.48753612908506148525e-2) * r +
1.36929880922735805310e-1) * r +
5.99832206555887937690e-1) * r +
1.0);
}
x = num / den;
if (q < 0.0) x = -x;
return mu + (x * sigma);
}

static struct PyModuleDef statisticsmodule = {
PyModuleDef_HEAD_INIT,
"_statistics",
_statistics__normal_dist_inv_cdf__doc__,
-1,
speedups_methods,
NULL,
NULL,
NULL,
NULL
};


PyMODINIT_FUNC
PyInit__statistics(void)
{
PyObject *m = PyModule_Create(&statisticsmodule);
if (!m) return NULL;
return m;
}
50 changes: 50 additions & 0 deletions Modules/clinic/_statisticsmodule.c.h

Some generated files are not rendered by default. Learn more about how customized files appear on GitHub.

2 changes: 2 additions & 0 deletions PC/config.c
Original file line number Diff line number Diff line change
Expand Up @@ -23,6 +23,7 @@ extern PyObject* PyInit__sha1(void);
extern PyObject* PyInit__sha256(void);
extern PyObject* PyInit__sha512(void);
extern PyObject* PyInit__sha3(void);
extern PyObject* PyInit__statistics(void);
extern PyObject* PyInit__blake2(void);
extern PyObject* PyInit_time(void);
extern PyObject* PyInit__thread(void);
Expand Down Expand Up @@ -103,6 +104,7 @@ struct _inittab _PyImport_Inittab[] = {
{"_blake2", PyInit__blake2},
{"time", PyInit_time},
{"_thread", PyInit__thread},
{"_statistics", PyInit__statistics},
#ifdef WIN32
{"msvcrt", PyInit_msvcrt},
{"_locale", PyInit__locale},
Expand Down
1 change: 1 addition & 0 deletions PCbuild/pythoncore.vcxproj
Original file line number Diff line number Diff line change
Expand Up @@ -333,6 +333,7 @@
<ClCompile Include="..\Modules\sha256module.c" />
<ClCompile Include="..\Modules\sha512module.c" />
<ClCompile Include="..\Modules\signalmodule.c" />
<ClCompile Include="..\Modules\_statisticsmodule.c" />
<ClCompile Include="..\Modules\symtablemodule.c" />
<ClCompile Include="..\Modules\_threadmodule.c" />
<ClCompile Include="..\Modules\_tracemalloc.c" />
Expand Down
3 changes: 3 additions & 0 deletions PCbuild/pythoncore.vcxproj.filters
Original file line number Diff line number Diff line change
Expand Up @@ -611,6 +611,9 @@
<ClCompile Include="..\Modules\_sre.c">
<Filter>Modules</Filter>
</ClCompile>
<ClCompile Include="..\Modules\_statisticsmodule.c">
<Filter>Modules</Filter>
</ClCompile>
<ClCompile Include="..\Modules\_struct.c">
<Filter>Modules</Filter>
</ClCompile>
Expand Down
Loading