This repository contains the official PyTorch implementation for Visual Prompt Tuning.

Related tags

Admin Panels vpt
Overview

Visual Prompt Tuning

https://arxiv.org/abs/2203.12119


This repository contains the official PyTorch implementation for Visual Prompt Tuning.

vpt_teaser

Environment settings

See env_setup.sh

Structure of the this repo (key files are marked with 👉 ):

  • src/configs: handles config parameters for the experiments.

    • 👉 src/config/config.py: main config setups for experiments and explanation for each of them.
  • src/data: loading and setup input datasets. The src/data/vtab_datasets are borrowed from

    VTAB github repo.

  • src/engine: main training and eval actions here.

  • src/models: handles backbone archs and heads for different fine-tuning protocols

    • 👉 src/models/vit_prompt: a folder contains the same backbones in vit_backbones folder, specified for VPT. This folder should contain the same file names as those in vit_backbones

    • 👉 src/models/vit_models.py: main model for transformer-based models ❗️ Note ❗️ : Current version only support ViT, Swin and ViT with mae, moco-v3

    • src/models/build_model.py: main action here to utilize the config and build the model to train / eval.

  • src/solver: optimization, losses and learning rate schedules.

  • src/utils: helper functions for io, loggings, training, visualizations.

  • 👉 train.py: call this one for training and eval a model with a specified transfer type.

  • 👉 tune_fgvc.py: call this one for tuning learning rate and weight decay for a model with a specified transfer type. We used this script for FGVC tasks.

  • 👉 tune_vtab.py: call this one for tuning vtab tasks: use 800/200 split to find the best lr and wd, and use the best lr/wd for the final runs

  • launch.py: contains functions used to launch the job.

Experiments

Key configs:

  • 🔥 VPT related:
    • MODEL.PROMPT.NUM_TOKENS: prompt length
    • MODEL.PROMPT.DEEP: deep or shallow prompt
  • Fine-tuning method specification:
    • MODEL.TRANSFER_TYPE
  • Vision backbones:
    • DATA.FEATURE: specify which representation to use
    • MODEL.TYPE: the general backbone type, e.g., "vit" or "swin"
    • MODEL.MODEL_ROOT: folder with pre-trained model checkpoints
  • Optimization related:
    • SOLVER.BASE_LR: learning rate for the experiment
    • SOLVER.WEIGHT_DECAY: weight decay value for the experiment
    • DATA.BATCH_SIZE
  • Datasets related:
    • DATA.NAME
    • DATA.DATAPATH: where you put the datasets
    • DATA.NUMBER_CLASSES
  • Others:
    • RUN_N_TIMES: ensure only run once in case for duplicated submision, not used during vtab runs
    • OUTPUT_DIR: output dir of the final model and logs
    • MODEL.SAVE_CKPT: if set to True, will save model ckpts and final output of both val and test set

Datasets preperation:

See Table 8 in the Appendix for dataset details.

Pre-trained model preperation

Download and place the pre-trained Transformer-based backbones to MODEL.MODEL_ROOT (ConvNeXt-Base and ResNet50 would be automatically downloaded via the links in the code).

See Table 9 in the Appendix for more details about pre-trained backbones.

Pre-trained Backbone Pre-trained Objective Link md5sum
ViT-B/16 Supervised link d9715d
ViT-B/16 MoCo v3 link 8f39ce
ViT-B/16 MAE link 8cad7c
Swin-B Supervised link bf9cc1
ConvNeXt-Base Supervised link -
ResNet-50 Supervised link -

Examples for training and aggregating results

See demo.ipynb for how to use this repo.

Citation

If you find our work helpful in your research, please cite it as:

@inproceedings{jia2022vpt,
  title={Visual Prompt Tuning},
  author={Jia, Menglin and Tang, Luming and Chen, Bor-Chun and Cardie, Claire and Belongie, Serge and Hariharan, Bharath and Lim, Ser-Nam},
  booktitle={European Conference on Computer Vision (ECCV)},
  year={2022}
}

License

The majority of VPT is licensed under the CC-BY-NC 4.0 license (see LICENSE for details). Portions of the project are available under separate license terms: GitHub - google-research/task_adaptation and huggingface/transformers are licensed under the Apache 2.0 license; Swin-Transformer, ConvNeXt and ViT-pytorch are licensed under the MIT license; and MoCo-v3 and MAE are licensed under the Attribution-NonCommercial 4.0 International license.

