question-mark
Stuck on an issue?

Lightrun Answers was designed to reduce the constant googling that comes with debugging 3rd party libraries. It collects links to all the places you might be looking at while hunting down a tough bug.

And, if you’re still stuck at the end, we’re happy to hop on a call to see how we can help out.

get_weights()/set_weights() take too long

See original GitHub issue

System information.

  • TensorFlow version: 2.3.0
  • CUDA version: 10.1
  • GPU compute capability: 5.2

Describe the problem.

I am running some code which repeatedly (every training iteration) calls layer.get_weights() and layer.set_weights(). The callback operation containing these calls takes 0.009ms compared to the 0.003ms taken to run the batch and as such more than triples the training time required. I assume that this operation is simply moving tensors around (should be only on GPU) and thus should not take time comparable to the large matrix multiplications occurring during the batch iteration. I have reviewed the source code and to the best of my understanding this is what is happening. However, it is obviously taking an extraordinarily long time. Does anyone have any idea why this happens, or any approaches to reduce the time taken to call set_weights() and get_weights()? This abnormally long runtime may be due to the structure of the get_weights()/set_weights() functions, which is why I am raising this issue as a bug.

My intuition is that it may be due to data being sent to the CPU and back, or converted from tensors to numpy. Or, perhaps, upon calling set_weights, tensorflow rebuilds the entire graph from scratch or something similar.

One thing I noticed is that keras has their own pruning functionality shown here and this functionality incidentally also has a long callback runtime (see below). Perhaps this is related?

3/422 [..............................] - ETA: 12s - loss: 0.0628 - accuracy: 0.9896  
WARNING:tensorflow:Callback method `on_train_batch_end` is slow compared to the batch time (batch time: 0.0075s vs `on_train_batch_end` time: 0.0076s). Check your callbacks.

Describe the current behavior.

The callback to on_train_batch_end() in the code below calls get_weights() twice and set_weights() once, and takes twice as long to run as the batch update:

Epoch 1/40
  1/629 [..............................] - ETA: 0s - loss: 2.3555 - accuracy: 0.0000e+00
  WARNING:tensorflow:Callbacks method `on_train_batch_end` is slow compared to the batch time (batch time: 0.0030s vs `on_train_batch_end` time: 0.0090s). Check your callbacks.

This is explicitly due to calling get_weights() and set_weights(), as their removal from the callback reduces runtime of the callback to negligible amounts.

Describe the expected behavior.

Ideally, I would like to achieve iterative magnitude pruning with the lowest possible runtime.

Standalone code to reproduce the issue.

import sys
import numpy as np
import tensorflow as tf
import tensorflow_probability as tfp
from sklearn.preprocessing import StandardScaler
from sklearn.preprocessing import OneHotEncoder
from sklearn.model_selection import train_test_split
from tensorflow import keras
from tensorflow.keras.models import Sequential
from tensorflow.keras.layers import Dense
from tensorflow.keras.callbacks import Callback
from tensorflow.keras.datasets import mnist

### PRUNE WEIGHTS CALLBACK ###
class pruneModelCallback(Callback):
    def __init__(self, init_weight_dict=None, mask_dict=None):
        self.n_batches = 0
        self.init_weight_dict = init_weight_dict
        self.mask_dict = mask_dict
    
    def on_train_batch_begin(self, batch, logs=None):
        # save weights at initialization
        if self.n_batches == 0:
            if self.init_weight_dict is not None:
                for layer_i in range(len(self.model.layers)):
                    w = self.init_weight_dict['w_'+str(layer_i+1)]
                    b = self.init_weight_dict['b_'+str(layer_i+1)]

                    self.model.layers[layer_i].set_weights([w,b])
            else:
                self.init_weight_dict = {}
                for layer_i in range(len(self.model.layers)):
                    w = self.model.layers[layer_i].get_weights()[0]
                    b = self.model.layers[layer_i].get_weights()[1]

                    self.init_weight_dict['w_'+str(layer_i+1)] = w
                    self.init_weight_dict['b_'+str(layer_i+1)] = b
        
        self.n_batches = self.n_batches + 1
        
    # This is the problematic function, runs every training iteration batch
    def on_train_batch_end(self, batch, logs=None):
        # zero out pruned weights
        if self.mask_dict is not None:
            for layer_i in range(len(self.model.layers)):
                # removing these slightly improves runtime
                w = self.model.layers[layer_i].get_weights()[0]
                b = self.model.layers[layer_i].get_weights()[1]

                w_mask = self.mask_dict['w_'+str(layer_i+1)]

                # this multiplication takes no time comparably and removing it 
                # does not influence time taken
                w_pruned = w * w_mask

                # removing this function call significantly speeds up the runtime
                self.model.layers[layer_i].set_weights([w_pruned,b])

class pruneWeights():
    def __init__(self, model, percentile, pruning_type="IMP"):
        # generate pruned mask
        if pruning_type == "IMP":
            return self._IMP(model, percentile)
        else:
            raise ValueError("Unknown pruning_type {}".format(pruning_type))
    
    def _IMP(self, model, percentile):
        mask_dict = {}
        w_list = None
        for layer_i in range(len(model.layers)):
            w = model.layers[layer_i].get_weights()[0]
            w_shape = tf.shape(w)
            full_shape = tf.math.reduce_prod(w_shape)
            w_flat = tf.reshape(w, full_shape)
            if w_list is None:
                w_list = w_flat
            else:
                w_list = tf.concat([w_list, w_flat], axis=0)
        w_list = tf.math.abs(w_list)
        thresh = tfp.stats.percentile(w_list, percentile*100)
        test_mask = tf.cast(tf.math.greater(w_list, thresh), tf.float32)
        for layer_i in range(len(model.layers)):
            w = model.layers[layer_i].get_weights()[0]
            w = tf.math.abs(w)
            mask = tf.cast(tf.math.greater(w, thresh), tf.float32)
            mask_dict['w_'+str(layer_i+1)] = mask
        self.mask_dict = mask_dict

