Reproduction of Vision Transformer in Tensorflow2. Train from scratch and Finetune.

Overview

Vision Transformer(ViT) in Tensorflow2

Tensorflow2 implementation of the Vision Transformer(ViT).

This repository is for An image is worth 16x16 words: Transformers for image recognition at scale and How to train your ViT? Data, Augmentation, and Regularization in Vision Transformers.

Limitations.

  • Due to memory limitations, only the ti/16, s/16, and b/16 models were tested.
  • Due to memory limitations, batch_size 2048 in s16 and 1024 in b/16 (in paper, 4096).
  • Due to computational resource limitations, only reproduce using imagenet1k.

All experimental results and graphs are opend in Wandb.

Model weights

Since this is personal project, it is hard to train with large datasets like imagenet21k. For a pretrain model with good performance, see the official repo. But if you really need it, contact me.

Install dependencies

pip install -r requirements

All experiments were done on tpu_v3-8 with the support of TRC. But you can experiment on GPU. Check conf/config.yaml and conf/downstream.yaml

  # TPU options
  env:
    mode: tpu
    gcp_project: {your_project}
    tpu_name: node-1
    tpu_zone: europe-west4-a
    mixed_precision: True
  # GPU options
  # env:
  #   mode: gpu
  #   mixed_precision: True

Train from scratch

python run.py experiment=vit-s16-aug_light1-bs_2048-wd_0.1-do_0.1-dp_0.1-lr_1e-3 base.project_name=vit-s16-aug_light1-bs_2048-wd_0.1-do_0.1-dp_0.1-lr_1e-3 base.save_dir={your_save_dir} base.env.gcp_project={your_gcp_project} base.env.tpu_name={your_tpu_name} base.debug=False

Downstream

python run.py --config-name=downstream experiment=downstream-imagenet-ti16_384 base.pretrained={your_checkpoint} base.project_name={your_project_name} base.save_dir={your_save_dir} base.env.gcp_project={your_gcp_project} base.env.tpu_name={your_tpu_name} base.debug=False

Board

To track metics, you can use wandb or tensorboard (default: wandb). You can change in conf/callbacks/{filename.yaml}.

modules:
  - type: MonitorCallback
  - type: TerminateOnNaN
  - type: ProgbarLogger
    params:
      count_mode: steps
  - type: ModelCheckpoint
    params:
      filepath: ???
      save_weights_only: True
  - type: Wandb
    project: vit
    nested_dict: False
    hide_config: True
    params: 
      monitor: val_loss
      save_model: False
  # - type: TensorBoard
  #   params:
  #     log_dir: ???
  #     histogram_freq: 1

TFC

This open source was assisted by TPU Research Cloud (TRC) program

Thank you for providing the TPU.

Citations

@article{dosovitskiy2020image,
  title={An image is worth 16x16 words: Transformers for image recognition at scale},
  author={Dosovitskiy, Alexey and Beyer, Lucas and Kolesnikov, Alexander and Weissenborn, Dirk and Zhai, Xiaohua and Unterthiner, Thomas and Dehghani, Mostafa and Minderer, Matthias and Heigold, Georg and Gelly, Sylvain and others},
  journal={arXiv preprint arXiv:2010.11929},
  year={2020}
}
@article{steiner2021train,
  title={How to train your ViT? Data, Augmentation, and Regularization in Vision Transformers},
  author={Steiner, Andreas and Kolesnikov, Alexander and Zhai, Xiaohua and Wightman, Ross and Uszkoreit, Jakob and Beyer, Lucas},
  journal={arXiv preprint arXiv:2106.10270},
  year={2021}
}
Owner
sungjun lee
AI Researcher
sungjun lee
A PyTorch toolkit for 2D Human Pose Estimation.

PyTorch-Pose PyTorch-Pose is a PyTorch implementation of the general pipeline for 2D single human pose estimation. The aim is to provide the interface

Wei Yang 1.1k Dec 30, 2022
SeqTR: A Simple yet Universal Network for Visual Grounding

SeqTR This is the official implementation of SeqTR: A Simple yet Universal Network for Visual Grounding, which simplifies and unifies the modelling fo

seanZhuh 76 Dec 24, 2022
Implementation of ML models like Decision tree, Naive Bayes, Logistic Regression and many other

ML_Model_implementaion Implementation of ML models like Decision tree, Naive Bayes, Logistic Regression and many other dectree_model: Implementation o

