base on The PyTorch implementation of Generative Pre-trained Transformers (GPTs) using Kolmogorov-Arnold Networks (KANs) for language modeling # KAN-GPT ![PyPI - Downloads](https://img.shields.io/pypi/dm/kan-gpt) [![PyPI - Version](https://img.shields.io/pypi/v/kan-gpt)](https://pypi.org/project/kan-gpt/) [![codecov](https://codecov.io/gh/AdityaNG/kan-gpt/branch/main/graph/badge.svg?token=kan-gpt_token_here)](https://codecov.io/gh/AdityaNG/kan-gpt) [![CI](https://github.com/AdityaNG/kan-gpt/actions/workflows/main.yml/badge.svg)](https://github.com/AdityaNG/kan-gpt/actions/workflows/main.yml) [![GitHub License](https://img.shields.io/github/license/AdityaNG/kan-gpt)](https://github.com/AdityaNG/kan-gpt/blob/main/LICENSE) The PyTorch implementation of Generative Pre-trained Transformers (GPTs) using Kolmogorov-Arnold Networks (KANs) for language modeling ## Install it from PyPI ```bash pip install kan_gpt ``` ## Citation If you find our work useful cite us! ``` @misc{GANESH2024KANGPT, author = {Aditya Nalgunda Ganesh}, title = {KAN-GPT: The PyTorch implementation of Generative Pre-trained Transformers (GPTs) using Kolmogorov-Arnold Networks (KANs) for language modeling}, year = {2024}, month = {May}, note = {Release 1.0.0, 9th May 2024}, url = {https://github.com/AdityaNG/kan-gpt/} } ``` ## Usage Refer to the [KAN_GPT.ipynb](https://github.com/AdityaNG/kan-gpt/blob/main/KAN_GPT.ipynb) and [kan_gpt/prompt.py](https://github.com/AdityaNG/kan-gpt/blob/main/kan_gpt/prompt.py) for usage examples. The following is an outline of how to use the model: ```py from kan_gpt.model import GPT from transformers import GPT2Tokenizer model_config = GPT.get_default_config() model_config.model_type = "gpt2" model_config.vocab_size = 50257 model_config.block_size = 1024 model = GPT(model_config) tokenizer = GPT2Tokenizer.from_pretrained('gpt2') prompt = "Bangalore is often described as the " prompt_encoded = tokenizer.encode( text=prompt, add_special_tokens=False ) x = torch.tensor(prompt_encoded).unsqueeze(0) model.eval() y = model.generate(x, 50) # sample 50 tokens result = tokenizer.decode(y[0]) print(result) # Bangalore is often described as the Silicon Valley of India. # The city has witnessed rapid growth in the past two decades..... ``` ## Setup for Development ```bash # Download Repo git clone https://github.com/AdityaNG/kan-gpt cd kan-gpt git pull # Download Dataset python3 -m kan_gpt.download_dataset --dataset tinyshakespeare python3 -m kan_gpt.download_dataset --dataset mnist python3 -m kan_gpt.download_dataset --dataset webtext # Install dependencies for development pip install -r requirements.txt pip install -e . ``` ## Train Use the following dummy script to make sure everything is working as expected ```bash WANDB_MODE=offline CUDA_VISIBLE_DEVICE="" python3 -m kan_gpt.train --architecture MLP --batch_size 1 --dummy_dataset --device cpu --max_iters 200 WANDB_MODE=offline CUDA_VISIBLE_DEVICE="" python3 -m kan_gpt.train --architecture KAN --batch_size 1 --dummy_dataset --device cpu --max_iters 200 ``` Then make use of the training script ```bash python -m kan_gpt.train ``` ## Prompt You can prompt the model to produce text as follows ```bash python -m kan_gpt.prompt --prompt "Bangalore is often described as the " --model_path (checkpoint) ``` ## Results We train and compare KAN-GPT with an equivalent MLP-GPT model on the Tiny Shakespeare dataset. We observe that the KAN-GPT performs slightly better than the MLP-GPT. We are looking into further experiments to dive deeper. The results are shown below: | Metrics | | | |---------|---------|---------| | ![results_loss](media/results_loss.png) | ![results_cross_entropy](media/results_cross_entropy.png) | ![results_perplexity](media/results_perplexity.png) | ## TODOs - [x] Integrate [minGPT](https://github.com/karpathy/minGPT) and [pykan](https://github.com/KindXiaoming/pykan) - [x] Dataset downloading script for [WebText](https://github.com/openai/gpt-2-output-dataset) - [x] PyTorch Dataset parser for [WebText](https://github.com/openai/gpt-2-output-dataset) - [x] PyTorch Dataset parser for [tinyshakespeare](https://raw.githubusercontent.com/karpathy/char-rnn/master/data/tinyshakespeare/input.txt) - [x] Mini training POC for KAN-GPT - [x] Integrate KAN training logic from `KAN.train_kan` - [x] Train a dummy batch w/o any memory issues - [x] Mini training POC for MLP-GPT - [x] Train MLP-GPT on the webtext dataset as a baseline - [x] Train KAN-GPT on the webtext dataset as a baseline - [x] Metrics comparing KAN-GPT and MLP-GPT - [x] Auto Save checkpoints - [x] Auto Save checkpoints to W&B - [ ] Auto Download model weights from git / huggingface - [x] W&B hyperparam sweep script - [x] Script to load checkpoint in interactive mode - [ ] Reduce requrements.txt constraints - [ ] Define pydantic model for training and sweep args - [ ] Pruning the package, get rid of unused code - [ ] Training script to PyTorch Lighting - [x] Documentation: `mkdocs gh-deploy` - [x] Integrate with [efficient-kan](https://github.com/Blealtan/efficient-kan/blob/master/src/efficient_kan/kan.py) - [x] Test Cases - [x] KAN: Forward-Backward test - [x] GPT: Forward-Backward test - [x] KAN_GPT: Forward-Backward test - [x] EFFICIENT_KAN: Forward-Backward test ## Development Read the [CONTRIBUTING.md](https://github.com/AdityaNG/kan-gpt/blob/main/CONTRIBUTING.md) file. ## References - [minGPT](https://github.com/karpathy/minGPT) - [pykan](https://github.com/KindXiaoming/pykan) - [webtext](https://github.com/openai/gpt-2-output-dataset) - [tinyshakespeare](https://raw.githubusercontent.com/karpathy/char-rnn/master/data/tinyshakespeare/input.txt) ", Assign "at most 3 tags" to the expected json: {"id":"10012","tags":[]} "only from the tags list I provide: [{"id":77,"name":"3d"},{"id":89,"name":"agent"},{"id":17,"name":"ai"},{"id":54,"name":"algorithm"},{"id":24,"name":"api"},{"id":44,"name":"authentication"},{"id":3,"name":"aws"},{"id":27,"name":"backend"},{"id":60,"name":"benchmark"},{"id":72,"name":"best-practices"},{"id":39,"name":"bitcoin"},{"id":37,"name":"blockchain"},{"id":1,"name":"blog"},{"id":45,"name":"bundler"},{"id":58,"name":"cache"},{"id":21,"name":"chat"},{"id":49,"name":"cicd"},{"id":4,"name":"cli"},{"id":64,"name":"cloud-native"},{"id":48,"name":"cms"},{"id":61,"name":"compiler"},{"id":68,"name":"containerization"},{"id":92,"name":"crm"},{"id":34,"name":"data"},{"id":47,"name":"database"},{"id":8,"name":"declarative-gui "},{"id":9,"name":"deploy-tool"},{"id":53,"name":"desktop-app"},{"id":6,"name":"dev-exp-lib"},{"id":59,"name":"dev-tool"},{"id":13,"name":"ecommerce"},{"id":26,"name":"editor"},{"id":66,"name":"emulator"},{"id":62,"name":"filesystem"},{"id":80,"name":"finance"},{"id":15,"name":"firmware"},{"id":73,"name":"for-fun"},{"id":2,"name":"framework"},{"id":11,"name":"frontend"},{"id":22,"name":"game"},{"id":81,"name":"game-engine "},{"id":23,"name":"graphql"},{"id":84,"name":"gui"},{"id":91,"name":"http"},{"id":5,"name":"http-client"},{"id":51,"name":"iac"},{"id":30,"name":"ide"},{"id":78,"name":"iot"},{"id":40,"name":"json"},{"id":83,"name":"julian"},{"id":38,"name":"k8s"},{"id":31,"name":"language"},{"id":10,"name":"learning-resource"},{"id":33,"name":"lib"},{"id":41,"name":"linter"},{"id":28,"name":"lms"},{"id":16,"name":"logging"},{"id":76,"name":"low-code"},{"id":90,"name":"message-queue"},{"id":42,"name":"mobile-app"},{"id":18,"name":"monitoring"},{"id":36,"name":"networking"},{"id":7,"name":"node-version"},{"id":55,"name":"nosql"},{"id":57,"name":"observability"},{"id":46,"name":"orm"},{"id":52,"name":"os"},{"id":14,"name":"parser"},{"id":74,"name":"react"},{"id":82,"name":"real-time"},{"id":56,"name":"robot"},{"id":65,"name":"runtime"},{"id":32,"name":"sdk"},{"id":71,"name":"search"},{"id":63,"name":"secrets"},{"id":25,"name":"security"},{"id":85,"name":"server"},{"id":86,"name":"serverless"},{"id":70,"name":"storage"},{"id":75,"name":"system-design"},{"id":79,"name":"terminal"},{"id":29,"name":"testing"},{"id":12,"name":"ui"},{"id":50,"name":"ux"},{"id":88,"name":"video"},{"id":20,"name":"web-app"},{"id":35,"name":"web-server"},{"id":43,"name":"webassembly"},{"id":69,"name":"workflow"},{"id":87,"name":"yaml"}]" returns me the "expected json"