AI prompts
base on Mirage: Automatically Generating Fast GPU Kernels without Programming in Triton/CUDA # Mirage: Automatically Generating Fast GPU Kernels for PyTorch Programs
Mirage is a tool that automatically generates fast GPU kernels for PyTorch programs through superoptimization techniques. For example, to get fast GPU kernels for attention, users only need to write a few lines of Python code to describe attention's computation. For a given PyTorch program, Mirage automatically searches the space of potential GPU kernels that are functionally equivalent to the input program and discovers highly-optimized kernel candidates. This approach allows Mirage to find new custom kernels that outperform existing expert-designed ones.
## Quick Installation
The quickest way to try Mirage is installing the latest stable release from pip:
```bash
pip install mirage-project
```
We also provide some pre-built binary wheels in the [Release Page](https://github.com/mirage-project/mirage/releases/latest). For example, to install mirage 0.2.2 compiled with CUDA 12.2 for python 3.10, using the following command:
```bash
pip install https://github.com/mirage-project/mirage/releases/download/v0.2.2/mirage_project-0.2.2+cu122-cp310-cp310-linux_x86_64.whl
```
You can also install Mirage from source code:
```bash
git clone --recursive https://www.github.com/mirage-project/mirage
cd mirage
pip install -e . -v
```
## Quickstart
Mirage can automatically generate fast GPU kernels for arbitrary PyTorch programs. The Mirage-generated kernels can be integrated into a PyTorch program with a few lines of code changes. As an example, we show how to use Mirage to generate kernels that fuse [RMSNorm](https://arxiv.org/pdf/1910.07467) and Linear to accelerate Transformer-based large language model computation. More examples are available in [tutorials](https://mirage-project.readthedocs.io/en/latest/tutorials/index.html).
The follow code snippet shows a native PyTorch implementation for a Transformer layer in LLaMA-3-8B.
```python
rms_norm_1 = torch.nn.RMSNorm(4096)
rms_norm_2 = torch.nn.RMSNorm(4096)
Y = rms_norm_1(X)
Z = torch.matmul(Y, Wqkv)
O = attention(Z)
U = rms_norm_2(Z)
V = torch.matmul(U, W13)
V1, V3 = V.chunk(2, -1) # split omitted in the above figure
output = torch.matmul(silu(V1) * V3, W2) # silu and this matmul omitted in the above figure
```
<p align="center">
<img src="img/llama-3-8b-rms-norm-linear.png?raw=true" alt="Mirage generates kernels that fuses RMSNorm and Linear" height="280"/>
</p>
To accelerate Transformer computation, we can use Mirage to generate GPU kernels that fuse RMSNorm and Linear, as shown in the code snippet below. Generating optimized kernels only requires write a few lines of code to describe the desired computation. The `get_mirage_kernel` function below returns the best kernel discovered by Mirage. These kernels can directly run as functions in your PyTorch programs. This kernel is 1.5–1.7x faster than running the two operators separately in PyTorch.
```python
def get_mirage_kernel(batch_size, output_dim):
graph = mi.new_kernel_graph()
X = graph.new_input(dims=(batch_size, 4096), dtype=mi.float16)
Wqkv = graph.new_input(dims=(4096, output_dim), dtype=mi.float16)
Y = graph.rms_norm(X, normalized_shape=(4096,))
Z = graph.matmul(Y, Wqkv)
graph.mark_output(Y)
return graph.superoptimize()
kernel_1 = get_mirage_kernel(batch_size, output_dim=Wqkv.shape[-1])
kernel_2 = get_mirage_kernel(batch_size, output_dim=W13.shape[-1])
Z = kernel_1(inputs=[X, Wqkv])
O = attention(Z)
V = kernel_2(inputs=[Z, W13])
V1, V3 = V.chunk(2, -1) # split omitted in the above figure
output = torch.matmul(silu(V1) * V3, W2) # silu and this matmul omitted in the above figure
```
## Contribution
Please let us know if you encounter any bugs or have any suggestions by [submitting an issue](https://github.com/mirage-project/mirage/issues).
We welcome all contributions to Mirage from bug fixes to new features and extensions.
## Citation
A paper describing Mirage's techniques is available [on arxiv](https://arxiv.org/abs/2405.05751). Please cite Mirage as:
``` bibtex
@misc{wu2024mirage,
title={A Multi-Level Superoptimizer for Tensor Programs},
author={Mengdi Wu and Xinhao Cheng and Oded Padon and Zhihao Jia},
eprint={2405.05751},
archivePrefix={arXiv},
year={2024},
}
```
## License
Mirage uses Apache License 2.0.
", Assign "at most 3 tags" to the expected json: {"id":"10146","tags":[]} "only from the tags list I provide: [{"id":77,"name":"3d"},{"id":89,"name":"agent"},{"id":17,"name":"ai"},{"id":54,"name":"algorithm"},{"id":24,"name":"api"},{"id":44,"name":"authentication"},{"id":3,"name":"aws"},{"id":27,"name":"backend"},{"id":60,"name":"benchmark"},{"id":72,"name":"best-practices"},{"id":39,"name":"bitcoin"},{"id":37,"name":"blockchain"},{"id":1,"name":"blog"},{"id":45,"name":"bundler"},{"id":58,"name":"cache"},{"id":21,"name":"chat"},{"id":49,"name":"cicd"},{"id":4,"name":"cli"},{"id":64,"name":"cloud-native"},{"id":48,"name":"cms"},{"id":61,"name":"compiler"},{"id":68,"name":"containerization"},{"id":92,"name":"crm"},{"id":34,"name":"data"},{"id":47,"name":"database"},{"id":8,"name":"declarative-gui "},{"id":9,"name":"deploy-tool"},{"id":53,"name":"desktop-app"},{"id":6,"name":"dev-exp-lib"},{"id":59,"name":"dev-tool"},{"id":13,"name":"ecommerce"},{"id":26,"name":"editor"},{"id":66,"name":"emulator"},{"id":62,"name":"filesystem"},{"id":80,"name":"finance"},{"id":15,"name":"firmware"},{"id":73,"name":"for-fun"},{"id":2,"name":"framework"},{"id":11,"name":"frontend"},{"id":22,"name":"game"},{"id":81,"name":"game-engine "},{"id":23,"name":"graphql"},{"id":84,"name":"gui"},{"id":91,"name":"http"},{"id":5,"name":"http-client"},{"id":51,"name":"iac"},{"id":30,"name":"ide"},{"id":78,"name":"iot"},{"id":40,"name":"json"},{"id":83,"name":"julian"},{"id":38,"name":"k8s"},{"id":31,"name":"language"},{"id":10,"name":"learning-resource"},{"id":33,"name":"lib"},{"id":41,"name":"linter"},{"id":28,"name":"lms"},{"id":16,"name":"logging"},{"id":76,"name":"low-code"},{"id":90,"name":"message-queue"},{"id":42,"name":"mobile-app"},{"id":18,"name":"monitoring"},{"id":36,"name":"networking"},{"id":7,"name":"node-version"},{"id":55,"name":"nosql"},{"id":57,"name":"observability"},{"id":46,"name":"orm"},{"id":52,"name":"os"},{"id":14,"name":"parser"},{"id":74,"name":"react"},{"id":82,"name":"real-time"},{"id":56,"name":"robot"},{"id":65,"name":"runtime"},{"id":32,"name":"sdk"},{"id":71,"name":"search"},{"id":63,"name":"secrets"},{"id":25,"name":"security"},{"id":85,"name":"server"},{"id":86,"name":"serverless"},{"id":70,"name":"storage"},{"id":75,"name":"system-design"},{"id":79,"name":"terminal"},{"id":29,"name":"testing"},{"id":12,"name":"ui"},{"id":50,"name":"ux"},{"id":88,"name":"video"},{"id":20,"name":"web-app"},{"id":35,"name":"web-server"},{"id":43,"name":"webassembly"},{"id":69,"name":"workflow"},{"id":87,"name":"yaml"}]" returns me the "expected json"