question-mark
Stuck on an issue?

Lightrun Answers was designed to reduce the constant googling that comes with debugging 3rd party libraries. It collects links to all the places you might be looking at while hunting down a tough bug.

And, if you’re still stuck at the end, we’re happy to hop on a call to see how we can help out.

Warmstart for Native Keras Model in TFX

See original GitHub issue

Thank you for submitting a TFX documentation issue. Per our GitHub policy, we only address code/doc bugs, performance issues, feature requests, and build/installation issues on GitHub and we welcome external contributions!

The TFX docs are open source! To get involved, please read the documentation contributor guide: https://www.tensorflow.org/community/contribute/docs

URL(s) with the issue:

Please provide a link to the documentation entry>> [1] https://github.com/tensorflow/tfx/blob/master/tfx/examples/chicago_taxi_pipeline/taxi_pipeline_warmstart.py [2] https://github.com/tensorflow/tfx/blob/master/tfx/examples/chicago_taxi_pipeline/taxi_pipeline_native_keras.py

Description of issue (what needs changing): Example on Warmstart for Native Keras Model TFX pipeline

Clear description:

I am looking for examples of Native Keras model tfx pipeline, where it allows warmstart of the model so for future training on new dataset, the model can continue training instead of re-compiling the whole model and retrain from beginning on new dataset.

Link in [1] above shows an example on how a tensorflow estimator with warmstart can be set-up, however the only example on keras model tfx pipeline in [2] does not cover this component.

Does it follow the same configuration as the estimator component? To be more specific, can the code below be used in a keras model pipeline (taken from estimator example)?


  # Get the latest model so that we can warm start from the model.
  latest_model_resolver = ResolverNode(
      instance_name='latest_model_resolver',
      resolver_class=latest_artifacts_resolver.LatestArtifactsResolver,
      latest_model=Channel(type=Model))

Correct links

Is the link to the source code correct? Yes

Issue Analytics

  • State:open
  • Created 2 years ago
  • Comments:6 (3 by maintainers)

github_iconTop GitHub Comments

1reaction
wlee192commented, Mar 23, 2021

@chongkong ,

Thanks, that’s the clarification i was after! So just to confirm my understanding, below would be how it looks like in the pipeline and processing code.

## Pipeline Code

warmstart_model_resolver = resolver.Resolver(
    instance_name='warmstart_model_resolver',
    strategy_class=latest_artifacts_resolver.LatestArtifactsResolver,
    latest_model=Channel(type=Model))

# Here warm_start is simply a bool (True/False)
trainer = Trainer(
    module_file=module_file,
    custom_executor_spec=executor_spec.ExecutorClassSpec(GenericExecutor),
    examples=transform.outputs['transformed_examples'],
    transform_graph=transform.outputs['transform_graph'],
    schema=schema_gen.outputs['schema'],
    base_model=warmstart_model_resolver.outputs['latest_model'] if warm_start else None,
    train_args=trainer_pb2.TrainArgs(num_steps=TRAIN_STEPS),
    eval_args=trainer_pb2.EvalArgs(num_steps=EVAL_STEPS))

## Processing_Util code

def run_fn(fn_args: FnArgs):
  """Train the model based on given args.

  Args:
    fn_args: Holds args used to train the model as name/value pairs.
  """
  
  tf_transform_output = tft.TFTransformOutput(fn_args.transform_output)
  # create the train and valuation dataset
  train_dataset = _input_fn(fn_args.train_files, 
                            fn_args.data_accessor, 
                            tf_transform_output, 
                            _TRAIN_BATCH_SIZE)
  eval_dataset = _input_fn(fn_args.eval_files, 
                           fn_args.data_accessor, 
                           tf_transform_output, 
                           _EVAL_BATCH_SIZE)

  if fn_args.base_model:
    # restore keras model object for incremental data training instead of re-commpiling model and train from beginning
    model = tf.keras.models.load_model(fn_args.base_model)
  else:
    mirrored_strategy = tf.distribute.MirroredStrategy()
    with mirrored_strategy.scope():
        model = _build_keras_model()

  tensorboard_callback = tf.keras.callbacks.TensorBoard(
      log_dir=fn_args.model_run_dir, update_freq='epoch')


  model.fit(
      train_dataset,
      steps_per_epoch=fn_args.train_steps,
      epochs = TRAIN_EPOCH,
      validation_data=eval_dataset,
      validation_steps=fn_args.eval_steps,
      callbacks=[tensorboard_callback])

  ## model signature code goes here...
  
  model.save(fn_args.serving_model_dir, save_format='tf', signatures=signatures)
0reactions
wlee192commented, Mar 24, 2021

@chongkong ,

I’ve ran a quick test of this today in Google Colab environment. As far as getting the previous model is concerned, it can get the previous trained model fine for any warmstart.

However, i did notice during training for the warmstart it did produce a warning message:

First Model Training - no base model at this stage so start training from scratch

image

WarmStart Model Training - toggle warmstart to True so pipeline will locate previously trained model and train further from there

image

image

The loss for the warmstart model looks to be higher than the first model’s last epoch loss, but then it does reduce further then the first model as the epoch progresses which is good sign it is learning further from the warmstart.

Not too sure if the warning message Model failed to serialize as JSON. Ignoring... is something of concern? model.fit and model.predict is still working for the warmstart model

FYI i was using the taxi example for keras in this exercise >>

https://github.com/tensorflow/tfx/blob/master/tfx/examples/chicago_taxi_pipeline/taxi_pipeline_native_keras.py https://github.com/tensorflow/tfx/blob/master/tfx/examples/chicago_taxi_pipeline/taxi_utils_native_keras.py

Read more comments on GitHub >

github_iconTop Results From Across the Web

TensorFlow Extended (TFX) for Dummies(Part Uno!)
The Trainer TFX pipeline component trains a TensorFlow model. ... Model Rewriting Library for examples of how to convert both Estimator and Keras...
Read more >
What's new in TensorFlow 2.11?
TensorFlow 2.11 includes a new utility function: keras.utils.warmstart_embedding_matrix. It lets you initialize embedding vectors for a new ...
Read more >
Trainer Not Warm Starting With GenericExecutor & Keras ...
I'm presently trying to get a Trainer component of a TFX pipeline to warm-start from a previous run of the same pipeline. The...
Read more >
ML Model in Production: Real-world example of End-to-End ...
The Trainer TFX pipeline component trains a TensorFlow model. ... (Optional) pre-trained models used for scenarios such as warmstart.g.
Read more >
UPLIFT: Parallelization Strategies for Feature Transformations ...
allelization schemes and interleaving transformations with model training. ... ML systems and libraries—including TensorFlow TFX [11, 12], ... We warm start.
Read more >

github_iconTop Related Medium Post

No results found

github_iconTop Related StackOverflow Question

No results found

github_iconTroubleshoot Live Code

Lightrun enables developers to add logs, metrics and snapshots to live code - no restarts or redeploys required.
Start Free

github_iconTop Related Reddit Thread

No results found

github_iconTop Related Hackernoon Post

No results found

github_iconTop Related Tweet

No results found

github_iconTop Related Dev.to Post

No results found

github_iconTop Related Hashnode Post

No results found