train_dreambooth_flax.py 😭
See original GitHub issueDescribe the bug
I tried CompVis/stable-diffusion-v1-4 flax and bf16 branches with train_dreambooth_flax.py but not working and I can generate images with this code in same vm
real_seed = random.randint(0, 2147483647)
prng_seed = jax.random.PRNGKey(real_seed)
num_samples = jax.device_count()
prompt_n = num_samples * [prompt]
prompt_ids = pipe.prepare_inputs(prompt_n)
prng_seed = jax.random.split(prng_seed, jax.device_count())
prompt_ids = shard(prompt_ids)
images = pipe(prompt_ids, params, prng_seed, num_inference_steps=num_inference_steps, height=height, width=width, guidance_scale=guidance_scale, jit=True).images
images = pipe.numpy_to_pil(np.asarray(images.reshape((num_samples,) + images.shape[-3:])))
From error log:
File "train_dreambooth_flax.py", line 370, in main
prompt_ids = shard(prompt_ids)
lambda x: x.reshape((local_device_count, -1) + x.shape[1:]), xs)
ValueError: cannot reshape array of size 308 into shape (8,newaxis,77)
Reproduction
pip install "jax[tpu]>=0.2.16" -f https://storage.googleapis.com/jax-releases/libtpu_releases.html
git clone https://github.com/huggingface/diffusers
cd diffusers/examples/dreambooth
mkdir instance-images class-images save-model
pip install -U -r requirements_flax.txt
sudo apt-get install git-lfs
git lfs install
git clone -b flax https://user:token@huggingface.co/CompVis/stable-diffusion-v1-4
wget http://1.jpg http://2.jpg http://3.jpg -P instance-images
python3 train_dreambooth_flax.py --pretrained_model_name_or_path="/home/camenduru/diffusers/examples/dreambooth/stable-diffusion-v1-4" \
--instance_data_dir="instance-images" \
--class_data_dir="class-images" \
--output_dir="save-model" \
--with_prior_preservation \
--prior_loss_weight=1.0 \
--instance_prompt="parkminyoung" \
--class_prompt="person" \
--resolution=512 \
--train_batch_size=1 \
--learning_rate=5e-6 \
--num_class_images=12 \
--max_train_steps=650
Logs
INFO:__main__:Number of class images to sample: 12.
Generating class images: 0%| | 0/3 [00:00<?, ?it/s]
Traceback (most recent call last):
File "train_dreambooth_flax.py", line 665, in <module>
main()
File "train_dreambooth_flax.py", line 370, in main
prompt_ids = shard(prompt_ids)
File "/home/camenduru/.local/lib/python3.8/site-packages/flax/training/common_utils.py", line 37, in shard
return jax.tree_util.tree_map(
File "/home/camenduru/.local/lib/python3.8/site-packages/jax/_src/tree_util.py", line 200, in tree_map
return treedef.unflatten(f(*xs) for xs in zip(*all_leaves))
File "/home/camenduru/.local/lib/python3.8/site-packages/jax/_src/tree_util.py", line 200, in <genexpr>
return treedef.unflatten(f(*xs) for xs in zip(*all_leaves))
File "/home/camenduru/.local/lib/python3.8/site-packages/flax/training/common_utils.py", line 38, in <lambda>
lambda x: x.reshape((local_device_count, -1) + x.shape[1:]), xs)
ValueError: cannot reshape array of size 308 into shape (8,newaxis,77)
### System Info
gcloud | v3-8 | tpu-vm-base
Issue Analytics
- State:
- Created a year ago
- Comments:7 (7 by maintainers)
Top Results From Across the Web
No results found
Top Related Medium Post
No results found
Top Related StackOverflow Question
No results found
Troubleshoot Live Code
Lightrun enables developers to add logs, metrics and snapshots to live code - no restarts or redeploys required.
Start FreeTop Related Reddit Thread
No results found
Top Related Hackernoon Post
No results found
Top Related Tweet
No results found
Top Related Dev.to Post
No results found
Top Related Hashnode Post
No results found
Top GitHub Comments
650/650 [10:20<00:00, 1.05it/s] 🤯 🚀
@entrpn @duongna21 thanks ♥ also we can add to readme --mixed_precision=“bf16” then we just need 16GB card like v3-8