OOM error with custom dataset - python systematically crashes after a couple of epochs
See original GitHub issueContext I am required to train a model to detect anomalies on images coming from a video stream from CCTV cameras. I already built the dataset with the same format as the MVTEC dataset (“good” images in a separated folder, “anomalies” class in a different one, and “ground truth” with their segmented mask in a third one). I created my own custom yaml file which looks like this (I intentionally removed the paths, please ignore those lines):
dataset:
name: brooks
format: folder
path: <removed_in_purpose>
normal_dir: <removed_in_purpose> # name of the folder containing normal images.
abnormal_dir: <removed_in_purpose> # name of the folder containing abnormal images.
normal_test_dir: null # name of the folder containing normal test images.
task: segmentation # classification or segmentation
mask: <removed_in_purpose> #optional
extensions: null
split_ratio: 0.1 # ratio of the normal images that will be used to create a test split
image_size: [512,512] #[256,256] #[115, 194] #[1149, 1940]
train_batch_size: 1
test_batch_size: 1
num_workers: 4
transform_config:
train: null
val: null
create_validation_set: true
tiling:
apply: false
tile_size: null
stride: null
remove_border_count: 0
use_random_tiling: False
random_tile_count: 16
model:
name: padim
backbone: resnet18
pre_trained: true
layers:
- layer1
- layer2
- layer3
normalization_method: min_max # options: [none, min_max, cdf]
metrics:
image:
- F1Score
- AUROC
pixel:
- F1Score
- AUROC
threshold:
image_default: 3
pixel_default: 3
adaptive: true
visualization:
show_images: False # show images on the screen
save_images: True # save images to the file system
log_images: True # log images to the available loggers (if any)
image_save_path: null # path to which images will be saved
mode: full # options: ["full", "simple"]
project:
seed: 42
path: <removed_in_purpose>
logging:
logger: [] # options: [tensorboard, wandb, csv] or combinations.
log_graph: false # Logs the model graph to respective logger.
#optimization:
# openvino:
# apply: false
# PL Trainer Args. Don't add extra parameter here.
trainer:
accelerator: auto # <"cpu", "gpu", "tpu", "ipu", "hpu", "auto">
accumulate_grad_batches: 1
amp_backend: native
auto_lr_find: false
auto_scale_batch_size: false
auto_select_gpus: false
benchmark: false
check_val_every_n_epoch: 1 # Don't validate before extracting features.
default_root_dir: null
detect_anomaly: false
deterministic: false
devices: 1
enable_checkpointing: true
enable_model_summary: true
enable_progress_bar: true
fast_dev_run: false
gpus: null # Set automatically
gradient_clip_val: 0
ipus: null
limit_predict_batches: 1.0
limit_test_batches: 1.0
limit_train_batches: 1.0
limit_val_batches: 1.0
log_every_n_steps: 50
max_epochs: 4
max_steps: -1
max_time: null
min_epochs: null
min_steps: null
move_metrics_to_cpu: false
multiple_trainloader_mode: max_size_cycle
num_nodes: 1
num_processes: null
num_sanity_val_steps: 0
overfit_batches: 0.0
plugins: null
precision: 32
profiler: null
reload_dataloaders_every_n_epochs: 0
replace_sampler_ddp: true
sync_batchnorm: false
tpu_cores: null
track_grad_norm: -1
val_check_interval: 1.0 # Don't validate before extracting features.
Describe the bug When trying to use the previous configuration file to train a padim network, the trainer will start but crash consistently after only one, two or three epochs (depending on the batch size, and the input image size) - see screenshot below.
As mentioned, I tried different batch sizes (as low as 1), number of epochs, and image input sizes, and these are some of the tests I tried:
Test | Image input size | Max number of epochs before crashing | Max batch size | Completed run (✅) or crashed (❌)? | Comments |
---|---|---|---|---|---|
Test 0 | [100,100] | 10 | 8 | ✅ | Accuracy too low !! |
Test 1 | [256,256] | 4 | 1 | ❌ | |
Test 1 | [200,200] | 4 | 1 | ✅ | Accuracy too low !! |
Test 2 | [256,256] | 1 | 1 | ✅ | Accuracy too low !! |
I have tested this using three different environments, with the same results: Using a 80 core Xeon CPU with 96GB of memory with no GPU; using an aws g5.xlarge instance with 16GB RAM and 24GB GPU (NVIDIA A10G); and using Google Colab. In all of them I get mostly the same results: the code just crashes after a couple of epochs. If I monitor the RAM/GPU usage, I can see that the process is killed once a certain max usage is achieved.
In summary: The only meaningful good results start when I train the model for input size > 256, and for more than 1 epoch. For an image input size of 100px, I can train it for only 10 epochs before it crashes. So effectively, I cannot train the model to achieve the accuracy I would expect.
Expected behavior
- I would expect to be able to train the model for as many epochs as I need, and for pytorch (or anomalib itself) to handle the tensors and training in a way that doesn’t blow up the memory.
Screenshots
Hardware and Software Configuration
My conda env config:
# packages in environment at /home/manuelbv/anaconda3/envs/anomalib_env:
#
# Name Version Build Channel
_libgcc_mutex 0.1 main
_openmp_mutex 5.1 1_gnu
anomalib 0.3.5 pypi_0 pypi
bcrypt 4.0.0 pypi_0 pypi
ca-certificates 2022.07.19 h06a4308_0
certifi 2022.6.15 py38h06a4308_0
cffi 1.15.1 pypi_0 pypi
click 8.1.3 pypi_0 pypi
cryptography 38.0.1 pypi_0 pypi
idna 3.4 pypi_0 pypi
ld_impl_linux-64 2.38 h1181459_1
libffi 3.3 he6710b0_2
libgcc-ng 11.2.0 h1234567_1
libgomp 11.2.0 h1234567_1
libstdcxx-ng 11.2.0 h1234567_1
monotonic 1.6 pypi_0 pypi
ncurses 6.3 h5eee18b_3
numpy 1.23.3 pypi_0 pypi
oauthlib 3.2.1 pypi_0 pypi
openssl 1.1.1q h7f8727e_0
pandas 1.4.4 pypi_0 pypi
paramiko 2.11.0 pypi_0 pypi
pillow 9.2.0 pypi_0 pypi
pip 22.1.2 py38h06a4308_0
protobuf 3.19.5 pypi_0 pypi
psutil 5.9.2 pypi_0 pypi
pycparser 2.21 pypi_0 pypi
pynacl 1.5.0 pypi_0 pypi
python 3.8.13 h12debd9_0
python-dateutil 2.8.2 pypi_0 pypi
pytz 2022.2.1 pypi_0 pypi
pyyaml 6.0 pypi_0 pypi
readline 8.1.2 h7f8727e_1
requests 2.28.1 pypi_0 pypi
scipy 1.9.1 pypi_0 pypi
setproctitle 1.3.2 pypi_0 pypi
setuptools 63.4.1 py38h06a4308_0
six 1.16.0 pypi_0 pypi
sqlite 3.39.2 h5082296_0
tk 8.6.12 h1ccaba5_0
tqdm 4.64.1 pypi_0 pypi
wheel 0.37.1 pyhd3eb1b0_0
xz 5.2.5 h7f8727e_1
zlib 1.2.12 h5eee18b_3
And pip freeze (inside the conda environment used):
absl-py==1.2.0
aiohttp==3.8.1
aiosignal==1.2.0
albumentations==1.2.1
analytics-python==1.4.0
-e git+https://github.com/openvinotoolkit/anomalib.git@a0e040d445a4f4f4e772cbad4e4630036d82bdc0#egg=anomalib
antlr4-python3-runtime==4.9.3
anyio==3.6.1
async-timeout==4.0.2
attrs==22.1.0
backoff==1.10.0
bcrypt==4.0.0
cachetools==5.2.0
certifi @ file:///opt/conda/conda-bld/certifi_1655968806487/work/certifi
cffi==1.15.1
charset-normalizer==2.1.1
click==8.1.3
cryptography==38.0.1
cycler==0.11.0
docker-pycreds==0.4.0
docstring-parser==0.15
einops==0.4.1
fastapi==0.83.0
ffmpy==0.3.0
fonttools==4.37.1
frozenlist==1.3.1
fsspec==2022.8.2
gitdb==4.0.9
GitPython==3.1.27
google-auth==2.11.0
google-auth-oauthlib==0.4.6
gradio==3.3
grpcio==1.48.1
h11==0.12.0
httpcore==0.15.0
httpx==0.23.0
idna==3.4
imageio==2.21.3
imgaug==0.4.0
importlib-metadata==4.12.0
Jinja2==3.1.2
joblib==1.1.0
jsonargparse==4.13.3
kiwisolver==1.4.4
kornia==0.6.7
linkify-it-py==1.0.3
Markdown==3.4.1
markdown-it-py==2.1.0
MarkupSafe==2.1.1
matplotlib==3.5.3
mdit-py-plugins==0.3.0
mdurl==0.1.2
monotonic==1.6
multidict==6.0.2
networkx==2.8.6
numpy==1.23.3
oauthlib==3.2.1
omegaconf==2.2.3
opencv-python==4.6.0.66
opencv-python-headless==4.6.0.66
orjson==3.8.0
packaging==21.3
pandas==1.4.4
paramiko==2.11.0
pathtools==0.1.2
Pillow==9.2.0
Pmw==2.0.1
promise==2.3
protobuf==3.19.5
psutil==5.9.2
pyasn1==0.4.8
pyasn1-modules==0.2.8
pycparser==2.21
pycryptodome==3.15.0
pydantic==1.10.2
pyDeprecate==0.3.2
pydub==0.25.1
PyNaCl==1.5.0
pyparsing==3.0.9
python-dateutil==2.8.2
python-gdsii==0.2.1
python-multipart==0.0.5
pytorch-lightning==1.6.5
pytz==2022.2.1
PyWavelets==1.3.0
PyYAML==6.0
qudida==0.0.4
requests==2.28.1
requests-oauthlib==1.3.1
rfc3986==1.5.0
rsa==4.9
ruamel.yaml==0.17.21
ruamel.yaml.clib==0.2.6
scikit-image==0.19.3
scikit-learn==1.1.2
scipy==1.9.1
sentry-sdk==1.9.8
setproctitle==1.3.2
Shapely==1.8.4
shortuuid==1.0.9
shyaml==0.6.2
six==1.16.0
smmap==5.0.0
sniffio==1.3.0
starlette==0.19.1
tensorboard==2.10.0
tensorboard-data-server==0.6.1
tensorboard-plugin-wit==1.8.1
threadpoolctl==3.1.0
tifffile==2022.8.12
timm==0.5.4
torch==1.11.0
torchmetrics==0.9.1
torchtext==0.12.0
torchvision==0.12.0
tqdm==4.64.1
typing_extensions==4.3.0
uc-micro-py==1.0.1
urllib3==1.26.12
uvicorn==0.18.3
vext==0.7.6
wandb==0.12.17
websockets==10.3
Werkzeug==2.2.2
yarl==1.8.1
zipp==3.8.1
Additional comments Could you please help me figure out how to train my model for as many epochs as I require to get my accuracy to a decent level, without the program crashing? Thank you!!!
Issue Analytics
- State:
- Created a year ago
- Comments:5 (4 by maintainers)
Top GitHub Comments
The error is probably GPU OOM in both cases and there’s not too much you can do about it besides increasing your GPU VRAM or reducing the training set size. The first is a bit more difficult so you should start with reducing the training set size.
limit_train_batches: 1.0
in your config file defines the percentage of training data that is used for training (100%). You can decrease the value and see at which points the remaining samples fit on your GPU memory.Also try to decrease decrease the image size to (256,256)
PaDim isn’t “trained” but is extracting image features at training time that are stored. If the features of every dataset image have been retrieved, at test time the test set image features are compared against the stored training features. Thus you don’t have to “train” PaDim for more than one epoch (except for maybe you use random image augmentations).
You accuracy will not rise after more epochs! Try different algorithms (e.g. PatchCore) or extraction backbones (e.g. Wide ResNet 50) or better training data.