Grad2Task: Improved Few-shot Text Classification Using Gradients for Task Representation

Overview

Grad2Task: Improved Few-shot Text Classification Using Gradients for Task Representation

Prerequisites

This repo is built upon a local copy of transformers==2.1.1. This repo has been tested on torch==1.4.0 with python 3.7 and CUDA 10.1.

To start, create a new environment and install:

conda create -n grad2task python=3.7
conda activate grad2task
cd Grad2Task
pip install -e .

We use wandb for logging. Please set it up following this doc and specify your project name on wandb in run_meta_training.sh:

export WANDB=[YOUR PROJECT NAME]

Download the dataset and unzip it under the main folder: https://drive.google.com/file/d/1uAdgZFYv9epk6tQVQ3SwboxFpSlkC_ZW/view?usp=sharing

If need to place it somewhere else, specify its path in path.sh.

Train & Evaluation

To train/evaluate models:

bash meta_learn.sh [MODEL_NAME] [MODE] [EXP_ID]

where [MODEL_NAME] refers to model name, [MODE] is experiment model and [EXP_ID] is an optional experiment id used for mark different runs using the same model. Options for [MODEL_NAM] and MODE are listed as follow:

[MODE] Description
train Training models.
test_best Test the model with the best validation performance.
test_latest Test the latest checkpoint.
test Test model without meta-training. Only applicable to the fine-tune-baseline model.
[MODEL_NAME] Description
fine-tune-baseline Fine-tuning BERT for each task separately.
bert-protonet-euc ProtoNet with BERT as encoder, using Euclidean distance as distance metric.
bert-protonet-euc-bn ProtoNet with BERT+Bottleneck Adapters as encoder, using Euclidean distance as distance metric.
bert-protonet ProtoNet with BERT as encoder, using cosine distance as distance metric.
bert-protonet-bn ProtoNet with BERT+Bottleneck Adapters as encoder, using cosine distance as distance metric.
bert-leopard Leopard with pretrained BERT [1].
bert-leopard-fixlr Leopard but with fixed learning rates.
bert-cnap-bn-euc-context-cls-shift-scale-ar Our proposed approach using gradients as task representation.
bert-cnap-bn-euc-context-cls-shift-scale-ar-X Our proposed approach using average input encoding as task representation.
bert-cnap-bn-euc-context-cls-shift-scale-ar-XGrad Our proposed approach using both gradients and input encoding as task representation.
bert-cnap-bn-euc-context-cls-shift-scale-ar-XY Our proposed approach using input and textual label encoding as task representation.
bert-cnap-bn-euc-context-shift-scale-ar Same with our proposed approach except adapting all tokens instead of just the [CLS] token as we do.
bert-cnap-bn-pretrained-taskemb Our proposed approach with pretrained task embedding model.
bert-cnap-bn-hyper A hypernetwork based approach.

To run a model with different hyperparameters, first name this run by [EXP_ID] and then specify the new hyperparameters in run/meta_learn.sh. For example, if one wants to run bert-protonet-euc with a smaller learning rate, they could modify run/meta_learn.sh as:

...
elif [ $1 == "bert-protonet-bn" ]; then # ProtoNet with cosince distance
    export LEARNING_RATE=2e-5
    export CHECKPOINT_FREQ=1000
    if [ ${EXP_ID} == *"lr1e-5" ]; then
        export LEARNING_RATE=1e-5
        export CHECKPOINT_FREQ=2000
        # modify other hyperparameters here
    fi
...

and then run:

bash meta_learn.sh bert-protonet-bn train lr1e-5

Reference

[1] T. Bansal, R. Jha, and A. McCallum. Learning to few-shot learn across diverse natural language classification tasks. In Proceedings of the 28th International Conference on Computational Linguistics, pages 5108–5123, 2020.

Owner
Jixuan Wang
Computer Science PhD student at University of Toronto. Research interests include deep learning and machine learning, and their applications in healthcare.
Jixuan Wang
3D Multi-Person Pose Estimation by Integrating Top-Down and Bottom-Up Networks

3D Multi-Person Pose Estimation by Integrating Top-Down and Bottom-Up Networks Introduction This repository contains the code and models for the follo

124 Jan 06, 2023
Text to Image Generation with Semantic-Spatial Aware GAN

text2image This repository includes the implementation for Text to Image Generation with Semantic-Spatial Aware GAN This repo is not completely. Netwo

CVDDL 124 Dec 30, 2022
This is the replication package for paper submission: Towards Training Reproducible Deep Learning Models.

This is the replication package for paper submission: Towards Training Reproducible Deep Learning Models.

0 Feb 02, 2022
A web-based application for quick, scalable, and automated hyperparameter tuning and stacked ensembling in Python.

