Skip to content

Commit ad0cc0c

Browse files
BenHewinsBenHewins
authored andcommitted
Added truncated normal tests
Added some tests for TruncatedNormal, mainly for the unbounded case Corrected mathematical and precision errors in TruncatedNormal
1 parent a75d884 commit ad0cc0c

File tree

4 files changed

+258
-32
lines changed

4 files changed

+258
-32
lines changed

src/Numerics/Distributions/TruncatedNormal.cs

Lines changed: 57 additions & 32 deletions
Original file line numberDiff line numberDiff line change
@@ -40,15 +40,31 @@ namespace MathNet.Numerics.Distributions {
4040
/// For more details about this distribution, see
4141
/// <a href="https://en.wikipedia.org/wiki/Truncated_normal_distribution">Wikipedia - Truncated normal distribution</a>
4242
/// </summary>
43-
public class TruncatedNormal : IContinuousDistribution {
43+
public class TruncatedNormal : IContinuousDistribution
44+
{
4445

4546
System.Random _random;
4647

47-
readonly double _mean;
48-
readonly double _stdDev;
48+
/// <summary>
49+
/// Mean of the untruncated normal distribution.
50+
/// </summary>
51+
readonly double _mu;
52+
/// <summary>
53+
/// Standard deviation of the uncorrected normal distribution.
54+
/// </summary>
55+
readonly double _sigma;
4956
readonly double _lowerBound;
5057
readonly double _upperBound;
51-
readonly Normal _uncorrectedNormal;
58+
readonly Normal _standardNormal = new Normal(0.0, 1.0);
59+
/// <summary>
60+
/// Position in the standard normal distribution of the lower bound.
61+
/// </summary>
62+
readonly double _alpha;
63+
/// <summary>
64+
/// Position in the standard normal distribution of the upper bound.
65+
/// </summary>
66+
readonly double _beta;
67+
5268
/// <summary>
5369
/// The total density of the uncorrected normal distribution which is within the lower and upper bounds.
5470
/// Referred to as "Z" in the wikipedia equations. Z = Φ(UpperBound) - Φ(LowerBound).
@@ -61,7 +77,7 @@ public class TruncatedNormal : IContinuousDistribution {
6177
/// normal distribution.
6278
/// </summary>
6379
/// <param name="mean">The mean (μ) of the untruncated distribution.</param>
64-
/// <param name="stddev">The standard deviation (σ) of the untruncated distribution. Range: σ 0.</param>
80+
/// <param name="stddev">The standard deviation (σ) of the untruncated distribution. Range: σ > 0.</param>
6581
/// <param name="lowerBound">The inclusive lower bound of the truncated distribution. Default is double.NegativeInfinity.</param>
6682
/// <param name="upperBound">The inclusive upper bound of the truncated distribution. Must be larger than <paramref name="lowerBound"/>.
6783
/// Default is double.PositiveInfinity.</param>
@@ -76,7 +92,7 @@ public TruncatedNormal(double mean, double stddev, double lowerBound = double.Ne
7692
/// be initialized with the default <seealso cref="System.Random"/> random number generator.
7793
/// </summary>
7894
/// <param name="mean">The mean (μ) of the normal distribution.</param>
79-
/// <param name="stddev">The standard deviation (σ) of the normal distribution. Range: σ 0.</param>
95+
/// <param name="stddev">The standard deviation (σ) of the normal distribution. Range: σ > 0.</param>
8096
/// <param name="randomSource">The random number generator which is used to draw random samples.</param>
8197
public TruncatedNormal(double mean, double stddev, System.Random randomSource, double lowerBound = double.NegativeInfinity, double upperBound = double.PositiveInfinity)
8298
{
@@ -86,28 +102,30 @@ public TruncatedNormal(double mean, double stddev, System.Random randomSource, d
86102
}
87103

88104
_random = randomSource ?? SystemRandomSource.Default;
89-
_mean = mean;
90-
_stdDev = stddev;
105+
_mu = mean;
106+
_sigma = stddev;
91107
_lowerBound = lowerBound;
92108
_upperBound = upperBound;
93-
_uncorrectedNormal = Normal.WithMeanStdDev(_mean, _stdDev);
94-
_cumulativeDensityWithinBounds = _uncorrectedNormal.CumulativeDistribution(_upperBound) - _uncorrectedNormal.CumulativeDistribution(_lowerBound);
109+
_alpha = (_lowerBound - _mu) / _sigma;
110+
_beta = (_upperBound - _mu) / _sigma;
111+
112+
_cumulativeDensityWithinBounds = _standardNormal.CumulativeDistribution(_beta) - _standardNormal.CumulativeDistribution(_alpha);
95113
}
96114

97115
/// <summary>
98116
/// Tests whether the provided values are valid parameters for this distribution.
99117
/// </summary>
100118
/// <param name="mean">The mean (μ) of the normal distribution.</param>
101-
/// <param name="stddev">The standard deviation (σ) of the normal distribution. Range: σ 0.</param>
119+
/// <param name="stddev">The standard deviation (σ) of the normal distribution. Range: σ > 0.</param>
102120
public static bool IsValidParameterSet(double mean, double stddev, double lowerBound, double upperBound)
103121
{
104-
bool normalRequirements = Normal.IsValidParameterSet(mean, stddev);
122+
bool normalRequirements = Normal.IsValidParameterSet(mean, stddev) && stddev > 0;
105123
bool boundsAreOrdered = lowerBound < upperBound;
106124
return normalRequirements && boundsAreOrdered;
107125
}
108126

109127
public override string ToString() {
110-
return "TruncatedNormal(μ = " + _mean + ", σ = " + _stdDev +", LowerBound = " + _lowerBound + ", UpperBound = " + _upperBound + ")";
128+
return "TruncatedNormal(μ = " + _mu + ", σ = " + _sigma +", LowerBound = " + _lowerBound + ", UpperBound = " + _upperBound + ")";
111129
}
112130

113131
/// <summary>
@@ -117,11 +135,11 @@ public double Mode
117135
{
118136
get
119137
{
120-
if (_mean < _lowerBound)
138+
if (_mu < _lowerBound)
121139
return _lowerBound;
122-
if (_mean > _upperBound)
140+
if (_mu > _upperBound)
123141
return _upperBound;
124-
return _mean;
142+
return _mu;
125143
}
126144
}
127145

@@ -148,9 +166,9 @@ public double Mean
148166
{
149167
get
150168
{
151-
var pdfDifference = _uncorrectedNormal.Density(_lowerBound) - _uncorrectedNormal.Density(_upperBound);
152-
var diffFromUncorrected = pdfDifference * _stdDev / _cumulativeDensityWithinBounds;
153-
return _mean + diffFromUncorrected;
169+
var pdfDifference = _standardNormal.Density(_alpha) - _standardNormal.Density(_beta);
170+
var diffFromUncorrected = pdfDifference * _sigma / _cumulativeDensityWithinBounds;
171+
return _mu + diffFromUncorrected;
154172
}
155173
}
156174

@@ -161,24 +179,31 @@ public double Variance
161179
{
162180
get
163181
{
182+
//Apparently "Barr and Sherrill (1999)" has a simpler expression for one sided truncations, if anyone has access...
183+
164184
//TODO might need special handling for cases where either or both bounds are infinity
185+
var densityAtLower = double.IsNegativeInfinity(_lowerBound) ? 0.0 : _standardNormal.Density(_alpha);
186+
var densityAtUpper = double.IsPositiveInfinity(_upperBound) ? 0.0 : _standardNormal.Density(_beta);
187+
188+
var standardisedLower = double.IsNegativeInfinity(_lowerBound) ? 0.0 : _alpha;
189+
var standardisedUpper = double.IsPositiveInfinity(_upperBound) ? 0.0 : _beta;
165190

166191
//Second term
167-
var secondNumerator = _lowerBound * _uncorrectedNormal.Density(_lowerBound) - _upperBound * _uncorrectedNormal.Density(_upperBound);
192+
var secondNumerator = standardisedLower * densityAtLower - standardisedUpper * densityAtUpper;
168193
var secordTerm = secondNumerator / _cumulativeDensityWithinBounds;
169194

170195
//Third term
171-
var thirdNumerator = _uncorrectedNormal.Density(_lowerBound) - _uncorrectedNormal.Density(_upperBound);
196+
var thirdNumerator = densityAtLower - densityAtUpper;
172197
var thirdTerm = (thirdNumerator / _cumulativeDensityWithinBounds) * (thirdNumerator / _cumulativeDensityWithinBounds);
173198

174199
var sumOfTerms = 1 + secordTerm + thirdTerm;
175200

176-
return _stdDev * _stdDev * sumOfTerms;
201+
return _sigma * _sigma * sumOfTerms;
177202
}
178203
}
179204

180205
/// <summary>
181-
/// Gets the standard deviation (σ) of the truncated normal distribution. Range: σ 0.
206+
/// Gets the standard deviation (σ) of the truncated normal distribution. Range: σ > 0.
182207
/// </summary>
183208
public double StdDev
184209
{
@@ -192,9 +217,9 @@ public double Entropy
192217
{
193218
get
194219
{
195-
var firstTerm = Constants.LogSqrt2PiE + Math.Log(_stdDev + _cumulativeDensityWithinBounds);
220+
var firstTerm = Constants.LogSqrt2PiE + Math.Log(_sigma + _cumulativeDensityWithinBounds);
196221

197-
var secondNumerator = _lowerBound * _uncorrectedNormal.Density(_lowerBound) - _upperBound * _uncorrectedNormal.Density(_upperBound);
222+
var secondNumerator = _lowerBound * _standardNormal.Density(_alpha) - _upperBound * _standardNormal.Density(_beta);
198223
var secondTerm = secondNumerator / (2 * _cumulativeDensityWithinBounds);
199224

200225
return firstTerm + secondTerm;
@@ -240,7 +265,7 @@ public double Density(double x)
240265
if (x < _lowerBound || _upperBound < x)
241266
return 0d;
242267

243-
return _uncorrectedNormal.Density(x) / (_stdDev * _cumulativeDensityWithinBounds);
268+
return _standardNormal.Density((x - _mu) / _sigma) / (_sigma * _cumulativeDensityWithinBounds);
244269
}
245270

246271
/// <summary>
@@ -251,11 +276,11 @@ public double Density(double x)
251276
/// <seealso cref="PDFLn"/>
252277
public double DensityLn(double x)
253278
{
254-
return Math.Log(Density(x));
255-
}
279+
return _standardNormal.DensityLn((x - _mu) / _sigma) - Math.Log(_sigma) - Math.Log(_cumulativeDensityWithinBounds);
280+
}
256281

257282
//TODO: implement sampling, use method described by Mazet here: http://miv.u-strasbg.fr/mazet/rtnorm/
258-
// see implmentations listed on that page for examples.
283+
// see implementations listed on that page for examples.
259284

260285
public double Sample()
261286
{
@@ -285,7 +310,7 @@ public double CumulativeDistribution(double x)
285310
if (x > _upperBound)
286311
return 1d;
287312

288-
double cumulative = _uncorrectedNormal.CumulativeDistribution(x) - _uncorrectedNormal.CumulativeDistribution(_lowerBound);
313+
double cumulative = _standardNormal.CumulativeDistribution((x - _mu) / _sigma) - _standardNormal.CumulativeDistribution(_alpha);
289314
return cumulative / _cumulativeDensityWithinBounds;
290315
}
291316

@@ -299,9 +324,9 @@ public double CumulativeDistribution(double x)
299324
public double InverseCumulativeDistribution(double p)
300325
{
301326
//TODO check that this is correct with someone.
302-
var pUntruncated = p * _cumulativeDensityWithinBounds + _uncorrectedNormal.CumulativeDistribution(_lowerBound);
327+
var pUntruncated = p * _cumulativeDensityWithinBounds + _standardNormal.CumulativeDistribution(_alpha);
303328

304-
return _uncorrectedNormal.InverseCumulativeDistribution(pUntruncated);
329+
return _standardNormal.InverseCumulativeDistribution(pUntruncated) * _sigma + _mu;
305330
}
306331

307332
}

src/UnitTests/DistributionTests/CommonDistributionTests.cs

Lines changed: 2 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -89,6 +89,8 @@ public class CommonDistributionTests
8989
new StudentT(0.0, 1.0, 5.0),
9090
new Triangular(0, 1, 0.7),
9191
new Weibull(1.0, 1.0),
92+
new TruncatedNormal(0, 1.0, -5.0, 5.0), //Finite
93+
new TruncatedNormal(0, 1.0, -5.0), //Semi-finite
9294
};
9395

9496
[Test]

0 commit comments

Comments
 (0)