Skip to content

[MRG] ENH Consistent loss name for squared error #19310

New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Merged
merged 28 commits into from
Mar 19, 2021
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
Show all changes
28 commits
Select commit Hold shift + click to select a range
820b7b7
MNT deprecate mse criterion in tree module
lorentzenchr Jan 31, 2021
627c343
MNT deprecate mse criterion for RandomForestRegressor
lorentzenchr Jan 31, 2021
bc3b7f8
MNT deprecate criterion mse and loss ls in GradientBoosting
lorentzenchr Jan 31, 2021
2fbc6ee
MNT deprecate loss least_squares in HistGradientBoostingRegressor
lorentzenchr Jan 31, 2021
fdd21f6
MNT deprecate loss squared_loss in linear_model SGD
lorentzenchr Jan 31, 2021
590f2f6
MNT/TST replace criterion 'mse' by 'squared_error' in PDP tests
lorentzenchr Jan 31, 2021
fa7f8bd
MNT/TST forgot a few deprecated 'ls' in gradient boosting tests
lorentzenchr Jan 31, 2021
ab4c861
MNT/TST replace squared_loss in test_sgd.py
lorentzenchr Jan 31, 2021
7d3d2bd
MNT deprecate loss squared_loss in RANSACRegressor
lorentzenchr Jan 31, 2021
67ceac9
MNT internally rename squared_loss to squared_error in neural_network
lorentzenchr Jan 31, 2021
83bb09a
MNT replace losses in benchmarks
lorentzenchr Jan 31, 2021
baec17d
DOC replace losses in docs
lorentzenchr Jan 31, 2021
cb0c4e4
EXA replace losses in exampels
lorentzenchr Jan 31, 2021
fded6f7
MNT replace least_squares in HGBT utils
lorentzenchr Jan 31, 2021
0777251
CLN correct directive deprecated
lorentzenchr Jan 31, 2021
68e1f9b
CLN filter FutureWarning for squared_loss in SGD tests
lorentzenchr Jan 31, 2021
8692240
CLN hickups in SGD tests due to param checks in init of BaseSGD
lorentzenchr Jan 31, 2021
1d570bd
Merge branch 'main' into consistent_squared_error
lorentzenchr Feb 1, 2021
1bf1c0b
Merge branch 'main' into consistent_squared_error
lorentzenchr Feb 18, 2021
1e9683a
CLN fix double import of pytest
lorentzenchr Feb 18, 2021
91ec366
address review comments 1st round
lorentzenchr Feb 27, 2021
cc94841
Merge branch 'main' into consistent_squared_error
lorentzenchr Feb 27, 2021
e3e92d7
FIX test_export.py
lorentzenchr Mar 2, 2021
0bfc742
Merge branch 'main' into consistent_squared_error
lorentzenchr Mar 2, 2021
0179ac9
DOC add whatsnew entry
lorentzenchr Mar 2, 2021
b50bd75
DOC use |API| tag in whatsnew
lorentzenchr Mar 15, 2021
e288a6a
FIX criterion="mse" test in forest
lorentzenchr Mar 15, 2021
7d220b8
FIX check for DecisionTreeRegressor ExtraTreeRegressor in ensemble base
lorentzenchr Mar 15, 2021
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
2 changes: 1 addition & 1 deletion benchmarks/bench_hist_gradient_boosting.py
Original file line number Diff line number Diff line change
Expand Up @@ -110,7 +110,7 @@ def one_run(n_samples):
else:
# regression
if loss == 'default':
loss = 'least_squares'
loss = 'squared_error'
est.set_params(loss=loss)
est.fit(X_train, y_train, sample_weight=sample_weight_train)
sklearn_fit_duration = time() - tic
Expand Down
2 changes: 1 addition & 1 deletion benchmarks/bench_hist_gradient_boosting_threading.py
Original file line number Diff line number Diff line change
Expand Up @@ -112,7 +112,7 @@ def get_estimator_and_data():
else:
# regression
if loss == 'default':
loss = 'least_squares'
loss = 'squared_error'
sklearn_est.set_params(loss=loss)


Expand Down
15 changes: 9 additions & 6 deletions doc/modules/ensemble.rst
Original file line number Diff line number Diff line change
Expand Up @@ -537,7 +537,8 @@ Regression
:class:`GradientBoostingRegressor` supports a number of
:ref:`different loss functions <gradient_boosting_loss>`
for regression which can be specified via the argument
``loss``; the default loss function for regression is least squares (``'ls'``).
``loss``; the default loss function for regression is squared error
(``'squared_error'``).

::

