question-mark
Stuck on an issue?

Lightrun Answers was designed to reduce the constant googling that comes with debugging 3rd party libraries. It collects links to all the places you might be looking at while hunting down a tough bug.

And, if you’re still stuck at the end, we’re happy to hop on a call to see how we can help out.

How to get a Confusion Matrix

See original GitHub issue

❓ Questions and Help

I made an image classifier with Lightning Flash and now I would like to see the confusion matrix of my classes. I am struggling to get this properly working.

I tried to add the metric to the ImageClassifier but as described in the documentation, only scalars as metrics are allowed - so this attempt failed.

model = ImageClassifier(
    num_classes=datamodule.num_classes,
    metrics=torchmetrics.ConfusionMatrix(datamodule.num_classes, compute_on_step=False),
)

Is there a proper way to get a confusion matrix?

If not, I would highly appreciate if there would be way to get the raw data from trainer.validate and trainer.test so that one can easily calculate their own metric.

# suggestion
val_res = trainer.validate(model, datamodule=datamodule, raw_results=True)
# >>> ({metric=99, ...}, ['class1, 'class2', ...], ['class2', 'class2', ...])

cheers and thanks for the library!

  • OS: [macOS, Linux]
  • Packaging [pip]
  • Version [e.g. 0.5.0]

Issue Analytics

  • State:closed
  • Created 2 years ago
  • Comments:7

github_iconTop GitHub Comments

1reaction
dmarxcommented, Oct 31, 2021

adding a minimal working example that reproduces the error I suspect Michl enocuntered:

# pip install lightning-flash[image] torchmetrics

import flash
import torch
from flash.image import ImageClassifier, ImageClassificationData
from torchvision.datasets import CIFAR10
import torchmetrics

datamodule = ImageClassificationData.from_datasets(
        train_dataset=CIFAR10('.', download=True, train=True),
    )

model = ImageClassifier(
    num_classes=10,
    metrics=torchmetrics.ConfusionMatrix(10, compute_on_step=False),
)

trainer = flash.Trainer(max_steps=30, gpus=torch.cuda.device_count())
trainer.finetune(model, datamodule=datamodule, strategy="freeze")

Produces the following error:


