AI prompts
base on Muon optimizer: +>30% sample efficiency with <3% wallclock overhead # Muon: An optimizer for the hidden layers of neural networks
This repo contains an implementation of the `Muon` optimizer originally described in [this thread](https://x.com/kellerjordan0/status/1842300916864844014) and [this writeup](https://kellerjordan.github.io/posts/muon/).
## Installation
```
pip install git+https://github.com/KellerJordan/Muon
```
## Usage
Muon is intended to optimize only the internal ≥2D parameters of a network.
Embeddings, classifier heads, and scalar or vector parameters should be optimized using AdamW.
```python
# optimizer = torch.optim.AdamW(model.parameters(), lr=3e-4, betas=(0.90, 0.95), weight_decay=0.01)
from muon import Muon
# Find ≥2D parameters in the body of the network -- these should be optimized by Muon
muon_params = [p for p in model.body.parameters() if p.ndim >= 2]
# Find everything else -- these should be optimized by AdamW
adamw_params = ([p for p in model.body.parameters() if p.ndim < 2]
+ [*model.head.parameters(), *model.embed.parameters()])
# Create the optimizer
optimizers = [Muon(muon_params, lr=0.02, momentum=0.95, rank=0, world_size=1),
torch.optim.AdamW(adamw_params, lr=3e-4, betas=(0.90, 0.95), weight_decay=0.01)]
...
# in the training step
for opt in optimizers:
opt.step()
```
You'll have to replace `model.body`, `model.head`, and `model.embed` with whatever subset is appropriate for your model.
E.g., for a ConvNet, `muon_params` should be all the convolutional filters, and `adamw_params` should be everything else.
## Example usage
[Example use in the NanoGPT speedrun](https://github.com/KellerJordan/modded-nanogpt/blob/d700b8724cbda3e7b1e5bcadbc0957f6ad1738fd/train_gpt.py#L519)
[Example use in the CIFAR-10 speedrun](https://github.com/KellerJordan/cifar10-airbench/blob/28bff5f5b31e95aa45b5b20e1f48baf1ed98d5f6/airbench94_muon.py#L362)
## Hyperparameter tuning
Typically, the default values of momentum (0.95), nesterov (True), and ns_steps (5) work well. The only hyperparameter which must be tuned is the learning rate.
It should have constant muP scaling, that is, as you scale up the model size, you shouldn't need to retune the learning rate.
## Benchmarks
For a comparison between AdamW, Shampoo, SOAP, and Muon for training a 124M-parameter transformer, see [here](https://github.com/KellerJordan/modded-nanogpt/tree/master/records/102924_Optimizers).
## Accomplishments
* [Lowered the record for training to 94% on CIFAR-10 from 3.3 A100-seconds to 2.6 A100-seconds](https://github.com/KellerJordan/cifar10-airbench)
* [Used to train a transformer to GPT-2 (XL) performance in $175 of compute](https://x.com/kellerjordan0/status/1850995958697308307)
* [Improved the training speed record for attaining GPT-2 (small) performance by a factor of 1.35x](https://x.com/kellerjordan0/status/1842300916864844014)
* [Used by the Kimi.ai frontier lab for scaled LLM training](https://x.com/Kimi_Moonshot/status/1893379158472044623)
## More learning resources and results about Muon
* [Blog post on Muon by Jialin Su (the creator of RoPE)](https://kexue.fm/archives/10592)
* [Blog post by Jeremy Bernstein on theoretical background of Muon](https://jeremybernste.in/writing/deriving-muon)
* [Tech report by Kimi.ai on using Muon for scaled training](https://arxiv.org/abs/2502.16982v1)
* [Why we chose Muon: Our chain of thought (by Jianlin Su at Kimi.ai)](https://x.com/Kimi_Moonshot/status/1897929976948965870)
## Citation
```bibtex
@misc{jordan2024muon,
author = {Keller Jordan and Yuchen Jin and Vlado Boza and You Jiacheng and
Franz Cesista and Laker Newhouse and Jeremy Bernstein},
title = {Muon: An optimizer for hidden layers in neural networks},
year = {2024},
url = {https://kellerjordan.github.io/posts/muon/}
}
```
", Assign "at most 3 tags" to the expected json: {"id":"13022","tags":[]} "only from the tags list I provide: [{"id":77,"name":"3d"},{"id":89,"name":"agent"},{"id":17,"name":"ai"},{"id":54,"name":"algorithm"},{"id":24,"name":"api"},{"id":44,"name":"authentication"},{"id":3,"name":"aws"},{"id":27,"name":"backend"},{"id":60,"name":"benchmark"},{"id":72,"name":"best-practices"},{"id":39,"name":"bitcoin"},{"id":37,"name":"blockchain"},{"id":1,"name":"blog"},{"id":45,"name":"bundler"},{"id":58,"name":"cache"},{"id":21,"name":"chat"},{"id":49,"name":"cicd"},{"id":4,"name":"cli"},{"id":64,"name":"cloud-native"},{"id":48,"name":"cms"},{"id":61,"name":"compiler"},{"id":68,"name":"containerization"},{"id":92,"name":"crm"},{"id":34,"name":"data"},{"id":47,"name":"database"},{"id":8,"name":"declarative-gui "},{"id":9,"name":"deploy-tool"},{"id":53,"name":"desktop-app"},{"id":6,"name":"dev-exp-lib"},{"id":59,"name":"dev-tool"},{"id":13,"name":"ecommerce"},{"id":26,"name":"editor"},{"id":66,"name":"emulator"},{"id":62,"name":"filesystem"},{"id":80,"name":"finance"},{"id":15,"name":"firmware"},{"id":73,"name":"for-fun"},{"id":2,"name":"framework"},{"id":11,"name":"frontend"},{"id":22,"name":"game"},{"id":81,"name":"game-engine "},{"id":23,"name":"graphql"},{"id":84,"name":"gui"},{"id":91,"name":"http"},{"id":5,"name":"http-client"},{"id":51,"name":"iac"},{"id":30,"name":"ide"},{"id":78,"name":"iot"},{"id":40,"name":"json"},{"id":83,"name":"julian"},{"id":38,"name":"k8s"},{"id":31,"name":"language"},{"id":10,"name":"learning-resource"},{"id":33,"name":"lib"},{"id":41,"name":"linter"},{"id":28,"name":"lms"},{"id":16,"name":"logging"},{"id":76,"name":"low-code"},{"id":90,"name":"message-queue"},{"id":42,"name":"mobile-app"},{"id":18,"name":"monitoring"},{"id":36,"name":"networking"},{"id":7,"name":"node-version"},{"id":55,"name":"nosql"},{"id":57,"name":"observability"},{"id":46,"name":"orm"},{"id":52,"name":"os"},{"id":14,"name":"parser"},{"id":74,"name":"react"},{"id":82,"name":"real-time"},{"id":56,"name":"robot"},{"id":65,"name":"runtime"},{"id":32,"name":"sdk"},{"id":71,"name":"search"},{"id":63,"name":"secrets"},{"id":25,"name":"security"},{"id":85,"name":"server"},{"id":86,"name":"serverless"},{"id":70,"name":"storage"},{"id":75,"name":"system-design"},{"id":79,"name":"terminal"},{"id":29,"name":"testing"},{"id":12,"name":"ui"},{"id":50,"name":"ux"},{"id":88,"name":"video"},{"id":20,"name":"web-app"},{"id":35,"name":"web-server"},{"id":43,"name":"webassembly"},{"id":69,"name":"workflow"},{"id":87,"name":"yaml"}]" returns me the "expected json"