AI prompts
base on Official PyTorch implementation of ECCV 2024 Paper: ControlNet++: Improving Conditional Controls with Efficient Consistency Feedback. # [ControlNet++: Improving Conditional Controls with Efficient Consistency Feedback (ECCV 2024)](https://liming-ai.github.io/ControlNet_Plus_Plus/)
<div align="center">
[![arXiv](https://img.shields.io/badge/arXiv-2404.07987-b31b1b.svg)](https://arxiv.org/abs/2404.07987)
[![huggingface demo](https://img.shields.io/badge/%F0%9F%A4%97%20Spaces-ControlNet++-yellow)](https://huggingface.co/spaces/limingcv/ControlNet-Plus-Plus)
</div>
## 🕹️ Environments
```bash
git clone https://github.com/liming-ai/ControlNet_Plus_Plus.git
pip3 install -r requirements.txt
pip3 install -U openmim
mim install mmengine
mim install "mmcv==2.1.0"
pip3 install "mmsegmentation>=1.0.0"
pip3 install mmdet
pip3 install clean-fid
pip3 install torchmetrics
```
## 🕹️ Data Preperation
**All the organized data has been put on Huggingface and will be automatically downloaded during training or evaluation.** You can preview it in advance to check the data samples and disk space occupied with following links.
| Task | Training Data 🤗 | Evaluation Data 🤗 |
|:----------:|:------------------------------------------------------------------------------------|:------------------------------------------------------------------------------------|
| LineArt, Hed, Canny | [Data](https://huggingface.co/datasets/limingcv/MultiGen-20M_train), 1.14 TB | [Data](https://huggingface.co/datasets/limingcv/MultiGen-20M_canny_eval), 2.25GB |
| Depth | [Data](https://huggingface.co/datasets/limingcv/MultiGen-20M_depth), 1.22 TB | [Data](https://huggingface.co/datasets/limingcv/MultiGen-20M_depth_eval), 2.17GB |
| Segmentation ADE20K | [Data](https://huggingface.co/datasets/limingcv/Captioned_ADE20K), 7.04 GB | Same Path as Training Data |
| Segmentation COCOStuff | [Data](https://huggingface.co/datasets/limingcv/Captioned_COCOStuff), 61.9 GB | Same Path as Training Data |
## 🕹️ Training
By default, our training is based on 8 A100-80G GPUs. If your computational resources are insufficient for training, you may need to reduce the batch size and increase gradient accumulation at the same time, and we have not observed any performance degradation. Reducing the training resolution will result in performance degradation.
### For segmentation task
[ControlNet V1.1 Seg](https://github.com/lllyasviel/ControlNet-v1-1-nightly/blob/main/README.md#controlnet-11-segmentation) is trained on both ADE20K and COCOStuff, and these two datasets have different masks. To this end, we first perform normal model fine-tuning on each dataset, and then perform reward fine-tuning.
```bash
# Please refer to the reward script for details
bash train/reward_ade20k.sh
bash train/reward_cocostuff.sh
```
### For other tasks
We can directly perform reward fine-tuning.
```bash
bash train/reward_canny.sh
bash train/reward_depth.sh
bash train/reward_hed.sh
bash train/reward_linedrawing.sh
```
### Core Code
Please refer to the [core code here](https://github.com/liming-ai/ControlNet_Plus_Plus/blob/9167f0d85ccc5ad1eb9a83f3e7fa8d3422d5d9d5/train/reward_control.py#L1429), in summary:
#### Step 1: Predict the single-step denoised RGB image with noise sampler:
```python
# Predict the single-step denoised latents
pred_original_sample = [
noise_scheduler.step(noise, t, noisy_latent).pred_original_sample.to(weight_dtype) \
for (noise, t, noisy_latent) in zip(model_pred, timesteps, noisy_latents)
]
pred_original_sample = torch.stack(pred_original_sample)
# Map the denoised latents into RGB images
pred_original_sample = 1 / vae.config.scaling_factor * pred_original_sample
image = vae.decode(pred_original_sample.to(weight_dtype)).sample
image = (image / 2 + 0.5).clamp(0, 1)
```
#### Step 2: Normalize the single-step denoised images according to different reward models
```python
# The normalization depends on different reward models.
if args.task_name == 'depth':
image = torchvision.transforms.functional.resize(image, (384, 384))
image = normalize(image, (0.5, 0.5, 0.5), (0.5, 0.5, 0.5))
elif args.task_name in ['canny', 'lineart', 'hed']:
pass
else:
image = normalize(image, (0.485, 0.456, 0.406), (0.229, 0.224, 0.225))
```
#### Step 3: Apply both diffusion training loss and reward loss:
```python
# reward model inference
if args.task_name == 'canny':
outputs = reward_model(image.to(accelerator.device), low_threshold, high_threshold)
else:
outputs = reward_model(image.to(accelerator.device))
# Determine which samples in the current batch need to calculate reward loss
timestep_mask = (args.min_timestep_rewarding <= timesteps.reshape(-1, 1)) & (timesteps.reshape(-1, 1) <= args.max_timestep_rewarding)
# Calculate reward loss
reward_loss = get_reward_loss(outputs, labels, args.task_name, reduction='none')
# Calculate final loss
reward_loss = reward_loss.reshape_as(timestep_mask)
reward_loss = (timestep_mask * reward_loss).sum() / (timestep_mask.sum() + 1e-10)
loss = pretrain_loss + reward_loss * args.grad_scale
```
## 🕹️ Evaluation
### Checkpoints Preparation
Please download the model weights and put them into each subset of `checkpoints`:
| model |HF weights🤗 |
|:----------:|:------------------------------------------------------------------------------------|
| LineArt | [model](https://huggingface.co/limingcv/reward_controlnet/tree/main/checkpoints/lineart) |
| Depth | [model](https://huggingface.co/limingcv/reward_controlnet/tree/main/checkpoints/depth) |
| Hed (SoftEdge) | [model](https://huggingface.co/limingcv/reward_controlnet/tree/main/checkpoints/hed) |
| Canny | [model](https://huggingface.co/limingcv/reward_controlnet/tree/main/checkpoints/canny) |
| Segmentation (ADE20K) | [UperNet-R50](https://huggingface.co/limingcv/reward_controlnet/tree/main/checkpoints/ade20k_reward-model-UperNet-R50/checkpoint-5000/controlnet), [FCN-R101](https://huggingface.co/limingcv/reward_controlnet/tree/main/checkpoints/ade20k_reward-model-FCN-R101-d8/checkpoint-5000/controlnet) |
| Segmentation (COCOStuff) | [model](https://huggingface.co/limingcv/reward_controlnet/tree/main/checkpoints/cocostuff/reward_5k) |
### Evaluate Controllability
Please make sure the folder directory is consistent with the test script, then you can eval each model by:
```bash
bash eval/eval_ade20k.sh
bash eval/eval_cocostuff.sh
bash eval/eval_canny.sh
bash eval/eval_depth.sh
bash eval/eval_hed.sh
bash eval/eval_linedrawing.sh
```
*The segmentation mIoU results of ControlNet and ControlNet++ in the arXiv v1 version of the paper were tested using images and labels saved in `.jpg` format, which resulted in errors. We retested and reported the results using images and labels saved in `.png` format, please refer to our latest arXiv and ECCV Camera Ready releases. Other comparison methods (Gligen/T2I-Adapter/UniControl/UniControlNet) and other evaluation metrics (FID/CLIP-score) were not affected by this error.*
### Evaluate CLIP-Score and FID
Please refer to the code for evaluating [CLIP-Score](eval/eval_clip.py) and [FID](eval/eval_fid.py)
## 🕹️ Inference
Please refer to the [Inference Branch](https://github.com/liming-ai/ControlNet_Plus_Plus/tree/inference) or try our [online Huggingface demo](https://huggingface.co/spaces/limingcv/ControlNet-Plus-Plus)
## License
This project is licensed under the Apache License 2.0 - see the [LICENSE](LICENSE.txt) file for details.
## 🙏 Acknowledgements
We sincerely thank the [Huggingface](https://huggingface.co), [ControlNet](https://github.com/lllyasviel/ControlNet), [OpenMMLab](https://github.com/open-mmlab) and [ImageReward](https://github.com/THUDM/ImageReward) communities for their open source code and contributions. Our project would not be possible without these amazing works.
## Citation
If our work assists your research, feel free to give us a star ⭐ or cite us using:
```
@inproceedings{controlnet_plus_plus,
author = {Ming Li and Taojiannan Yang and Huafeng Kuang and Jie Wu and Zhaoning Wang and Xuefeng Xiao and Chen Chen},
title = {ControlNet $$++ $$: Improving Conditional Controls with Efficient Consistency Feedback},
booktitle = {European Conference on Computer Vision (ECCV)},
year = {2024},
}
```
", Assign "at most 3 tags" to the expected json: {"id":"9531","tags":[]} "only from the tags list I provide: [{"id":77,"name":"3d"},{"id":89,"name":"agent"},{"id":17,"name":"ai"},{"id":54,"name":"algorithm"},{"id":24,"name":"api"},{"id":44,"name":"authentication"},{"id":3,"name":"aws"},{"id":27,"name":"backend"},{"id":60,"name":"benchmark"},{"id":72,"name":"best-practices"},{"id":39,"name":"bitcoin"},{"id":37,"name":"blockchain"},{"id":1,"name":"blog"},{"id":45,"name":"bundler"},{"id":58,"name":"cache"},{"id":21,"name":"chat"},{"id":49,"name":"cicd"},{"id":4,"name":"cli"},{"id":64,"name":"cloud-native"},{"id":48,"name":"cms"},{"id":61,"name":"compiler"},{"id":68,"name":"containerization"},{"id":92,"name":"crm"},{"id":34,"name":"data"},{"id":47,"name":"database"},{"id":8,"name":"declarative-gui "},{"id":9,"name":"deploy-tool"},{"id":53,"name":"desktop-app"},{"id":6,"name":"dev-exp-lib"},{"id":59,"name":"dev-tool"},{"id":13,"name":"ecommerce"},{"id":26,"name":"editor"},{"id":66,"name":"emulator"},{"id":62,"name":"filesystem"},{"id":80,"name":"finance"},{"id":15,"name":"firmware"},{"id":73,"name":"for-fun"},{"id":2,"name":"framework"},{"id":11,"name":"frontend"},{"id":22,"name":"game"},{"id":81,"name":"game-engine "},{"id":23,"name":"graphql"},{"id":84,"name":"gui"},{"id":91,"name":"http"},{"id":5,"name":"http-client"},{"id":51,"name":"iac"},{"id":30,"name":"ide"},{"id":78,"name":"iot"},{"id":40,"name":"json"},{"id":83,"name":"julian"},{"id":38,"name":"k8s"},{"id":31,"name":"language"},{"id":10,"name":"learning-resource"},{"id":33,"name":"lib"},{"id":41,"name":"linter"},{"id":28,"name":"lms"},{"id":16,"name":"logging"},{"id":76,"name":"low-code"},{"id":90,"name":"message-queue"},{"id":42,"name":"mobile-app"},{"id":18,"name":"monitoring"},{"id":36,"name":"networking"},{"id":7,"name":"node-version"},{"id":55,"name":"nosql"},{"id":57,"name":"observability"},{"id":46,"name":"orm"},{"id":52,"name":"os"},{"id":14,"name":"parser"},{"id":74,"name":"react"},{"id":82,"name":"real-time"},{"id":56,"name":"robot"},{"id":65,"name":"runtime"},{"id":32,"name":"sdk"},{"id":71,"name":"search"},{"id":63,"name":"secrets"},{"id":25,"name":"security"},{"id":85,"name":"server"},{"id":86,"name":"serverless"},{"id":70,"name":"storage"},{"id":75,"name":"system-design"},{"id":79,"name":"terminal"},{"id":29,"name":"testing"},{"id":12,"name":"ui"},{"id":50,"name":"ux"},{"id":88,"name":"video"},{"id":20,"name":"web-app"},{"id":35,"name":"web-server"},{"id":43,"name":"webassembly"},{"id":69,"name":"workflow"},{"id":87,"name":"yaml"}]" returns me the "expected json"