Comments
  • vtab1k Dataset accuracy

    vtab1k Dataset accuracy

    Hello I have done the vtab1k experiment on three datasets, but the experimental results are much different from the paper. The result of cifar100 dataset is 72.4, the result of smallnorb/azimuth dataset is 15.7, and the result of smallnorb/elevation dataset is 22.6. I don't know why. Is my profile wrong? My profile is as follows:

              NUM_GPUS: 1
              NUM_SHARDS: 1
              OUTPUT_DIR: ""
              RUN_N_TIMES: 1
              MODEL:
                TRANSFER_TYPE: "prompt"
                TYPE: "vit"
                LINEAR:
                  MLP_SIZES: []
              SOLVER:
                SCHEDULER: "cosine"
                PATIENCE: 300
                LOSS: "softmax"
                OPTIMIZER: "sgd"
                MOMENTUM: 0.9
                WEIGHT_DECAY: 0.0001
                LOG_EVERY_N: 100
                WARMUP_EPOCH: 10
                TOTAL_EPOCH: 100
              DATA:
                NAME: "vtab-cifar(num_classes=100)"
                NUMBER_CLASSES: 100
                DATAPATH: "/home/vpt/dataset"
                FEATURE: "sup_vitb16_224"
                BATCH_SIZE: 128
    
    opened by 111chengxuyuan 12
  • vtab1k dataset splits

    vtab1k dataset splits

    hi, for the vtab-1k benchmark, we need to use the tensorflow api to get the exact dataset splits, which is quite hard for people from mainland, China.

    i was wondering if you could upload the splits txt files to this repo? thx

    opened by zhaoedf 10
  • Ensemble seems degrades the performance on CIFAR-100

    Ensemble seems degrades the performance on CIFAR-100

    Hi, thanks for your great work and thorough experiments. I found that in Fig. 15., the results show ensemble can improve the performance of VPT. I reimplemented the ensemble method in a quick way: I trained two sets of visual prompts in parallel on the CIFAR-100 dataset. And during the testing, I directly ensembled those prompts, i.e., using 2X prompts to perform inference. However, I found the ensemble surely degraded the performance, i.e., two sets of visual prompts (6 tokens each set) achieved accuracies 80.36, 80.85 respectively. But after ensembling those prompts (6*2=12 tokens), the accuracy was 74.75. That's a little strange to me because the result of the ensemble was much worse than the separate results. Could you please share some ideas on this phenomenon? Thanks a lot!

    By the way, the best numbers of tokens of VPT are different for different datasets. And it seems that sometimes too many prompts would lead to performance degradation. And I noticed the 5 sets of different prompts were ensembled in the paper, which means maybe the number of prompts is up to 500. Will it degrade the final performance?

    opened by nuaajeff 7
  • Optimal hyperparameters

    Optimal hyperparameters

    Hello! The paper contains information that the optimal hyperparameter values for each experiment can be found in Appendix C. However, there is no such information in Appendix C. Could you share the optimal hyperparameter values for each experiment to save compute power required for grid search?

    opened by ptomaszewska 7
  • How to train successfully in Cifar100

    How to train successfully in Cifar100

    I have tried a long time for training Cifar100, but it still does not work. please help me to run successfully, thanks very much!

    1.Have downloaded the Cifar100 dataset: 1660829391508

    2.Have downloaded the pretrained_model and rename 1660829299011

    3.Configs file: vpt/configs/prompt/cifar100.yaml 1660829612675

    Configs file: vpt/src/configs/config.py 1660829798311

    4.Train model: /vpt/run.sh 1660829880467

    5. Error log W

    opened by miss-rain 4
  • How to use multiple GPUs?

    How to use multiple GPUs?

    If I set NUM_GPUS =2, there are following mistakes. Could you please tell me how to use multiple GPUs?

    Traceback (most recent call last):
      File "train.py", line 132, in <module>
        main(args)
      File "train.py", line 127, in main
        train(cfg, args)
      File "train.py", line 102, in train
        train_loader, val_loader, test_loader = get_loaders(cfg, logger)
      File "train.py", line 69, in get_loaders
        train_loader = data_loader.construct_trainval_loader(cfg)
      File "/home/haoc/wangyidong/vpt/src/data/loader.py", line 79, in construct_trainval_loader
        drop_last=drop_last,
      File "/home/haoc/wangyidong/vpt/src/data/loader.py", line 39, in _construct_loader
        sampler = DistributedSampler(dataset) if cfg.NUM_GPUS > 1 else None
      File "/home/haoc/miniconda3/envs/prompt/lib/python3.7/site-packages/torch/utils/data/distributed.py", line 65, in __init__
        num_replicas = dist.get_world_size()
      File "/home/haoc/miniconda3/envs/prompt/lib/python3.7/site-packages/torch/distributed/distributed_c10d.py", line 638, in get_world_size
        return _get_group_size(group)
      File "/home/haoc/miniconda3/envs/prompt/lib/python3.7/site-packages/torch/distributed/distributed_c10d.py", line 220, in _get_group_size
        _check_default_pg()
      File "/home/haoc/miniconda3/envs/prompt/lib/python3.7/site-packages/torch/distributed/distributed_c10d.py", line 211, in _check_default_pg
        "Default process group is not initialized"
    AssertionError: Default process group is not initialized
    
    
    opened by qianlanwyd 4
  • Total tunable parameters

    Total tunable parameters

    image Hello, I'd like to ask you a quesetion, what is the unit of the "Total params" in the experiment table? is it "M"? For example, the "Total params" of VPT-deep is 1.18x, does it mean 1.18M?

    opened by 111chengxuyuan 3
  • Metric for the total parameters

    Metric for the total parameters

    Thanks for the great work!

    I noticed that you are using Total Params as one metric in your paper to measure the trainable parameters, as below.

    image

    However, I am quite confused about how to derive these scores. For example, LINEAR only tunes the classification heads, so the total trainable parameters should be (sum_of_classes * 768 + sum_of_classes = 0.72M) for the 19 VTAB dataset added together, while FULL should have (85.8 * 19 = 1630.2M) trainable parameters. It seems a little bit far from 1.01x for LINEAR and 19.01x for FULL finetuning.

    Besides, is it possible to share the number of prompts for each task used to get the results in Table 4?

    Kind regards, Charles

    opened by Charleshhy 3
  • All results for each dataset in Vtab-1k

    All results for each dataset in Vtab-1k

    Hello, I want to konw if you have all results for each dataset with CNN, because I am researching about the prompt learning with CNN. If you have these, I will be thankful.

    opened by xiaohhuiiiii 2
  • Why visual prompts can outperform full?

    Why visual prompts can outperform full?

    Thanks for this wonderful work. The paper contents a lot of details but i still i want to know why the learned visual prompt can achieve such good performance even outperform the Full. i am confused about it and wondering you can help solve this problem.

    Sent from PPHub

    opened by zsmmsz99 2
  • The tunable parameters of VPT+Bias for Semantic Segmentation

    The tunable parameters of VPT+Bias for Semantic Segmentation

    Can you provide more details of different methods and their tunable parameters in Table 4 for semantic segmentation?

    Except for header parameters: 1)the tunable parameter number of BIAS is 13.46-13.18=0.28M 2)the tunable parameter number of VPT is 13.43-13.18=0.25M 3) why is the tunable parameter number of VPT+BIAS 15.79-13.18=2.61M, rather than 0.28+0.25=0.53M?

    It seems to me that BIAS was reimplemented based on the paper [5] (fine-tunes only the bias terms). However, was VPT+BIAS reimplemented based on the paper [8] (fine-tunes the bias terms and introduces some lightweight residual layers)?

    image image

    opened by liulingbo918 2
