Skip to content

Commit 94ce9b8

Browse files
vadim-ushtanitUshtanitNicolasHugglemaitre
committed
FIX index sample_weight in least_absolute_deviation loss in HistGradientBoosting (#19407)
Co-authored-by: Vadim Ushtanit <vadim.ushtanit@gmail.com> Co-authored-by: Nicolas Hug <contact@nicolas-hug.com> Co-authored-by: Guillaume Lemaitre <g.lemaitre58@gmail.com>
1 parent 688eb29 commit 94ce9b8

File tree

3 files changed

+26
-4
lines changed

3 files changed

+26
-4
lines changed

doc/whats_new/v0.24.rst

Lines changed: 7 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -12,6 +12,13 @@ Version 0.24.2
1212
Changelog
1313
---------
1414

15+
:mod:`sklearn.ensemble`
16+
.......................
17+
18+
- |Fix| Fixed a bug in :class:`ensemble.HistGradientBoostingRegressor` `fit`
19+
with `sample_weight` parameter and `least_absolute_deviation` loss function.
20+
:pr:`19407` by :user:`Vadim Ushtanit <vadim-ushtanit>`.
21+
1522
:mod:`sklearn.preprocessing`
1623
............................
1724

sklearn/ensemble/_hist_gradient_boosting/loss.py

Lines changed: 5 additions & 4 deletions
Original file line numberDiff line numberDiff line change
@@ -261,10 +261,11 @@ def update_leaves_values(self, grower, y_true, raw_predictions,
261261
median_res = np.median(y_true[indices]
262262
- raw_predictions[indices])
263263
else:
264-
median_res = _weighted_percentile(y_true[indices]
265-
- raw_predictions[indices],
266-
sample_weight=sample_weight,
267-
percentile=50)
264+
median_res = _weighted_percentile(
265+
y_true[indices] - raw_predictions[indices],
266+
sample_weight=sample_weight[indices],
267+
percentile=50
268+
)
268269
leaf.value = grower.shrinkage * median_res
269270
# Note that the regularization is ignored here
270271

sklearn/ensemble/_hist_gradient_boosting/tests/test_gradient_boosting.py

Lines changed: 14 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -203,6 +203,20 @@ def test_least_absolute_deviation():
203203
assert gbdt.score(X, y) > .9
204204

205205

206+
def test_least_absolute_deviation_sample_weight():
207+
# non regression test for issue #19400
208+
# make sure no error is thrown during fit of
209+
# HistGradientBoostingRegressor with least_absolute_deviation loss function
210+
# and passing sample_weight
211+
rng = np.random.RandomState(0)
212+
n_samples = 100
213+
X = rng.uniform(-1, 1, size=(n_samples, 2))
214+
y = rng.uniform(-1, 1, size=n_samples)
215+
sample_weight = rng.uniform(0, 1, size=n_samples)
216+
gbdt = HistGradientBoostingRegressor(loss='least_absolute_deviation')
217+
gbdt.fit(X, y, sample_weight=sample_weight)
218+
219+
206220
@pytest.mark.parametrize('y', [([1., -2., 0.]), ([0., 0., 0.])])
207221
def test_poisson_y_positive(y):
208222
# Test that ValueError is raised if either one y_i < 0 or sum(y_i) <= 0.

0 commit comments

Comments
 (0)