MarginMSE and GPL's MarginDistillationLoss both throw "RuntimeError: Found dtype Long but expected Float"
See original GitHub issuePlease advice
code:
# Our training loss
train_loss = losses.MarginMSELoss(model)
# Configure the training
warmup_steps = math.ceil(len(train_dataloader) * num_epochs * 0.1) #10% of train data for warm-up
logging.info("Warmup-steps: {}".format(warmup_steps))
# Train the model
model.fit(train_objectives=[(train_dataloader, train_loss)],
evaluator=dev_evaluator,
epochs=num_epochs,
callback = email_callback,
evaluation_steps=1000,
warmup_steps=warmup_steps,
output_path=model_save_path,
checkpoint_save_total_limit = 3,
use_amp=False
Trace
[/usr/local/lib/python3.7/dist-packages/sentence_transformers/SentenceTransformer.py](https://localhost:8080/#) in fit(self, train_objectives, evaluator, epochs, steps_per_epoch, scheduler, warmup_steps, optimizer_class, optimizer_params, weight_decay, evaluation_steps, output_path, save_best_model, max_grad_norm, use_amp, callback, show_progress_bar, checkpoint_path, checkpoint_save_steps, checkpoint_save_total_limit)
711 else:
712 loss_value = loss_model(features, labels)
--> 713 loss_value.backward()
714 torch.nn.utils.clip_grad_norm_(loss_model.parameters(), max_grad_norm)
715 optimizer.step()
[/usr/local/lib/python3.7/dist-packages/torch/_tensor.py](https://localhost:8080/#) in backward(self, gradient, retain_graph, create_graph, inputs)
305 create_graph=create_graph,
306 inputs=inputs)
--> 307 torch.autograd.backward(self, gradient, retain_graph, create_graph, inputs=inputs)
308
309 def register_hook(self, hook):
[/usr/local/lib/python3.7/dist-packages/torch/autograd/__init__.py](https://localhost:8080/#) in backward(tensors, grad_tensors, retain_graph, create_graph, grad_variables, inputs)
154 Variable._execution_engine.run_backward(
155 tensors, grad_tensors_, retain_graph, create_graph, inputs,
--> 156 allow_unreachable=True, accumulate_grad=True) # allow_unreachable flag
157
158
RuntimeError: Found dtype Long but expected Float
Issue Analytics
- State:
- Created 2 years ago
- Comments:5 (3 by maintainers)
Top Results From Across the Web
RuntimeError: Found dtype Long but expected Float when fine ...
My dataset has two columns, Text and Sentiment , it looks like this. Text Sentiment This was good place 1 This was bad...
Read more >fine-tune RuntimeError: expected dtype Float but got ... - GitHub
Hi I have the last pytorch version (1.8.1+cu101), but have got the same RuntimeError: Found dtype Long but expected Float.
Read more >expected dtype Float but got dtype Long for my loss function ...
Creating a basic CNN model for binary image classification. Tried evaluating my model before training but receive the runtime error for ...
Read more >MarginMSE and GPL's MarginDistillationLoss both throw ...
MarginMSE and GPL's MarginDistillationLoss both throw "RuntimeError: Found dtype Long but expected Float". Please advice.
Read more >
Top Related Medium Post
No results found
Top Related StackOverflow Question
No results found
Troubleshoot Live Code
Lightrun enables developers to add logs, metrics and snapshots to live code - no restarts or redeploys required.
Start Free
Top Related Reddit Thread
No results found
Top Related Hackernoon Post
No results found
Top Related Tweet
No results found
Top Related Dev.to Post
No results found
Top Related Hashnode Post
No results found
Looks good
You need triplets and a label that gives the margin between query&pos and query&neg