question-mark
Stuck on an issue?

Lightrun Answers was designed to reduce the constant googling that comes with debugging 3rd party libraries. It collects links to all the places you might be looking at while hunting down a tough bug.

And, if you’re still stuck at the end, we’re happy to hop on a call to see how we can help out.

[KED-1271] Tensorflow Model Dataset

See original GitHub issue

Description

In order to prevent pickling errors with Tensorflow models, it would be great to include a tensorflow dataset which would invoke the APIs preferred save and load methods.

Context

Saving and loading tensorflow models via pickling can result in errors. It is possible to write a new dataset which invokes the APIs preferred save and load methods and prevents pickling errors.

Possible Implementation

I’ve been using this in my work. I haven’t tried the exists functionality (I assume this won’t work), however for regular pipeline operation, it works well.

from kedro.io import AbstractDataSet
import tensorflow as tf
from typing import Dict, Any


class TensorflowModelDataset(AbstractDataSet):

    def __init__(self, filepath: str, load_args: Dict[str, Any] = None, save_args: Dict[str, Any] = None):
        """
        A dataset to save and load tensorflow models.

        Args:
            data: The tensorflow model to save.

        """

        self.filepath = filepath

        self._load_args = dict(custom_objects=None,
                               compile=True)

        self._save_args = dict(overwrite=True,
                               include_optimizer=True,
                               save_format='tf',
                               signatures=None,
                               options=None)

        if load_args is not None:
            self._load_args.update(load_args)

        if save_args is not None:
            self._save_args.update(save_args)

    def _load(self):
        return tf.keras.models.load_model(
            self.filepath,
            **self._load_args
        )

    def _save(self, data):
        tf.keras.models.save_model(
            data,
            self.filepath,
            **self._save_args
        )

    def _exists(self):
        return True if self.data is not None else False

    def _describe(self):
        return None

Issue Analytics

  • State:closed
  • Created 4 years ago
  • Comments:10 (8 by maintainers)

github_iconTop GitHub Comments

1reaction
willashfordcommented, Jan 29, 2022

If you check the function signature of the save method, it expects a tf model class, not a sklearn pipeline.

If you want to save an object that isn’t covered by an existing dataset class, you’ll need to create a new dataset class. This is a good guide for this https://kedro.readthedocs.io/en/stable/07_extend_kedro/03_custom_datasets.html

1reaction
w0rdsm1thcommented, Feb 26, 2020

FYI both, I am going to take a look at this as part of today’s pysprint. basically transplanting @williamashfordQB’s code and adding test cases

Read more comments on GitHub >

github_iconTop Results From Across the Web

[KED-1271] Tensorflow Model Dataset · Issue #176 - GitHub
I've been using this in my work. I haven't tried the exists functionality (I assume this won't work), however for regular pipeline operation,...
Read more >
TensorFlow Datasets
A collection of datasets ready to use with TensorFlow or other Python ML frameworks, such as Jax, enabling easy-to-use and high-performance input pipelines....
Read more >
Models & datasets - TensorFlow
Models & datasets. Explore repositories and other resources to find available models, modules and datasets created by the TensorFlow community.
Read more >
tf.data.Dataset | TensorFlow v2.11.0
Create a source dataset from your input data. Apply dataset transformations to preprocess the data. Iterate over the dataset and process the elements....
Read more >
tf.data: Build TensorFlow input pipelines
The tf.data API enables you to build complex input pipelines from simple, reusable pieces. For example, the pipeline for an image model might...
Read more >

github_iconTop Related Medium Post

No results found

github_iconTop Related StackOverflow Question

No results found

github_iconTroubleshoot Live Code

Lightrun enables developers to add logs, metrics and snapshots to live code - no restarts or redeploys required.
Start Free

github_iconTop Related Reddit Thread

No results found

github_iconTop Related Hackernoon Post

No results found

github_iconTop Related Tweet

No results found

github_iconTop Related Dev.to Post

No results found

github_iconTop Related Hashnode Post

No results found