get_weights()/set_weights() take too long
See original GitHub issueSystem information.
- TensorFlow version: 2.3.0
- CUDA version: 10.1
- GPU compute capability: 5.2
Describe the problem.
I am running some code which repeatedly (every training iteration) calls layer.get_weights() and layer.set_weights(). The callback operation containing these calls takes 0.009ms compared to the 0.003ms taken to run the batch and as such more than triples the training time required. I assume that this operation is simply moving tensors around (should be only on GPU) and thus should not take time comparable to the large matrix multiplications occurring during the batch iteration. I have reviewed the source code and to the best of my understanding this is what is happening. However, it is obviously taking an extraordinarily long time. Does anyone have any idea why this happens, or any approaches to reduce the time taken to call set_weights() and get_weights()? This abnormally long runtime may be due to the structure of the get_weights()/set_weights() functions, which is why I am raising this issue as a bug.
My intuition is that it may be due to data being sent to the CPU and back, or converted from tensors to numpy. Or, perhaps, upon calling set_weights, tensorflow rebuilds the entire graph from scratch or something similar.
One thing I noticed is that keras has their own pruning functionality shown here and this functionality incidentally also has a long callback runtime (see below). Perhaps this is related?
3/422 [..............................] - ETA: 12s - loss: 0.0628 - accuracy: 0.9896
WARNING:tensorflow:Callback method `on_train_batch_end` is slow compared to the batch time (batch time: 0.0075s vs `on_train_batch_end` time: 0.0076s). Check your callbacks.
Describe the current behavior.
The callback to on_train_batch_end() in the code below calls get_weights() twice and set_weights() once, and takes twice as long to run as the batch update:
Epoch 1/40
1/629 [..............................] - ETA: 0s - loss: 2.3555 - accuracy: 0.0000e+00
WARNING:tensorflow:Callbacks method `on_train_batch_end` is slow compared to the batch time (batch time: 0.0030s vs `on_train_batch_end` time: 0.0090s). Check your callbacks.
This is explicitly due to calling get_weights() and set_weights(), as their removal from the callback reduces runtime of the callback to negligible amounts.
Describe the expected behavior.
Ideally, I would like to achieve iterative magnitude pruning with the lowest possible runtime.
Standalone code to reproduce the issue.
import sys
import numpy as np
import tensorflow as tf
import tensorflow_probability as tfp
from sklearn.preprocessing import StandardScaler
from sklearn.preprocessing import OneHotEncoder
from sklearn.model_selection import train_test_split
from tensorflow import keras
from tensorflow.keras.models import Sequential
from tensorflow.keras.layers import Dense
from tensorflow.keras.callbacks import Callback
from tensorflow.keras.datasets import mnist
### PRUNE WEIGHTS CALLBACK ###
class pruneModelCallback(Callback):
def __init__(self, init_weight_dict=None, mask_dict=None):
self.n_batches = 0
self.init_weight_dict = init_weight_dict
self.mask_dict = mask_dict
def on_train_batch_begin(self, batch, logs=None):
# save weights at initialization
if self.n_batches == 0:
if self.init_weight_dict is not None:
for layer_i in range(len(self.model.layers)):
w = self.init_weight_dict['w_'+str(layer_i+1)]
b = self.init_weight_dict['b_'+str(layer_i+1)]
self.model.layers[layer_i].set_weights([w,b])
else:
self.init_weight_dict = {}
for layer_i in range(len(self.model.layers)):
w = self.model.layers[layer_i].get_weights()[0]
b = self.model.layers[layer_i].get_weights()[1]
self.init_weight_dict['w_'+str(layer_i+1)] = w
self.init_weight_dict['b_'+str(layer_i+1)] = b
self.n_batches = self.n_batches + 1
# This is the problematic function, runs every training iteration batch
def on_train_batch_end(self, batch, logs=None):
# zero out pruned weights
if self.mask_dict is not None:
for layer_i in range(len(self.model.layers)):
# removing these slightly improves runtime
w = self.model.layers[layer_i].get_weights()[0]
b = self.model.layers[layer_i].get_weights()[1]
w_mask = self.mask_dict['w_'+str(layer_i+1)]
# this multiplication takes no time comparably and removing it
# does not influence time taken
w_pruned = w * w_mask
# removing this function call significantly speeds up the runtime
self.model.layers[layer_i].set_weights([w_pruned,b])
class pruneWeights():
def __init__(self, model, percentile, pruning_type="IMP"):
# generate pruned mask
if pruning_type == "IMP":
return self._IMP(model, percentile)
else:
raise ValueError("Unknown pruning_type {}".format(pruning_type))
def _IMP(self, model, percentile):
mask_dict = {}
w_list = None
for layer_i in range(len(model.layers)):
w = model.layers[layer_i].get_weights()[0]
w_shape = tf.shape(w)
full_shape = tf.math.reduce_prod(w_shape)
w_flat = tf.reshape(w, full_shape)
if w_list is None:
w_list = w_flat
else:
w_list = tf.concat([w_list, w_flat], axis=0)
w_list = tf.math.abs(w_list)
thresh = tfp.stats.percentile(w_list, percentile*100)
test_mask = tf.cast(tf.math.greater(w_list, thresh), tf.float32)
for layer_i in range(len(model.layers)):
w = model.layers[layer_i].get_weights()[0]
w = tf.math.abs(w)
mask = tf.cast(tf.math.greater(w, thresh), tf.float32)
mask_dict['w_'+str(layer_i+1)] = mask
self.mask_dict = mask_dict
def pruning_breakdown(mask_dict, model):
for layer_i in range(len(model.layers)):
mask = mask_dict['w_'+str(layer_i+1)]
print("w_"+str(layer_i+1)+": "+str(tf.math.reduce_mean(mask)))
def main():
### LOAD MNIST DATASET ###
(x_train , y_train), (x_test , y_test) = mnist.load_data()
x = np.concatenate((x_train, x_test), axis=0)
y = np.concatenate((y_train, y_test), axis=0)
x= x.astype("float32") / 255
x= np.reshape(x, (np.shape(x)[0], 784))
scaler = StandardScaler(with_std=False)
scaler.fit(x)
x_t= scaler.transform(x)
ohe = OneHotEncoder()
y_t= ohe.fit_transform(y.reshape(-1, 1)).toarray()
input_dim = np.shape(x_t)[1:]
output_dim = np.shape(y_t)[1:]
del x_train, x_test, y_train, y_test, x, y
### SPLIT TRAINING DATA ###
X_train, X_test, y_train, y_test = train_test_split(x_t, y_t, test_size=0.33, random_state=42)
idxs = tf.range(tf.shape(X_train)[0])
### MODEL INIT ###
model = Sequential([
Dense(300, input_dim=input_dim[0], activation='relu'),
Dense(100, activation='relu'),
Dense(50, activation='relu'),
Dense(output_dim[0], activation='softmax')
])
model.compile(
optimizer = keras.optimizers.Adam(lr=1.2e-4),
loss = tf.keras.losses.CategoricalCrossentropy(),
metrics = ['accuracy']
)
### TRAIN MODEL ###
epochs = 20
pr = pruneModelCallback()
es = tf.keras.callbacks.EarlyStopping(patience=5, restore_best_weights=True,)
history = model.fit(
x=X_train,
y=y_train,
batch_size=64,
epochs=epochs,
validation_data=(X_test,y_test),
callbacks = [pr, es],
)
### ITERATIVELY PRUNE MODEL ###
percentile_per_it = 0.20
it = 14
for i in range(it):
total_percentile = 1-tf.math.pow(1-percentile_per_it, i+1)
print(total_percentile)
pruned_weights = pruneWeights(model, total_percentile)
pruning_breakdown(pruned_weights.mask_dict, model)
model = Sequential([
Dense(300, input_dim=input_dim[0], activation='relu'),
Dense(100, activation='relu'),
Dense(50, activation='relu'),
Dense(output_dim[0], activation='softmax')
])
model.compile(
optimizer = keras.optimizers.Adam(lr=1.2e-4),
loss = tf.keras.losses.CategoricalCrossentropy(),
metrics = ['accuracy']
)
### TRAIN MODEL ###
epochs = 40
pr_next = pruneModelCallback(init_weight_dict=pr.init_weight_dict, mask_dict=pruned_weights.mask_dict)
es = tf.keras.callbacks.EarlyStopping(patience=5, restore_best_weights=True,)
history = model.fit(
x=X_train,
y=y_train,
batch_size=64,
epochs=epochs,
validation_data=(X_test,y_test),
callbacks = [pr_next, es],
)
pr = pr_next
if __name__ == "__main__":
main()
Issue Analytics
- State:
- Created a year ago
- Comments:11 (1 by maintainers)
@sachinprasadhs any update on this issue? Still looking for a faster workaround.
Are you satisfied with the resolution of your issue? Yes No