question-mark
Stuck on an issue?

Lightrun Answers was designed to reduce the constant googling that comes with debugging 3rd party libraries. It collects links to all the places you might be looking at while hunting down a tough bug.

And, if you’re still stuck at the end, we’re happy to hop on a call to see how we can help out.

train_dreambooth_flax.py 😭

See original GitHub issue

Describe the bug

I tried CompVis/stable-diffusion-v1-4 flax and bf16 branches with train_dreambooth_flax.py but not working and I can generate images with this code in same vm

    real_seed = random.randint(0, 2147483647)
    prng_seed = jax.random.PRNGKey(real_seed)
    num_samples = jax.device_count()
    prompt_n = num_samples * [prompt]
    prompt_ids = pipe.prepare_inputs(prompt_n)
    prng_seed = jax.random.split(prng_seed, jax.device_count())
    prompt_ids = shard(prompt_ids)
    images = pipe(prompt_ids, params, prng_seed, num_inference_steps=num_inference_steps, height=height, width=width, guidance_scale=guidance_scale, jit=True).images
    images = pipe.numpy_to_pil(np.asarray(images.reshape((num_samples,) + images.shape[-3:])))

From error log:

File "train_dreambooth_flax.py", line 370, in main
    prompt_ids = shard(prompt_ids)

    lambda x: x.reshape((local_device_count, -1) + x.shape[1:]), xs)
ValueError: cannot reshape array of size 308 into shape (8,newaxis,77)

https://github.com/huggingface/diffusers/blob/7fb4b882b9618005035fdd70d78b312bbcd2a5ef/examples/dreambooth/train_dreambooth_flax.py#L370

Reproduction

pip install "jax[tpu]>=0.2.16" -f https://storage.googleapis.com/jax-releases/libtpu_releases.html
git clone https://github.com/huggingface/diffusers
cd diffusers/examples/dreambooth
mkdir instance-images class-images save-model
pip install -U -r requirements_flax.txt
sudo apt-get install git-lfs
git lfs install
git clone -b flax https://user:token@huggingface.co/CompVis/stable-diffusion-v1-4
wget http://1.jpg http://2.jpg http://3.jpg -P instance-images
python3 train_dreambooth_flax.py --pretrained_model_name_or_path="/home/camenduru/diffusers/examples/dreambooth/stable-diffusion-v1-4"  \
--instance_data_dir="instance-images"  \
--class_data_dir="class-images"  \
--output_dir="save-model"  \
--with_prior_preservation  \
--prior_loss_weight=1.0  \
--instance_prompt="parkminyoung"  \
--class_prompt="person"  \
--resolution=512  \
--train_batch_size=1  \
--learning_rate=5e-6  \
--num_class_images=12  \
--max_train_steps=650 

Logs

INFO:__main__:Number of class images to sample: 12.
Generating class images:   0%|                                                                                                                                                               | 0/3 [00:00<?, ?it/s]
Traceback (most recent call last):
  File "train_dreambooth_flax.py", line 665, in <module>
    main()
  File "train_dreambooth_flax.py", line 370, in main
    prompt_ids = shard(prompt_ids)
  File "/home/camenduru/.local/lib/python3.8/site-packages/flax/training/common_utils.py", line 37, in shard
    return jax.tree_util.tree_map(
  File "/home/camenduru/.local/lib/python3.8/site-packages/jax/_src/tree_util.py", line 200, in tree_map
    return treedef.unflatten(f(*xs) for xs in zip(*all_leaves))
  File "/home/camenduru/.local/lib/python3.8/site-packages/jax/_src/tree_util.py", line 200, in <genexpr>
    return treedef.unflatten(f(*xs) for xs in zip(*all_leaves))
  File "/home/camenduru/.local/lib/python3.8/site-packages/flax/training/common_utils.py", line 38, in <lambda>
    lambda x: x.reshape((local_device_count, -1) + x.shape[1:]), xs)
ValueError: cannot reshape array of size 308 into shape (8,newaxis,77)


### System Info

gcloud | v3-8 | tpu-vm-base

Issue Analytics

  • State:closed
  • Created a year ago
  • Comments:7 (7 by maintainers)

github_iconTop GitHub Comments

1reaction
camendurucommented, Nov 3, 2022

650/650 [10:20<00:00, 1.05it/s] 🤯 🚀

python3 train_dreambooth_flax.py \
  --pretrained_model_name_or_path="/home/camenduru/diffusers/examples/dreambooth/stable-diffusion-v1-5"  \
  --instance_data_dir="instance-images" \
  --output_dir="save-model" \
  --instance_prompt="parkminyoung" \
  --resolution=512 \
  --mixed_precision="bf16" \
  --train_batch_size=1 \
  --learning_rate=1e-6 \
  --max_train_steps=650
INFO:__main__:***** Running training *****
INFO:__main__:  Num examples = 8
INFO:__main__:  Num Epochs = 650
INFO:__main__:  Instantaneous batch size per device = 1
INFO:__main__:  Total train batch size (w. parallel & distributed) = 8
INFO:__main__:  Total optimization steps = 650
1reaction
camendurucommented, Nov 3, 2022

@entrpn @duongna21 thanks ♥ also we can add to readme --mixed_precision=“bf16” then we just need 16GB card like v3-8

Read more comments on GitHub >

github_iconTop Results From Across the Web

No results found

github_iconTop Related Medium Post

No results found

github_iconTop Related StackOverflow Question

No results found

github_iconTroubleshoot Live Code

Lightrun enables developers to add logs, metrics and snapshots to live code - no restarts or redeploys required.
Start Free

github_iconTop Related Reddit Thread

No results found

github_iconTop Related Hackernoon Post

No results found

github_iconTop Related Tweet

No results found

github_iconTop Related Dev.to Post

No results found

github_iconTop Related Hashnode Post

No results found