-
Notifications
You must be signed in to change notification settings - Fork 74.8k
Restructure Keras Scikit-Learn wrappers to better implement Scikit-Learn API #37201
New issue
Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.
By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.
Already on GitHub? Sign in to your account
Conversation
Thanks for your effort @adriangb ! I'm sorry I couldn't write back before. Of course, I don't have any problem with you updating the other PR, but I'm not sure if I have to do something to allow you to edit it. I've added you as a collaborator to my tf fork. In any case, this is a nice improvement for tf.keras in my opinion, but I don't know if the development team is interested. My PR has been open for a while and the same changes were proposed in the keras repo about a year ago now. |
I'm hoping that maybe the reviewers were just a bit busy? I feel like the wrappers are used quite often, especially in beginner tutorials. I think it's important that the initial experience be seamless. In addition to your PR, this resolves several issues: #33204, #36074, #34689 and #36137. There are/will be more, like the comment I posted on your PR. I'm hoping we can at least get some feedback from @fchollet or @pavithrasv regarding interest in these changes. |
I was able to add built-in support for all of the multi-output modes that Scikit-Learn supports, as well as a framework to easily support multi-input models. This means this would close #34689 as well. Because of the great modularity of the Functional API, only a limited number of multi-output cases can be automatically supported (those that scikit-learn itself supports with 1-1 mapping of model outputs to Although the original problem statement ('fix compatibility') was quite large, I do feel that this PR has grown very large. I would personally prefer to split it into smaller PRs (even if that means more work for me), but I will leave that up to the reviewers. @gbaned , is there a timeline for this review process? It would be nice to at least get tests running so that I can see if there are issues. |
I realized that we actually need to re-implement not only an R^2 score, but also a classifier accuracy score: Keras does not use the This prompted me to think: does it make sense to make scikit-learn an optional dependency that is only imported within this module? I can see pros and cons, just playing devil's advocate here. |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
Approving to run the presubmits.
Can you add tests to make sure that saving and loading works with the wrapper?
Thank you for kicking off those tests! There are several tests for pickling/unpickling of Functional API models with and without Callbacks, etc. I think the only thing that is missing is a test for subclassed models. I'll add that in the next few days. |
Quite a few errors:
Also, I added the test suggested by @k-w-w (it's called I guess another approval is needed for tests to run again, it'd be nice to fix the windows build errors before that though. |
Thanks for fixing the bugs! I'm pretty sure the windows tests are unrelated. Running the tests again |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
Hello Adrian,
Thank you for the PR. In general, the scikit-learn API wrapper for Keras is not well-maintained at this time. Because we don't have the resources to maintain it (or even to review a very involved PR like yours), we are considering deprecating it.
Since you are a user of this feature and you've already spent a lot of time developing improvements, I would recommend starting a new repository & pip package hosting an up-to-date version of the API wrapper (effectively, do a fork). We could then redirect users of tf.keras.wrappers.scikit_learn to your repository & pip package, while we deprecate this functionality in tf.keras.
What do you think?
Sequential.evaluate, | ||
Sequential.fit, | ||
Sequential.predict, | ||
Sequential.predict_classes, |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
predict_classses
is now deprecated.
@fchollet, thank you for looking at this! I think that could work really well. But I'd like to list out pros and cons to consider Pros:
Cons:
Overall, I think the pros outweigh the cons. I will start working on getting a separate repo with CI/publishing working. In the meantime, if you could do a brief review of this PR as it currently stands, that would be super useful to make sure the initial release is as good as possible. |
…sorflow/tensorflow#37201) Co-authored-by: David Díaz Vico <david.diaz.vico@outlook.com>
Are there any preferences as far as:
I played around with a packaging a bit, everything seems to work as far as CI/testing/releasing. It would have |
…sorflow/tensorflow#37201) Co-authored-by: David Díaz Vico <david.diaz.vico@outlook.com>
…sorflow/tensorflow#37201) Co-authored-by: David Díaz Vico <david.diaz.vico@outlook.com>
A quick update: the package is now fully operational. I settled on the name SciKeras. Some important updates since this PR:
With all of this, estimators created with these wrappers now pass all of scikit-learn's estimator checks, except those that require setting a random state. As far as I understand, it is not possible to easily set a random seed in tf. |
Hi @gbaned, just checking if there are any updates on this proposal/PR? Thanks! |
Thanks for the update. It's great to see that you've already released the new package. We can recommend that people start using it instead of
This should be fixable: https://www.tensorflow.org/api_docs/python/tf/random/set_seed What do you want us to do with the current PR? Should we close it? |
That sounds good.
Will take a look, thank you.
I think let's keep it open for a bit longer. Dask is looking to adopt SciKeras as a wrapper (here), so as they do their testing I expect there to be a couple of issues that crop up in the next couple of weeks that I may need input from the TF team on. Unless the TF team is willing to check the SciKeras repo if they are tagged. |
Ok, sounds good! Please reach out if you need anything from us (over email preferably, so we don't miss it). We'll start recommending your library as soon as it starts getting traction then 👍 |
@adriangb Any update on this PR? Please. Thanks! |
@adriangb Any update on this PR? Please. Thanks! |
Hi @gbaned, as per François' comment above, the plan is to not merge this PR and instead move this part of tf.keras to an external package. I had asked to keep this PR open for communication and help with bringup of the external package, but I've since established communication with François directly and the external package is making good progress, so I think we can close this PR 😄 |
This is a modification of #32533. I am opening a new PR because that one seems stalled and I made a lot of changes/improvements (but keeping the same idea).
A quick summary:
The existing
scikit-learn
wrappers for Keras models are not compatible with manyscikit-learn
functions. Additionally, they require that dataset dimensions be determined before calling thefit
method, which is unlike thescikit-learn
estimators and makes it hard to build dynamically adaptable pipelines.What this PR does:
This PR does not change any API.
By moving the storage of parameters from
self.sk_params
toself.__dict__
, compatibility with a lot of thescikit-learn
functionalities are improved. Additionally, I gave the model building function the ability to request the data that will be fitted (to determine dimensions) as well as any other attributes of the wrapper instance. Finally, I enabled copying/pickling of wrapped models as well as the ability to wrap instances ofModel
, which should allow for greater flexibility in incorporating into an exsitingKeras
workflow.How is it tested:
All existing tests are working and were unchanged. This confirms that there were no API changes. New tests were added for all of the new functionality as well as some of the most common
scikit-learn
operations that were previously broken.I would like to credit @daviddiazvico , the author of the original PR: I borrowed a lot of the tests he had written as well as the original idea for fixing these issues. I will make him a co-author of the final commit if this gets approved.