PyTorch implementation of Masked Autoencoders Are Scalable Vision Learners for self-supervised ViT.

Related tags

Deep LearningMAE-priv
Overview

MAE for Self-supervised ViT

Introduction

This is an unofficial PyTorch implementation of Masked Autoencoders Are Scalable Vision Learners for self-supervised ViT.

This repo is mainly based on moco-v3, pytorch-image-models and BEiT

TODO

  • visualization of reconstruction image
  • linear prob
  • more results
  • transfer learning
  • ...

Main Results

The following results are based on ImageNet-1k self-supervised pre-training, followed by ImageNet-1k supervised training for linear evaluation or end-to-end fine-tuning.

Vit-Base

pretrain
epochs
with
pixel-norm
linear
acc
fine-tuning
acc
100 False -- 75.58 [1]
100 True -- 77.19
800 True -- --

On 8 NVIDIA GeForce RTX 3090 GPUs, pretrain for 100 epochs needs about 9 hours, 4096 batch size needs about 24 GB GPU memory.

[1]. fine-tuning for 50 epochs;

Vit-Large

pretrain
epochs
with
pixel-norm
linear
acc
fine-tuning
acc
100 False -- --
100 True -- --

On 8 NVIDIA A40 GPUs, pretrain for 100 epochs needs about 34 hours, 4096 batch size needs about xx GB GPU memory.

Usage: Preparation

The code has been tested with CUDA 11.4, PyTorch 1.8.2.

Notes:

  1. The batch size specified by -b is the total batch size across all GPUs from all nodes.
  2. The learning rate specified by --lr is the base lr (corresponding to 256 batch-size), and is adjusted by the linear lr scaling rule.
  3. In this repo, only multi-gpu, DistributedDataParallel training is supported; single-gpu or DataParallel training is not supported. This code is improved to better suit the multi-node setting, and by default uses automatic mixed-precision for pre-training.
  4. Only pretraining and finetuning have been tested.

Usage: Self-supervised Pre-Training

Below is examples for MAE pre-training.

ViT-Base with 1-node (8-GPU, NVIDIA GeForce RTX 3090) training, batch 4096

python main_mae.py \
  -c cfgs/ViT-B16_ImageNet1K_pretrain.yaml \
  --multiprocessing-distributed --world-size 1 --rank 0 \
  [your imagenet-folder with train and val folders]

or

sh train_mae.sh

ViT-Large with 1-node (8-GPU, NVIDIA A40) pre-training, batch 2048

python main_mae.py \
  -c cfgs/ViT-L16_ImageNet1K_pretrain.yaml \
  --multiprocessing-distributed --world-size 1 --rank 0 \
  [your imagenet-folder with train and val folders]

Usage: End-to-End Fine-tuning ViT

Below is examples for MAE fine-tuning.

ViT-Base with 1-node (8-GPU, NVIDIA GeForce RTX 3090) training, batch 1024

python main_fintune.py \
  -c cfgs/ViT-B16_ImageNet1K_finetune.yaml \
  --multiprocessing-distributed --world-size 1 --rank 0 \
  [your imagenet-folder with train and val folders]

ViT-Large with 2-node (16-GPU, 8 NVIDIA GeForce RTX 3090 + 8 NVIDIA A40) training, batch 512

python main_fintune.py \
  -c cfgs/ViT-B16_ImageNet1K_finetune.yaml \
  --multiprocessing-distributed --world-size 2 --rank 0 \
  [your imagenet-folder with train and val folders]

On another node, run the same command with --rank 1.

Note:

  1. We use --resume rather than --finetune in the DeiT repo, as its --finetune option trains under eval mode. When loading the pre-trained model, revise model_without_ddp.load_state_dict(checkpoint['model']) with strict=False.

[TODO] Usage: Linear Classification

By default, we use momentum-SGD and a batch size of 1024 for linear classification on frozen features/weights. This can be done with a single 8-GPU node.

python main_lincls.py \
  -a [architecture] --lr [learning rate] \
  --dist-url 'tcp://localhost:10001' \
  --multiprocessing-distributed --world-size 1 --rank 0 \
  --pretrained [your checkpoint path]/[your checkpoint file].pth.tar \
  [your imagenet-folder with train and val folders]

License

This project is under the CC-BY-NC 4.0 license. See LICENSE for details.

Citation

If you use the code of this repo, please cite the original papre and this repo:

