base on Implementation of "ZipLoRA: Any Subject in Any Style by Effectively Merging LoRAs" # ZipLoRA-pytorch
This is an implementation of [ZipLoRA: Any Subject in Any Style by Effectively Merging LoRAs](https://ziplora.github.io/) by [mkshing](https://twitter.com/mk1stats).
The paper summary by the author is found [here](https://twitter.com/natanielruizg/status/1727718489425616912).
![result](assets/result.png)
## Installation
```
git clone
[email protected]:mkshing/ziplora-pytorch.git
cd ziplora-pytorch
pip install -r requirements.txt
```
## Usage
### 1. Train LoRAs for subject/style images
In this step, 2 LoRAs for subject/style images are trained based on SDXL. Using SDXL here is important because they found that the pre-trained SDXL exhibits strong learning when fine-tuned on only one reference style image.
Fortunately, diffusers already implemented LoRA based on SDXL [here](https://github.com/huggingface/diffusers/blob/main/examples/dreambooth/README_sdxl.md) and you can simply follow the instruction.
For example, your training script would be like this.
```bash
export MODEL_NAME="stabilityai/stable-diffusion-xl-base-1.0"
# for subject
export OUTPUT_DIR="lora-sdxl-dog"
export INSTANCE_DIR="dog"
export PROMPT="a sbu dog"
export VALID_PROMPT="a sbu dog in a bucket"
# for style
# export OUTPUT_DIR="lora-sdxl-waterpainting"
# export INSTANCE_DIR="waterpainting"
# export PROMPT="a cat of in szn style"
# export VALID_PROMPT="a man in szn style"
accelerate launch train_dreambooth_lora_sdxl.py \
--pretrained_model_name_or_path=$MODEL_NAME \
--instance_data_dir=$INSTANCE_DIR \
--output_dir=$OUTPUT_DIR \
--instance_prompt="${PROMPT}" \
--rank=64 \
--resolution=1024 \
--train_batch_size=1 \
--learning_rate=5e-5 \
--report_to="wandb" \
--lr_scheduler="constant" \
--lr_warmup_steps=0 \
--max_train_steps=1000 \
--validation_prompt="${VALID_PROMPT}" \
--validation_epochs=50 \
--seed="0" \
--mixed_precision="fp16" \
--enable_xformers_memory_efficient_attention \
--gradient_checkpointing \
--use_8bit_adam \
--push_to_hub \
```
* In the above script, all hyperparameters such as `--max_train_steps` and `--rank` are followed the paper. But, of course, you can tweak them for your images.
* You can find style images in [aim-uofa/StyleDrop-PyTorch](https://github.com/aim-uofa/StyleDrop-PyTorch/tree/main/data).
### 2. Train ZipLoRA
```bash
export MODEL_NAME="stabilityai/stable-diffusion-xl-base-1.0"
# for subject
export LORA_PATH="mkshing/lora-sdxl-dog"
export INSTANCE_DIR="dog"
export PROMPT="a sbu dog"
# for style
export LORA_PATH2="mkshing/lora-sdxl-waterpainting"
export INSTANCE_DIR2="waterpainting"
export PROMPT2="a cat of in szn style"
# general
export OUTPUT_DIR="ziplora-sdxl-dog-waterpainting"
export VALID_PROMPT="a sbu dog in szn style"
accelerate launch train_dreambooth_ziplora_sdxl.py \
--pretrained_model_name_or_path=$MODEL_NAME \
--output_dir=$OUTPUT_DIR \
--lora_name_or_path=$LORA_PATH \
--instance_prompt="${PROMPT}" \
--instance_data_dir=$INSTANCE_DIR \
--lora_name_or_path_2=$LORA_PATH2 \
--instance_prompt_2="${PROMPT2}" \
--instance_data_dir_2=$INSTANCE_DIR2 \
--resolution=1024 \
--train_batch_size=1 \
--learning_rate=5e-5 \
--similarity_lambda=0.01 \
--lr_scheduler="constant" \
--lr_warmup_steps=0 \
--max_train_steps=100 \
--validation_prompt="${VALID_PROMPT}" \
--validation_epochs=10 \
--seed="0" \
--mixed_precision="fp16" \
--report_to="wandb" \
--gradient_checkpointing \
--use_8bit_adam \
--enable_xformers_memory_efficient_attention \
```
* If you're facing VRAM limitations during training, use the `--quick_release` flag to help free up VRAM.
### 3. Inference
```python
import torch
from diffusers import StableDiffusionXLPipeline
from ziplora_pytorch.utils import insert_ziplora_to_unet
pipeline = StableDiffusionXLPipeline.from_pretrained(pretrained_model_name_or_path)
pipeline.unet = insert_ziplora_to_unet(pipeline.unet, ziplora_name_or_path)
pipeline.to(device="cuda", dtype=torch.float16)
image = pipeline(prompt=prompt).images[0]
image.save("out.png")
```
Also, you can quickly interact with your ziplora by using gradio.
```bash
export MODEL_NAME="stabilityai/stable-diffusion-xl-base-1.0"
export ZIPLORA_PATH="..."
python inference.py --pretrained_model_name_or_path=$MODEL_NAME --ziplora_name_or_path=$ZIPLORA_PATH
```
## TODO
- [x] super quick instruction for training each loras
- [x] ZipLoRA (training)
- [x] ZipLoRA (inference)
- [ ] Pre-optimization lora weights
", Assign "at most 3 tags" to the expected json: {"id":"5347","tags":[]} "only from the tags list I provide: [{"id":77,"name":"3d"},{"id":89,"name":"agent"},{"id":17,"name":"ai"},{"id":54,"name":"algorithm"},{"id":24,"name":"api"},{"id":44,"name":"authentication"},{"id":3,"name":"aws"},{"id":27,"name":"backend"},{"id":60,"name":"benchmark"},{"id":72,"name":"best-practices"},{"id":39,"name":"bitcoin"},{"id":37,"name":"blockchain"},{"id":1,"name":"blog"},{"id":45,"name":"bundler"},{"id":58,"name":"cache"},{"id":21,"name":"chat"},{"id":49,"name":"cicd"},{"id":4,"name":"cli"},{"id":64,"name":"cloud-native"},{"id":48,"name":"cms"},{"id":61,"name":"compiler"},{"id":68,"name":"containerization"},{"id":92,"name":"crm"},{"id":34,"name":"data"},{"id":47,"name":"database"},{"id":8,"name":"declarative-gui "},{"id":9,"name":"deploy-tool"},{"id":53,"name":"desktop-app"},{"id":6,"name":"dev-exp-lib"},{"id":59,"name":"dev-tool"},{"id":13,"name":"ecommerce"},{"id":26,"name":"editor"},{"id":66,"name":"emulator"},{"id":62,"name":"filesystem"},{"id":80,"name":"finance"},{"id":15,"name":"firmware"},{"id":73,"name":"for-fun"},{"id":2,"name":"framework"},{"id":11,"name":"frontend"},{"id":22,"name":"game"},{"id":81,"name":"game-engine "},{"id":23,"name":"graphql"},{"id":84,"name":"gui"},{"id":91,"name":"http"},{"id":5,"name":"http-client"},{"id":51,"name":"iac"},{"id":30,"name":"ide"},{"id":78,"name":"iot"},{"id":40,"name":"json"},{"id":83,"name":"julian"},{"id":38,"name":"k8s"},{"id":31,"name":"language"},{"id":10,"name":"learning-resource"},{"id":33,"name":"lib"},{"id":41,"name":"linter"},{"id":28,"name":"lms"},{"id":16,"name":"logging"},{"id":76,"name":"low-code"},{"id":90,"name":"message-queue"},{"id":42,"name":"mobile-app"},{"id":18,"name":"monitoring"},{"id":36,"name":"networking"},{"id":7,"name":"node-version"},{"id":55,"name":"nosql"},{"id":57,"name":"observability"},{"id":46,"name":"orm"},{"id":52,"name":"os"},{"id":14,"name":"parser"},{"id":74,"name":"react"},{"id":82,"name":"real-time"},{"id":56,"name":"robot"},{"id":65,"name":"runtime"},{"id":32,"name":"sdk"},{"id":71,"name":"search"},{"id":63,"name":"secrets"},{"id":25,"name":"security"},{"id":85,"name":"server"},{"id":86,"name":"serverless"},{"id":70,"name":"storage"},{"id":75,"name":"system-design"},{"id":79,"name":"terminal"},{"id":29,"name":"testing"},{"id":12,"name":"ui"},{"id":50,"name":"ux"},{"id":88,"name":"video"},{"id":20,"name":"web-app"},{"id":35,"name":"web-server"},{"id":43,"name":"webassembly"},{"id":69,"name":"workflow"},{"id":87,"name":"yaml"}]" returns me the "expected json"