[Bug] loss_fn argument for Trainer must not be a function since 1.2.4
See original GitHub issueDescription
As in previous examples shown, loss_fn should be callable like this:
trainer = trax.supervised.Trainer(
    model=eval(train_model.selector),
    loss_fn=trax.layers.CrossEntropyLoss,
    optimizer=trax.optimizers.Adam,
    lr_schedule=trax.lr.MultifactorSchedule,
    inputs=trax.supervised.inputs.Inputs(train_stream),
    output_dir=output_dir,
)
However, since the latest upgrade to 1.2.4 this cannot not work anymore.
In the trainer_lib the loss_fn gets passed to a Serial constructor:
Which in turn runs _ensure_flat in it’s constructor
However, all objects in layers have to be of type base.Laser:
def _ensure_flat(layers):
  """Ensures that layers is a single flat list of Layer instances."""
  if len(layers) == 1 and layers[0] is None:
    layers = ()
  else:
    layers = _deep_flatten(layers)
  for obj in layers:
    if not isinstance(obj, base.Layer):
      raise ValueError(
          f'Found nonlayer object ({obj}) in layers: {layers}')
  return layers
See
Thus we’ll see an exception:
ValueError: Found nonlayer object (<function CrossEntropyLoss at 0x7fc5be59a9e0>) in layers:
Issue Analytics
- State:
- Created 3 years ago
- Comments:6 (4 by maintainers)
 Top Results From Across the Web
Top Results From Across the Web
Trainer - Hugging Face
The Trainer class is optimized for Transformers models and can have surprising behaviors when you use it on other models. When using it...
Read more >Trainer — PyTorch Lightning 1.8.5.post0 documentation
The trainer will catch the KeyboardInterrupt and attempt a graceful shutdown, including running ... If you don't then use this argument for convenience....
Read more >er_1110-2-1156.pdf - USACE Publications
Abbreviations and terms, which may not be familiar to the reader, are ... USACE has had an active Dam Safety Program since the....
Read more >Hugging face: RuntimeError: model_init should have 0 or 1 ...
trainer = Trainer( model_init = get_model, args = training_args, ... The function may have zero argument, or a single one containing the ...
Read more >Part 2: Guidelines for trainers
The trainer should not try to argue the point with such individuals, ... what the trainee is expected to know, or be able...
Read more > Top Related Medium Post
Top Related Medium Post
No results found
 Top Related StackOverflow Question
Top Related StackOverflow Question
No results found
 Troubleshoot Live Code
Troubleshoot Live Code
Lightrun enables developers to add logs, metrics and snapshots to live code - no restarts or redeploys required.
Start Free Top Related Reddit Thread
Top Related Reddit Thread
No results found
 Top Related Hackernoon Post
Top Related Hackernoon Post
No results found
 Top Related Tweet
Top Related Tweet
No results found
 Top Related Dev.to Post
Top Related Dev.to Post
No results found
 Top Related Hashnode Post
Top Related Hashnode Post
No results found

I think the way we’re trying to go is
trax.layers.Fn('MyLoss', loss_fn)-Fnis justLambda. Thanks for your help as we try to make it understandable - we need to improve the docs too!Closing for now as the immediate issue is resolved.
You’re welcome. No big deal - in cases like this exceptions are rather clear and one can figure out what to do 😃 Regarding the argument: I really know too little at this point about trax to be able to say what makes sense but one idea would be to do like Keras when compiling the model (see docs):
But of course that will make the implementation of
Trainera bit more complex.One advantage of functions could be that I can just provide a function in cases where I already have one in my code. Also I don’t need to write a wrapper-object for such a function or any function which does not require data or a state during execution.
After all I would say this is a design question. We can always have a
trax.layers.LambdaLossas inLambdaLoss(loss_fn=my_loss_fn)to provide a simple wrapper.I say thanks for sharing t2t and trax with us 😃