question-mark
Stuck on an issue?

Lightrun Answers was designed to reduce the constant googling that comes with debugging 3rd party libraries. It collects links to all the places you might be looking at while hunting down a tough bug.

And, if you’re still stuck at the end, we’re happy to hop on a call to see how we can help out.

Visualize the model/params structure

See original GitHub issue

Problem you have encountered:

Currently it is not clear how to inspect a model structure. For example to see whether some imported model uses dropout, batch-norm,… Or to find out which weight we want to freeze vs fine tune. Similarly the params returned by model.init is difficult to inspect, as extracting the structure require writing some custom code.

What you expected to happen:

It would be nice if __repr__ was displaying some human-readable structure, like in pytorch

  1. For params, the shape/dtype, rather than 1000+ lines of weights values:
>>> params
FrozenDict({
    params: {
        Encoder_0: {
            Conv_0: {
                kernel: float32[3, 3, 1, 16],
                bias: float32[16],
            },
            Conv_1: {
                kernel: float32[3, 3, 16, 32],
                bias: float32[32],
            },
            Conv_2: {
                kernel: float32[7, 7, 32, 64],
                bias: float32[64],
            },
        },
        Decoder_0: {
            ConvTranspose_0: {
                kernel: float32[7, 7, 64, 32],
                bias: float32[32],
            },
            ConvTranspose_1: {
                kernel: float32[3, 3, 32, 16],
                bias: float32[16],
            },
            ConvTranspose_2: {
                kernel: float32[3, 3, 16, 1],
                bias: float32[1],
            },
        },
    },
})
  1. For model, the modules names & submodules:

Torch for example display the model structure quite clearly, so it is easy to view which operations are used:

import torchvision.models.resnet as resnet

