question-mark
Stuck on an issue?

Lightrun Answers was designed to reduce the constant googling that comes with debugging 3rd party libraries. It collects links to all the places you might be looking at while hunting down a tough bug.

And, if you’re still stuck at the end, we’re happy to hop on a call to see how we can help out.

Custom model can't get parameters

See original GitHub issue

Python:

import numpy as np
import torch as th
from torch import nn

class NatureCNN(nn.Module):
    def __init__(self, features_dim: int = 512) -> None:
        super().__init__()
        self.cnn = nn.Sequential(
            nn.Conv2d(3, 32, kernel_size=8, stride=4, padding=0),
            nn.ReLU(),
            nn.Conv2d(32, 64, kernel_size=4, stride=2, padding=0),
            nn.ReLU(),
            nn.Conv2d(64, 64, kernel_size=3, stride=1, padding=0),
            nn.ReLU(),
            nn.Flatten(),
        )

        temp = np.zeros((3, 128, 128))

        with th.no_grad():
            n_flatten = self.cnn(th.as_tensor(temp).float()).shape[1]

        self.linear = nn.Sequential(nn.Linear(n_flatten, features_dim), nn.ReLU())

    def forward(self, observations: th.Tensor) -> th.Tensor:
        return self.linear(self.cnn(observations))

cnn = NatureCNN()
print(cnn.parameters())
for param in cnn.parameters():
    print(param.size())

Console:
<generator object Module.parameters at 0x7f5dec222b20>
torch.Size([32, 3, 8, 8])
torch.Size([32])
torch.Size([64, 32, 4, 4])
torch.Size([64])
torch.Size([64, 64, 3, 3])
torch.Size([64])
torch.Size([512, 144])
torch.Size([512])

C#

using TorchSharp.Modules;
using static TorchSharp.torch;

namespace TestApp1;

public class NatureCNN : nn.Module<Tensor, Tensor>
{
    private readonly Sequential cnn;
    private readonly Sequential linear;

    public NatureCNN(int featuresDim = 512)
        : base(nameof(NatureCNN))
    {
        cnn = nn.Sequential(
            nn.Conv2d(3, 32, 8, 4, 0),
            nn.ReLU(),
            nn.Conv2d(32, 64, 4, 2, 0),
            nn.ReLU(),
            nn.Conv2d(64, 64, 3, 1, 0),
            nn.ReLU(),
            nn.Flatten()
            );


        double[,,] temp = new double[3, 128, 128];

        long nFlatten;
        // Compute shape by doing one forward pass
        using (no_grad())
        {
            nFlatten = cnn.forward(as_tensor(temp).@float()).shape[1];
        }
        linear = nn.Sequential(nn.Linear(nFlatten, featuresDim), nn.ReLU());
    }

    public override Tensor forward(Tensor observations)
        => linear.forward(cnn.forward(observations));

    protected override void Dispose(bool disposing)
    {
        if (disposing)
        {
            cnn.Dispose();
            linear.Dispose();
        }
        base.Dispose(disposing);
    }
}

NatureCNN cNN = new NatureCNN();
var a = cNN.parameters().ToList();
// a nothing

Issue Analytics

  • State:closed
  • Created 8 months ago
  • Comments:9 (9 by maintainers)

github_iconTop GitHub Comments

1reaction
NiklasGustafssoncommented, Jan 18, 2023

Yes, but you can serialize the weights to a MemoryStream using save(), then get the byte array from the stream, and store that with the environment. That should also do fewer copies than the code above.

0reactions
ChengYen-Tangcommented, Jan 19, 2023

OK, Thank you

Read more comments on GitHub >

github_iconTop Results From Across the Web

Can't get parameters into a model in c# - Stack Overflow
ParseQueryString gives you a NameValueCollection . You can simply access the keys if you know what they are and assign the value to...
Read more >
Using pydantic models for GET request query params? ...
Description Is there a way to use pydantic models for GET requests? I would like to have a similar interface for both query...
Read more >
Custom Model Binding in ASP.NET Core
Model binding allows controller actions to work directly with model types (passed in as method arguments), rather than HTTP requests. Mapping ...
Read more >
Model parameters—ArcGIS Pro | Documentation
To set a model variable as a parameter, the model must be edited in ModelBuilder. In ModelBuilder, right-click the variable and select Parameter....
Read more >
Use parameters in a custom query - Looker Studio Help
Parameters let you build more responsive, customizable reports. You can pass parameters in a data source back to the underlying query. To use...
Read more >

github_iconTop Related Medium Post

No results found

github_iconTop Related StackOverflow Question

No results found

github_iconTroubleshoot Live Code

Lightrun enables developers to add logs, metrics and snapshots to live code - no restarts or redeploys required.
Start Free

github_iconTop Related Reddit Thread

No results found

github_iconTop Related Hackernoon Post

No results found

github_iconTop Related Tweet

No results found

github_iconTop Related Dev.to Post

No results found

github_iconTop Related Hashnode Post

No results found