Skip to content

Sample_weight isn't overwritten anymore in logistic_regression #18480

New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Closed
wants to merge 2 commits into from
Closed
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
3 changes: 3 additions & 0 deletions doc/whats_new/v0.24.rst
Original file line number Diff line number Diff line change
Expand Up @@ -341,6 +341,9 @@ Changelog
efficient leave-one-out cross-validation scheme ``cv=None``. :pr:`6624` by
:user:`Marijn van Vliet <wmvanvliet>`.

- |Fix|: Fixed a bug in :class:`linear_model.LogisticRegression`: the
sample_weight object is not modified anymore. :pr:`18480` by
:user:`Bart Van Dosselaer <Ansur>`:

:mod:`sklearn.manifold`
.......................
Expand Down
4 changes: 2 additions & 2 deletions sklearn/linear_model/_logistic.py
Original file line number Diff line number Diff line change
Expand Up @@ -665,7 +665,7 @@ def _logistic_regression_path(X, y, pos_class=None, Cs=10, fit_intercept=True,
if isinstance(class_weight, dict) or multi_class == 'multinomial':
class_weight_ = compute_class_weight(class_weight,
classes=classes, y=y)
sample_weight *= class_weight_[le.fit_transform(y)]
sample_weight = sample_weight * class_weight_[le.fit_transform(y)]
Copy link
Member

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

@rth instead of doing this, would it be better to have copy=False/True within _check_sample_weight signature?


# For doing a ovr, we need to mask the labels first. for the
# multinomial case this is not necessary.
Expand All @@ -681,7 +681,7 @@ def _logistic_regression_path(X, y, pos_class=None, Cs=10, fit_intercept=True,
class_weight_ = compute_class_weight(class_weight,
classes=mask_classes,
y=y_bin)
sample_weight *= class_weight_[le.fit_transform(y_bin)]
sample_weight = sample_weight * class_weight_[le.fit_transform(y_bin)]

else:
if solver not in ['sag', 'saga']:
Expand Down
19 changes: 19 additions & 0 deletions sklearn/linear_model/tests/test_logistic.py
Original file line number Diff line number Diff line change
Expand Up @@ -1865,3 +1865,22 @@ def test_multinomial_identifiability_on_iris(fit_intercept):
assert_allclose(clf.coef_.sum(axis=0), 0, atol=1e-10)
if fit_intercept:
clf.intercept_.sum(axis=0) == pytest.approx(0, abs=1e-15)


@pytest.mark.parametrize("multi_class", {'ovr', 'multinomial', 'auto'})
def test_sample_weight_not_modified(multi_class):
X, y = load_iris(return_X_y=True)
np.random.seed(1234)
W = np.random.random(len(X)) * 10.0

for weight in [{0: 1.0, 1: 10.0, 2: 1.0}]:
for class_weight in (weight, 'balanced'):
expected = W.sum()

clf = LogisticRegression(random_state=0,
class_weight=class_weight,
max_iter=200,
multi_class=multi_class)
Comment on lines +1870 to +1883
Copy link
Member

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

We can take advantage of pytest.parametrize:

@pytest.mark.parametrize("multi_class", {'ovr', 'multinomial', 'auto'})
@pytest.mark.parametrize("class_weight", [
    {0: 1.0, 1: 10.0, 2: 1.0}, 'balanced'
])
def test_sample_weight_not_modified(multi_class, class_weight):
    X, y = load_iris(return_X_y=True)
    n_features = len(X)
    W = np.ones(n_features)
    W[:n_features // 2] = 2

    expected = W.sum()
	...

clf.fit(X, y, sample_weight=W)
actual = W.sum()
assert expected == actual, 'Sum of weight before ({}) should be the same as sum if weight after ({})'.format(expected, actual)
Copy link
Member

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

We need to keep lines shorter than 80 char, and also to avoid exact float comparison

Suggested change
assert expected == actual, 'Sum of weight before ({}) should be the same as sum if weight after ({})'.format(expected, actual)
msg = (
f'Sum of weight before ({expected}) should be the same as'
f'sum if weight after ({actual})'
)
assert_allclose(expected, actual, err_msg=msg)