Dealing with variable length inputs
See original GitHub issueLet’s assume we are working with variable length inputs. One of the strongest parts in using tf.data.Dataset is the ability to pad batches as they come.
But since scikit-learn’s API is mainly focused around dataframes and arrays, incorporating this is kind of hard. Obviously, you can pad everything, but this can be a huge waste of memory. I’m trying to work with the sklearn.pipeline.Pipeline object, and I thought to myself "alright, I’ll just create a custom transformer at the end of my pipeline just before the model, and make it return a tf.data.Dataset object to later plug in my model. But this is not possible since the .transform signature only accepts X and not y, while you’ll need both to work with tf.data.Dataset.
So assume we have 4 features for each data point, and each has it’s own sequence length, for example a datapoint might look like this:
sample_features = {'a': [1,2,3], 'b': [1,2,3,4,5], 'c': 1, 'd': [1,2]}
sample_label = 0
How will I be able to manage this kind of dataset under scikit learn + scikeras?
Issue Analytics
- State:
- Created 3 years ago
- Comments:16 (8 by maintainers)

Top Related StackOverflow Question
@adriangb I’ve put myself a reminder and I’ll try ot look at this next weekend
EDIT
Just to not leave my end open, at the end I did not have time to see this through… Though still interested in the incorporation of
tf.data.Datasetin SciKeras 😄Yep
Please enjoy your weekend! No rush.
Sure thing. The jist of it is that these are dependency injection points for users to insert custom data transformations. Calling
BaseWrapper.fitinstantiates and fits the transformers here. Adding another transformer just consists of adding a some default transformers (sklearn.preprocessing.FunctionTransformer) and a couple of lines to instantiates and fit the transformer. I think the hardest part is going to be figuring out the signature of the transformer since it’ll be non-standard (Sklearn accepts only 1 parameter, we need 2 or a tuple).