Print `Trainable` as a column
See original GitHub issue🚀 Feature
New column in summary, Trainable
determines whether gradients need to be computed.
We can know this from model’s parameters easily:
for p in model.parameters():
print(p.requires_grad)
In short, expected feature is:
_________________________________________________________________________________________________________
Layer Type Output Shape Param # Trainable
=========================================================================================================
vgg VGG (-1, 1000) 0
├─features Sequential (-1, 512, 7, 7) 0
| └─0 Conv2d (-1, 64, 224, 224) 1,792 True
| └─1 ReLU (-1, 64, 224, 224) 0 -
| └─2 Conv2d (-1, 64, 224, 224) 36,928 True
| └─3 ReLU (-1, 64, 224, 224) 0 -
| └─4 MaxPool2d (-1, 64, 112, 112) 0
| └─5 Conv2d (-1, 128, 112, 112) 73,856 True
| └─6 ReLU (-1, 128, 112, 112) 0 -
...
├─classifier Sequential (-1, 1000) 0
| └─0 Linear (-1, 4096) 102,764,544 False
| └─1 ReLU (-1, 4096) 0 -
| └─2 Dropout (-1, 4096) 0 -
| └─3 Linear (-1, 4096) 16,781,312 False
| └─4 ReLU (-1, 4096) 0 -
| └─5 Dropout (-1, 4096) 0 -
| └─6 Linear (-1, 1000) 4,097,000 False
Motivation & pitch
I have been trying transfering model with DenseNet, and got summary.
model = torchvision.models.densenet201(pretrained=True)
model.classifier = nn.Sequential(
nn.Linear(1920, 10)
)
for p in model.classifier.parameters():
p.requires_grad = False
summary(model, (3, 224, 224))
but there is no information which layer is trainable. this is the tail of result.
| | | └─conv2 Conv2d (-1, 32, 7, 7) 36,864
| └─norm5 BatchNorm2d (-1, 1920, 7, 7) 7,681
├─classifier Sequential (-1, 10) 0
| └─0 Linear (-1, 10) 19,210
==========================================================================================
Trainable params: 18,092,928
Non-trainable params: 19,210
Total params: 18,112,138
Alternatives
No response
Additional context
I will wait for your response. I want to hear what you think about this.
Issue Analytics
- State:
- Created a year ago
- Comments:7 (7 by maintainers)
Top Results From Across the Web
While debugging, how to print all variables (which is in list ...
To print the complete list of all all variables or nodes of a tensor-flow graph, you may try this:
Read more >Learnable parameters ("trainable params") in a Keras ...
Let's discuss how we can quickly access and calculate the number of learnable parameters in a convolutional neural network (CNN) in code ...
Read more >How to Calculate the Number of Parameters in Keras Models
The “Param #” column shows you the number of parameters that are trained for each layer. ... In this model, all the layers...
Read more >Add layer trainable information in model summary and plot ...
I think he is trying to print which layers are trainable on the plot_model() visualization tool instead of the model.summary() function. AFAIK, ...
Read more >Manipulating GPflow models
Constraints and trainable variables¶. GPflow helpfully creates an unconstrained representation of all the variables. In the previous example, all the variables ...
Read more >Top Related Medium Post
No results found
Top Related StackOverflow Question
No results found
Troubleshoot Live Code
Lightrun enables developers to add logs, metrics and snapshots to live code - no restarts or redeploys required.
Start FreeTop Related Reddit Thread
No results found
Top Related Hackernoon Post
No results found
Top Related Tweet
No results found
Top Related Dev.to Post
No results found
Top Related Hashnode Post
No results found
Top GitHub Comments
Well, that will become hairy, I honestly don’t want to spread on multiple lines. The only suggestion I can see is:
good, I totally agree with you.
one thing I want to suggest is, it needs to be noticed from documentation. for example, “False; contains partial mixed-trainable parameters”