@@ -40,15 +40,31 @@ namespace MathNet.Numerics.Distributions {
40
40
/// For more details about this distribution, see
41
41
/// <a href="https://en.wikipedia.org/wiki/Truncated_normal_distribution">Wikipedia - Truncated normal distribution</a>
42
42
/// </summary>
43
- public class TruncatedNormal : IContinuousDistribution {
43
+ public class TruncatedNormal : IContinuousDistribution
44
+ {
44
45
45
46
System . Random _random ;
46
47
47
- readonly double _mean ;
48
- readonly double _stdDev ;
48
+ /// <summary>
49
+ /// Mean of the untruncated normal distribution.
50
+ /// </summary>
51
+ readonly double _mu ;
52
+ /// <summary>
53
+ /// Standard deviation of the uncorrected normal distribution.
54
+ /// </summary>
55
+ readonly double _sigma ;
49
56
readonly double _lowerBound ;
50
57
readonly double _upperBound ;
51
- readonly Normal _uncorrectedNormal ;
58
+ readonly Normal _standardNormal = new Normal ( 0.0 , 1.0 ) ;
59
+ /// <summary>
60
+ /// Position in the standard normal distribution of the lower bound.
61
+ /// </summary>
62
+ readonly double _alpha ;
63
+ /// <summary>
64
+ /// Position in the standard normal distribution of the upper bound.
65
+ /// </summary>
66
+ readonly double _beta ;
67
+
52
68
/// <summary>
53
69
/// The total density of the uncorrected normal distribution which is within the lower and upper bounds.
54
70
/// Referred to as "Z" in the wikipedia equations. Z = Φ(UpperBound) - Φ(LowerBound).
@@ -61,7 +77,7 @@ public class TruncatedNormal : IContinuousDistribution {
61
77
/// normal distribution.
62
78
/// </summary>
63
79
/// <param name="mean">The mean (μ) of the untruncated distribution.</param>
64
- /// <param name="stddev">The standard deviation (σ) of the untruncated distribution. Range: σ ≥ 0.</param>
80
+ /// <param name="stddev">The standard deviation (σ) of the untruncated distribution. Range: σ > 0.</param>
65
81
/// <param name="lowerBound">The inclusive lower bound of the truncated distribution. Default is double.NegativeInfinity.</param>
66
82
/// <param name="upperBound">The inclusive upper bound of the truncated distribution. Must be larger than <paramref name="lowerBound"/>.
67
83
/// Default is double.PositiveInfinity.</param>
@@ -76,7 +92,7 @@ public TruncatedNormal(double mean, double stddev, double lowerBound = double.Ne
76
92
/// be initialized with the default <seealso cref="System.Random"/> random number generator.
77
93
/// </summary>
78
94
/// <param name="mean">The mean (μ) of the normal distribution.</param>
79
- /// <param name="stddev">The standard deviation (σ) of the normal distribution. Range: σ ≥ 0.</param>
95
+ /// <param name="stddev">The standard deviation (σ) of the normal distribution. Range: σ > 0.</param>
80
96
/// <param name="randomSource">The random number generator which is used to draw random samples.</param>
81
97
public TruncatedNormal ( double mean , double stddev , System . Random randomSource , double lowerBound = double . NegativeInfinity , double upperBound = double . PositiveInfinity )
82
98
{
@@ -86,28 +102,30 @@ public TruncatedNormal(double mean, double stddev, System.Random randomSource, d
86
102
}
87
103
88
104
_random = randomSource ?? SystemRandomSource . Default ;
89
- _mean = mean ;
90
- _stdDev = stddev ;
105
+ _mu = mean ;
106
+ _sigma = stddev ;
91
107
_lowerBound = lowerBound ;
92
108
_upperBound = upperBound ;
93
- _uncorrectedNormal = Normal . WithMeanStdDev ( _mean , _stdDev ) ;
94
- _cumulativeDensityWithinBounds = _uncorrectedNormal . CumulativeDistribution ( _upperBound ) - _uncorrectedNormal . CumulativeDistribution ( _lowerBound ) ;
109
+ _alpha = ( _lowerBound - _mu ) / _sigma ;
110
+ _beta = ( _upperBound - _mu ) / _sigma ;
111
+
112
+ _cumulativeDensityWithinBounds = _standardNormal . CumulativeDistribution ( _beta ) - _standardNormal . CumulativeDistribution ( _alpha ) ;
95
113
}
96
114
97
115
/// <summary>
98
116
/// Tests whether the provided values are valid parameters for this distribution.
99
117
/// </summary>
100
118
/// <param name="mean">The mean (μ) of the normal distribution.</param>
101
- /// <param name="stddev">The standard deviation (σ) of the normal distribution. Range: σ ≥ 0.</param>
119
+ /// <param name="stddev">The standard deviation (σ) of the normal distribution. Range: σ > 0.</param>
102
120
public static bool IsValidParameterSet ( double mean , double stddev , double lowerBound , double upperBound )
103
121
{
104
- bool normalRequirements = Normal . IsValidParameterSet ( mean , stddev ) ;
122
+ bool normalRequirements = Normal . IsValidParameterSet ( mean , stddev ) && stddev > 0 ;
105
123
bool boundsAreOrdered = lowerBound < upperBound ;
106
124
return normalRequirements && boundsAreOrdered ;
107
125
}
108
126
109
127
public override string ToString ( ) {
110
- return "TruncatedNormal(μ = " + _mean + ", σ = " + _stdDev + ", LowerBound = " + _lowerBound + ", UpperBound = " + _upperBound + ")" ;
128
+ return "TruncatedNormal(μ = " + _mu + ", σ = " + _sigma + ", LowerBound = " + _lowerBound + ", UpperBound = " + _upperBound + ")" ;
111
129
}
112
130
113
131
/// <summary>
@@ -117,11 +135,11 @@ public double Mode
117
135
{
118
136
get
119
137
{
120
- if ( _mean < _lowerBound )
138
+ if ( _mu < _lowerBound )
121
139
return _lowerBound ;
122
- if ( _mean > _upperBound )
140
+ if ( _mu > _upperBound )
123
141
return _upperBound ;
124
- return _mean ;
142
+ return _mu ;
125
143
}
126
144
}
127
145
@@ -148,9 +166,9 @@ public double Mean
148
166
{
149
167
get
150
168
{
151
- var pdfDifference = _uncorrectedNormal . Density ( _lowerBound ) - _uncorrectedNormal . Density ( _upperBound ) ;
152
- var diffFromUncorrected = pdfDifference * _stdDev / _cumulativeDensityWithinBounds ;
153
- return _mean + diffFromUncorrected ;
169
+ var pdfDifference = _standardNormal . Density ( _alpha ) - _standardNormal . Density ( _beta ) ;
170
+ var diffFromUncorrected = pdfDifference * _sigma / _cumulativeDensityWithinBounds ;
171
+ return _mu + diffFromUncorrected ;
154
172
}
155
173
}
156
174
@@ -161,24 +179,31 @@ public double Variance
161
179
{
162
180
get
163
181
{
182
+ //Apparently "Barr and Sherrill (1999)" has a simpler expression for one sided truncations, if anyone has access...
183
+
164
184
//TODO might need special handling for cases where either or both bounds are infinity
185
+ var densityAtLower = double . IsNegativeInfinity ( _lowerBound ) ? 0.0 : _standardNormal . Density ( _alpha ) ;
186
+ var densityAtUpper = double . IsPositiveInfinity ( _upperBound ) ? 0.0 : _standardNormal . Density ( _beta ) ;
187
+
188
+ var standardisedLower = double . IsNegativeInfinity ( _lowerBound ) ? 0.0 : _alpha ;
189
+ var standardisedUpper = double . IsPositiveInfinity ( _upperBound ) ? 0.0 : _beta ;
165
190
166
191
//Second term
167
- var secondNumerator = _lowerBound * _uncorrectedNormal . Density ( _lowerBound ) - _upperBound * _uncorrectedNormal . Density ( _upperBound ) ;
192
+ var secondNumerator = standardisedLower * densityAtLower - standardisedUpper * densityAtUpper ;
168
193
var secordTerm = secondNumerator / _cumulativeDensityWithinBounds ;
169
194
170
195
//Third term
171
- var thirdNumerator = _uncorrectedNormal . Density ( _lowerBound ) - _uncorrectedNormal . Density ( _upperBound ) ;
196
+ var thirdNumerator = densityAtLower - densityAtUpper ;
172
197
var thirdTerm = ( thirdNumerator / _cumulativeDensityWithinBounds ) * ( thirdNumerator / _cumulativeDensityWithinBounds ) ;
173
198
174
199
var sumOfTerms = 1 + secordTerm + thirdTerm ;
175
200
176
- return _stdDev * _stdDev * sumOfTerms ;
201
+ return _sigma * _sigma * sumOfTerms ;
177
202
}
178
203
}
179
204
180
205
/// <summary>
181
- /// Gets the standard deviation (σ) of the truncated normal distribution. Range: σ ≥ 0.
206
+ /// Gets the standard deviation (σ) of the truncated normal distribution. Range: σ > 0.
182
207
/// </summary>
183
208
public double StdDev
184
209
{
@@ -192,9 +217,9 @@ public double Entropy
192
217
{
193
218
get
194
219
{
195
- var firstTerm = Constants . LogSqrt2PiE + Math . Log ( _stdDev + _cumulativeDensityWithinBounds ) ;
220
+ var firstTerm = Constants . LogSqrt2PiE + Math . Log ( _sigma + _cumulativeDensityWithinBounds ) ;
196
221
197
- var secondNumerator = _lowerBound * _uncorrectedNormal . Density ( _lowerBound ) - _upperBound * _uncorrectedNormal . Density ( _upperBound ) ;
222
+ var secondNumerator = _lowerBound * _standardNormal . Density ( _alpha ) - _upperBound * _standardNormal . Density ( _beta ) ;
198
223
var secondTerm = secondNumerator / ( 2 * _cumulativeDensityWithinBounds ) ;
199
224
200
225
return firstTerm + secondTerm ;
@@ -240,7 +265,7 @@ public double Density(double x)
240
265
if ( x < _lowerBound || _upperBound < x )
241
266
return 0d ;
242
267
243
- return _uncorrectedNormal . Density ( x ) / ( _stdDev * _cumulativeDensityWithinBounds ) ;
268
+ return _standardNormal . Density ( ( x - _mu ) / _sigma ) / ( _sigma * _cumulativeDensityWithinBounds ) ;
244
269
}
245
270
246
271
/// <summary>
@@ -251,11 +276,11 @@ public double Density(double x)
251
276
/// <seealso cref="PDFLn"/>
252
277
public double DensityLn ( double x )
253
278
{
254
- return Math . Log ( Density ( x ) ) ;
255
- }
279
+ return _standardNormal . DensityLn ( ( x - _mu ) / _sigma ) - Math . Log ( _sigma ) - Math . Log ( _cumulativeDensityWithinBounds ) ;
280
+ }
256
281
257
282
//TODO: implement sampling, use method described by Mazet here: http://miv.u-strasbg.fr/mazet/rtnorm/
258
- // see implmentations listed on that page for examples.
283
+ // see implementations listed on that page for examples.
259
284
260
285
public double Sample ( )
261
286
{
@@ -285,7 +310,7 @@ public double CumulativeDistribution(double x)
285
310
if ( x > _upperBound )
286
311
return 1d ;
287
312
288
- double cumulative = _uncorrectedNormal . CumulativeDistribution ( x ) - _uncorrectedNormal . CumulativeDistribution ( _lowerBound ) ;
313
+ double cumulative = _standardNormal . CumulativeDistribution ( ( x - _mu ) / _sigma ) - _standardNormal . CumulativeDistribution ( _alpha ) ;
289
314
return cumulative / _cumulativeDensityWithinBounds ;
290
315
}
291
316
@@ -299,9 +324,9 @@ public double CumulativeDistribution(double x)
299
324
public double InverseCumulativeDistribution ( double p )
300
325
{
301
326
//TODO check that this is correct with someone.
302
- var pUntruncated = p * _cumulativeDensityWithinBounds + _uncorrectedNormal . CumulativeDistribution ( _lowerBound ) ;
327
+ var pUntruncated = p * _cumulativeDensityWithinBounds + _standardNormal . CumulativeDistribution ( _alpha ) ;
303
328
304
- return _uncorrectedNormal . InverseCumulativeDistribution ( pUntruncated ) ;
329
+ return _standardNormal . InverseCumulativeDistribution ( pUntruncated ) * _sigma + _mu ;
305
330
}
306
331
307
332
}
0 commit comments