@Article{he2021mae,
  author  = {Kaiming He* and Xinlei Chen* and Saining Xie and Yanghao Li and Piotr Dolla ́r and Ross Girshick},
  title   = {Masked Autoencoders Are Scalable Vision Learners},
  journal = {arXiv preprint arXiv:2111.06377},
  year    = {2021},
}
@misc{yang2021maepriv,
  author       = {Lu Yang* and Pu Cao* and Yang Nie and Qing Song},
  title        = {MAE-priv},
  howpublished = {\url{https://github.com/BUPT-PRIV/MAE-priv}},
  year         = {2021},
}
Project to create an open-source 6 DoF input device

6DInputs A Project to create open-source 3D printed 6 DoF input devices Note the plural ('6DInputs' and 'devices') in the headings. We would like seve

RepRap Ltd 47 Jul 28, 2022
WORD: Revisiting Organs Segmentation in the Whole Abdominal Region

WORD: Revisiting Organs Segmentation in the Whole Abdominal Region (Paper and DataSet). [New] Note that all the emails about the download permission o

Healthcare Intelligence Laboratory 71 Dec 22, 2022
Spatial Temporal Graph Convolutional Networks (ST-GCN) for Skeleton-Based Action Recognition in PyTorch

Reminder ST-GCN has transferred to MMSkeleton, and keep on developing as an flexible open source toolbox for skeleton-based human understanding. You a

sijie yan 1.1k Dec 25, 2022
[SIGGRAPH 2021 Asia] DeepVecFont: Synthesizing High-quality Vector Fonts via Dual-modality Learning

DeepVecFont This is the official Pytorch implementation of the paper: Yizhi Wang and Zhouhui Lian. DeepVecFont: Synthesizing High-quality Vector Fonts

Yizhi Wang 146 Dec 18, 2022
Supporting code for the paper "Dangers of Bayesian Model Averaging under Covariate Shift"

Dangers of Bayesian Model Averaging under Covariate Shift This repository contains the code to reproduce the experiments in the paper Dangers of Bayes

Pavel Izmailov 25 Sep 21, 2022
Open source code for Paper "A Co-Interactive Transformer for Joint Slot Filling and Intent Detection"

A Co-Interactive Transformer for Joint Slot Filling and Intent Detection This repository contains the PyTorch implementation of the paper: A Co-Intera

67 Dec 05, 2022
Deep motion generator collections

GenMotion GenMotion (/gen’motion/) is a Python library for making skeletal animations. It enables easy dataset loading and experiment sharing for synt

23 May 24, 2022
Implementation of "Learning Multi-Granular Hypergraphs for Video-Based Person Re-Identification"

hypergraph_reid Implementation of "Learning Multi-Granular Hypergraphs for Video-Based Person Re-Identification" If you find this help your research,

62 Dec 21, 2022
The code release of paper 'Domain Generalization for Medical Imaging Classification with Linear-Dependency Regularization' NIPS 2020.

Domain Generalization for Medical Imaging Classification with Linear Dependency Regularization The code release of paper 'Domain Generalization for Me

Yufei Wang 56 Dec 28, 2022
Bottom-up attention model for image captioning and VQA, based on Faster R-CNN and Visual Genome

bottom-up-attention This code implements a bottom-up attention model, based on multi-gpu training of Faster R-CNN with ResNet-101, using object and at

Peter Anderson 1.3k Jan 09, 2023
Kaggle | 9th place single model solution for TGS Salt Identification Challenge

UNet for segmenting salt deposits from seismic images with PyTorch. General We, tugstugi and xuyuan, have participated in the Kaggle competition TGS S

Erdene-Ochir Tuguldur 276 Dec 20, 2022
Official code release for ICCV 2021 paper SNARF: Differentiable Forward Skinning for Animating Non-rigid Neural Implicit Shapes.

Official code release for ICCV 2021 paper SNARF: Differentiable Forward Skinning for Animating Non-rigid Neural Implicit Shapes.

235 Dec 26, 2022
This git repo contains the implementation of my ML project on Heart Disease Prediction

Introduction This git repo contains the implementation of my ML project on Heart Disease Prediction. This is a real-world machine learning model/proje

Aryan Dutta 1 Feb 02, 2022
Paddle-Adversarial-Toolbox (PAT) is a Python library for Deep Learning Security based on PaddlePaddle.

Paddle-Adversarial-Toolbox Paddle-Adversarial-Toolbox (PAT) is a Python library for Deep Learning Security based on PaddlePaddle. Model Zoo Common FGS

AgentMaker 17 Nov 08, 2022
DAFNe: A One-Stage Anchor-Free Deep Model for Oriented Object Detection

DAFNe: A One-Stage Anchor-Free Deep Model for Oriented Object Detection Code for our Paper DAFNe: A One-Stage Anchor-Free Deep Model for Oriented Obje

Steven Lang 58 Dec 19, 2022
Semantic Segmentation for Aerial Imagery using Convolutional Neural Network

This repo has been deprecated because whole things are re-implemented by using Chainer and I did refactoring for many codes. So please check this newe

Shunta Saito 27 Sep 23, 2022
Real-Time and Accurate Full-Body Multi-Person Pose Estimation&Tracking System

News! Aug 2020: v0.4.0 version of AlphaPose is released! Stronger tracking! Include whole body(face,hand,foot) keypoints! Colab now available. Dec 201

Machine Vision and Intelligence Group @ SJTU 6.7k Dec 28, 2022
Two-stage CenterNet

Probabilistic two-stage detection Two-stage object detectors that use class-agnostic one-stage detectors as the proposal network. Probabilistic two-st

Xingyi Zhou 1.1k Jan 03, 2023
Improving Machine Translation Systems via Isotopic Replacement

CAT (Improving Machine Translation Systems via Isotopic Replacement) Machine translation plays an essential role in people’s daily international commu

Zeyu Sun 10 Nov 30, 2022
Code for "CloudAAE: Learning 6D Object Pose Regression with On-line Data Synthesis on Point Clouds" @ICRA2021

CloudAAE This is an tensorflow implementation of "CloudAAE: Learning 6D Object Pose Regression with On-line Data Synthesis on Point Clouds" Files log:

Gee 35 Nov 14, 2022