flax.errors.ScopeParamNotFoundError: No parameter named "kernel" exists in "/MLP_0/Dense_0" when attempting to use Jax2TF with a pre-trained JAX NeRF Model
See original GitHub issueRedirected from the JAX repo (https://github.com/google/jax/issues/9139#issue-1096888310)
Tensorflow vers: 2.7; JAX vers: 0.2.24; jaxlib vers: 0.1.72+cuda111; FLAX vers: 0.3.6
The following code is based on the MNIST FLAX jax2tf example, which I adapted for JAX NeRF:
import collections
from os import path
from absl import app
from absl import flags
from flax.training import checkpoints
from jax import random
from jax.experimental.jax2tf.examples import saved_model_lib
from nerf import models
from nerf import utils
import tensorflow as tf
FLAGS = flags.FLAGS
utils.define_flags()
def main(unused_argv):
rng = random.PRNGKey(20200823)
rng, key = random.split(rng)
utils.update_flags(FLAGS)
utils.check_flags(FLAGS)
model, state = models.get_model_state(key, FLAGS, restore=False)
print('Loading model')
state = checkpoints.restore_checkpoint(FLAGS.train_dir, state)
params = state.optimizer.target
predict_fn = lambda params, input: model.apply({"params": params}, input)
Rays = collections.namedtuple("Rays", ("origins", "directions", "viewdirs"))
input_signatures = [Rays(origins=tf.TensorSpec((3,),tf.float32),directions=tf.TensorSpec((3,),tf.float32),viewdirs=tf.TensorSpec((3,),tf.float32))]
saved_model_lib.convert_and_save_model(
predict_fn,
params,
'/any/path/',
input_signatures=input_signatures)
if __name__ == "__main__":
app.run(main)
In order to simplify the inputs to the network, and since I am only interested in running inference in TF, I initialize the RNG keys and randomized
NeRF model inputs to None
and False
respectively, so that only the rays
are inputted. This is the only change over the original JAX NeRF code:
def __call__(self, rays, rng_0 = None, rng_1=None, randomized=False, depth_gt = None, rgb_only = False,depth_sampling = False):
"""Nerf Model.
Args:
rng_0: jnp.ndarray, random number generator for coarse model sampling.
rng_1: jnp.ndarray, random number generator for fine model sampling.
rays: util.Rays, a namedtuple of ray origins, directions, and viewdirs.
randomized: bool, use randomized stratified sampling.
rgb_only: bool, return only rgb
Returns:
ret: list, [(rgb_coarse, disp_coarse, acc_coarse), (rgb, disp, acc)]
"""
# Stratified sampling along rays
if (randomized):
key, rng_0 = random.split(rng_0)
else:
key = None
(also, every call to model.apply()
has its args order inverted to match this)
The error is prompted when attempting to compute the TF graph in this line of ‘saved_model_lib.py’:
tf_graph = tf.function(lambda inputs: tf_fn(param_vars, inputs),
autograph=False,
experimental_compile=compile_model)
Full error stack:
Traceback (most recent call last):
File "/home/jorge/anaconda3/envs/jaxnerf/lib/python3.8/runpy.py", line 194, in _run_module_as_main
return _run_code(code, main_globals, None,
File "/home/jorge/anaconda3/envs/jaxnerf/lib/python3.8/runpy.py", line 87, in _run_code
exec(code, run_globals)
File "/home/jorge/jaxnerf/nerf/save_jax_as_tf.py", line 45, in <module>
app.run(main)
File "/home/jorge/anaconda3/envs/jaxnerf/lib/python3.8/site-packages/absl/app.py", line 312, in run
_run_main(main, args)
File "/home/jorge/anaconda3/envs/jaxnerf/lib/python3.8/site-packages/absl/app.py", line 258, in _run_main
sys.exit(main(argv))
File "/home/jorge/jaxnerf/nerf/save_jax_as_tf.py", line 38, in main
saved_model_lib.convert_and_save_model(
File "/home/jorge/anaconda3/envs/jaxnerf/lib/python3.8/site-packages/jax/experimental/jax2tf/examples/saved_model_lib.py", line 114, in convert_and_save_model
tf_graph.get_concrete_function(input_signatures[0])
File "/home/jorge/anaconda3/envs/jaxnerf/lib/python3.8/site-packages/tensorflow/python/eager/def_function.py", line 1259, in get_concrete_function
concrete = self._get_concrete_function_garbage_collected(*args, **kwargs)
File "/home/jorge/anaconda3/envs/jaxnerf/lib/python3.8/site-packages/tensorflow/python/eager/def_function.py", line 1239, in _get_concrete_function_garbage_collected
self._initialize(args, kwargs, add_initializers_to=initializers)
File "/home/jorge/anaconda3/envs/jaxnerf/lib/python3.8/site-packages/tensorflow/python/eager/def_function.py", line 780, in _initialize
self._stateful_fn._get_concrete_function_internal_garbage_collected( # pylint: disable=protected-access
File "/home/jorge/anaconda3/envs/jaxnerf/lib/python3.8/site-packages/tensorflow/python/eager/function.py", line 3157, in _get_concrete_function_internal_garbage_collected
graph_function, _ = self._maybe_define_function(args, kwargs)
File "/home/jorge/anaconda3/envs/jaxnerf/lib/python3.8/site-packages/tensorflow/python/eager/function.py", line 3557, in _maybe_define_function
graph_function = self._create_graph_function(args, kwargs)
File "/home/jorge/anaconda3/envs/jaxnerf/lib/python3.8/site-packages/tensorflow/python/eager/function.py", line 3392, in _create_graph_function
func_graph_module.func_graph_from_py_func(
File "/home/jorge/anaconda3/envs/jaxnerf/lib/python3.8/site-packages/tensorflow/python/framework/func_graph.py", line 1143, in func_graph_from_py_func
func_outputs = python_func(*func_args, **func_kwargs)
File "/home/jorge/anaconda3/envs/jaxnerf/lib/python3.8/site-packages/tensorflow/python/eager/def_function.py", line 672, in wrapped_fn
out = weak_wrapped_fn().__wrapped__(*args, **kwds)
File "/home/jorge/anaconda3/envs/jaxnerf/lib/python3.8/site-packages/jax/experimental/jax2tf/examples/saved_model_lib.py", line 107, in <lambda>
tf_graph = tf.function(lambda inputs: tf_fn(param_vars, inputs),
File "/home/jorge/anaconda3/envs/jaxnerf/lib/python3.8/site-packages/jax/experimental/jax2tf/jax2tf.py", line 418, in converted_fun
out_with_avals = _interpret_fun(flat_fun, args_flat, args_avals_flat,
File "/home/jorge/anaconda3/envs/jaxnerf/lib/python3.8/site-packages/jax/experimental/jax2tf/jax2tf.py", line 486, in _interpret_fun
fun.call_wrapped(*in_vals)
File "/home/jorge/anaconda3/envs/jaxnerf/lib/python3.8/site-packages/jax/linear_util.py", line 166, in call_wrapped
ans = self.f(*args, **dict(self.params, **kwargs))
File "/home/jorge/anaconda3/envs/jaxnerf/lib/python3.8/site-packages/jax/experimental/jax2tf/jax2tf.py", line 272, in fun_no_kwargs
return fun(*args, **kwargs)
File "/home/jorge/jaxnerf/nerf/save_jax_as_tf.py", line 35, in <lambda>
predict_fn = lambda params, input: model.apply({"params": params}, input)
File "/home/jorge/jaxnerf/nerf/nerf/models.py", line 268, in __call__
raw_rgb, raw_sigma = self.MLP_0(samples_enc)
File "/home/jorge/jaxnerf/nerf_sh/nerf/model_utils.py", line 70, in __call__
x = dense_layer(self.net_width)(x)
File "/home/jorge/anaconda3/envs/jaxnerf/lib/python3.8/site-packages/flax/linen/linear.py", line 171, in __call__
kernel = self.param('kernel',
flax.errors.ScopeParamNotFoundError: No parameter named "kernel" exists in "/MLP_0/Dense_0". (https://flax.readthedocs.io/en/latest/flax.errors.html#flax.errors.ScopeParamNotFoundError)
Has anyone else attempted to save a JAX NeRF model using jax2tf and encountered any such issue?
Issue Analytics
- State:
- Created 2 years ago
- Comments:9
Top GitHub Comments
@Arcanous98 - thanks for providing the extra info!
Actually I just noticed something from your first response that I should have noticed immediately:
if the output of the inserted printout:
has this structure:
The
params
object shouldn’t itself have an extraparams:
layer inside it, since in yourpredict_fn
function you write:which is adding an extra nesting layer under another
"params"
key, which would lead to precisely the error that you’re seeing.In Flax the
init
function returns, and theapply
function takes a variable (frozen-) dictionary structured at the top-level likewhere each of those nested_dicts share the same module-defined nesting structure.
If you try to remove the extra
{"params": ...}
nesting, does your code run correctly?Great! Happy to Help! Happy NeRFing. 😉