base on Code for CRATE (Coding RAte reduction TransformEr). # CRATE (Coding RAte reduction TransformEr) This repository is the official PyTorch implementation of the papers: - **White-Box Transformers via Sparse Rate Reduction** [**NeurIPS-2023**, [paper link](https://openreview.net/forum?id=THfl8hdVxH#)]. By [Yaodong Yu](https://yaodongyu.github.io) (UC Berkeley), [Sam Buchanan](https://sdbuchanan.com) (TTIC), [Druv Pai](https://druvpai.github.io) (UC Berkeley), [Tianzhe Chu](https://tianzhechu.com/) (UC Berkeley), [Ziyang Wu](https://robinwu218.github.io/) (UC Berkeley), [Shengbang Tong](https://tsb0601.github.io/petertongsb/) (UC Berkeley), [Benjamin D Haeffele](https://www.cis.jhu.edu/~haeffele/#about) (Johns Hopkins University), and [Yi Ma](http://people.eecs.berkeley.edu/~yima/) (UC Berkeley). - **Emergence of Segmentation with Minimalistic White-Box Transformers** [**CPAL-2024**, [paper link](https://arxiv.org/abs/2308.16271)]. By [Yaodong Yu](https://yaodongyu.github.io)* (UC Berkeley), [Tianzhe Chu](https://tianzhechu.com/)* (UC Berkeley & ShanghaiTech U), [Shengbang Tong](https://tsb0601.github.io/petertongsb/) (UC Berkeley & NYU), [Ziyang Wu](https://robinwu218.github.io/) (UC Berkeley), [Druv Pai](https://druvpai.github.io) (UC Berkeley), [Sam Buchanan](https://sdbuchanan.com) (TTIC), and [Yi Ma](http://people.eecs.berkeley.edu/~yima/) (UC Berkeley & HKU). 2023. (* equal contribution) - **Masked Autoencoding via Structured Diffusion with White-Box Transformers** [**ICLR-2024**, [paper link](https://arxiv.org/abs/2404.02446)]. By [Druv Pai](https://druvpai.github.io) (UC Berkeley), [Ziyang Wu](https://robinwu218.github.io/) (UC Berkeley), [Sam Buchanan](https://sdbuchanan.com), [Yaodong Yu](https://yaodongyu.github.io) (UC Berkeley), and [Yi Ma](http://people.eecs.berkeley.edu/~yima/) (UC Berkeley). Also, we have released a larger journal-length overview paper of this line of research, which contains a superset of all the results presented above, and also more results in NLP and vision SSL. - **White-Box Transformers via Sparse Rate Reduction: Compression is All There Is?** [[paper link](https://arxiv.org/abs/2311.13110)]. By [Yaodong Yu](https://yaodongyu.github.io) (UC Berkeley), [Sam Buchanan](https://sdbuchanan.com) (TTIC), [Druv Pai](https://druvpai.github.io) (UC Berkeley), [Tianzhe Chu](https://tianzhechu.com/) (UC Berkeley), [Ziyang Wu](https://robinwu218.github.io/) (UC Berkeley), [Shengbang Tong](https://tsb0601.github.io/petertongsb/) (UC Berkeley), [Hao Bai](https://www.jackgethome.com/) (UIUC), [Yuexiang Zhai](https://yx-s-z.github.io/) (UC Berkeley), [Benjamin D Haeffele](https://www.cis.jhu.edu/~haeffele/#about) (Johns Hopkins University), and [Yi Ma](http://people.eecs.berkeley.edu/~yima/) (UC Berkeley). # Table of Contents * [CRATE (Coding RAte reduction TransformEr)](#crate-coding-rate-reduction-transformer) * [Theoretical Background: What is CRATE?](#theoretical-background-what-is-crate) * [1. CRATE Architecture overview](#1-crate-architecture-overview) * [2. One layer/block of CRATE](#2-one-layerblock-of-crate) * [3. Per-layer optimization in CRATE](#3-per-layer-optimization-in-crate) * [4. Segmentation visualization of CRATE](#4-segmentation-visualization-of-crate) * [Autoencoding](#autoencoding) * [Implementation and experiments](#implementation-and-experiments) * [Constructing a CRATE model](#constructing-a-crate-model) * [Pre-trained Checkpoints (ImageNet-1K)](#pre-trained-checkpoints-imagenet-1k) * [Training CRATE on ImageNet](#training-crate-on-imagenet) * [Finetuning pretrained / training random initialized CRATE on CIFAR10](#finetuning-pretrained--training-random-initialized-crate-on-cifar10) * [Demo: Emergent segmentation in CRATE](#demo-emergent-segmentation-in-crate) * [Constructing a CRATE autoencoding model](#constructing-a-crate-autoencoding-model) * [Pre-trained Checkpoints (ImageNet-1K)](#pre-trained-checkpoints-imagenet-1k-1) * [Training/Fine-Tuning CRATE-MAE](#trainingfine-tuning-crate-mae) * [Reference](#reference) ## Theoretical Background: What is CRATE? CRATE (Coding RAte reduction TransformEr) is a white-box (mathematically interpretable) transformer architecture, where each layer performs a single step of an alternating minimization algorithm to optimize the **sparse rate reduction objective** <p align="center"> <img src="figs/fig_objective.png" width="400"\> </p> <p align="center"> where $R$ and $R^{c}$ are different _coding rates_ for the input representations w.r.t.~different codebooks, and the $\ell^{0}$-norm promotes the sparsity of the final token representations $\boldsymbol{Z} = f(\boldsymbol{X})$. The function $f$ is defined as $$f=f^{L} \circ f^{L-1} \circ \cdots \circ f^{1} \circ f^{\mathrm{pre}},$$ where $f^{\mathrm{pre}}$ is the pre-processing mapping, and $f^{\ell}$ is the $\ell$-th layer forward mapping that transforms the token distribution to optimize the above sparse rate reduction objective incrementally. More specifically, $f^{\ell}$ transforms the $\ell$-th layer token representations $\boldsymbol{Z}^{\ell}$ to $\boldsymbol{Z}^{\ell+1}$ via the $\texttt{MSSA}$ (Multi-Head Subspace Self-Attention) block and the $\texttt{ISTA}$ (Iterative Shrinkage-Thresholding Algorithms) block, i.e., $$\boldsymbol{Z}^{\ell+1} = f^{\ell}(\boldsymbol{Z}^{\ell}) = \texttt{ISTA}(\boldsymbol{Z}^{\ell} + \texttt{MSSA}(\boldsymbol{Z}^{\ell})).$$ ### 1. CRATE Architecture overview The following figure presents an overview of the pipeline for our proposed **CRATE** architecture: <p align="center"> <img src="figs/fig_pipeline.png" width="900"\> </p> <p align="center"> ### 2. One layer/block of CRATE The following figure shows the overall architecture of one layer of **CRATE** as the composition of $\texttt{MSSA}$ and $\texttt{ISTA}$ blocks. <p align="center"> <img src="figs/fig_arch.png" width="900"\> </p> <p align="center"> ### 3. Per-layer optimization in CRATE In the following figure, we measure the compression term [ $R^{c}$ ($\boldsymbol{Z}^{\ell+1/2}$) ] and the sparsity term [ $||\boldsymbol{Z}^{\ell+1}||_0$ ] defined in the **sparse rate reduction objective**, and we find that each layer of **CRATE** indeed optimizes the targeted objectives, showing that our white-box theoretical design is predictive of practice. <p align="center"> <img src="figs/fig_layerwise.png" width="900"\> </p> <p align="center"> ### 4. Segmentation visualization of CRATE In the following figure, we visualize self-attention maps from a supervised **CRATE** model with 8x8 patches (similar to the ones shown in [DINO](https://github.com/facebookresearch/dino) :t-rex:). <p align="center"> <img src="figs/fig_seg.png" width="900"\> </p> <p align="center"> We also discover a surprising empirical phenomenon where each attention head in **CRATE** retains its own semantics. <p align="center"> <img src="figs/fig_seg_headwise.png" width="900"\> </p> <p align="center"> ## Autoencoding We can also use our theory to build a principled autoencoder, which has the following architecture. <p align="center"> <img src="figs/fig_arch_autoencoder.png" width="900"\> </p> <p align="center"> It has many of the same empirical properties as the base **CRATE** model, such as segmented attention maps and amenability to layer-wise analysis. We train it on the masked autoencoding task (calling this model **CRATE-MAE**), and it achieves comparable performance in linear probing and reconstruction quality as the base ViT-MAE. <p align="center"> <img src="figs/fig_masked_reconstruction.png" width="900"\> </p> <p align="center"> # Implementation and Experiments ## Constructing a CRATE model A CRATE model can be defined using the following code, (the below parameters are specified for CRATE-Tiny) ```python from model.crate import CRATE dim = 384 n_heads = 6 depth = 12 model = CRATE(image_size=224, patch_size=16, num_classes=1000, dim=dim, depth=depth, heads=n_heads, dim_head=dim // n_heads) ``` ### Pre-trained Checkpoints (ImageNet-1K) | model | `dim` | `n_heads` | `depth` | pre-trained checkpoint | | -------- | -------- | -------- | -------- | -------- | | **CRATE-T**(iny) | 384 | 6 | 12 | TODO | | **CRATE-S**(mall) | 576 | 12 | 12 | [download link](https://drive.google.com/file/d/1hYgDJl4EKHYfKprwhEjmWmWHuxnK6_h8/view?usp=share_link) | | **CRATE-B**(ase) | 768 | 12 | 12 | TODO | | **CRATE-L**(arge) | 1024 | 16 | 24 | TODO | ## Training CRATE on ImageNet To train a CRATE model on ImageNet-1K, run the following script (training CRATE-tiny) As an example, we use the following command for training CRATE-tiny on ImageNet-1K: ```python python main.py --arch CRATE_tiny --batch-size 512 --epochs 200 --optimizer Lion --lr 0.0002 --weight-decay 0.05 --print-freq 25 --data DATA_DIR ``` and replace `DATA_DIR` with `[imagenet-folder with train and val folders]`. ## Finetuning pretrained / training random initialized CRATE on CIFAR10 ```python python finetune.py --bs 256 --net CRATE_tiny --opt adamW --lr 5e-5 --n_epochs 200 --randomaug 1 --data cifar10 --ckpt_dir CKPT_DIR --data_dir DATA_DIR ``` Replace `CKPT_DIR` with the path for the pretrained CRATE weight, and replace `DATA_DIR` with the path for the `CIFAR10` dataset. If `CKPT_DIR` is `None`, then this script is for training CRATE from random initialization on CIFAR10. ## Demo: Emergent segmentation in CRATE CRATE models exhibit emergent segmentation in their self-attention maps solely through supervised training. We provide a Colab Jupyter notebook to visualize the emerged segmentations from a supervised **CRATE** model. The demo provides visualizations which match the segmentation figures above. Link: [crate-emergence.ipynb](https://colab.research.google.com/drive/1rYn_NlepyW7Fu5LDliyBDmFZylHco7ss?usp=sharing) (in colab) <p align="center"> <img src="figs/fig_seg_headwise.png" width="900"\> </p> <p align="center"> ## Constructing a CRATE autoencoding model A CRATE-autoencoding model (specifically **CRATE-MAE-Base**) can be defined using the following code: ```python from model.crate_ae.crate_ae import mae_crate_base model = mae_crate_base() ``` The other sizes in the paper are also importable in that way. Modifying the `model/crate_ae/crate_ae.py` file will let you initialize and serve your own config. ### Pre-trained Checkpoints (ImageNet-1K) | model | `dim` | `n_heads` | `depth` | pre-trained checkpoint | | -------- | -------- | -------- | -------- | -------- | | **CRATE-MAE-S**(mall) | 576 | 12 | 12 | TODO | | **CRATE-MAE-B**(ase) | 768 | 12 | 12 | [link](https://drive.google.com/file/d/11i5BMwymqOsunq44WD3omN5mS6ZREQPO/view?usp=sharing) | ## Training/Fine-Tuning CRATE-MAE To train or fine-tune a CRATE-MAE model on ImageNet-1K, please refer to the [codebase on MAE training](https://github.com/facebookresearch/mae) from Meta FAIR. The `models_mae.py` file in that codebase can be replaced with the contents of `model/crate_ae/crate_ae.py`, and the rest of the code should go through with minimal alterations. ## Demo: Emergent segmentation in CRATE-MAE CRATE-MAE models also exhibit emergent segmentation in their self-attention maps. We provide a Colab Jupyter notebook to visualize the emerged segmentations from a **CRATE-MAE** model. The demo provides visualizations which match the segmentation figures above. Link: [crate-mae.ipynb](https://colab.research.google.com/drive/1xcD-xcxprfgZuvwsRKuDroH7xMjr0Ad3?usp=sharing) (in colab) # Reference For technical details and full experimental results, please check the [CRATE paper](https://arxiv.org/abs/2306.01129), [CRATE segmentation paper](https://arxiv.org/abs/2308.16271), [CRATE autoencoding paper](https://openreview.net/forum?id=PvyOYleymy), or [the long-form overview paper](https://arxiv.org/abs/2311.13110). Please consider citing our work if you find it helpful to yours: ``` @article{yu2024white, title={White-Box Transformers via Sparse Rate Reduction}, author={Yu, Yaodong and Buchanan, Sam and Pai, Druv and Chu, Tianzhe and Wu, Ziyang and Tong, Shengbang and Haeffele, Benjamin and Ma, Yi}, journal={Advances in Neural Information Processing Systems}, volume={36}, year={2024} } ``` ``` @inproceedings{yu2024emergence, title={Emergence of Segmentation with Minimalistic White-Box Transformers}, author={Yu, Yaodong and Chu, Tianzhe and Tong, Shengbang and Wu, Ziyang and Pai, Druv and Buchanan, Sam and Ma, Yi}, booktitle={Conference on Parsimony and Learning}, pages={72--93}, year={2024}, organization={PMLR} } ``` ``` @inproceedings{pai2024masked, title={Masked Completion via Structured Diffusion with White-Box Transformers}, author={Pai, Druv and Buchanan, Sam and Wu, Ziyang and Yu, Yaodong and Ma, Yi}, booktitle={The Twelfth International Conference on Learning Representations}, year={2024} } ``` ``` @article{yu2023white, title={White-Box Transformers via Sparse Rate Reduction: Compression Is All There Is?}, author={Yu, Yaodong and Buchanan, Sam and Pai, Druv and Chu, Tianzhe and Wu, Ziyang and Tong, Shengbang and Bai, Hao and Zhai, Yuexiang and Haeffele, Benjamin D and Ma, Yi}, journal={arXiv preprint arXiv:2311.13110}, year={2023} } ``` ", Assign "at most 3 tags" to the expected json: {"id":"5325","tags":[]} "only from the tags list I provide: [{"id":77,"name":"3d"},{"id":89,"name":"agent"},{"id":17,"name":"ai"},{"id":54,"name":"algorithm"},{"id":24,"name":"api"},{"id":44,"name":"authentication"},{"id":3,"name":"aws"},{"id":27,"name":"backend"},{"id":60,"name":"benchmark"},{"id":72,"name":"best-practices"},{"id":39,"name":"bitcoin"},{"id":37,"name":"blockchain"},{"id":1,"name":"blog"},{"id":45,"name":"bundler"},{"id":58,"name":"cache"},{"id":21,"name":"chat"},{"id":49,"name":"cicd"},{"id":4,"name":"cli"},{"id":64,"name":"cloud-native"},{"id":48,"name":"cms"},{"id":61,"name":"compiler"},{"id":68,"name":"containerization"},{"id":92,"name":"crm"},{"id":34,"name":"data"},{"id":47,"name":"database"},{"id":8,"name":"declarative-gui "},{"id":9,"name":"deploy-tool"},{"id":53,"name":"desktop-app"},{"id":6,"name":"dev-exp-lib"},{"id":59,"name":"dev-tool"},{"id":13,"name":"ecommerce"},{"id":26,"name":"editor"},{"id":66,"name":"emulator"},{"id":62,"name":"filesystem"},{"id":80,"name":"finance"},{"id":15,"name":"firmware"},{"id":73,"name":"for-fun"},{"id":2,"name":"framework"},{"id":11,"name":"frontend"},{"id":22,"name":"game"},{"id":81,"name":"game-engine "},{"id":23,"name":"graphql"},{"id":84,"name":"gui"},{"id":91,"name":"http"},{"id":5,"name":"http-client"},{"id":51,"name":"iac"},{"id":30,"name":"ide"},{"id":78,"name":"iot"},{"id":40,"name":"json"},{"id":83,"name":"julian"},{"id":38,"name":"k8s"},{"id":31,"name":"language"},{"id":10,"name":"learning-resource"},{"id":33,"name":"lib"},{"id":41,"name":"linter"},{"id":28,"name":"lms"},{"id":16,"name":"logging"},{"id":76,"name":"low-code"},{"id":90,"name":"message-queue"},{"id":42,"name":"mobile-app"},{"id":18,"name":"monitoring"},{"id":36,"name":"networking"},{"id":7,"name":"node-version"},{"id":55,"name":"nosql"},{"id":57,"name":"observability"},{"id":46,"name":"orm"},{"id":52,"name":"os"},{"id":14,"name":"parser"},{"id":74,"name":"react"},{"id":82,"name":"real-time"},{"id":56,"name":"robot"},{"id":65,"name":"runtime"},{"id":32,"name":"sdk"},{"id":71,"name":"search"},{"id":63,"name":"secrets"},{"id":25,"name":"security"},{"id":85,"name":"server"},{"id":86,"name":"serverless"},{"id":70,"name":"storage"},{"id":75,"name":"system-design"},{"id":79,"name":"terminal"},{"id":29,"name":"testing"},{"id":12,"name":"ui"},{"id":50,"name":"ux"},{"id":88,"name":"video"},{"id":20,"name":"web-app"},{"id":35,"name":"web-server"},{"id":43,"name":"webassembly"},{"id":69,"name":"workflow"},{"id":87,"name":"yaml"}]" returns me the "expected json"