question-mark
Stuck on an issue?

Lightrun Answers was designed to reduce the constant googling that comes with debugging 3rd party libraries. It collects links to all the places you might be looking at while hunting down a tough bug.

And, if you’re still stuck at the end, we’re happy to hop on a call to see how we can help out.

flax.errors.ScopeParamNotFoundError: No parameter named "kernel" exists in "/MLP_0/Dense_0" when attempting to use Jax2TF with a pre-trained JAX NeRF Model

See original GitHub issue

Redirected from the JAX repo (https://github.com/google/jax/issues/9139#issue-1096888310)

Tensorflow vers: 2.7; JAX vers: 0.2.24; jaxlib vers: 0.1.72+cuda111; FLAX vers: 0.3.6

The following code is based on the MNIST FLAX jax2tf example, which I adapted for JAX NeRF:

import collections
from os import path
from absl import app
from absl import flags
from flax.training import checkpoints
from jax import random
from jax.experimental.jax2tf.examples import saved_model_lib
from nerf import models
from nerf import utils
import tensorflow as tf

FLAGS = flags.FLAGS

utils.define_flags()

def main(unused_argv):
    rng = random.PRNGKey(20200823)
    rng, key = random.split(rng)
    utils.update_flags(FLAGS)
    utils.check_flags(FLAGS)    
    model, state = models.get_model_state(key, FLAGS, restore=False)
    print('Loading model')
    state = checkpoints.restore_checkpoint(FLAGS.train_dir, state)
    params = state.optimizer.target
    predict_fn = lambda params, input: model.apply({"params": params}, input)
    Rays = collections.namedtuple("Rays", ("origins", "directions", "viewdirs"))
    input_signatures = [Rays(origins=tf.TensorSpec((3,),tf.float32),directions=tf.TensorSpec((3,),tf.float32),viewdirs=tf.TensorSpec((3,),tf.float32))]
    saved_model_lib.convert_and_save_model(
        predict_fn,
        params,
        '/any/path/',
        input_signatures=input_signatures) 

if __name__ == "__main__":
    app.run(main)

In order to simplify the inputs to the network, and since I am only interested in running inference in TF, I initialize the RNG keys and randomized NeRF model inputs to None and False respectively, so that only the rays are inputted. This is the only change over the original JAX NeRF code:

def __call__(self, rays, rng_0 = None, rng_1=None, randomized=False, depth_gt = None, rgb_only = False,depth_sampling = False):
        """Nerf Model.

        Args:
          rng_0: jnp.ndarray, random number generator for coarse model sampling.
          rng_1: jnp.ndarray, random number generator for fine model sampling.
          rays: util.Rays, a namedtuple of ray origins, directions, and viewdirs.
          randomized: bool, use randomized stratified sampling.
          rgb_only: bool, return only rgb

        Returns:
          ret: list, [(rgb_coarse, disp_coarse, acc_coarse), (rgb, disp, acc)]
        """
        # Stratified sampling along rays
        if (randomized):  
            key, rng_0 = random.split(rng_0)
        else:
            key = None

(also, every call to model.apply() has its args order inverted to match this)

The error is prompted when attempting to compute the TF graph in this line of ‘saved_model_lib.py’:

tf_graph = tf.function(lambda inputs: tf_fn(param_vars, inputs),
                         autograph=False,
                         experimental_compile=compile_model)

Full error stack:

Traceback (most recent call last):
  File "/home/jorge/anaconda3/envs/jaxnerf/lib/python3.8/runpy.py", line 194, in _run_module_as_main
    return _run_code(code, main_globals, None,
  File "/home/jorge/anaconda3/envs/jaxnerf/lib/python3.8/runpy.py", line 87, in _run_code
    exec(code, run_globals)
  File "/home/jorge/jaxnerf/nerf/save_jax_as_tf.py", line 45, in <module>
    app.run(main)
  File "/home/jorge/anaconda3/envs/jaxnerf/lib/python3.8/site-packages/absl/app.py", line 312, in run
    _run_main(main, args)
  File "/home/jorge/anaconda3/envs/jaxnerf/lib/python3.8/site-packages/absl/app.py", line 258, in _run_main
    sys.exit(main(argv))
  File "/home/jorge/jaxnerf/nerf/save_jax_as_tf.py", line 38, in main
    saved_model_lib.convert_and_save_model(
  File "/home/jorge/anaconda3/envs/jaxnerf/lib/python3.8/site-packages/jax/experimental/jax2tf/examples/saved_model_lib.py", line 114, in convert_and_save_model
    tf_graph.get_concrete_function(input_signatures[0])
  File "/home/jorge/anaconda3/envs/jaxnerf/lib/python3.8/site-packages/tensorflow/python/eager/def_function.py", line 1259, in get_concrete_function
    concrete = self._get_concrete_function_garbage_collected(*args, **kwargs)
  File "/home/jorge/anaconda3/envs/jaxnerf/lib/python3.8/site-packages/tensorflow/python/eager/def_function.py", line 1239, in _get_concrete_function_garbage_collected
    self._initialize(args, kwargs, add_initializers_to=initializers)
  File "/home/jorge/anaconda3/envs/jaxnerf/lib/python3.8/site-packages/tensorflow/python/eager/def_function.py", line 780, in _initialize
    self._stateful_fn._get_concrete_function_internal_garbage_collected(  # pylint: disable=protected-access
  File "/home/jorge/anaconda3/envs/jaxnerf/lib/python3.8/site-packages/tensorflow/python/eager/function.py", line 3157, in _get_concrete_function_internal_garbage_collected
    graph_function, _ = self._maybe_define_function(args, kwargs)
  File "/home/jorge/anaconda3/envs/jaxnerf/lib/python3.8/site-packages/tensorflow/python/eager/function.py", line 3557, in _maybe_define_function
    graph_function = self._create_graph_function(args, kwargs)
  File "/home/jorge/anaconda3/envs/jaxnerf/lib/python3.8/site-packages/tensorflow/python/eager/function.py", line 3392, in _create_graph_function
    func_graph_module.func_graph_from_py_func(
  File "/home/jorge/anaconda3/envs/jaxnerf/lib/python3.8/site-packages/tensorflow/python/framework/func_graph.py", line 1143, in func_graph_from_py_func
    func_outputs = python_func(*func_args, **func_kwargs)
  File "/home/jorge/anaconda3/envs/jaxnerf/lib/python3.8/site-packages/tensorflow/python/eager/def_function.py", line 672, in wrapped_fn
    out = weak_wrapped_fn().__wrapped__(*args, **kwds)
  File "/home/jorge/anaconda3/envs/jaxnerf/lib/python3.8/site-packages/jax/experimental/jax2tf/examples/saved_model_lib.py", line 107, in <lambda>
    tf_graph = tf.function(lambda inputs: tf_fn(param_vars, inputs),
  File "/home/jorge/anaconda3/envs/jaxnerf/lib/python3.8/site-packages/jax/experimental/jax2tf/jax2tf.py", line 418, in converted_fun
    out_with_avals = _interpret_fun(flat_fun, args_flat, args_avals_flat,
  File "/home/jorge/anaconda3/envs/jaxnerf/lib/python3.8/site-packages/jax/experimental/jax2tf/jax2tf.py", line 486, in _interpret_fun
    fun.call_wrapped(*in_vals)
  File "/home/jorge/anaconda3/envs/jaxnerf/lib/python3.8/site-packages/jax/linear_util.py", line 166, in call_wrapped
    ans = self.f(*args, **dict(self.params, **kwargs))
  File "/home/jorge/anaconda3/envs/jaxnerf/lib/python3.8/site-packages/jax/experimental/jax2tf/jax2tf.py", line 272, in fun_no_kwargs
    return fun(*args, **kwargs)
  File "/home/jorge/jaxnerf/nerf/save_jax_as_tf.py", line 35, in <lambda>
    predict_fn = lambda params, input: model.apply({"params": params}, input)
  File "/home/jorge/jaxnerf/nerf/nerf/models.py", line 268, in __call__
    raw_rgb, raw_sigma = self.MLP_0(samples_enc)
  File "/home/jorge/jaxnerf/nerf_sh/nerf/model_utils.py", line 70, in __call__
    x = dense_layer(self.net_width)(x)
  File "/home/jorge/anaconda3/envs/jaxnerf/lib/python3.8/site-packages/flax/linen/linear.py", line 171, in __call__
    kernel = self.param('kernel',
flax.errors.ScopeParamNotFoundError: No parameter named "kernel" exists in "/MLP_0/Dense_0". (https://flax.readthedocs.io/en/latest/flax.errors.html#flax.errors.ScopeParamNotFoundError)

Has anyone else attempted to save a JAX NeRF model using jax2tf and encountered any such issue?

Issue Analytics

  • State:closed
  • Created 2 years ago
  • Comments:9

github_iconTop GitHub Comments

2reactions
levskayacommented, Jan 10, 2022

@Arcanous98 - thanks for providing the extra info!

Actually I just noticed something from your first response that I should have noticed immediately:

if the output of the inserted printout:

input_signatures = [Rays(origin....
print(jax.tree_map(jnp.shape, params)
saved_model_lib.convert_and_save_model(...

has this structure:

FrozenDict({
    params: {
        MLP_0: { ... }
        ...
    }
})

The params object shouldn’t itself have an extra params: layer inside it, since in your predict_fn function you write:

predict_fn = lambda params, input: model.apply({"params": params}, input)

which is adding an extra nesting layer under another "params" key, which would lead to precisely the error that you’re seeing.

In Flax the init function returns, and the apply function takes a variable (frozen-) dictionary structured at the top-level like

{
 "params": nested_param_dict, 
 "some_stateful_collection": some_stateful_collection_dict, 
 "some_other_stateful_collection": some_other_stateful_collection_dict
...
}

where each of those nested_dicts share the same module-defined nesting structure.

If you try to remove the extra {"params": ...} nesting, does your code run correctly?

1reaction
levskayacommented, Jan 10, 2022

Great! Happy to Help! Happy NeRFing. 😉

Read more comments on GitHub >

github_iconTop Results From Across the Web

flax.errors.ScopeParamNotFoundError: No parameter ... - GitHub
flax.errors.ScopeParamNotFoundError: No parameter named "kernel" exists in "/MLP_0/Dense_0" when attempting to use Jax2TF with a pre-trained JAX NeRF Model ...
Read more >
flax.errors
[docs]class ScopeParamNotFoundError(FlaxError): """This error is thrown when trying to access a parameter that does not exist. For instance, in the code ...
Read more >
jaxnerf/README.md · flax-community/DietNerf-Demo at main
[NeRF: Representing Scenes as Neural Radiance Fields for View Synthesis](http://www.matthewtancik.com/nerf). This code is created and maintained ...
Read more >
Implementing NeRF in JAX – Weights & Biases - Wandb
We train a model to learn the NeRF representation of a simple 3D scene using JAX and Flax on Google Cloud TPUs.
Read more >
[Jax + Flax] Minimal Implementation of NeRF | Kaggle
Jax supports the just-in-time (JIT) compilation of Python functions into XLA-optimized kernels using a one-function API.
Read more >

github_iconTop Related Medium Post

No results found

github_iconTop Related StackOverflow Question

No results found

github_iconTroubleshoot Live Code

Lightrun enables developers to add logs, metrics and snapshots to live code - no restarts or redeploys required.
Start Free

github_iconTop Related Reddit Thread

No results found

github_iconTop Related Hackernoon Post

No results found

github_iconTop Related Tweet

No results found

github_iconTop Related Dev.to Post

No results found

github_iconTop Related Hashnode Post

No results found