Owner
Menglin Jia
K-Mn-P: "jia meng lin" (mandarin pronunciation of those chemical elements)
Menglin Jia
Prompt-aligned Gradient for Prompt Tuning

Prompt-aligned Gradient for Prompt Tuning We present Prompt-aligned Gradient, dubbed ProGrad, to prevent prompt tuning from forgetting the the general

PatatiPatata 71 Nov 25, 2022
Code for the paper, DPTDR: Deep Prompt Tuning for Dense Passage Retrieval

Introduction Code for the paper, DPTDR: Deep Prompt Tuning for Dense Passage Retrieval. We mainly develped our code based on coCondenser and P-tuning

tangzhy 16 Nov 10, 2022
This repository contains the code used for distillation and fine-tuning of compact biomedical transformers that have been introduced in the paper "On The Effectiveness of Compact Biomedical Transformers"

Compact Biomedical Transformers This repository contains the code used for distillation and fine-tuning of compact biomedical transformers that have b

NLPie Research 6 Nov 8, 2022
[KDD22] Official PyTorch implementation for "Towards Unified Conversational Recommender Systems via Knowledge-Enhanced Prompt Learning".

UniCRS This is the official PyTorch implementation for the paper: Xiaolei Wang*, Kun Zhou*, Ji-Rong Wen, Wayne Xin Zhao. Towards Unified Conversationa