Expand All @@ -549,8 +550,10 @@ for regression which can be specified via the argument
>>> X, y = make_friedman1(n_samples=1200, random_state=0, noise=1.0)
>>> X_train, X_test = X[:200], X[200:]
>>> y_train, y_test = y[:200], y[200:]
>>> est = GradientBoostingRegressor(n_estimators=100, learning_rate=0.1,
... max_depth=1, random_state=0, loss='ls').fit(X_train, y_train)
>>> est = GradientBoostingRegressor(
... n_estimators=100, learning_rate=0.1, max_depth=1, random_state=0,
... loss='squared_error'
... ).fit(X_train, y_train)
>>> mean_squared_error(y_test, est.predict(X_test))
5.00...

Expand Down Expand Up @@ -741,8 +744,8 @@ the parameter ``loss``:

* Regression

* Least squares (``'ls'``): The natural choice for regression due
to its superior computational properties. The initial model is
* Squared error (``'squared_error'``): The natural choice for regression
due to its superior computational properties. The initial model is
given by the mean of the target values.
* Least absolute deviation (``'lad'``): A robust loss function for
regression. The initial model is given by the median of the
Expand Down Expand Up @@ -950,7 +953,7 @@ controls the number of iterations of the boosting process::
>>> clf.score(X_test, y_test)
0.8965

Available losses for regression are 'least_squares',
Available losses for regression are 'squared_error',
'least_absolute_deviation', which is less sensitive to outliers, and
'poisson', which is well suited to model counts and frequencies. For
classification, 'binary_crossentropy' is used for binary classification and
Expand Down
8 changes: 4 additions & 4 deletions doc/modules/sgd.rst
Original file line number Diff line number Diff line change
Expand Up @@ -30,7 +30,7 @@ For example, using `SGDClassifier(loss='log')` results in logistic regression,
i.e. a model equivalent to :class:`~sklearn.linear_model.LogisticRegression`
which is fitted via SGD instead of being fitted by one of the other solvers
in :class:`~sklearn.linear_model.LogisticRegression`. Similarly,
`SGDRegressor(loss='squared_loss', penalty='l2')` and
`SGDRegressor(loss='squared_error', penalty='l2')` and
:class:`~sklearn.linear_model.Ridge` solve the same optimization problem, via
different means.

Expand Down Expand Up @@ -211,7 +211,7 @@ samples (> 10.000), for other problems we recommend :class:`Ridge`,
The concrete loss function can be set via the ``loss``
parameter. :class:`SGDRegressor` supports the following loss functions:

* ``loss="squared_loss"``: Ordinary least squares,
* ``loss="squared_error"``: Ordinary least squares,
* ``loss="huber"``: Huber loss for robust regression,
* ``loss="epsilon_insensitive"``: linear Support Vector Regression.

Expand Down Expand Up @@ -362,9 +362,9 @@ Different choices for :math:`L` entail different classifiers or regressors:

- Hinge (soft-margin): equivalent to Support Vector Classification.
:math:`L(y_i, f(x_i)) = \max(0, 1 - y_i f(x_i))`.
- Perceptron:
- Perceptron:
:math:`L(y_i, f(x_i)) = \max(0, - y_i f(x_i))`.
- Modified Huber:
- Modified Huber:
:math:`L(y_i, f(x_i)) = \max(0, 1 - y_i f(x_i))^2` if :math:`y_i f(x_i) >
1`, and :math:`L(y_i, f(x_i)) = -4 y_i f(x_i)` otherwise.
- Log: equivalent to Logistic Regression.
Expand Down
31 changes: 31 additions & 0 deletions doc/whats_new/v1.0.rst
Original file line number Diff line number Diff line change
Expand Up @@ -45,6 +45,37 @@ Changelog
:pr:`123456` by :user:`Joe Bloggs <joeongithub>`.
where 123456 is the *pull request* number, not the issue number.

- |API| The option for using the squared error via ``loss`` and
``criterion`` parameters was made more consistent. The preferred way is by
setting the value to `"squared_error"`. Old option names are still valid,
produce the same models, but are deprecated and will be removed in version
1.2.
:pr:`19310` by :user:`Christian Lorentzen <lorentzenchr>`.

- For :class:`ensemble.ExtraTreesRegressor`, `criterion="mse"` is deprecated,
use `"squared_error"` instead which is now the default.

- For :class:`ensemble.GradientBoostingRegressor`, `loss="ls"` is deprecated,
use `"squared_error"` instead which is now the default.

- For :class:`ensemble.RandomForestRegressor`, `criterion="mse"` is deprecated,
use `"squared_error"` instead which is now the default.

- For :class:`ensemble.HistGradientBoostingRegressor`, `loss="least_squares"`
is deprecated, use `"squared_error"` instead which is now the default.

