base on GaLore: Memory-Efficient LLM Training by Gradient Low-Rank Projection # GaLore
This repo contains the pre-release version of GaLore algorithm, proposed by [GaLore: Memory-Efficient LLM Training by Gradient Low-Rank Projection](https://arxiv.org/abs/2403.03507).
Gradient Low-Rank Projection (GaLore) is a memory-efficient low-rank training strategy that allows *full-parameter* learning but is more *memory-efficient* than common low-rank adaptation methods, such as LoRA.
As a gradient projection method, GaLore is independent of the choice of optimizers and can be easily plugged into existing ones with only two lines of code, as shown in Algorithm 1 below.
<div align="center">
<img src="imgs/galore_code_box.png" alt="Image 2" style="width: 550px; margin: 0 auto;">
</div>
## News
- **2024-09-01**: We are working on GaLore 2, which is a more efficient and accessible version of GaLore. Please stay tuned!
- **2024-07-11**: We release Q-GaLore: Quantized GaLore with INT4 Projection. [[paper](https://arxiv.org/abs/2407.08296)] [[code](https://github.com/VITA-Group/Q-GaLore)]
- **2024-07-01**: GaLore is accepted to ICML 2024 as Oral!
- **2024-04-20**: Please join our Slack workspace [GaLore-Social](https://join.slack.com/t/galore-social/shared_invite/zt-2ev152px0-DguuQ5WRTLQjtq2C88HBvQ) to discuss with us and the community.
## Installation
### Install GaLore optimizer
Install from pip:
```bash
pip install galore-torch
```
or if you want to install from source:
```bash
git clone
[email protected]:jiaweizzhao/GaLore.git
cd GaLore
pip install -e .
```
### Install experiment dependencies
```bash
pip install -r exp_requirements.txt
```
Our experiment scripts are tested on Python 3.8 with PyTorch 2.1.
## Usage
### Save optimizer memory using GaLore optimizers
```python
from galore_torch import GaLoreAdamW, GaLoreAdamW8bit, GaLoreAdafactor
# define param groups as galore_params and non_galore_params
param_groups = [{'params': non_galore_params},
{'params': galore_params, 'rank': 128, 'update_proj_gap': 200, 'scale': 0.25, 'proj_type': 'std'}]
optimizer = GaLoreAdamW(param_groups, lr=0.01)
```
### Save weight gradient memory using per-layer weight updates
We use `register_post_accumulate_grad_hook` provided by [PyTorch](https://pytorch.org/tutorials/intermediate/optimizer_step_in_backward_tutorial.html) (`torch>=2.1.0`) to enable per-layer weight updates. An example is shown below:
```python
# define an optimizer for each parameter p, and store them in optimizer_dict
for p in model.parameters():
if p.requires_grad:
optimizer_dict[p] = GaLoreAdamW([{'params': p, 'rank': 128, 'update_proj_gap': 200, 'scale': 0.25, 'proj_type': 'std'}], lr=0.01)
# define a hook function to update the parameter p during the backward pass
def optimizer_hook(p):
if p.grad is None:
return
optimizer_dict[p].step()
optimizer_dict[p].zero_grad()
# Register the hook onto every parameter
for p in model.parameters():
if p.requires_grad:
p.register_post_accumulate_grad_hook(optimizer_hook)
```
More details can be found in [torchrun_main.py](https://github.com/jiaweizzhao/GaLore/blob/a6bc1650984b1c090a4e108d7c0e3109ee7ad844/torchrun_main.py#L334).
## Benchmark 1: Pre-Training LLaMA on C4 dataset
`torchrun_main.py` is the main script for training LLaMA models on C4 with GaLore. Our benchmark scripts for various sizes of models are in `scripts/benchmark_c4` folder.
For example, to train a 60m model on C4, do the following:
```bash
# LLaMA-60M, GaLore-Adam, 1 A100, 1 Node
torchrun --standalone --nproc_per_node 1 torchrun_main.py \
--model_config configs/llama_60m.json \
--lr 0.01 \
--galore_scale 0.25 \
--rank 128 \
--update_proj_gap 200 \
--batch_size 256 \
--total_batch_size 512 \
--num_training_steps 10000 \
--warmup_steps 1000 \
--weight_decay 0 \
--dtype bfloat16 \
--eval_every 1000 \
--optimizer galore_adamw
```
### Train 7B model with a single GPU with 24GB memory
To train a 7B model with a single GPU such as NVIDIA RTX 4090, all you need to do is to specify `--optimizer=galore_adamw8bit_per_layer`, which enables `GaLoreAdamW8bit` with per-layer weight updates.
With activation checkpointing, you can maintain a batch size of 16 tested on NVIDIA RTX 4090.
```bash
# LLaMA-7B, 8-bit GaLore-Adam, single GPU, activation checkpointing
# bsz=16, 22.8G,
torchrun --standalone --nproc_per_node 1 torchrun_main.py \
--model_config configs/llama_7b.json \
--lr 0.005 \
--galore_scale 0.25 \
--rank 1024 \
--update_proj_gap 500 \
--batch_size 16 \
--total_batch_size 512 \
--activation_checkpointing \
--num_training_steps 150000 \
--warmup_steps 15000 \
--weight_decay 0 \
--grad_clipping 1.0 \
--dtype bfloat16 \
--eval_every 1000 \
--single_gpu \
--optimizer galore_adamw8bit_per_layer
```
Currently per-layer weight updates technique is only supported for single GPU training (`--single_gpu`) without using `nn.parallel.DistributedDataParallel`. We are working on supporting multi-GPU training with per-layer weight updates.
## Benchmark 2: Fine-Tuning RoBERTa on GLUE tasks
`run_glue.py` is the main script for fine-tuning RoBERTa models on GLUE tasks with GaLore. An example script is shown below:
```bash
python run_glue.py \
--model_name_or_path roberta-base \
--task_name mrpc \
--enable_galore \
--lora_all_modules \
--max_length 512 \
--seed=1234 \
--lora_r 4 \
--galore_scale 4 \
--per_device_train_batch_size 16 \
--update_proj_gap 500 \
--learning_rate 3e-5 \
--num_train_epochs 30 \
--output_dir results/ft/roberta_base/mrpc
```
## Citation
```bibtex
@misc{zhao2024galore,
title={GaLore: Memory-Efficient LLM Training by Gradient Low-Rank Projection},
author={Jiawei Zhao and Zhenyu Zhang and Beidi Chen and Zhangyang Wang and Anima Anandkumar and Yuandong Tian},
year={2024},
eprint={2403.03507},
archivePrefix={arXiv},
primaryClass={cs.LG}
}
```", Assign "at most 3 tags" to the expected json: {"id":"8417","tags":[]} "only from the tags list I provide: [{"id":77,"name":"3d"},{"id":89,"name":"agent"},{"id":17,"name":"ai"},{"id":54,"name":"algorithm"},{"id":24,"name":"api"},{"id":44,"name":"authentication"},{"id":3,"name":"aws"},{"id":27,"name":"backend"},{"id":60,"name":"benchmark"},{"id":72,"name":"best-practices"},{"id":39,"name":"bitcoin"},{"id":37,"name":"blockchain"},{"id":1,"name":"blog"},{"id":45,"name":"bundler"},{"id":58,"name":"cache"},{"id":21,"name":"chat"},{"id":49,"name":"cicd"},{"id":4,"name":"cli"},{"id":64,"name":"cloud-native"},{"id":48,"name":"cms"},{"id":61,"name":"compiler"},{"id":68,"name":"containerization"},{"id":92,"name":"crm"},{"id":34,"name":"data"},{"id":47,"name":"database"},{"id":8,"name":"declarative-gui "},{"id":9,"name":"deploy-tool"},{"id":53,"name":"desktop-app"},{"id":6,"name":"dev-exp-lib"},{"id":59,"name":"dev-tool"},{"id":13,"name":"ecommerce"},{"id":26,"name":"editor"},{"id":66,"name":"emulator"},{"id":62,"name":"filesystem"},{"id":80,"name":"finance"},{"id":15,"name":"firmware"},{"id":73,"name":"for-fun"},{"id":2,"name":"framework"},{"id":11,"name":"frontend"},{"id":22,"name":"game"},{"id":81,"name":"game-engine "},{"id":23,"name":"graphql"},{"id":84,"name":"gui"},{"id":91,"name":"http"},{"id":5,"name":"http-client"},{"id":51,"name":"iac"},{"id":30,"name":"ide"},{"id":78,"name":"iot"},{"id":40,"name":"json"},{"id":83,"name":"julian"},{"id":38,"name":"k8s"},{"id":31,"name":"language"},{"id":10,"name":"learning-resource"},{"id":33,"name":"lib"},{"id":41,"name":"linter"},{"id":28,"name":"lms"},{"id":16,"name":"logging"},{"id":76,"name":"low-code"},{"id":90,"name":"message-queue"},{"id":42,"name":"mobile-app"},{"id":18,"name":"monitoring"},{"id":36,"name":"networking"},{"id":7,"name":"node-version"},{"id":55,"name":"nosql"},{"id":57,"name":"observability"},{"id":46,"name":"orm"},{"id":52,"name":"os"},{"id":14,"name":"parser"},{"id":74,"name":"react"},{"id":82,"name":"real-time"},{"id":56,"name":"robot"},{"id":65,"name":"runtime"},{"id":32,"name":"sdk"},{"id":71,"name":"search"},{"id":63,"name":"secrets"},{"id":25,"name":"security"},{"id":85,"name":"server"},{"id":86,"name":"serverless"},{"id":70,"name":"storage"},{"id":75,"name":"system-design"},{"id":79,"name":"terminal"},{"id":29,"name":"testing"},{"id":12,"name":"ui"},{"id":50,"name":"ux"},{"id":88,"name":"video"},{"id":20,"name":"web-app"},{"id":35,"name":"web-server"},{"id":43,"name":"webassembly"},{"id":69,"name":"workflow"},{"id":87,"name":"yaml"}]" returns me the "expected json"