Skip to content

Commit ecf3c7d

Browse files
thomasjpfanogrisel
authored andcommitted
FIX MultiOutputRegressor correctly ducktypes fitted estimators (#19308)
Co-authored-by: Olivier Grisel <olivier.grisel@gmail.com>
1 parent 70d1324 commit ecf3c7d

File tree

3 files changed

+26
-1
lines changed

3 files changed

+26
-1
lines changed

doc/whats_new/v0.24.rst

Lines changed: 7 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -48,6 +48,13 @@ Changelog
4848
`'use_encoded_value'` strategies.
4949
:pr:`19234` by `Guillaume Lemaitre <glemaitre>`.
5050

51+
:mod:`sklearn.multioutput`
52+
..........................
53+
54+
- |Fix| :class:`multioutput.MultiOutputRegressor` now works with estimators
55+
that dynamically define `predict` during fitting, such as
56+
:class:`ensemble.StackingRegressor`. :pr:`19308` by `Thomas Fan`_.
57+
5158
:mod:`sklearn.semi_supervised`
5259
..............................
5360

sklearn/multioutput.py

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -198,7 +198,7 @@ def predict(self, X):
198198
Note: Separate models are generated for each predictor.
199199
"""
200200
check_is_fitted(self)
201-
if not hasattr(self.estimator, "predict"):
201+
if not hasattr(self.estimators_[0], "predict"):
202202
raise ValueError("The base estimator should implement"
203203
" a predict method")
204204

sklearn/tests/test_multioutput.py

Lines changed: 18 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -13,6 +13,7 @@
1313
from sklearn import datasets
1414
from sklearn.base import clone
1515
from sklearn.datasets import make_classification
16+
from sklearn.datasets import load_linnerud
1617
from sklearn.ensemble import GradientBoostingRegressor, RandomForestClassifier
1718
from sklearn.exceptions import NotFittedError
1819
from sklearn.linear_model import Lasso
@@ -33,6 +34,7 @@
3334
from sklearn.dummy import DummyRegressor, DummyClassifier
3435
from sklearn.pipeline import make_pipeline
3536
from sklearn.impute import SimpleImputer
37+
from sklearn.ensemble import StackingRegressor
3638

3739

3840
def test_multi_target_regression():
@@ -648,3 +650,19 @@ def test_classifier_chain_tuple_invalid_order():
648650

649651
with pytest.raises(ValueError, match='invalid order'):
650652
chain.fit(X, y)
653+
654+
655+
def test_multioutputregressor_ducktypes_fitted_estimator():
656+
"""Test that MultiOutputRegressor checks the fitted estimator for
657+
predict. Non-regression test for #16549."""
658+
X, y = load_linnerud(return_X_y=True)
659+
stacker = StackingRegressor(
660+
estimators=[("sgd", SGDRegressor(random_state=1))],
661+
final_estimator=Ridge(),
662+
cv=2
663+
)
664+
665+
reg = MultiOutputRegressor(estimator=stacker).fit(X, y)
666+
667+
# Does not raise
668+
reg.predict(X)

0 commit comments

Comments
 (0)