model = resnet.resnet18()
print(model)
ResNet(
  (conv1): Conv2d(3, 64, kernel_size=(7, 7), stride=(2, 2), padding=(3, 3), bias=False)
  (bn1): BatchNorm2d(64, eps=1e-05, momentum=0.1, affine=True, track_running_stats=True)
  (relu): ReLU(inplace=True)
  (maxpool): MaxPool2d(kernel_size=3, stride=2, padding=1, dilation=1, ceil_mode=False)
  (layer1): Sequential(
    (0): BasicBlock(
      (conv1): Conv2d(64, 64, kernel_size=(3, 3), stride=(1, 1), padding=(1, 1), bias=False)
      (bn1): BatchNorm2d(64, eps=1e-05, momentum=0.1, affine=True, track_running_stats=True)
      (relu): ReLU(inplace=True)
      (conv2): Conv2d(64, 64, kernel_size=(3, 3), stride=(1, 1), padding=(1, 1), bias=False)
      (bn2): BatchNorm2d(64, eps=1e-05, momentum=0.1, affine=True, track_running_stats=True)
    )
    (1): BasicBlock(
      (conv1): Conv2d(64, 64, kernel_size=(3, 3), stride=(1, 1), padding=(1, 1), bias=False)
      (bn1): BatchNorm2d(64, eps=1e-05, momentum=0.1, affine=True, track_running_stats=True)
      (relu): ReLU(inplace=True)
      (conv2): Conv2d(64, 64, kernel_size=(3, 3), stride=(1, 1), padding=(1, 1), bias=False)
      (bn2): BatchNorm2d(64, eps=1e-05, momentum=0.1, affine=True, track_running_stats=True)
    )
  )
  (layer2): Sequential(
    (0): BasicBlock(
      (conv1): Conv2d(64, 128, kernel_size=(3, 3), stride=(2, 2), padding=(1, 1), bias=False)
      (bn1): BatchNorm2d(128, eps=1e-05, momentum=0.1, affine=True, track_running_stats=True)
      (relu): ReLU(inplace=True)

Issue Analytics

  • State:closed
  • Created 3 years ago
  • Comments:7 (2 by maintainers)

github_iconTop GitHub Comments

9reactions
myaguescommented, Feb 26, 2021

You can use parameter_overview in clu for params visualization:

Commands
import jax
import numpy as np
from flax import linen as nn
from clu import parameter_overview

class CNN(nn.Module):
  @nn.compact
  def __call__(self, x):
    x = nn.Conv(features=32, kernel_size=(3, 3))(x)
    x = nn.relu(x)
    x = nn.avg_pool(x, window_shape=(2, 2), strides=(2, 2))
    x = nn.Conv(features=64, kernel_size=(3, 3))(x)
    x = nn.relu(x)
    x = nn.avg_pool(x, window_shape=(2, 2), strides=(2, 2))
    x = x.reshape((x.shape[0], -1))  # flatten
    x = nn.Dense(features=256)(x)
    x = nn.relu(x)
    x = nn.Dense(features=10)(x)
    x = nn.log_softmax(x)
    return x

key = jax.random.PRNGKey(0)
variables = CNN().init(key, np.random.randn(1, 32, 32, 3))
print(parameter_overview.get_parameter_overview(variables))
+-----------------------+----------------+-----------+-----------+--------+
| Name                  | Shape          | Size      | Mean      | Std    |
+-----------------------+----------------+-----------+-----------+--------+
| params/Conv_0/bias    | (32,)          | 32        | 0.0       | 0.0    |
| params/Conv_0/kernel  | (3, 3, 3, 32)  | 864       | 0.00277   | 0.2    |
| params/Conv_1/bias    | (64,)          | 64        | 0.0       | 0.0    |
| params/Conv_1/kernel  | (3, 3, 32, 64) | 18,432    | 0.000202  | 0.0591 |
| params/Dense_0/bias   | (256,)         | 256       | 0.0       | 0.0    |
| params/Dense_0/kernel | (4096, 256)    | 1,048,576 | -1.54e-05 | 0.0156 |
| params/Dense_1/bias   | (10,)          | 10        | 0.0       | 0.0    |
| params/Dense_1/kernel | (256, 10)      | 2,560     | -0.00159  | 0.0622 |
+-----------------------+----------------+-----------+-----------+--------+
Total: 1,070,794

5reactions
n2cholascommented, Feb 22, 2021

Haiku has excellent model summary functionality (documentation here). I think equivalents to haiku.experimental.tabulate and haiku.experimental.eval_summary would be very helpful.

Read more comments on GitHub >

github_iconTop Results From Across the Web

Visualization — Mesa .1 documentation
A visualization server which renders a model via one or more elements. ... with a “type” property which defines the rest of the...
Read more >
Viewing model architecture with Cloud Logging | AutoML Tables
This page provides information on how to use Cloud Logging to view details about an AutoML Tables model. Using Logging, you can see...
Read more >
Example Model Params — NuPIC 1.0.6.dev0 documentation
This raw model params file is used in the Quick Start. ... To see detailed algorithm parameters for the algorithms see their API...
Read more >
MLStyleTransfer.ModelParameters - Apple Developer
Structure. MLStyleTransfer.ModelParameters. Parameters that affect the training process of a style transfer model. iOS 15.0+ iPadOS 15.0+ macOS 11.0+ Mac ...
Read more >
MATLAB rsimgetrtp - Global model parameter structure
Return global parameter structure for model rtwdemo_rsimtf to param_struct . ... ModelParam, Value 1 if parameter is a model parameter and 0 if...
Read more >

github_iconTop Related Medium Post

No results found

github_iconTop Related StackOverflow Question

No results found

github_iconTroubleshoot Live Code

Lightrun enables developers to add logs, metrics and snapshots to live code - no restarts or redeploys required.
Start Free

github_iconTop Related Reddit Thread

No results found

github_iconTop Related Hackernoon Post

No results found

github_iconTop Related Tweet

No results found

github_iconTop Related Dev.to Post

No results found

github_iconTop Related Hashnode Post

No results found