- For :class:`linear_model.RANSACRegressor`, `loss="squared_loss"` is
deprecated, use `"squared_error"` instead.

- For :class:`linear_model.SGDRegressor`, `loss="squared_loss"` is
deprecated, use `"squared_error"` instead which is now the default.

- For :class:`tree.DecisionTreeRegressor`, `criterion="mse"` is deprecated,
use `"squared_error"` instead which is now the default.

- For :class:`tree.ExtraTreeRegressor`, `criterion="mse"` is deprecated,
use `"squared_error"` instead which is now the default.

:mod:`sklearn.cluster`
......................

Expand Down
2 changes: 1 addition & 1 deletion examples/applications/plot_model_complexity_influence.py
Original file line number Diff line number Diff line change
Expand Up @@ -177,7 +177,7 @@ def _count_nonzero_coefficients(estimator):
'prediction_performance_label': 'MSE',
'n_samples': 30},
{'estimator': GradientBoostingRegressor,
'tuned_params': {'loss': 'ls'},
'tuned_params': {'loss': 'squared_error'},
'changing_param': 'n_estimators',
'changing_param_values': [10, 50, 100, 200, 500],
'complexity_label': 'n_trees',
Expand Down
24 changes: 12 additions & 12 deletions examples/ensemble/plot_gradient_boosting_quantile.py
Original file line number Diff line number Diff line change
Expand Up @@ -71,24 +71,24 @@ def f(x):
all_models["q %1.2f" % alpha] = gbr.fit(X_train, y_train)

# %%
# For the sake of comparison, also fit a baseline model trained with the usual
# least squares loss (ls), also known as the mean squared error (MSE).
gbr_ls = GradientBoostingRegressor(loss='ls', **common_params)
all_models["ls"] = gbr_ls.fit(X_train, y_train)
# For the sake of comparison, we also fit a baseline model trained with the
# usual (mean) squared error (MSE).
gbr_ls = GradientBoostingRegressor(loss='squared_error', **common_params)
all_models["mse"] = gbr_ls.fit(X_train, y_train)

# %%
# Create an evenly spaced evaluation set of input values spanning the [0, 10]
# range.
xx = np.atleast_2d(np.linspace(0, 10, 1000)).T

# %%
# Plot the true conditional mean function f, the prediction of the conditional
# mean (least squares loss), the conditional median and the conditional 90%
# interval (from 5th to 95th conditional percentiles).
# Plot the true conditional mean function f, the predictions of the conditional
# mean (loss equals squared error), the conditional median and the conditional
# 90% interval (from 5th to 95th conditional percentiles).
import matplotlib.pyplot as plt


y_pred = all_models['ls'].predict(xx)
y_pred = all_models['mse'].predict(xx)
y_lower = all_models['q 0.05'].predict(xx)
y_upper = all_models['q 0.95'].predict(xx)
y_med = all_models['q 0.50'].predict(xx)
Expand Down Expand Up @@ -153,7 +153,7 @@ def highlight_min(x):
#
# Note that because the target distribution is asymmetric, the expected
# conditional mean and conditional median are signficiantly different and
# therefore one could not use the least squares model get a good estimation of
# therefore one could not use the squared error model get a good estimation of
# the conditional median nor the converse.
#
# If the target distribution were symmetric and had no outliers (e.g. with a
Expand All @@ -179,9 +179,9 @@ def highlight_min(x):
# shows that the best test metric is obtained when the model is trained by
# minimizing this same metric.
#
# Note that the conditional median estimator is competitive with the least
# squares estimator in terms of MSE on the test set: this can be explained by
# the fact the least squares estimator is very sensitive to large outliers
# Note that the conditional median estimator is competitive with the squared
# error estimator in terms of MSE on the test set: this can be explained by
# the fact the squared error estimator is very sensitive to large outliers
# which can cause significant overfitting. This can be seen on the right hand
# side of the previous plot. The conditional median estimator is biased
# (underestimation for this asymetric noise) but is also naturally robust to
Expand Down
2 changes: 1 addition & 1 deletion examples/ensemble/plot_gradient_boosting_regression.py
Original file line number Diff line number Diff line change
Expand Up @@ -67,7 +67,7 @@
'max_depth': 4,
'min_samples_split': 5,
'learning_rate': 0.01,
'loss': 'ls'}
'loss': 'squared_error'}

