base on [ECCV 2024] The official implementation of paper "BrushNet: A Plug-and-Play Image Inpainting Model with Decomposed Dual-Branch Diffusion" # BrushNet This repository contains the implementation of the ECCV2024 paper "BrushNet: A Plug-and-Play Image Inpainting Model with Decomposed Dual-Branch Diffusion" Keywords: Image Inpainting, Diffusion Models, Image Generation > [Xuan Ju](https://github.com/juxuan27)<sup>12</sup>, [Xian Liu](https://alvinliu0.github.io/)<sup>12</sup>, [Xintao Wang](https://xinntao.github.io/)<sup>1*</sup>, [Yuxuan Bian](https://scholar.google.com.hk/citations?user=HzemVzoAAAAJ&hl=zh-CN&oi=ao)<sup>2</sup>, [Ying Shan](https://www.linkedin.com/in/YingShanProfile/)<sup>1</sup>, [Qiang Xu](https://cure-lab.github.io/)<sup>2*</sup><br> > <sup>1</sup>ARC Lab, Tencent PCG <sup>2</sup>The Chinese University of Hong Kong <sup>*</sup>Corresponding Author <p align="center"> <a href="https://tencentarc.github.io/BrushNet/">🌐Project Page</a> | <a href="https://arxiv.org/abs/2403.06976">πŸ“œArxiv</a> | <a href="https://forms.gle/9TgMZ8tm49UYsZ9s5">πŸ—„οΈData</a> | <a href="https://drive.google.com/file/d/1IkEBWcd2Fui2WHcckap4QFPcCI0gkHBh/view">πŸ“ΉVideo</a> | <a href="https://huggingface.co/spaces/TencentARC/BrushNet">πŸ€—Hugging Face Demo</a> | </p> **πŸ“– Table of Contents** - [BrushNet](#brushnet) - [πŸ”₯ Update Log](#-update-log) - [TODO](#todo) - [πŸ› οΈ Method Overview](#️-method-overview) - [πŸš€ Getting Started](#-getting-started) - [Environment Requirement 🌍](#environment-requirement-) - [Data Download ⬇️](#data-download-️) - [πŸƒπŸΌ Running Scripts](#-running-scripts) - [Training 🀯](#training-) - [Inference πŸ“œ](#inference-) - [Evaluation πŸ“](#evaluation-) - [🀝🏼 Cite Us](#-cite-us) - [πŸ’– Acknowledgement](#-acknowledgement) ## πŸ”₯ Update Log - [2024/12/17] πŸ“’ πŸ“’ [BrushEdit](https://github.com/TencentARC/BrushEdit) are released, an efficient, white-box, free-form image editing tool powered by LLM-agents and an all-in-one inpainting model. - [2024/12/17] πŸ“’ πŸ“’ [BrushNetX](https://huggingface.co/TencentARC/BrushEdit/tree/main/brushnetX) (Stronger BrushNet) models are released. ## TODO - [x] Release trainig and inference code - [x] Release checkpoint (sdv1.5) - [x] Release checkpoint (sdxl). Sadly, I only have V100 for training this checkpoint, which can only train with a batch size of 1 with a slow speed. The current ckpt is only trained for a small step number thus perform not well. But fortunately, [yuanhang](https://github.com/yuanhangio) volunteer to help training a better version. Please stay tuned! Thank [yuanhang](https://github.com/yuanhangio) for his effort! - [x] Release evluation code - [x] Release gradio demo - [x] Release comfyui demo. Thank [nullquant](https://github.com/nullquant) ([ConfyUI-BrushNet](https://github.com/nullquant/ComfyUI-BrushNet)) and [kijai](https://github.com/kijai) ([ComfyUI-BrushNet-Wrapper](https://github.com/kijai/ComfyUI-BrushNet-Wrapper)) for helping! - [x] Release [trainig data](https://huggingface.co/datasets/random123123/BrushData). Thank [random123123](https://huggingface.co/random123123) for helping! - [x] We use BrushNet to participate in CVPR2024 GenAI Media Generation Challenge Workshop and get top prize! The solution is provided in [InstructionGuidedEditing](InstructionGuidedEditing) - [x] Release a new version of checkpoint (sdxl). ## πŸ› οΈ Method Overview BrushNet is a diffusion-based text-guided image inpainting model that can be plug-and-play into any pre-trained diffusion model. Our architectural design incorporates two key insights: (1) dividing the masked image features and noisy latent reduces the model's learning load, and (2) leveraging dense per-pixel control over the entire pre-trained model enhances its suitability for image inpainting tasks. More analysis can be found in the main paper. ![](examples/brushnet/src/model.png) ## πŸš€ Getting Started ### Environment Requirement 🌍 BrushNet has been implemented and tested on Pytorch 1.12.1 with python 3.9. Clone the repo: ``` git clone https://github.com/TencentARC/BrushNet.git ``` We recommend you first use `conda` to create virtual environment, and install `pytorch` following [official instructions](https://pytorch.org/). For example: ``` conda create -n diffusers python=3.9 -y conda activate diffusers python -m pip install --upgrade pip pip install torch==1.12.1+cu116 torchvision==0.13.1+cu116 torchaudio==0.12.1 --extra-index-url https://download.pytorch.org/whl/cu116 ``` Then, you can install diffusers (implemented in this repo) with: ``` pip install -e . ``` After that, you can install required packages thourgh: ``` cd examples/brushnet/ pip install -r requirements.txt ``` ### Data Download ⬇️ **Dataset** You can download the BrushData and BrushBench [here](https://forms.gle/9TgMZ8tm49UYsZ9s5) (as well as the EditBench we re-processed), which are used for training and testing the BrushNet. By downloading the data, you are agreeing to the terms and conditions of the license. The data structure should be like: ``` |-- data |-- BrushData |-- 00200.tar |-- 00201.tar |-- ... |-- BrushDench |-- images |-- mapping_file.json |-- EditBench |-- images |-- mapping_file.json ``` Noted: *We only provide a part of the BrushData in google drive due to the space limit. [random123123](https://huggingface.co/random123123) has helped upload a full dataset on hugging face [here](https://huggingface.co/datasets/random123123/BrushData). Thank for his help!* **Checkpoints** Checkpoints of BrushNet can be downloaded from [here](https://drive.google.com/drive/folders/1fqmS1CEOvXCxNWFrsSYd_jHYXxrydh1n?usp=drive_link). The ckpt folder contains - BrushNet pretrained checkpoints for Stable Diffusion v1.5 (`segmentation_mask_brushnet_ckpt` and `random_mask_brushnet_ckpt`) - pretrinaed Stable Diffusion v1.5 checkpoint (e.g., realisticVisionV60B1_v51VAE from [Civitai](https://civitai.com/)). You can use `scripts/convert_original_stable_diffusion_to_diffusers.py` to process other models downloaded from Civitai. - BrushNet pretrained checkpoints for Stable Diffusion XL (`segmentation_mask_brushnet_ckpt_sdxl_v1` and `random_mask_brushnet_ckpt_sdxl_v0`). A better version will be shortly released by [yuanhang](https://github.com/yuanhangio). Please stay tuned! - pretrinaed Stable Diffusion XL checkpoint (e.g., juggernautXL_juggernautX from [Civitai](https://civitai.com/)). You can use `StableDiffusionXLPipeline.from_single_file("path of safetensors").save_pretrained("path to save",safe_serialization=False)` to process other models downloaded from Civitai. The data structure should be like: ``` |-- data |-- BrushData |-- BrushDench |-- EditBench |-- ckpt |-- realisticVisionV60B1_v51VAE |-- model_index.json |-- vae |-- ... |-- segmentation_mask_brushnet_ckpt |-- segmentation_mask_brushnet_ckpt_sdxl_v0 |-- random_mask_brushnet_ckpt |-- random_mask_brushnet_ckpt_sdxl_v0 |-- ... ``` The checkpoint in `segmentation_mask_brushnet_ckpt` and `segmentation_mask_brushnet_ckpt_sdxl_v0` provide checkpoints trained on BrushData, which has segmentation prior (mask are with the same shape of objects). The `random_mask_brushnet_ckpt` and `random_mask_brushnet_ckpt_sdxl` provide a more general ckpt for random mask shape. ## πŸƒπŸΌ Running Scripts ### Training 🀯 You can train with segmentation mask using the script: ``` # sd v1.5 accelerate launch examples/brushnet/train_brushnet.py \ --pretrained_model_name_or_path runwayml/stable-diffusion-v1-5 \ --output_dir runs/logs/brushnet_segmentationmask \ --train_data_dir data/BrushData \ --resolution 512 \ --learning_rate 1e-5 \ --train_batch_size 2 \ --tracker_project_name brushnet \ --report_to tensorboard \ --resume_from_checkpoint latest \ --validation_steps 300 --checkpointing_steps 10000 # sdxl accelerate launch examples/brushnet/train_brushnet_sdxl.py \ --pretrained_model_name_or_path stabilityai/stable-diffusion-xl-base-1.0 \ --output_dir runs/logs/brushnetsdxl_segmentationmask \ --train_data_dir data/BrushData \ --resolution 1024 \ --learning_rate 1e-5 \ --train_batch_size 1 \ --gradient_accumulation_steps 4 \ --tracker_project_name brushnet \ --report_to tensorboard \ --resume_from_checkpoint latest \ --validation_steps 300 \ --checkpointing_steps 10000 ``` To use custom dataset, you can process your own data to the format of BrushData and revise `--train_data_dir`. You can train with random mask using the script (by adding `--random_mask`): ``` # sd v1.5 accelerate launch examples/brushnet/train_brushnet.py \ --pretrained_model_name_or_path runwayml/stable-diffusion-v1-5 \ --output_dir runs/logs/brushnet_randommask \ --train_data_dir data/BrushData \ --resolution 512 \ --learning_rate 1e-5 \ --train_batch_size 2 \ --tracker_project_name brushnet \ --report_to tensorboard \ --resume_from_checkpoint latest \ --validation_steps 300 \ --random_mask # sdxl accelerate launch examples/brushnet/train_brushnet_sdxl.py \ --pretrained_model_name_or_path stabilityai/stable-diffusion-xl-base-1.0 \ --output_dir runs/logs/brushnetsdxl_randommask \ --train_data_dir data/BrushData \ --resolution 1024 \ --learning_rate 1e-5 \ --train_batch_size 1 \ --gradient_accumulation_steps 4 \ --tracker_project_name brushnet \ --report_to tensorboard \ --resume_from_checkpoint latest \ --validation_steps 300 \ --checkpointing_steps 10000 \ --random_mask ``` ### Inference πŸ“œ You can inference with the script: ``` # sd v1.5 python examples/brushnet/test_brushnet.py # sdxl python examples/brushnet/test_brushnet_sdxl.py ``` Since BrushNet is trained on Laion, it can only guarantee the performance on general scenarios. We recommend you train on your own data (e.g., product exhibition, virtual try-on) if you have high-quality industrial application requirements. We would also be appreciate if you would like to contribute your trained model! You can also inference through gradio demo: ``` # sd v1.5 python examples/brushnet/app_brushnet.py ``` ### Evaluation πŸ“ You can evaluate using the script: ``` python examples/brushnet/evaluate_brushnet.py \ --brushnet_ckpt_path data/ckpt/segmentation_mask_brushnet_ckpt \ --image_save_path runs/evaluation_result/BrushBench/brushnet_segmask/inside \ --mapping_file data/BrushBench/mapping_file.json \ --base_dir data/BrushBench \ --mask_key inpainting_mask ``` The `--mask_key` indicates which kind of mask to use, `inpainting_mask` for inside inpainting and `outpainting_mask` for outside inpainting. The evaluation results (images and metrics) will be saved in `--image_save_path`. *Noted that you need to ignore the nsfw detector in `src/diffusers/pipelines/brushnet/pipeline_brushnet.py#1261` to get the correct evaluation results. Moreover, we find different machine may generate different images, thus providing the results on our machine [here](https://drive.google.com/drive/folders/1dK3oIB2UvswlTtnIS1iHfx4s57MevWdZ?usp=sharing).* ## 🀝🏼 Cite Us ``` @misc{ju2024brushnet, title={BrushNet: A Plug-and-Play Image Inpainting Model with Decomposed Dual-Branch Diffusion}, author={Xuan Ju and Xian Liu and Xintao Wang and Yuxuan Bian and Ying Shan and Qiang Xu}, year={2024}, eprint={2403.06976}, archivePrefix={arXiv}, primaryClass={cs.CV} } ``` ## πŸ’– Acknowledgement <span id="acknowledgement"></span> Our code is modified based on [diffusers](https://github.com/huggingface/diffusers), thanks to all the contributors! ", Assign "at most 3 tags" to the expected json: {"id":"8955","tags":[]} "only from the tags list I provide: [{"id":77,"name":"3d"},{"id":89,"name":"agent"},{"id":17,"name":"ai"},{"id":54,"name":"algorithm"},{"id":24,"name":"api"},{"id":44,"name":"authentication"},{"id":3,"name":"aws"},{"id":27,"name":"backend"},{"id":60,"name":"benchmark"},{"id":72,"name":"best-practices"},{"id":39,"name":"bitcoin"},{"id":37,"name":"blockchain"},{"id":1,"name":"blog"},{"id":45,"name":"bundler"},{"id":58,"name":"cache"},{"id":21,"name":"chat"},{"id":49,"name":"cicd"},{"id":4,"name":"cli"},{"id":64,"name":"cloud-native"},{"id":48,"name":"cms"},{"id":61,"name":"compiler"},{"id":68,"name":"containerization"},{"id":92,"name":"crm"},{"id":34,"name":"data"},{"id":47,"name":"database"},{"id":8,"name":"declarative-gui "},{"id":9,"name":"deploy-tool"},{"id":53,"name":"desktop-app"},{"id":6,"name":"dev-exp-lib"},{"id":59,"name":"dev-tool"},{"id":13,"name":"ecommerce"},{"id":26,"name":"editor"},{"id":66,"name":"emulator"},{"id":62,"name":"filesystem"},{"id":80,"name":"finance"},{"id":15,"name":"firmware"},{"id":73,"name":"for-fun"},{"id":2,"name":"framework"},{"id":11,"name":"frontend"},{"id":22,"name":"game"},{"id":81,"name":"game-engine "},{"id":23,"name":"graphql"},{"id":84,"name":"gui"},{"id":91,"name":"http"},{"id":5,"name":"http-client"},{"id":51,"name":"iac"},{"id":30,"name":"ide"},{"id":78,"name":"iot"},{"id":40,"name":"json"},{"id":83,"name":"julian"},{"id":38,"name":"k8s"},{"id":31,"name":"language"},{"id":10,"name":"learning-resource"},{"id":33,"name":"lib"},{"id":41,"name":"linter"},{"id":28,"name":"lms"},{"id":16,"name":"logging"},{"id":76,"name":"low-code"},{"id":90,"name":"message-queue"},{"id":42,"name":"mobile-app"},{"id":18,"name":"monitoring"},{"id":36,"name":"networking"},{"id":7,"name":"node-version"},{"id":55,"name":"nosql"},{"id":57,"name":"observability"},{"id":46,"name":"orm"},{"id":52,"name":"os"},{"id":14,"name":"parser"},{"id":74,"name":"react"},{"id":82,"name":"real-time"},{"id":56,"name":"robot"},{"id":65,"name":"runtime"},{"id":32,"name":"sdk"},{"id":71,"name":"search"},{"id":63,"name":"secrets"},{"id":25,"name":"security"},{"id":85,"name":"server"},{"id":86,"name":"serverless"},{"id":70,"name":"storage"},{"id":75,"name":"system-design"},{"id":79,"name":"terminal"},{"id":29,"name":"testing"},{"id":12,"name":"ui"},{"id":50,"name":"ux"},{"id":88,"name":"video"},{"id":20,"name":"web-app"},{"id":35,"name":"web-server"},{"id":43,"name":"webassembly"},{"id":69,"name":"workflow"},{"id":87,"name":"yaml"}]" returns me the "expected json"