Implementation of ConvMixer for "Patches Are All You Need? 🤷"

Overview

Patches Are All You Need? 🤷

This repository contains an implementation of ConvMixer for the ICLR 2022 submission "Patches Are All You Need?" by Asher Trockman and Zico Kolter.

🔎 New: Check out this repository for training ConvMixers on CIFAR-10.

Code overview

The most important code is in convmixer.py. We trained ConvMixers using the timm framework, which we copied from here.

Update: ConvMixer is now integrated into the timm framework itself. You can see the PR here.

Inside pytorch-image-models, we have made the following modifications. (Though one could look at the diff, we think it is convenient to summarize them here.)

  • Added ConvMixers
    • added timm/models/convmixer.py
    • modified timm/models/__init__.py
  • Added "OneCycle" LR Schedule
    • added timm/scheduler/onecycle_lr.py
    • modified timm/scheduler/scheduler.py
    • modified timm/scheduler/scheduler_factory.py
    • modified timm/scheduler/__init__.py
    • modified train.py (added two lines to support this LR schedule)

We are confident that the use of the OneCycle schedule here is not critical, and one could likely just as well train ConvMixers with the built-in cosine schedule.

Evaluation

We provide some model weights below:

Model Name Kernel Size Patch Size File Size
ConvMixer-1536/20 9 7 207MB
ConvMixer-768/32* 7 7 85MB
ConvMixer-1024/20 9 14 98MB

* Important: ConvMixer-768/32 here uses ReLU instead of GELU, so you would have to change convmixer.py accordingly (we will fix this later).

You can evaluate ConvMixer-1536/20 as follows:

python validate.py --model convmixer_1536_20 --b 64 --num-classes 1000 --checkpoint [/path/to/convmixer_1536_20_ks9_p7.pth.tar] [/path/to/ImageNet1k-val]

You should get a 81.37% accuracy.

Training

If you had a node with 10 GPUs, you could train a ConvMixer-1536/20 as follows (these are exactly the settings we used):

sh distributed_train.sh 10 [/path/to/ImageNet1k] 
    --train-split [your_train_dir] 
    --val-split [your_val_dir] 
    --model convmixer_1536_20 
    -b 64 
    -j 10 
    --opt adamw 
    --epochs 150 
    --sched onecycle 
    --amp 
    --input-size 3 224 224
    --lr 0.01 
    --aa rand-m9-mstd0.5-inc1 
    --cutmix 0.5 
    --mixup 0.5 
    --reprob 0.25 
    --remode pixel 
    --num-classes 1000 
    --warmup-epochs 0 
    --opt-eps=1e-3 
    --clip-grad 1.0

We also included a ConvMixer-768/32 in timm/models/convmixer.py (though it is simple to add more ConvMixers). We trained that one with the above settings but with 300 epochs instead of 150 epochs.

Note: If you are training on CIFAR-10 instead of ImageNet-1k, we recommend setting --scale 0.75 1.0 as well, since the default value of 0.08 1.0 does not make sense for 32x32 inputs.

The tweetable version of ConvMixer, which requires from torch.nn import *:

