AI prompts
base on # Unified-IO 2
This repo contains code for [Unified-IO 2](https://unified-io-2.allenai.org/), including code to run a demo, do training,
and do inference. This codebase is modified from [T5X](https://github.com/google-research/t5x).
## News:
- [2/15/2024] We release the Pytorch code for unified-io 2. Details can be found [here](https://github.com/allenai/unified-io-2.pytorch)
- [1/5/2024] We release the source code of VIT-VQGAN in JAX, which is used to train our audio tokenizer. Details can be found [here](https://github.com/jiasenlu/vit-vqgan-jax/)
## Install
Install the dependencies with pip
- **Note:** Since this project has taken quite a long time, some of the packages we used are from older versions. We recently discovered that importing `orbax.checkpoint` may cause conflicts for `dtype="bfloat16"` with JAX when using Python 3.9, but it still works with **Python 3.8** (e.g., 3.8.10, which is the default in TPU VMs). This issue is possibly due to internal changes in orbax.checkpoint and pip.
For a TPU:
```
python3 -m pip install -e '.[tpu]' -f https://storage.googleapis.com/jax-releases/libtpu_releases.html -f https://storage.googleapis.com/jax-releases/jax_releases.html
```
For a GPU/CPU (note we have been using TPUs so GPU setups are not well tested):
```
python3 -m pip install -e '.' -f https://storage.googleapis.com/jax-releases/jax_cuda_releases.html
```
Running the demo requires additional dependencies, install them with:
```
python3 -m pip install -e '.[demo]' -f https://storage.googleapis.com/jax-releases/libtpu_releases.html -f https://storage.googleapis.com/jax-releases/jax_releases.html
```
The LLaMa tokenizer also needs to be installed, download the `.model` file from https://github.com/facebookresearch/llama/tree/main?tab=readme-ov-file
and then update `t5x/examples/unified_io/config.py` so `LLAMA_TOKENIZER_PATH` points to the download location.
## Checkpoints
We make checkpoints in the T5X format available on S3:
- XXL: s3://ai2-prior-uio/public/uio2-checkpoints/xxl-3m
- XL: s3://ai2-prior-uio/public/uio2-checkpoints/xl-3m
- Large: s3://ai2-prior-uio/public/uio2-checkpoints/large-3m
To download, copy the directory recursively. For example:
```
aws s3 --no-sign-request cp --recursive s3://ai2-prior-uio/public/uio2-checkpoints/large-3m large-3m --exclude "state*"
```
They should be copied to a local disk or to google file storage. Here, the `--exclude "state*"`
flag excludes the optimizer state from the download, it can be removed if you want
to continue training the checkpoint from the current optimizer state.
## Demo
To run the model interactively the demo notebook can be run.
Make sure the demo dependencies have been [installed](#Setup).
Then run the demo notebook:
```
jupyter notebook demo.ipynb
```
Set `FULL_CKPT_PATH` and `MODEL_TYPE` in the second cell to your checkpoint and
the correct model size. Then the notebook can be used to start the demo.
The demo shows how to load the model, parameters, and do inference.
The demo will be slow the first time it is used because the inference function
needs to be compiled, subsequent calls with similar inputs/outputs will be
much faster.
## Data
To train and eval on entire datasets the datasets need to be registered with `seqio` in `seqio.TaskRegistry`. See
`t5x/examples/unifiedio/data/tasks.py` for examples. See [seqio](https://github.com/google/seqio)
for more details on how datasets are managed by seqio.
Some datasets require running a pre-processing script before they can be used.
Make sure `config.MULTITASK_TFDS_DATA_DIR` is updated to
point to the location to store the datasets.
### Datasets
We provided some initial datasets in `t5x/examples/unifiedio/data/tasks.py`.
Our datasets are generally built one of three ways:
1. Constructed as a `tensorflow_dataset` and then uploaded to the location specified in `config.MULTITASK_TFDS_DATA_DIR`
2. Constructed as a set of tfrecords and uploaded to the same location
3. Directly using a dataset from https://www.tensorflow.org/datasets/catalog/overview
Datasets built in the first or second way require running a build script before they can be
used. `create_data` contains the needed build scripts. For example running:
```
python3 create_data/tfdatasets/coco_all/build.py ~/data/tfds ~/data/vqa ~/data/coco_annotations
```
Will upload a tfdataset of COCO data, which allows tasks such as `image_generation_coco_2017`
and `image_caption_coco_2017` to be used. Some datasets, such as the refexp datasets, that use
the public tensoflow catalog might have their own manual pre-processing steps as well
which will be specified on their webpage.
UnifiedIO 2 contains a large number of tasks, for this initial release we only include
a subset but will add more as we test and verify additional tasks.
### Preprocessing
Pre-processing in UIO2 happens in three stages:
1. Task-specific pre-processing constructs a prompt and builds input and outputs in the supported modalities. This stage needs to resize and pad images into the
correct sizes, and provide masks to show which parts of the image are padding (typically with `unified_io.data.data_utils.resize_and_pad`).
Audio segments need to be converted to mel-spectrograms, which can also be masked if working with
noised data. This stage is implemented by various preprocessing functions in `unified_io.data.preprocessing`.
The demo shows how to do this for raw inputs.
To allow this stage to do different pre-processing during training and testing,
we pass a `is_training` field in sequence_length dictionary to indicate
whether the dataset is being used for training or testing.
2. Next `modality_processing.unified_io_preprocessor` is run. This function does various task-general pre-preprocessing steps,
such as tokenizing the text, and adds empty values for missing modalities so the output dataset has a consistent set of fields.
3. Finally `UnifiedIOFeatureConverter` is applied, this can happen
after multiple datasets have been combined into a `seqio.Mixture`.
This function will make sure the output dataset has a consistent structure and is padded to have
fixed-size tensors, as is needed for jax. This dataset can now be batched and passed directly
into the loss or prediction functions of a UnifiedIO 2 model.
The padding is determined by the sequence_len dictionary.
To add a dataset, register it with seqio and ensure the last pre-processor
is `modality_processing.unified_io_preprocessor`. The preceding
functions should make sure the dataset has the appropriate fields for that function.
### Prompts
Our entire set of prompts in `t5x/examples/unified_io/data/prompt_dict`,
we randomly select among these prompts during training.
### Visualization
We include a visualization script to show what the data looks like after post-processing:
```
python3 t5x/examples/unified_io/scripts/dataset_visualize.py refcoco_unc viz --override```
```
To get a more compact view:
```
python3 t5x/examples/unified_io/scripts/dataset_visualize.py refcoco_unc viz --override --gin.get_target_modalities.target_modality=[\"text\"] --gin.get_input_modalities.input_modality=[\"text\",\"image\"] --nomasks
```
## Training
Once a checkpoint is downloaded and a dataset is ready, training can be run using train.py.
Our training strategy largely follows T5X, which is configured through [gin](https://github.com/google/gin-config).
Follow the setup from `https://github.com/google-research/t5x` to train on TPUs.
For example, to fine-tune the large model on refexp:
```
python3 t5x/train.py --gin_file=t5x/examples/unified_io/t5_1_1/large.gin --gin_file=t5x/examples/unified_io/t5_1_1/finetune/refexp.gin --gin.INITIAL_CHECKPOINT_PATH=\"/path/to/checkpoint\" --gin.MODEL_DIR=\"path/to/output_dir\" --gin.BATCH_SIZE=8
```
### Modalities
UnifiedIO 2 can be run on a subset of the supported modality, which makes training more
efficient. This can be set through the gin-configured parameters in
`get_input_modalities` and `get_target_modalities`. For example, refexp.gin
only turns on the image/text inputs and text outputs.
### Sequence Lengths
Due to jax's fixed size tensor constraint, we by default pad all inputs and targets to the
model to the maximum length supported. When training on mixtures where this is excessive,
this can be tweaked by changing the sequence_lengths used by `seqio`
For example, refexp,gin reduce the input and output sequence length since
refexp has little text.
### Wandb
We have modified train.py to use wandb, just make sure a `WANDB_API_KEY` environment variable is set.
The gin configurable function `utils.init_wandb` should be modified or configured
through gin to select the correct name/group/project/entity.
### Packing
If the training mixture contains a mix of long and short examples, packing
can make things more efficient. Packing will pack up to two examples together
into a single input sequence, it can be turned on with this flag:
```
--gin.PackingStrategy.pack_max_len=(864, 1280)
```
During training, two examples will be attempted to be packed in a sequence with total
input length of 864 input length and target length or 1280. A heuristic algorithm
will try to find pairs of examples that fit this criterion as data is streamed to
the training server, if none are found only one example will be used.
If this happens too frequently it is a good idea to increase the max length.
Statistics will be logged to wandb to track the packing efficiency.
## Evaluation
Evaluation script are run using eval.py, for example:
```
python3 t5x/eval.py --gin_file=t5x/examples/unified_io/t5_1_1/large.gin --gin_file=t5x/examples/unified_io/t5_1_1/eval/vision_language.gin --gin.CHECKPOINT_PATH=\"large-3m\" --gin.MIXTURE_OR_TASK_NAME=\"refcoco_unc\" --gin.EVAL_OUTPUT_DIR=\"output\"
```
The target dataset must have metrics registered with seqio. Evaluations script
can be similarly made more efficient by only using the needed modalities and
choosing the sequence lengths appropriately. Note most of our official results
come from collecting outputs and then running offline evaluations, the metrics
here are used mostly for validation scores.
## Citation
```bibtex
@article{lu2023uio2,
title = {Unified-IO 2: Scaling Autoregressive Multimodal Models with Vision, Language, Audio, and Action},
author = {Jiasen Lu and Christopher Clark and Sangho Lee and Zichen Zhang and Savya Khosla and Ryan Marten and Derek Hoiem and Aniruddha Kembhavi},
journal = {arXiv preprint arXiv:2312.17172},
year = {2023},
}
```
", Assign "at most 3 tags" to the expected json: {"id":"6669","tags":[]} "only from the tags list I provide: [{"id":77,"name":"3d"},{"id":89,"name":"agent"},{"id":17,"name":"ai"},{"id":54,"name":"algorithm"},{"id":24,"name":"api"},{"id":44,"name":"authentication"},{"id":3,"name":"aws"},{"id":27,"name":"backend"},{"id":60,"name":"benchmark"},{"id":72,"name":"best-practices"},{"id":39,"name":"bitcoin"},{"id":37,"name":"blockchain"},{"id":1,"name":"blog"},{"id":45,"name":"bundler"},{"id":58,"name":"cache"},{"id":21,"name":"chat"},{"id":49,"name":"cicd"},{"id":4,"name":"cli"},{"id":64,"name":"cloud-native"},{"id":48,"name":"cms"},{"id":61,"name":"compiler"},{"id":68,"name":"containerization"},{"id":92,"name":"crm"},{"id":34,"name":"data"},{"id":47,"name":"database"},{"id":8,"name":"declarative-gui "},{"id":9,"name":"deploy-tool"},{"id":53,"name":"desktop-app"},{"id":6,"name":"dev-exp-lib"},{"id":59,"name":"dev-tool"},{"id":13,"name":"ecommerce"},{"id":26,"name":"editor"},{"id":66,"name":"emulator"},{"id":62,"name":"filesystem"},{"id":80,"name":"finance"},{"id":15,"name":"firmware"},{"id":73,"name":"for-fun"},{"id":2,"name":"framework"},{"id":11,"name":"frontend"},{"id":22,"name":"game"},{"id":81,"name":"game-engine "},{"id":23,"name":"graphql"},{"id":84,"name":"gui"},{"id":91,"name":"http"},{"id":5,"name":"http-client"},{"id":51,"name":"iac"},{"id":30,"name":"ide"},{"id":78,"name":"iot"},{"id":40,"name":"json"},{"id":83,"name":"julian"},{"id":38,"name":"k8s"},{"id":31,"name":"language"},{"id":10,"name":"learning-resource"},{"id":33,"name":"lib"},{"id":41,"name":"linter"},{"id":28,"name":"lms"},{"id":16,"name":"logging"},{"id":76,"name":"low-code"},{"id":90,"name":"message-queue"},{"id":42,"name":"mobile-app"},{"id":18,"name":"monitoring"},{"id":36,"name":"networking"},{"id":7,"name":"node-version"},{"id":55,"name":"nosql"},{"id":57,"name":"observability"},{"id":46,"name":"orm"},{"id":52,"name":"os"},{"id":14,"name":"parser"},{"id":74,"name":"react"},{"id":82,"name":"real-time"},{"id":56,"name":"robot"},{"id":65,"name":"runtime"},{"id":32,"name":"sdk"},{"id":71,"name":"search"},{"id":63,"name":"secrets"},{"id":25,"name":"security"},{"id":85,"name":"server"},{"id":86,"name":"serverless"},{"id":70,"name":"storage"},{"id":75,"name":"system-design"},{"id":79,"name":"terminal"},{"id":29,"name":"testing"},{"id":12,"name":"ui"},{"id":50,"name":"ux"},{"id":88,"name":"video"},{"id":20,"name":"web-app"},{"id":35,"name":"web-server"},{"id":43,"name":"webassembly"},{"id":69,"name":"workflow"},{"id":87,"name":"yaml"}]" returns me the "expected json"