question-mark
Stuck on an issue?

Lightrun Answers was designed to reduce the constant googling that comes with debugging 3rd party libraries. It collects links to all the places you might be looking at while hunting down a tough bug.

And, if you’re still stuck at the end, we’re happy to hop on a call to see how we can help out.

[Bug] loss_fn argument for Trainer must not be a function since 1.2.4

See original GitHub issue

Description

As in previous examples shown, loss_fn should be callable like this:

trainer = trax.supervised.Trainer(
    model=eval(train_model.selector),
    loss_fn=trax.layers.CrossEntropyLoss,
    optimizer=trax.optimizers.Adam,
    lr_schedule=trax.lr.MultifactorSchedule,
    inputs=trax.supervised.inputs.Inputs(train_stream),
    output_dir=output_dir,
)

However, since the latest upgrade to 1.2.4 this cannot not work anymore.

In the trainer_lib the loss_fn gets passed to a Serial constructor:

https://github.com/google/trax/blob/93f2bd47f5f17aacafe3f312ae56ce6f98d93ee7/trax/supervised/trainer_lib.py#L130

Which in turn runs _ensure_flat in it’s constructor

https://github.com/google/trax/blob/5b1565910a53d0d1175f647cc67db48e334d8f90/trax/layers/combinators.py#L47

However, all objects in layers have to be of type base.Laser:

def _ensure_flat(layers):
  """Ensures that layers is a single flat list of Layer instances."""
  if len(layers) == 1 and layers[0] is None:
    layers = ()
  else:
    layers = _deep_flatten(layers)
  for obj in layers:
    if not isinstance(obj, base.Layer):
      raise ValueError(
          f'Found nonlayer object ({obj}) in layers: {layers}')
  return layers

See

https://github.com/google/trax/blob/5b1565910a53d0d1175f647cc67db48e334d8f90/trax/layers/combinators.py#L775

Thus we’ll see an exception:

ValueError: Found nonlayer object (<function CrossEntropyLoss at 0x7fc5be59a9e0>) in layers:

Issue Analytics

  • State:closed
  • Created 3 years ago
  • Comments:6 (4 by maintainers)

github_iconTop GitHub Comments

1reaction
lukaszkaisercommented, Apr 28, 2020

I think the way we’re trying to go is trax.layers.Fn('MyLoss', loss_fn) - Fn is just Lambda. Thanks for your help as we try to make it understandable - we need to improve the docs too!

Closing for now as the immediate issue is resolved.

0reactions
stefan-falkcommented, Apr 28, 2020

You’re welcome. No big deal - in cases like this exceptions are rather clear and one can figure out what to do 😃 Regarding the argument: I really know too little at this point about trax to be able to say what makes sense but one idea would be to do like Keras when compiling the model (see docs):

Arguments:

  • loss: String (name of objective function), objective function or tf.keras.losses.Loss instance. See tf.keras.losses. An objective function is any callable with the signature scalar_loss = fn(y_true, y_pred). If the model has multiple outputs, you can use a different loss on each output by passing a dictionary or a list of losses. The loss value that will be minimized by the model will then be the sum of all individual losses.

But of course that will make the implementation of Trainer a bit more complex.

One advantage of functions could be that I can just provide a function in cases where I already have one in my code. Also I don’t need to write a wrapper-object for such a function or any function which does not require data or a state during execution.

After all I would say this is a design question. We can always have a trax.layers.LambdaLoss as in LambdaLoss(loss_fn=my_loss_fn) to provide a simple wrapper.

Thanks for bringing this to our attention!

I say thanks for sharing t2t and trax with us 😃

Read more comments on GitHub >

github_iconTop Results From Across the Web

Trainer - Hugging Face
The Trainer class is optimized for Transformers models and can have surprising behaviors when you use it on other models. When using it...
Read more >
Trainer — PyTorch Lightning 1.8.5.post0 documentation
The trainer will catch the KeyboardInterrupt and attempt a graceful shutdown, including running ... If you don't then use this argument for convenience....
Read more >
er_1110-2-1156.pdf - USACE Publications
Abbreviations and terms, which may not be familiar to the reader, are ... USACE has had an active Dam Safety Program since the....
Read more >
Hugging face: RuntimeError: model_init should have 0 or 1 ...
trainer = Trainer( model_init = get_model, args = training_args, ... The function may have zero argument, or a single one containing the ...
Read more >
Part 2: Guidelines for trainers
The trainer should not try to argue the point with such individuals, ... what the trainee is expected to know, or be able...
Read more >

github_iconTop Related Medium Post

No results found

github_iconTop Related StackOverflow Question

No results found

github_iconTroubleshoot Live Code

Lightrun enables developers to add logs, metrics and snapshots to live code - no restarts or redeploys required.
Start Free

github_iconTop Related Reddit Thread

No results found

github_iconTop Related Hackernoon Post

No results found

github_iconTop Related Tweet

No results found

github_iconTop Related Dev.to Post

No results found

github_iconTop Related Hashnode Post

No results found