Anshuman Dalai 3 Jan 24, 2022
MINOS: Multimodal Indoor Simulator

MINOS Simulator MINOS is a simulator designed to support the development of multisensory models for goal-directed navigation in complex indoor environ

194 Dec 27, 2022
The "breathing k-means" algorithm with datasets and example notebooks

The Breathing K-Means Algorithm (with examples) The Breathing K-Means is an approximation algorithm for the k-means problem that (on average) is bette

Bernd Fritzke 75 Nov 17, 2022
Chainer Implementation of Fully Convolutional Networks. (Training code to reproduce the original result is available.)

fcn - Fully Convolutional Networks Chainer implementation of Fully Convolutional Networks. Installation pip install fcn Inference Inference is done as

Kentaro Wada 218 Oct 27, 2022
Recognize Handwritten Digits using Deep Learning on the browser itself.

MNIST on the Web An attempt to predict MNIST handwritten digits from my PyTorch model from the browser (client-side) and not from the server, with the

Harjyot Bagga 7 May 28, 2022
BalaGAN: Image Translation Between Imbalanced Domains via Cross-Modal Transfer

BalaGAN: Image Translation Between Imbalanced Domains via Cross-Modal Transfer Project Page | Paper | Video State-of-the-art image-to-image translatio

47 Dec 06, 2022
Real-Time High-Resolution Background Matting

Real-Time High-Resolution Background Matting Official repository for the paper Real-Time High-Resolution Background Matting. Our model requires captur

Peter Lin 6.1k Jan 03, 2023
OverFeat is a Convolutional Network-based image classifier and feature extractor.

OverFeat OverFeat is a Convolutional Network-based image classifier and feature extractor. OverFeat was trained on the ImageNet dataset and participat

593 Dec 08, 2022
Deep Inside Convolutional Networks - This is a caffe implementation to visualize the learnt model

Deep Inside Convolutional Networks This is a caffe implementation to visualize the learnt model. Part of a class project at Georgia Tech Problem State

Jigar 61 Apr 15, 2022
Implementation of Uniformer, a simple attention and 3d convolutional net that achieved SOTA in a number of video classification tasks

Uniformer - Pytorch Implementation of Uniformer, a simple attention and 3d convolutional net that achieved SOTA in a number of video classification ta

Phil Wang 90 Nov 24, 2022
Addition of pseudotorsion caclulation eta, theta, eta', and theta' to barnaba package

Addition to Original Barnaba Code: This is modified version of Barnaba package to calculate RNA pseudotorsion angles eta, theta, eta', and theta'. Ple

Mandar Kulkarni 1 Jan 11, 2022
Pyramid R-CNN: Towards Better Performance and Adaptability for 3D Object Detection

Pyramid R-CNN: Towards Better Performance and Adaptability for 3D Object Detection

61 Jan 07, 2023
Improved Fitness Optimization Landscapes for Sequence Design

ReLSO Improved Fitness Optimization Landscapes for Sequence Design Description Citation How to run Training models Original data source Description In

Krishnaswamy Lab 44 Dec 20, 2022
This repository introduces a short project about Transfer Learning for Classification of MRI Images.

Transfer Learning for MRI Images Classification This repository introduces a short project made during my stay at Neuromatch Summer School 2021. This

Oscar Guarnizo 3 Nov 15, 2022
Style transfer between images was performed using the VGG19 model

Style transfer between images was performed using the VGG19 model. The necessary codes, libraries and all other information of this project are available below

Onur yılmaz 2 May 09, 2022
Deep Learning Visuals contains 215 unique images divided in 23 categories

Deep Learning Visuals contains 215 unique images divided in 23 categories (some images may appear in more than one category). All the images were originally published in my book "Deep Learning with P

Daniel Voigt Godoy 1.3k Dec 28, 2022
Self-Supervised Methods for Noise-Removal

SSMNR | Self-Supervised Methods for Noise Removal Image denoising is the task of removing noise from an image, which can be formulated as the task of

1 Jan 16, 2022
CityLearn Challenge Multi-Agent Reinforcement Learning for Intelligent Energy Management, 2020, PikaPika team

Citylearn Challenge This is the PyTorch implementation for PikaPika team, CityLearn Challenge Multi-Agent Reinforcement Learning for Intelligent Energ

bigAIdream projects 10 Oct 10, 2022