23 frames
/usr/local/lib/python3.7/dist-packages/flash/core/trainer.py in finetune(self, model, train_dataloader, val_dataloaders, datamodule, strategy)
    186         """
    187         self._resolve_callbacks(model, strategy)
--> 188         return super().fit(model, train_dataloader, val_dataloaders, datamodule)
    189 
    190     def _resolve_callbacks(self, model: LightningModule, strategy: Optional[Union[str, BaseFinetuning]] = None) -> None:

/usr/local/lib/python3.7/dist-packages/pytorch_lightning/trainer/trainer.py in fit(self, model, train_dataloaders, val_dataloaders, datamodule, train_dataloader)
    550         self.checkpoint_connector.resume_start()
    551 
--> 552         self._run(model)
    553 
    554         assert self.state.stopped

/usr/local/lib/python3.7/dist-packages/pytorch_lightning/trainer/trainer.py in _run(self, model)
    920 
    921         # dispatch `start_training` or `start_evaluating` or `start_predicting`
--> 922         self._dispatch()
    923 
    924         # plugin will finalized fitting (e.g. ddp_spawn will load trained model)

/usr/local/lib/python3.7/dist-packages/pytorch_lightning/trainer/trainer.py in _dispatch(self)
    988             self.accelerator.start_predicting(self)
    989         else:
--> 990             self.accelerator.start_training(self)
    991 
    992     def run_stage(self):

/usr/local/lib/python3.7/dist-packages/pytorch_lightning/accelerators/accelerator.py in start_training(self, trainer)
     90 
     91     def start_training(self, trainer: "pl.Trainer") -> None:
---> 92         self.training_type_plugin.start_training(trainer)
     93 
     94     def start_evaluating(self, trainer: "pl.Trainer") -> None:

/usr/local/lib/python3.7/dist-packages/pytorch_lightning/plugins/training_type/training_type_plugin.py in start_training(self, trainer)
    159     def start_training(self, trainer: "pl.Trainer") -> None:
    160         # double dispatch to initiate the training loop
--> 161         self._results = trainer.run_stage()
    162 
    163     def start_evaluating(self, trainer: "pl.Trainer") -> None:

/usr/local/lib/python3.7/dist-packages/pytorch_lightning/trainer/trainer.py in run_stage(self)
    998         if self.predicting:
    999             return self._run_predict()
-> 1000         return self._run_train()
   1001 
   1002     def _pre_training_routine(self):

/usr/local/lib/python3.7/dist-packages/flash/core/trainer.py in _run_train(self)
    123 
    124         self.fit_loop.trainer = self
--> 125         self.fit_loop.run()
    126 
    127     def fit(

/usr/local/lib/python3.7/dist-packages/pytorch_lightning/loops/base.py in run(self, *args, **kwargs)
    109             try:
    110                 self.on_advance_start(*args, **kwargs)
--> 111                 self.advance(*args, **kwargs)
    112                 self.on_advance_end()
    113                 self.iteration_count += 1

/usr/local/lib/python3.7/dist-packages/pytorch_lightning/loops/fit_loop.py in advance(self)
    198         with self.trainer.profiler.profile("run_training_epoch"):
    199             # run train epoch
--> 200             epoch_output = self.epoch_loop.run(train_dataloader)
    201 
    202             if epoch_output is None:

/usr/local/lib/python3.7/dist-packages/pytorch_lightning/loops/base.py in run(self, *args, **kwargs)
    116                 break
    117 
--> 118         output = self.on_run_end()
    119         return output
    120 

/usr/local/lib/python3.7/dist-packages/pytorch_lightning/loops/epoch/training_epoch_loop.py in on_run_end(self)
    233 
    234         # call train epoch end hooks
--> 235         self._on_train_epoch_end_hook(processed_outputs)
    236         self.trainer.call_hook("on_epoch_end")
    237         self.trainer.logger_connector.on_epoch_end()

/usr/local/lib/python3.7/dist-packages/pytorch_lightning/loops/epoch/training_epoch_loop.py in _on_train_epoch_end_hook(self, processed_epoch_output)
    274             if hasattr(self.trainer, hook_name):
    275                 trainer_hook = getattr(self.trainer, hook_name)
--> 276                 trainer_hook(processed_epoch_output)
    277 
    278             # next call hook in lightningModule

/usr/local/lib/python3.7/dist-packages/pytorch_lightning/trainer/callback_hook.py in on_train_epoch_end(self, outputs)
    107                 callback.on_train_epoch_end(self, self.lightning_module, outputs)
    108             else:
--> 109                 callback.on_train_epoch_end(self, self.lightning_module)
    110 
    111     def on_validation_epoch_start(self):

/usr/local/lib/python3.7/dist-packages/pytorch_lightning/callbacks/model_checkpoint.py in on_train_epoch_end(self, trainer, pl_module, unused)
    308             and (trainer.current_epoch + 1) % self._every_n_epochs == 0
    309         ):
--> 310             self.save_checkpoint(trainer)
    311         trainer.fit_loop.global_step += 1
    312 

/usr/local/lib/python3.7/dist-packages/pytorch_lightning/callbacks/model_checkpoint.py in save_checkpoint(self, trainer, unused)
    371         global_step = trainer.global_step
    372 
--> 373         self._validate_monitor_key(trainer)
    374 
    375         # track epoch when ckpt was last checked

/usr/local/lib/python3.7/dist-packages/pytorch_lightning/callbacks/model_checkpoint.py in _validate_monitor_key(self, trainer)
    677 
    678     def _validate_monitor_key(self, trainer: "pl.Trainer") -> None:
--> 679         metrics = trainer.callback_metrics
    680 
    681         # validate metric

/usr/local/lib/python3.7/dist-packages/pytorch_lightning/trainer/properties.py in callback_metrics(self)
    621     @property
    622     def callback_metrics(self) -> dict:
--> 623         return self.logger_connector.callback_metrics
    624 
    625     @property

/usr/local/lib/python3.7/dist-packages/pytorch_lightning/trainer/connectors/logger_connector/logger_connector.py in callback_metrics(self)
    307     def callback_metrics(self) -> Dict[str, _METRIC]:
    308         if self.trainer._results:
--> 309             metrics = self.metrics[MetricSource.CALLBACK]
    310             self._callback_metrics.update(metrics)
    311         return self._callback_metrics

/usr/local/lib/python3.7/dist-packages/pytorch_lightning/trainer/connectors/logger_connector/logger_connector.py in metrics(self)
    295         """This function returns either batch or epoch metrics depending on ``_epoch_end_reached``."""
    296         on_step = not self._epoch_end_reached
--> 297         return self.trainer._results.metrics(on_step)
    298 
    299     @property

/usr/local/lib/python3.7/dist-packages/pytorch_lightning/trainer/connectors/logger_connector/result.py in metrics(self, on_step)
    568             # populate progress_bar metrics. convert tensors to numbers
    569             if result_metric.meta.prog_bar:
--> 570                 metrics[MetricSource.PBAR][forked_name] = metrics_to_scalars(value)
    571 
    572         return metrics

/usr/local/lib/python3.7/dist-packages/pytorch_lightning/utilities/metrics.py in metrics_to_scalars(metrics)
     38         return value.item()
     39 
---> 40     return apply_to_collection(metrics, torch.Tensor, to_item)

/usr/local/lib/python3.7/dist-packages/pytorch_lightning/utilities/apply_func.py in apply_to_collection(data, dtype, function, wrong_dtype, include_none, *args, **kwargs)
     94     # Breaking condition
     95     if isinstance(data, dtype) and (wrong_dtype is None or not isinstance(data, wrong_dtype)):
---> 96         return function(data, *args, **kwargs)
     97 
     98     elem_type = type(data)

/usr/local/lib/python3.7/dist-packages/pytorch_lightning/utilities/metrics.py in to_item(value)
     34         if value.numel() != 1:
     35             raise MisconfigurationException(
---> 36                 f"The metric `{value}` does not contain a single element, thus it cannot be converted to a scalar."
     37             )
     38         return value.item()

MisconfigurationException: The metric `tensor([[6., 0., 1., 0., 3., 0., 0., 1., 3., 1.],
        [1., 0., 0., 1., 2., 0., 0., 5., 3., 2.],
        [1., 0., 0., 2., 2., 0., 0., 4., 0., 0.],
        [1., 0., 1., 0., 5., 0., 0., 1., 1., 0.],
        [2., 1., 2., 0., 2., 0., 0., 3., 3., 2.],
        [1., 0., 0., 1., 3., 1., 0., 1., 1., 2.],
        [1., 0., 0., 0., 3., 0., 0., 0., 2., 0.],
        [3., 0., 0., 0., 2., 0., 0., 3., 5., 2.],
        [6., 0., 0., 0., 1., 0., 0., 0., 1., 5.],
        [3., 0., 2., 0., 2., 0., 0., 1., 0., 6.]])` does not contain a single element, thus it cannot be converted to a scalar.```
0reactions
mmmichlcommented, Jan 4, 2022

We switched to pytorch lightning, where we handled the metrics on our own.

Read more comments on GitHub >

github_iconTop Results From Across the Web

What is a Confusion Matrix in Machine Learning
A confusion matrix is a summary of prediction results on a classification problem. The number of correct and incorrect predictions are ...
Read more >
A simple guide to building a confusion matrix - Oracle Blogs
A confusion matrix is a way of assessing the performance of a classification model. It is a comparison between the ground truth (actual...
Read more >
How to Calculate Confusion Matrix Manually. | Analytics Vidhya
Calculate Accuracy · Calculate Precision · Calculate Recall | Sensitivity | True Positive Rate — TPR · Calculate the F1 Score · Calculate...
Read more >
sklearn.metrics.confusion_matrix
Compute confusion matrix to evaluate the accuracy of a classification. By definition a confusion matrix C is such that ...
Read more >
Python Machine Learning - Confusion Matrix - W3Schools
Creating a Confusion Matrix ... Confusion matrixes can be created by predictions made from a logistic regression. ... Next we will need to...
Read more >

github_iconTop Related Medium Post

No results found

github_iconTop Related StackOverflow Question

No results found

github_iconTroubleshoot Live Code

Lightrun enables developers to add logs, metrics and snapshots to live code - no restarts or redeploys required.
Start Free

github_iconTop Related Reddit Thread

No results found

github_iconTop Related Hackernoon Post

No results found

github_iconTop Related Tweet

No results found

github_iconTop Related Dev.to Post

No results found

github_iconTop Related Hashnode Post

No results found