base on GaLore: Memory-Efficient LLM Training by Gradient Low-Rank Projection # GaLore This repo contains the pre-release version of GaLore algorithm, proposed by [GaLore: Memory-Efficient LLM Training by Gradient Low-Rank Projection](https://arxiv.org/abs/2403.03507). Gradient Low-Rank Projection (GaLore) is a memory-efficient low-rank training strategy that allows *full-parameter* learning but is more *memory-efficient* than common low-rank adaptation methods, such as LoRA. As a gradient projection method, GaLore is independent of the choice of optimizers and can be easily plugged into existing ones with only two lines of code, as shown in Algorithm 1 below. <div align="center"> <img src="imgs/galore_code_box.png" alt="Image 2" style="width: 550px; margin: 0 auto;"> </div> ## News - **2024-09-01**: We are working on GaLore 2, which is a more efficient and accessible version of GaLore. Please stay tuned! - **2024-07-11**: We release Q-GaLore: Quantized GaLore with INT4 Projection. [[paper](https://arxiv.org/abs/2407.08296)] [[code](https://github.com/VITA-Group/Q-GaLore)] - **2024-07-01**: GaLore is accepted to ICML 2024 as Oral! - **2024-04-20**: Please join our Slack workspace [GaLore-Social](https://join.slack.com/t/galore-social/shared_invite/zt-2ev152px0-DguuQ5WRTLQjtq2C88HBvQ) to discuss with us and the community. ## Installation ### Install GaLore optimizer Install from pip: ```bash pip install galore-torch ``` or if you want to install from source: ```bash git clone [email protected]:jiaweizzhao/GaLore.git cd GaLore pip install -e . ``` ### Install experiment dependencies ```bash pip install -r exp_requirements.txt ``` Our experiment scripts are tested on Python 3.8 with PyTorch 2.1. ## Usage ### Save optimizer memory using GaLore optimizers ```python from galore_torch import GaLoreAdamW, GaLoreAdamW8bit, GaLoreAdafactor # define param groups as galore_params and non_galore_params param_groups = [{'params': non_galore_params}, {'params': galore_params, 'rank': 128, 'update_proj_gap': 200, 'scale': 0.25, 'proj_type': 'std'}] optimizer = GaLoreAdamW(param_groups, lr=0.01) ``` ### Save weight gradient memory using per-layer weight updates We use `register_post_accumulate_grad_hook` provided by [PyTorch](https://pytorch.org/tutorials/intermediate/optimizer_step_in_backward_tutorial.html) (`torch>=2.1.0`) to enable per-layer weight updates. An example is shown below: ```python # define an optimizer for each parameter p, and store them in optimizer_dict for p in model.parameters(): if p.requires_grad: optimizer_dict[p] = GaLoreAdamW([{'params': p, 'rank': 128, 'update_proj_gap': 200, 'scale': 0.25, 'proj_type': 'std'}], lr=0.01) # define a hook function to update the parameter p during the backward pass def optimizer_hook(p): if p.grad is None: return optimizer_dict[p].step() optimizer_dict[p].zero_grad() # Register the hook onto every parameter for p in model.parameters(): if p.requires_grad: p.register_post_accumulate_grad_hook(optimizer_hook) ``` More details can be found in [torchrun_main.py](https://github.com/jiaweizzhao/GaLore/blob/a6bc1650984b1c090a4e108d7c0e3109ee7ad844/torchrun_main.py#L334). ## Benchmark 1: Pre-Training LLaMA on C4 dataset `torchrun_main.py` is the main script for training LLaMA models on C4 with GaLore. Our benchmark scripts for various sizes of models are in `scripts/benchmark_c4` folder. For example, to train a 60m model on C4, do the following: ```bash # LLaMA-60M, GaLore-Adam, 1 A100, 1 Node torchrun --standalone --nproc_per_node 1 torchrun_main.py \ --model_config configs/llama_60m.json \ --lr 0.01 \ --galore_scale 0.25 \ --rank 128 \ --update_proj_gap 200 \ --batch_size 256 \ --total_batch_size 512 \ --num_training_steps 10000 \ --warmup_steps 1000 \ --weight_decay 0 \ --dtype bfloat16 \ --eval_every 1000 \ --optimizer galore_adamw ``` ### Train 7B model with a single GPU with 24GB memory To train a 7B model with a single GPU such as NVIDIA RTX 4090, all you need to do is to specify `--optimizer=galore_adamw8bit_per_layer`, which enables `GaLoreAdamW8bit` with per-layer weight updates. With activation checkpointing, you can maintain a batch size of 16 tested on NVIDIA RTX 4090. ```bash # LLaMA-7B, 8-bit GaLore-Adam, single GPU, activation checkpointing # bsz=16, 22.8G, torchrun --standalone --nproc_per_node 1 torchrun_main.py \ --model_config configs/llama_7b.json \ --lr 0.005 \ --galore_scale 0.25 \ --rank 1024 \ --update_proj_gap 500 \ --batch_size 16 \ --total_batch_size 512 \ --activation_checkpointing \ --num_training_steps 150000 \ --warmup_steps 15000 \ --weight_decay 0 \ --grad_clipping 1.0 \ --dtype bfloat16 \ --eval_every 1000 \ --single_gpu \ --optimizer galore_adamw8bit_per_layer ``` Currently per-layer weight updates technique is only supported for single GPU training (`--single_gpu`) without using `nn.parallel.DistributedDataParallel`. We are working on supporting multi-GPU training with per-layer weight updates. ## Benchmark 2: Fine-Tuning RoBERTa on GLUE tasks `run_glue.py` is the main script for fine-tuning RoBERTa models on GLUE tasks with GaLore. An example script is shown below: ```bash python run_glue.py \ --model_name_or_path roberta-base \ --task_name mrpc \ --enable_galore \ --lora_all_modules \ --max_length 512 \ --seed=1234 \ --lora_r 4 \ --galore_scale 4 \ --per_device_train_batch_size 16 \ --update_proj_gap 500 \ --learning_rate 3e-5 \ --num_train_epochs 30 \ --output_dir results/ft/roberta_base/mrpc ``` ## Citation ```bibtex @misc{zhao2024galore, title={GaLore: Memory-Efficient LLM Training by Gradient Low-Rank Projection}, author={Jiawei Zhao and Zhenyu Zhang and Beidi Chen and Zhangyang Wang and Anima Anandkumar and Yuandong Tian}, year={2024}, eprint={2403.03507}, archivePrefix={arXiv}, primaryClass={cs.LG} } ```", Assign "at most 3 tags" to the expected json: {"id":"8417","tags":[]} "only from the tags list I provide: [{"id":77,"name":"3d"},{"id":89,"name":"agent"},{"id":17,"name":"ai"},{"id":54,"name":"algorithm"},{"id":24,"name":"api"},{"id":44,"name":"authentication"},{"id":3,"name":"aws"},{"id":27,"name":"backend"},{"id":60,"name":"benchmark"},{"id":72,"name":"best-practices"},{"id":39,"name":"bitcoin"},{"id":37,"name":"blockchain"},{"id":1,"name":"blog"},{"id":45,"name":"bundler"},{"id":58,"name":"cache"},{"id":21,"name":"chat"},{"id":49,"name":"cicd"},{"id":4,"name":"cli"},{"id":64,"name":"cloud-native"},{"id":48,"name":"cms"},{"id":61,"name":"compiler"},{"id":68,"name":"containerization"},{"id":92,"name":"crm"},{"id":34,"name":"data"},{"id":47,"name":"database"},{"id":8,"name":"declarative-gui "},{"id":9,"name":"deploy-tool"},{"id":53,"name":"desktop-app"},{"id":6,"name":"dev-exp-lib"},{"id":59,"name":"dev-tool"},{"id":13,"name":"ecommerce"},{"id":26,"name":"editor"},{"id":66,"name":"emulator"},{"id":62,"name":"filesystem"},{"id":80,"name":"finance"},{"id":15,"name":"firmware"},{"id":73,"name":"for-fun"},{"id":2,"name":"framework"},{"id":11,"name":"frontend"},{"id":22,"name":"game"},{"id":81,"name":"game-engine "},{"id":23,"name":"graphql"},{"id":84,"name":"gui"},{"id":91,"name":"http"},{"id":5,"name":"http-client"},{"id":51,"name":"iac"},{"id":30,"name":"ide"},{"id":78,"name":"iot"},{"id":40,"name":"json"},{"id":83,"name":"julian"},{"id":38,"name":"k8s"},{"id":31,"name":"language"},{"id":10,"name":"learning-resource"},{"id":33,"name":"lib"},{"id":41,"name":"linter"},{"id":28,"name":"lms"},{"id":16,"name":"logging"},{"id":76,"name":"low-code"},{"id":90,"name":"message-queue"},{"id":42,"name":"mobile-app"},{"id":18,"name":"monitoring"},{"id":36,"name":"networking"},{"id":7,"name":"node-version"},{"id":55,"name":"nosql"},{"id":57,"name":"observability"},{"id":46,"name":"orm"},{"id":52,"name":"os"},{"id":14,"name":"parser"},{"id":74,"name":"react"},{"id":82,"name":"real-time"},{"id":56,"name":"robot"},{"id":65,"name":"runtime"},{"id":32,"name":"sdk"},{"id":71,"name":"search"},{"id":63,"name":"secrets"},{"id":25,"name":"security"},{"id":85,"name":"server"},{"id":86,"name":"serverless"},{"id":70,"name":"storage"},{"id":75,"name":"system-design"},{"id":79,"name":"terminal"},{"id":29,"name":"testing"},{"id":12,"name":"ui"},{"id":50,"name":"ux"},{"id":88,"name":"video"},{"id":20,"name":"web-app"},{"id":35,"name":"web-server"},{"id":43,"name":"webassembly"},{"id":69,"name":"workflow"},{"id":87,"name":"yaml"}]" returns me the "expected json"