def pruning_breakdown(mask_dict, model):
    for layer_i in range(len(model.layers)):
        mask = mask_dict['w_'+str(layer_i+1)]
        print("w_"+str(layer_i+1)+": "+str(tf.math.reduce_mean(mask)))

def main():
    ### LOAD MNIST DATASET ###

    (x_train , y_train), (x_test , y_test) = mnist.load_data()
    x = np.concatenate((x_train, x_test), axis=0)
    y = np.concatenate((y_train, y_test), axis=0)
    x= x.astype("float32") / 255
    x= np.reshape(x, (np.shape(x)[0], 784))
    scaler = StandardScaler(with_std=False)
    scaler.fit(x)
    x_t= scaler.transform(x)
    ohe = OneHotEncoder()
    y_t= ohe.fit_transform(y.reshape(-1, 1)).toarray()
    input_dim = np.shape(x_t)[1:]
    output_dim = np.shape(y_t)[1:]

    del x_train, x_test, y_train, y_test, x, y

    ### SPLIT TRAINING DATA ###

    X_train, X_test, y_train, y_test = train_test_split(x_t, y_t, test_size=0.33, random_state=42)

    idxs = tf.range(tf.shape(X_train)[0])

    ### MODEL INIT ###

    model = Sequential([
        Dense(300, input_dim=input_dim[0], activation='relu'),
        Dense(100, activation='relu'),
        Dense(50, activation='relu'),
        Dense(output_dim[0], activation='softmax')
    ])

    model.compile(
        optimizer = keras.optimizers.Adam(lr=1.2e-4),
        loss = tf.keras.losses.CategoricalCrossentropy(),
        metrics = ['accuracy']
    )

    ### TRAIN MODEL ###

    epochs = 20

    pr = pruneModelCallback()
    es = tf.keras.callbacks.EarlyStopping(patience=5, restore_best_weights=True,)

    history = model.fit(
        x=X_train,
        y=y_train,
        batch_size=64,
        epochs=epochs,
        validation_data=(X_test,y_test),
        callbacks = [pr, es],
    )

    ### ITERATIVELY PRUNE MODEL ###

    percentile_per_it = 0.20
    it = 14

    for i in range(it):
        total_percentile = 1-tf.math.pow(1-percentile_per_it, i+1)
        print(total_percentile)
        pruned_weights = pruneWeights(model, total_percentile)
        pruning_breakdown(pruned_weights.mask_dict, model)

        model = Sequential([
            Dense(300, input_dim=input_dim[0], activation='relu'),
            Dense(100, activation='relu'),
            Dense(50, activation='relu'),
            Dense(output_dim[0], activation='softmax')
        ])

        model.compile(
            optimizer = keras.optimizers.Adam(lr=1.2e-4),
            loss = tf.keras.losses.CategoricalCrossentropy(),
            metrics = ['accuracy']
        )

        ### TRAIN MODEL ###

        epochs = 40

        pr_next = pruneModelCallback(init_weight_dict=pr.init_weight_dict, mask_dict=pruned_weights.mask_dict)
        es = tf.keras.callbacks.EarlyStopping(patience=5, restore_best_weights=True,)

        history = model.fit(
            x=X_train,
            y=y_train,
            batch_size=64,
            epochs=epochs,
            validation_data=(X_test,y_test),
            callbacks = [pr_next, es],
        )

        pr = pr_next

if __name__ == "__main__":
    main()

Issue Analytics

  • State:closed
  • Created a year ago
  • Comments:11 (1 by maintainers)

github_iconTop GitHub Comments

2reactions
MatthewWiens101commented, Sep 9, 2022

@sachinprasadhs any update on this issue? Still looking for a faster workaround.

0reactions
google-ml-butler[bot]commented, Dec 14, 2022

Are you satisfied with the resolution of your issue? Yes No

Read more comments on GitHub >

github_iconTop Results From Across the Web

keras get_weights/set_weights take too long - Stack Overflow
I am running some code which repeatedly (every training iteration) calls layer.get_weights() and layer.set_weights().
Read more >
Move module_set_weight() into ModuleHandler::setWeight ...
Wondering why there is no corresponding module_get_weight() for module_set_weight(). The typical use case is to set my module's weight to ...
Read more >
Make function using Maya API 2.0 setWeights faster
Hi All, I have written the following code for a specific case. I want to copy the weight of a vertex to a...
Read more >
Keras layers API
While Keras offers a wide range of built-in layers, they don't cover ever possible use case. Creating custom layers is very common, and...
Read more >
get_weights() and set_weights() functions in Keras layers
The linear activation function is used as we are making a linear regression model. get_weights(). Use the get_weights() function to get the weights...
Read more >

github_iconTop Related Medium Post

No results found

github_iconTop Related StackOverflow Question

No results found

github_iconTroubleshoot Live Code

Lightrun enables developers to add logs, metrics and snapshots to live code - no restarts or redeploys required.
Start Free

github_iconTop Related Reddit Thread

No results found

github_iconTop Related Hackernoon Post

No results found

github_iconTop Related Tweet

No results found

github_iconTop Related Dev.to Post

No results found

github_iconTop Related Hashnode Post

No results found