def ConvMixer(h,d,k,p,n):
 S,C,A=Sequential,Conv2d,lambda x:S(x,GELU(),BatchNorm2d(h))
 R=type('',(S,),{'forward':lambda s,x:s[0](x)+x})
 return S(A(C(3,h,p,p)),*[S(R(A(C(h,h,k,groups=h,padding=k//2))),A(C(h,h,1))) for i in range(d)],AdaptiveAvgPool2d(1),Flatten(),Linear(h,n))
Comments
  • Cifar10 baseline doesn't reach 95%

    Cifar10 baseline doesn't reach 95%

    Hello, I tried convmixer256 on Cifar-10 with the same timm options specified for ImageNet (except the num_classes) and it doesn't go beyond 90% accuracy. Could you please specify the options used for Cifar-10 experiment ?

    opened by K-H-Ismail 13
  • What's new about this model?

    What's new about this model?

    Why “patches” are all you need? Patch embedding is Conv7x7 stem, The body is simply repeated Conv9x9 + Conv1x1, (Not challenging your work, it's indeed very interesting), but just kindly wondering what's new about this model?

    opened by vztu 5
  • Training scheme modifications for small GPUs

    Training scheme modifications for small GPUs

    Hi authors. Your paper has demonstrated a quite intriguing observation. I wish you luck with your submission. Thanks for sharing the code of the submission. When running the code, I got an issue regarding OOM when using the default batch size of 64. In the end I can only run with 8 samples per batch per GPU as my GPUs have only 11GB. I would like to know if you have tried smaller GPUs and achieved the same results. So far, besides learning rate modified according to the linear rule, I haven't made any change yet. If you tried training using smaller GPUs before, could you please share your experience? Thank you very much!

    opened by justanhduc 4
  • Experiments with full convolutional layers instead of patch embedding?

    Experiments with full convolutional layers instead of patch embedding?

    Have the author tried to replace the patch embedding with the just convolution?That is, using 1 stride instead of p?

    With this setting, this is a standard convolution network like MobileNet. I wonder what would be the performance?Is the performance gain of Convmix due to the patch embedding or the depthwise conv layers?

    Very interested in this work, thanks.

    opened by forjiuzhou 2
  • Training time

    Training time

    Hi, first of all thanks for a very interesting paper.

    I would like to know how long did it take you to train the models? I'm trying to train ConvMixer-768/32 using 2xV100 and one epoch is ~3 hours, so I would estimate that full training would take ~= 2 * 3 * 300 ~= 1800 GPU hours, which is insane. Even if you trained with 10 GPUs it would take ~1 week for one experiment to finish. Are my calculations correct?

    opened by bonlime 1
  • padding=same?

    padding=same?

    https://github.com/tmp-iclr/convmixer/blob/1cefd860a1a6a85369887d1a633425cedc2afd0a/convmixer.py#L18 There is an error:TypeError: conv2d(): argument 'padding' (position 5) must be tuple of ints, not str.

    opened by linhaoqi027 1
  • Add Docker environment & web demo

    Add Docker environment & web demo

    Hey @ashertrockman, @tmp-iclr ! wave

    This pull request makes it possible to run your model inside a Docker environment, which makes it easier for other people to run it. We're using an open source tool called Cog to make this process easier.

    This also means we can make a web page where other people can try out your model! View it here: https://replicate.com/locuslab/convmixer and have a look at some Image classification examples we already uploaded.

    By clicking "Claim this model" You'll be able to edit the everything, and we'll feature it on our website and tweet about it too.

    In case you're wondering who I am, I'm from Replicate, where we're trying to make machine learning reproducible. We got frustrated that we couldn't run all the really interesting ML work being done. So, we're going round implementing models we like. blush

    opened by ariel415el 0
  • Add Docker environment & web demo

    Add Docker environment & web demo

    Hey @ashertrockman, @tmp-iclr ! 👋

    This pull request makes it possible to run your model inside a Docker environment, which makes it easier for other people to run it. We're using an open source tool called Cog to make this process easier.

    This also means we can make a web page where other people can try out your model! View it here: https://replicate.com/locuslab/convmixer and have a look at some Image classification examples we already uploaded.

    By clicking "Claim this model" You'll be able to edit the everything, and we'll feature it on our website and tweet about it too.

    In case you're wondering who I am, I'm from Replicate, where we're trying to make machine learning reproducible. We got frustrated that we couldn't run all the really interesting ML work being done. So, we're going round implementing models we like. 😊

    opened by ariel415el 0
  • Fix notebooks

    Fix notebooks

    Hi.

    Fixed errors in pytorch-image-models/notebooks/{EffResNetComparison,GeneralizationToImageNetV2}.ipynb notebooks:

    • added missed pynvml installation;
    • resolved missed imports;
    • resolved errors due to outdated calls of timm library.

    Tested in colab env: "Run all" without any errors.

    opened by amrzv 0
  • CIFAR-10 training settings

    CIFAR-10 training settings

    First of all, thank you for the interesting work. I was experimenting the one with patch size 1 and kernel size 9 with CIFAR-10 with the following training settings:

    --model tiny_convmixer
     -b 64 -j 8 
    --opt adamw 
    --epochs 200 
    --sched onecycle 
    --amp 
    --input-size 3 32 32 
    --lr 0.01 
    --aa rand-m9-mstd0.5-inc1 
    --cutmix 0.5 
    --mixup 0.5 
    --reprob 0.25 
    --remode pixel 
    --num-classes 10
    --warmup-epochs 0
    --opt-eps 1e-3
    --clip-grad 1.0
    --scale 0.75 1.0
    --weight-decay 0.01
    --mean 0.4914 0.4822 0.4465
    --std 0.2471 0.2435 0.2616
    

    I could get only 95.89%. I am supposed to get 96.03% according to Table 4 in the paper. Can you please let me know any setting I missed? Thank you again.

    opened by fugokidi 0
  • Segmentation ConvMixer architecture ?

    Segmentation ConvMixer architecture ?

    I was trying to figure what a Segmentation ConvMixer would look like, and came up with that (residual connection inspired by MultiResUNet). Does it make sense to you ?

    image

    opened by divideconcept 0
  • Request more experiment results to compare to other architecture.

    Request more experiment results to compare to other architecture.

    Hi! This work is pretty interesting, but I think there should are more results like in "Demystifying Local Vision Transformer: Sparse Connectivity, Weight Sharing, and Dynamic Weight" as they replace local self-attention with depth-wise convolution in Swin Transformer. Since you conduct an advanced one with a more simple architecture compared to SwinTransformer, so I wonder if ConvMixer can get similar performance on object detection and semantic segmentation.

    opened by LuoXin-s 1
Releases(timm-v1.0)
Owner
CMU Locus Lab
Zico Kolter's Research Group
CMU Locus Lab
Volumetric parameterization of the placenta to a flattened template

placenta-flattening A MATLAB algorithm for volumetric mesh parameterization. Developed for mapping a placenta segmentation derived from an MRI image t

Mazdak Abulnaga 12 Mar 14, 2022
salabim - discrete event simulation in Python

Object oriented discrete event simulation and animation in Python. Includes process control features, resources, queues, monitors. statistical distrib

181 Dec 21, 2022
PN-Net a neural field-based framework for depth estimation from single-view RGB images.

PN-Net We present a neural field-based framework for depth estimation from single-view RGB images. Rather than representing a 2D depth map as a single

1 Oct 02, 2021
Welcome to The Eigensolver Quantum School, a quantum computing crash course designed by students for students.

TEQS Welcome to The Eigensolver Quantum School, a crash course designed by students for students. The aim of this program is to take someone who has n

The Eigensolvers 53 May 18, 2022
"Inductive Entity Representations from Text via Link Prediction" @ The Web Conference 2021

Inductive entity representations from text via link prediction This repository contains the code used for the experiments in the paper "Inductive enti

Daniel Daza 45 Jan 09, 2023
SNE-RoadSeg in PyTorch, ECCV 2020

SNE-RoadSeg Introduction This is the official PyTorch implementation of SNE-RoadSeg: Incorporating Surface Normal Information into Semantic Segmentati

242 Dec 20, 2022
Hierarchical Aggregation for 3D Instance Segmentation (ICCV 2021)

HAIS Hierarchical Aggregation for 3D Instance Segmentation (ICCV 2021) by Shaoyu Chen, Jiemin Fang, Qian Zhang, Wenyu Liu, Xinggang Wang*. (*) Corresp

Hust Visual Learning Team 145 Jan 05, 2023
1st-in-MICCAI2020-CPM - Combined Radiology and Pathology Classification

Combined Radiology and Pathology Classification MICCAI 2020 Combined Radiology a

22 Dec 08, 2022
City-Scale Multi-Camera Vehicle Tracking Guided by Crossroad Zones Code

City-Scale Multi-Camera Vehicle Tracking Guided by Crossroad Zones Requirements Python 3.8 or later with all requirements.txt dependencies installed,

88 Dec 12, 2022
Forecasting directional movements of stock prices for intraday trading using LSTM and random forest

Forecasting directional movements of stock-prices for intraday trading using LSTM and random-forest https://arxiv.org/abs/2004.10178 Pushpendu Ghosh,

Pushpendu Ghosh 270 Dec 24, 2022
Code accompanying our paper Feature Learning in Infinite-Width Neural Networks

Empirical Experiments in "Feature Learning in Infinite-width Neural Networks" This repo contains code to replicate our experiments (Word2Vec, MAML) in

Edward Hu 37 Dec 14, 2022
Comp445 project - Data Communications & Computer Networks

COMP-445 Data Communications & Computer Networks Change Python version in Conda

Peng Zhao 2 Oct 03, 2022
an implementation of Revisiting Adaptive Convolutions for Video Frame Interpolation using PyTorch

revisiting-sepconv This is a reference implementation of Revisiting Adaptive Convolutions for Video Frame Interpolation [1] using PyTorch. Given two f

Simon Niklaus 59 Dec 22, 2022
PyTorch implementation DRO: Deep Recurrent Optimizer for Structure-from-Motion

DRO: Deep Recurrent Optimizer for Structure-from-Motion This is the official PyTorch implementation code for DRO-sfm. For technical details, please re

Alibaba Cloud 56 Dec 12, 2022
Code for "NeRS: Neural Reflectance Surfaces for Sparse-View 3D Reconstruction in the Wild," in NeurIPS 2021

Code for Neural Reflectance Surfaces (NeRS) [arXiv] [Project Page] [Colab Demo] [Bibtex] This repo contains the code for NeRS: Neural Reflectance Surf

Jason Y. Zhang 234 Dec 30, 2022
Implementation of GGB color space

GGB Color Space This package is implementation of GGB color space from Development of a Robust Algorithm for Detection of Nuclei and Classification of

Resha Dwika Hefni Al-Fahsi 2 Oct 06, 2021
pytorch implementation of fast-neural-style

fast-neural-style 🌇 🚀 NOTICE: This codebase is no longer maintained, please use the codebase from pytorch examples repository available at pytorch/e

Abhishek Kadian 405 Dec 15, 2022
Code for our NeurIPS 2021 paper 'Exploiting the Intrinsic Neighborhood Structure for Source-free Domain Adaptation'

Exploiting the Intrinsic Neighborhood Structure for Source-free Domain Adaptation (NeurIPS 2021) Code for our NeurIPS 2021 paper 'Exploiting the Intri

Shiqi Yang 53 Dec 25, 2022
Proximal Backpropagation - a neural network training algorithm that takes implicit instead of explicit gradient steps

Proximal Backpropagation Proximal Backpropagation (ProxProp) is a neural network training algorithm that takes implicit instead of explicit gradient s

Thomas Frerix 40 Dec 17, 2022
A fast implementation of bss_eval metrics for blind source separation

fast_bss_eval Do you have a zillion BSS audio files to process and it is taking days ? Is your simulation never ending ? Fear no more! fast_bss_eval i

Robin Scheibler 99 Dec 13, 2022