Xiaolei Wang 13 Nov 22, 2022
Official code repository for Prompt-DT.

Prompting Decisicion Transformer for Few-Shot Policy Generalization Official code repository for Prompt-DT. [website][paper] Prompt-DT Architecture: I

Mengdi-Xu 33 Nov 24, 2022
Official repository of paper titled "MaPLe: Multi-modal Prompt Learning".

MaPLe: Multi-modal Prompt Learning MaPLe: Multi-modal Prompt Learning Muhammad Uzair Khattak, Hanoona Rasheed, Muhammad Maaz, Salman Khan, Fahad Shahb

Muhammad Uzair Khattak 89 Nov 14, 2022
Code for the NeurIPS 2022 paper "Generative Visual Prompt: Unifying Distributional Control of Pre-Trained Generative Models"

Generative Visual Prompt: Unifying Distributional Control of Pre-Trained Generative Models Official PyTorch implementation of our NeurIPS 2022 paper G

Chen Wu (吴尘) 83 Nov 20, 2022
A visual dashboard for model tuning.

MindInsight Introduction Installation System Environment Information Confirmation Installation Methods Installation by pip Installation by Source Code

Yong Dai 1 Aug 2, 2022
Pytorch Implementation of Learning to Prompt (L2P) for Continual Learning @ CVPR22

L2P Pytorch Implementation This repository contains PyTorch implementation code for awesome continual learning method L2P, proposed in Wang, Zifeng, e

Jaeho Lee 36 Nov 25, 2022
[CVPR2022] This repository contains code for the paper "Nested Collaborative Learning for Long-Tailed Visual Recognition", published at CVPR 2022

Nested Collaborative Learning for Long-Tailed Visual Recognition This repository is the official PyTorch implementation of the paper in CVPR 2022: Nes

Jun Li 62 Nov 4, 2022
Official implementation for Open-set Face Identification on Few-shot Gallery by Fine-Tuning

OSFI-by-FineTuning Official implementation for Open-set Face Identification on Few-shot Gallery by Fine-Tuning Requirements Pytorch 1.7.1 Torchvision

HojinPark 5 Aug 19, 2022
This repository contains utilities for converting Keras models to PyTorch, converting TF models to Keras, and converting TF models to PyTorch.

weight-transfer This repository contains utilities for converting Keras models to PyTorch, converting TF models to Keras, and converting TF models to

Kira_Z 6 Sep 20, 2022
Official PyTorch implementation of our CVPR 2022 paper: Beyond a Pre-Trained Object Detector: Cross-Modal Textual and Visual Context for Image Captioning

Beyond a Pre-Trained Object Detector: Cross-Modal Textual and Visual Context for Image Captioning This is the official PyTorch implementation of our C

null 40 Oct 28, 2022
Official Pytorch Implementation of SPECTRE: Visual Speech-Aware Perceptual 3D Facial Expression Reconstruction from Videos

SPECTRE: Visual Speech-Aware Perceptual 3D Facial Expression Reconstruction from Videos Our method performs visual-speech aware 3D reconstruction so t

Filntisis Panagiotis 80 Nov 14, 2022
[ECCV'22] The official PyTorch implementation of our ECCV 2022 paper: "AiATrack: Attention in Attention for Transformer Visual Tracking".

AiATrack The official PyTorch implementation of our ECCV 2022 paper: AiATrack: Attention in Attention for Transformer Visual Tracking Shenyuan Gao, Ch

Shenyuan Gao 57 Nov 23, 2022
The official repository for [CVPR2022] MOVER: Human-Aware Object Placement for Visual Environment Reconstruction.

Human-Aware Object Placement for Visual Environment Reconstruction. (CVPR2022) [Project Page] [Paper] [MPI Project Page] [Youtube Video] 3D Scene and

Hongwei Yi 71 Oct 3, 2022
Play With Python is an open source repository, which contains various types of games built in Python. Contributors are open to contribute in this repository with their contributions!

Play With Python ?? ?? Play Games! Built in Python. ?? Welcome contributors! Gaming is really a workout for your mind disguised as fun. Studies have s

Abhishek Sharma 7 Nov 17, 2022
Improving Visual Grounding with Visual-Linguistic Verification and Iterative Reasoning, CVPR 2022

Improving Visual Grounding with Visual-Linguistic Verification and Iterative Reasoning This is the official implementation of Improving Visual Groundi

null 45 Nov 17, 2022