# %%
# Fit regression model
Expand Down
10 changes: 10 additions & 0 deletions sklearn/ensemble/_base.py
Original file line number Diff line number Diff line change
Expand Up @@ -15,6 +15,7 @@
from ..base import is_classifier, is_regressor
from ..base import BaseEstimator
from ..base import MetaEstimatorMixin
from ..tree import DecisionTreeRegressor, ExtraTreeRegressor
from ..utils import Bunch, _print_elapsed_time
from ..utils import check_random_state
from ..utils.metaestimators import _BaseComposition
Expand Down Expand Up @@ -151,6 +152,15 @@ def _make_estimator(self, append=True, random_state=None):
estimator.set_params(**{p: getattr(self, p)
for p in self.estimator_params})

# TODO: Remove in v1.2
# criterion "mse" would cause warnings in every call to
# DecisionTreeRegressor.fit(..)
if (
isinstance(estimator, (DecisionTreeRegressor, ExtraTreeRegressor))
and getattr(estimator, "criterion", None) == "mse"
):
estimator.set_params(criterion="squared_error")

if random_state is not None:
_set_random_states(estimator, random_state)

Expand Down
41 changes: 30 additions & 11 deletions sklearn/ensemble/_forest.py
Original file line number Diff line number Diff line change
Expand Up @@ -345,6 +345,17 @@ def fit(self, X, y, sample_weight=None):

# Check parameters
self._validate_estimator()
# TODO: Remove in v1.2
if (
isinstance(self, (RandomForestRegressor, ExtraTreesRegressor))
and self.criterion == "mse"
):
warn(
"Criterion 'mse' was deprecated in v1.0 and will be "
"removed in version 1.2. Use `criterion='squared_error'` "
"which is equivalent.",
FutureWarning
)

if not self.bootstrap and self.oob_score:
raise ValueError("Out of bag estimation only available"
Expand Down Expand Up @@ -1310,15 +1321,19 @@ class RandomForestRegressor(ForestRegressor):
The default value of ``n_estimators`` changed from 10 to 100
in 0.22.

criterion : {"mse", "mae"}, default="mse"
criterion : {"squared_error", "mse", "mae"}, default="squared_error"
The function to measure the quality of a split. Supported criteria
are "mse" for the mean squared error, which is equal to variance
reduction as feature selection criterion, and "mae" for the mean
absolute error.
are "squared_error" for the mean squared error, which is equal to
variance reduction as feature selection criterion, and "mae" for the
mean absolute error.

.. versionadded:: 0.18
Mean Absolute Error (MAE) criterion.

.. deprecated:: 1.0
Criterion "mse" was deprecated in v1.0 and will be removed in
version 1.2. Use `criterion="squared_error"` which is equivalent.

max_depth : int, default=None
The maximum depth of the tree. If None, then nodes are expanded until
all leaves are pure or until all leaves contain less than
Expand Down Expand Up @@ -1537,7 +1552,7 @@ class RandomForestRegressor(ForestRegressor):
@_deprecate_positional_args
def __init__(self,
n_estimators=100, *,
criterion="mse",
criterion="squared_error",
max_depth=None,
min_samples_split=2,
min_samples_leaf=1,
Expand Down Expand Up @@ -1921,15 +1936,19 @@ class ExtraTreesRegressor(ForestRegressor):
The default value of ``n_estimators`` changed from 10 to 100
in 0.22.

criterion : {"mse", "mae"}, default="mse"
criterion : {"squared_error", "mse", "mae"}, default="squared_error"
The function to measure the quality of a split. Supported criteria
are "mse" for the mean squared error, which is equal to variance
reduction as feature selection criterion, and "mae" for the mean
absolute error.
are "squared_error" and "mse" for the mean squared error, which is
equal to variance reduction as feature selection criterion, and "mae"
for the mean absolute error.

.. versionadded:: 0.18
Mean Absolute Error (MAE) criterion.

.. deprecated:: 1.0
Criterion "mse" was deprecated in v1.0 and will be removed in
version 1.2. Use `criterion="squared_error"` which is equivalent.

max_depth : int, default=None
The maximum depth of the tree. If None, then nodes are expanded until
all leaves are pure or until all leaves contain less than
Expand Down Expand Up @@ -2141,7 +2160,7 @@ class ExtraTreesRegressor(ForestRegressor):
@_deprecate_positional_args
def __init__(self,
n_estimators=100, *,
criterion="mse",
criterion="squared_error",
max_depth=None,
min_samples_split=2,
min_samples_leaf=1,
Expand Down Expand Up @@ -2353,7 +2372,7 @@ class RandomTreesEmbedding(BaseForest):
[0., 1., 1., 0., 1., 0., 0., 1., 1., 0.]])
"""

criterion = 'mse'
criterion = "squared_error"
max_features = 1

@_deprecate_positional_args
Expand Down
Loading