AI prompts
base on # Stabilizing Transformer Training by Preventing Attention Entropy Collapse
This software project accompanies the research paper, [Stabilizing Transformer Training by Preventing Attention Entropy Collapse](https://proceedings.mlr.press/v202/zhai23a/zhai23a.pdf), published at ICML 2023.
## Introduction
Transformers are difficult to train. In this work, we study the training stability of Transformers by proposing a novel lense named `Attention Entropy Collapse`. Attention Entropy is defined as the quantity
$$\text{Ent}(A_i) = -\sum_{j=1}^T A_{i,j}\log(A_{i,j})$$
for an attention matrix $A$, with $A_{i,j}$ corresponding to the $i$ -th query and $j$ -th key/value, respectively. Our observation is that training instability often occurrs in conjunction with sharp decreases of the average attention entropy, and we denote this phenomenon as entropy collapse. This is illustrated in Figure below.
<p align="center">
<img src="demo.png" alt="drawing" width="500"/>
</p>
We provide both theoretical and emprical analyses to the entropy collapse phenomenon, and propose a simple fix named $\sigma$ Reparam, where we reparamaeterize all the weights in a Transformer with
$$\widehat{W}=\frac{\gamma}{\sigma(W)}W$$
## Getting Started
We provide two reference implementations. One in PyTorch, applied to the Vision Transformer (VIT) setting; and another in JAX, applied to speech recognition (ASR). Please refer to the [vision](vision) and [speech](speech) folders for details. The same PyTorch implementation was used for language modeling (LM) and machine translation (MT) experiments.
## BibTex
```
@inproceedings{zhai2023stabilizing,
title={Stabilizing Transformer Training by Preventing Attention Entropy Collapse},
author={Zhai, Shuangfei and Likhomanenko, Tatiana and Littwin, Etai and Busbridge, Dan and Ramapuram, Jason and Zhang, Yizhe and Gu, Jiatao and Susskind, Joshua M},
booktitle={International Conference on Machine Learning},
pages={40770--40803},
year={2023},
organization={PMLR}
}
```
", Assign "at most 3 tags" to the expected json: {"id":"6090","tags":[]} "only from the tags list I provide: [{"id":77,"name":"3d"},{"id":89,"name":"agent"},{"id":17,"name":"ai"},{"id":54,"name":"algorithm"},{"id":24,"name":"api"},{"id":44,"name":"authentication"},{"id":3,"name":"aws"},{"id":27,"name":"backend"},{"id":60,"name":"benchmark"},{"id":72,"name":"best-practices"},{"id":39,"name":"bitcoin"},{"id":37,"name":"blockchain"},{"id":1,"name":"blog"},{"id":45,"name":"bundler"},{"id":58,"name":"cache"},{"id":21,"name":"chat"},{"id":49,"name":"cicd"},{"id":4,"name":"cli"},{"id":64,"name":"cloud-native"},{"id":48,"name":"cms"},{"id":61,"name":"compiler"},{"id":68,"name":"containerization"},{"id":92,"name":"crm"},{"id":34,"name":"data"},{"id":47,"name":"database"},{"id":8,"name":"declarative-gui "},{"id":9,"name":"deploy-tool"},{"id":53,"name":"desktop-app"},{"id":6,"name":"dev-exp-lib"},{"id":59,"name":"dev-tool"},{"id":13,"name":"ecommerce"},{"id":26,"name":"editor"},{"id":66,"name":"emulator"},{"id":62,"name":"filesystem"},{"id":80,"name":"finance"},{"id":15,"name":"firmware"},{"id":73,"name":"for-fun"},{"id":2,"name":"framework"},{"id":11,"name":"frontend"},{"id":22,"name":"game"},{"id":81,"name":"game-engine "},{"id":23,"name":"graphql"},{"id":84,"name":"gui"},{"id":91,"name":"http"},{"id":5,"name":"http-client"},{"id":51,"name":"iac"},{"id":30,"name":"ide"},{"id":78,"name":"iot"},{"id":40,"name":"json"},{"id":83,"name":"julian"},{"id":38,"name":"k8s"},{"id":31,"name":"language"},{"id":10,"name":"learning-resource"},{"id":33,"name":"lib"},{"id":41,"name":"linter"},{"id":28,"name":"lms"},{"id":16,"name":"logging"},{"id":76,"name":"low-code"},{"id":90,"name":"message-queue"},{"id":42,"name":"mobile-app"},{"id":18,"name":"monitoring"},{"id":36,"name":"networking"},{"id":7,"name":"node-version"},{"id":55,"name":"nosql"},{"id":57,"name":"observability"},{"id":46,"name":"orm"},{"id":52,"name":"os"},{"id":14,"name":"parser"},{"id":74,"name":"react"},{"id":82,"name":"real-time"},{"id":56,"name":"robot"},{"id":65,"name":"runtime"},{"id":32,"name":"sdk"},{"id":71,"name":"search"},{"id":63,"name":"secrets"},{"id":25,"name":"security"},{"id":85,"name":"server"},{"id":86,"name":"serverless"},{"id":70,"name":"storage"},{"id":75,"name":"system-design"},{"id":79,"name":"terminal"},{"id":29,"name":"testing"},{"id":12,"name":"ui"},{"id":50,"name":"ux"},{"id":88,"name":"video"},{"id":20,"name":"web-app"},{"id":35,"name":"web-server"},{"id":43,"name":"webassembly"},{"id":69,"name":"workflow"},{"id":87,"name":"yaml"}]" returns me the "expected json"