Xcessiv Xcessiv is a tool to help you create the biggest, craziest, and most excessive stacked ensembles you can think of. Stacked ensembles are simpl

Reiichiro Nakano 1.3k Nov 17, 2022
A PyTorch implementation of Implicit Q-Learning

IQL-PyTorch This repository houses a minimal PyTorch implementation of Implicit Q-Learning (IQL), an offline reinforcement learning algorithm, along w

Garrett Thomas 30 Dec 12, 2022
Learning Representations that Support Robust Transfer of Predictors

Transfer Risk Minimization (TRM) Code for Learning Representations that Support Robust Transfer of Predictors Prepare the Datasets Preprocess the Scen

Yilun Xu 15 Dec 07, 2022
Official Pytorch implementation of 'RoI Tanh-polar Transformer Network for Face Parsing in the Wild.'

Official Pytorch implementation of 'RoI Tanh-polar Transformer Network for Face Parsing in the Wild.'

Jie Shen 125 Jan 08, 2023
[ICCV'21] Neural Radiance Flow for 4D View Synthesis and Video Processing

NeRFlow [ICCV'21] Neural Radiance Flow for 4D View Synthesis and Video Processing Datasets The pouring dataset used for experiments can be download he

44 Dec 20, 2022
Code repository for Self-supervised Structure-sensitive Learning, CVPR'17

Self-supervised Structure-sensitive Learning (SSL) Ke Gong, Xiaodan Liang, Xiaohui Shen, Liang Lin, "Look into Person: Self-supervised Structure-sensi

Clay Gong 219 Dec 29, 2022
ICCV2021 Paper: AutoShape: Real-Time Shape-Aware Monocular 3D Object Detection

ICCV2021 Paper: AutoShape: Real-Time Shape-Aware Monocular 3D Object Detection

Zongdai 107 Dec 20, 2022
Simple keras FCN Encoder/Decoder model for MS-COCO (food subset) segmentation

FCN_MSCOCO_Food_Segmentation Simple keras FCN Encoder/Decoder model for MS-COCO (food subset) segmentation Input data: [http://mscoco.org/dataset/#ove

Alexander Kalinovsky 11 Jan 08, 2019
A python-image-classification web application project, written in Python and served through the Flask Microframework

A python-image-classification web application project, written in Python and served through the Flask Microframework. This Project implements the VGG16 covolutional neural network, through Keras and

Gerald Maduabuchi 19 Dec 12, 2022
a reimplementation of Holistically-Nested Edge Detection in PyTorch

pytorch-hed This is a personal reimplementation of Holistically-Nested Edge Detection [1] using PyTorch. Should you be making use of this work, please

Simon Niklaus 375 Dec 06, 2022
VIL-100: A New Dataset and A Baseline Model for Video Instance Lane Detection (ICCV 2021)

Preparation Please see dataset/README.md to get more details about our datasets-VIL100 Please see INSTALL.md to install environment and evaluation too

82 Dec 15, 2022
CDTrans: Cross-domain Transformer for Unsupervised Domain Adaptation

CDTrans: Cross-domain Transformer for Unsupervised Domain Adaptation [arxiv] This is the official repository for CDTrans: Cross-domain Transformer for

238 Dec 22, 2022
Official implementation of the Implicit Behavioral Cloning (IBC) algorithm

Implicit Behavioral Cloning This codebase contains the official implementation of the Implicit Behavioral Cloning (IBC) algorithm from our paper: Impl

Google Research 210 Dec 09, 2022
TalkingHead-1KH is a talking-head dataset consisting of YouTube videos

TalkingHead-1KH Dataset TalkingHead-1KH is a talking-head dataset consisting of YouTube videos, originally created as a benchmark for face-vid2vid: On

173 Dec 29, 2022
LEDNet: A Lightweight Encoder-Decoder Network for Real-time Semantic Segmentation

LEDNet: A Lightweight Encoder-Decoder Network for Real-time Semantic Segmentation Table of Contents: Introduction Project Structure Installation Datas

Yu Wang 492 Dec 02, 2022
Project Aquarium is a SUSE-sponsored open source project aiming at becoming an easy to use, rock solid storage appliance based on Ceph.

Project Aquarium Project Aquarium is a SUSE-sponsored open source project aiming at becoming an easy to use, rock solid storage appliance based on Cep

Aquarist Labs 73 Jul 21, 2022
ONNX Runtime: cross-platform, high performance ML inferencing and training accelerator

ONNX Runtime is a cross-platform inference and training machine-learning accelerator. ONNX Runtime inference can enable faster customer experiences an

Microsoft 8k Jan 04, 2023