[KED-1271] Tensorflow Model Dataset
See original GitHub issueDescription
In order to prevent pickling errors with Tensorflow models, it would be great to include a tensorflow dataset which would invoke the APIs preferred save and load methods.
Context
Saving and loading tensorflow models via pickling can result in errors. It is possible to write a new dataset which invokes the APIs preferred save and load methods and prevents pickling errors.
Possible Implementation
I’ve been using this in my work. I haven’t tried the exists
functionality (I assume this won’t work), however for regular pipeline operation, it works well.
from kedro.io import AbstractDataSet
import tensorflow as tf
from typing import Dict, Any
class TensorflowModelDataset(AbstractDataSet):
def __init__(self, filepath: str, load_args: Dict[str, Any] = None, save_args: Dict[str, Any] = None):
"""
A dataset to save and load tensorflow models.
Args:
data: The tensorflow model to save.
"""
self.filepath = filepath
self._load_args = dict(custom_objects=None,
compile=True)
self._save_args = dict(overwrite=True,
include_optimizer=True,
save_format='tf',
signatures=None,
options=None)
if load_args is not None:
self._load_args.update(load_args)
if save_args is not None:
self._save_args.update(save_args)
def _load(self):
return tf.keras.models.load_model(
self.filepath,
**self._load_args
)
def _save(self, data):
tf.keras.models.save_model(
data,
self.filepath,
**self._save_args
)
def _exists(self):
return True if self.data is not None else False
def _describe(self):
return None
Issue Analytics
- State:
- Created 4 years ago
- Comments:10 (8 by maintainers)
Top Results From Across the Web
[KED-1271] Tensorflow Model Dataset · Issue #176 - GitHub
I've been using this in my work. I haven't tried the exists functionality (I assume this won't work), however for regular pipeline operation,...
Read more >TensorFlow Datasets
A collection of datasets ready to use with TensorFlow or other Python ML frameworks, such as Jax, enabling easy-to-use and high-performance input pipelines....
Read more >Models & datasets - TensorFlow
Models & datasets. Explore repositories and other resources to find available models, modules and datasets created by the TensorFlow community.
Read more >tf.data.Dataset | TensorFlow v2.11.0
Create a source dataset from your input data. Apply dataset transformations to preprocess the data. Iterate over the dataset and process the elements....
Read more >tf.data: Build TensorFlow input pipelines
The tf.data API enables you to build complex input pipelines from simple, reusable pieces. For example, the pipeline for an image model might...
Read more >Top Related Medium Post
No results found
Top Related StackOverflow Question
No results found
Troubleshoot Live Code
Lightrun enables developers to add logs, metrics and snapshots to live code - no restarts or redeploys required.
Start FreeTop Related Reddit Thread
No results found
Top Related Hackernoon Post
No results found
Top Related Tweet
No results found
Top Related Dev.to Post
No results found
Top Related Hashnode Post
No results found
Top GitHub Comments
If you check the function signature of the save method, it expects a tf model class, not a sklearn pipeline.
If you want to save an object that isn’t covered by an existing dataset class, you’ll need to create a new dataset class. This is a good guide for this https://kedro.readthedocs.io/en/stable/07_extend_kedro/03_custom_datasets.html
FYI both, I am going to take a look at this as part of today’s pysprint. basically transplanting @williamashfordQB’s code and adding test cases