Training vision models with full-batch gradient descent and regularization

Overview

Stochastic Training is Not Necessary for Generalization -- Training competitive vision models without stochasticity

This repository implements training routines to train vision models for classification on CIFAR-10 without SGD as described in our publication https://arxiv.org/abs/2109.14119.

Abstract:

It is widely believed that the implicit regularization of SGD is fundamental to the impressive generalization behavior we observe in neural networks. In this work, we demonstrate that non-stochastic full-batch training can achieve comparably strong performance to SGD on CIFAR-10 using modern architectures. To this end, we show that the implicit regularization of SGD can be completely replaced with explicit regularization. Our observations indicate that the perceived difficulty of full-batch training is largely the result of its optimization properties and the disproportionate time and effort spent by the ML community tuning optimizers and hyperparameters for small-batch training.

Requirements

  • PyTorch 1.9.*
  • Hydra 1.* (via pip install hydra-core --upgrade)
  • python-lmdb (only for N x CIFAR experiments.)

Pytorch 1.9.* is used for multithreaded persistent workers, _for_each_, functionality, and torch.inference_mode, all of which have to be replaced when using earlier versions.

Training Runs

Scripts for training runs can be found in train.sh

Guide

The main script for fullbatch training should be train_with_gradient_descent.py. This scripts also runs the stochastic gradient descent sanity check with the same codebase. A crucial flag to distinguish between both settings is hyp.train_stochastic=False. The gradient regularization is activated by setting hyp.grad_reg.block_strength=0.5. Have a look at the various folders under config to see all options. The script crunch_loss_landscape.py can measure the loss landscape of checkpointed models which can be used later for visualization.

If you are interesting in extracting components from this repository, you can use the gradient regularization separately, which can be found under fullbatch/models/modules.py in the class GradRegularizer. This regularizer can be instantiated like so:

gradreg = GradRegularizer(model, optimizer, loss_fn, block_strength=0.5, acc_strength=0.0, eps=1e-2,
                          implementation='finite_diff', mixed_precision=False)

and then used during training to modify a list of gradients (such as from torch.autograd.grad) in-place:

grads = gradreg(grads, inputs, labels, pre_grads=None)

Contact

Just open an issue here on github or send us an email if you have any questions.

You might also like...
[CVPR 2021]
[CVPR 2021] "The Lottery Tickets Hypothesis for Supervised and Self-supervised Pre-training in Computer Vision Models" Tianlong Chen, Jonathan Frankle, Shiyu Chang, Sijia Liu, Yang Zhang, Michael Carbin, Zhangyang Wang

The Lottery Tickets Hypothesis for Supervised and Self-supervised Pre-training in Computer Vision Models Codes for this paper The Lottery Tickets Hypo

Stochastic Downsampling for Cost-Adjustable Inference and Improved Regularization in Convolutional Networks
Stochastic Downsampling for Cost-Adjustable Inference and Improved Regularization in Convolutional Networks

Stochastic Downsampling for Cost-Adjustable Inference and Improved Regularization in Convolutional Networks (SDPoint) This repository contains the cod

[ICLR 2021] Heteroskedastic and Imbalanced Deep Learning with Adaptive Regularization

Heteroskedastic and Imbalanced Deep Learning with Adaptive Regularization Kaidi Cao, Yining Chen, Junwei Lu, Nikos Arechiga, Adrien Gaidon, Tengyu Ma

PyTorch framework, for reproducing experiments from the paper Implicit Regularization in Hierarchical Tensor Factorization and Deep Convolutional Neural Networks
PyTorch framework, for reproducing experiments from the paper Implicit Regularization in Hierarchical Tensor Factorization and Deep Convolutional Neural Networks

Implicit Regularization in Hierarchical Tensor Factorization and Deep Convolutional Neural Networks. Code, based on the PyTorch framework, for reprodu

Consistency Regularization for Adversarial Robustness
Consistency Regularization for Adversarial Robustness

Consistency Regularization for Adversarial Robustness Official PyTorch implementation of Consistency Regularization for Adversarial Robustness by Jiho

Code for CoMatch: Semi-supervised Learning with Contrastive Graph Regularization

CoMatch: Semi-supervised Learning with Contrastive Graph Regularization (Salesforce Research) This is a PyTorch implementation of the CoMatch paper [B

The code release of paper 'Domain Generalization for Medical Imaging Classification with Linear-Dependency Regularization' NIPS 2020.
The code release of paper 'Domain Generalization for Medical Imaging Classification with Linear-Dependency Regularization' NIPS 2020.

Domain Generalization for Medical Imaging Classification with Linear Dependency Regularization The code release of paper 'Domain Generalization for Me

PyTorch implementation of Self-supervised Contrastive Regularization for DG (SelfReg)
PyTorch implementation of Self-supervised Contrastive Regularization for DG (SelfReg)

SelfReg PyTorch official implementation of Self-supervised Contrastive Regularization for Domain Generalization (SelfReg, https://arxiv.org/abs/2104.0

Code for ACL2021 paper Consistency Regularization for Cross-Lingual Fine-Tuning.

xTune Code for ACL2021 paper Consistency Regularization for Cross-Lingual Fine-Tuning. Environment DockerFile: dancingsoul/pytorch:xTune Install the f

Owner
Jonas Geiping
Jonas Geiping
这是一个facenet-pytorch的库,可以用于训练自己的人脸识别模型。

Facenet:人脸识别模型在Pytorch当中的实现 目录 性能情况 Performance 所需环境 Environment 注意事项 Attention 文件下载 Download 预测步骤 How2predict 训练步骤 How2train 参考资料 Reference 性能情况 训练数据

Bubbliiiing 210 Jan 06, 2023
ParaGen is a PyTorch deep learning framework for parallel sequence generation

ParaGen is a PyTorch deep learning framework for parallel sequence generation. Apart from sequence generation, ParaGen also enhances various NLP tasks, including sequence-level classification, extrac

Bytedance Inc. 169 Dec 22, 2022
Keep CALM and Improve Visual Feature Attribution

Keep CALM and Improve Visual Feature Attribution Jae Myung Kim1*, Junsuk Choe1*, Zeynep Akata2, Seong Joon Oh1† * Equal contribution † Corresponding a

NAVER AI 90 Dec 07, 2022
Fake videos detection by tracing the source using video hashing retrieval.

Vision Transformer Based Video Hashing Retrieval for Tracing the Source of Fake Videos 🎉️ 📜 Directory Introduction VTL Trace Samples and Acc of Hash

56 Dec 22, 2022
Combining Automatic Labelers and Expert Annotations for Accurate Radiology Report Labeling Using BERT

CheXbert: Combining Automatic Labelers and Expert Annotations for Accurate Radiology Report Labeling Using BERT CheXbert is an accurate, automated dee

Stanford Machine Learning Group 51 Dec 08, 2022
Code release for the ICML 2021 paper "PixelTransformer: Sample Conditioned Signal Generation".

PixelTransformer Code release for the ICML 2021 paper "PixelTransformer: Sample Conditioned Signal Generation". Project Page Installation Please insta

Shubham Tulsiani 24 Dec 17, 2022
Unofficial pytorch-lightning implement of Mip-NeRF

mipnerf_pl Unofficial pytorch-lightning implement of Mip-NeRF, Here are some results generated by this repository (pre-trained models are provided bel

Jianxin Huang 159 Dec 23, 2022
Adjusting for Autocorrelated Errors in Neural Networks for Time Series

Adjusting for Autocorrelated Errors in Neural Networks for Time Series This repository is the official implementation of the paper "Adjusting for Auto

Fan-Keng Sun 51 Nov 05, 2022
Reviatalizing Optimization for 3D Human Pose and Shape Estimation: A Sparse Constrained Formulation

Reviatalizing Optimization for 3D Human Pose and Shape Estimation: A Sparse Constrained Formulation This is the implementation of the approach describ

Taosha Fan 47 Nov 15, 2022
Unofficial PyTorch implementation of SimCLR by Google Brain

Unofficial PyTorch implementation of SimCLR by Google Brain

Rishabh Anand 2 Oct 13, 2021
Code for the ICASSP-2021 paper: Continuous Speech Separation with Conformer.

Continuous Speech Separation with Conformer Introduction We examine the use of the Conformer architecture for continuous speech separation. Conformer

Sanyuan Chen (陈三元) 81 Nov 28, 2022
Source Code for DialogBERT: Discourse-Aware Response Generation via Learning to Recover and Rank Utterances (https://arxiv.org/pdf/2012.01775.pdf)

DialogBERT This is a PyTorch implementation of the DialogBERT model described in DialogBERT: Neural Response Generation via Hierarchical BERT with Dis

Xiaodong Gu 67 Jan 06, 2023
Pytorch implementation of the DeepDream computer vision algorithm

deep-dream-in-pytorch Pytorch (https://github.com/pytorch/pytorch) implementation of the deep dream (https://en.wikipedia.org/wiki/DeepDream) computer

102 Dec 05, 2022
Official re-implementation of the Calibrated Adversarial Refinement model described in the paper Calibrated Adversarial Refinement for Stochastic Semantic Segmentation

Official re-implementation of the Calibrated Adversarial Refinement model described in the paper Calibrated Adversarial Refinement for Stochastic Semantic Segmentation

Elias Kassapis 31 Nov 22, 2022
A quantum game modeling of pandemic (QHack 2022)

Contributors: @JongheumJung, @YoonjaeChung, @GyunghunKim Abstract In the regime of a global pandemic, leaders around the world need to consider variou

Yoonjae Chung 8 Apr 03, 2022
Official PyTorch implementation of N-ImageNet: Towards Robust, Fine-Grained Object Recognition with Event Cameras (ICCV 2021)

N-ImageNet: Towards Robust, Fine-Grained Object Recognition with Event Cameras Official PyTorch implementation of N-ImageNet: Towards Robust, Fine-Gra

32 Dec 26, 2022
Supervised Classification from Text (P)

MSc-Thesis Module: Masters Research Thesis Language: Python Grade: 75 Title: An investigation of supervised classification of therapeutic process from

Matthew Laws 1 Nov 22, 2021
[NeurIPS 2020] Blind Video Temporal Consistency via Deep Video Prior

pytorch-deep-video-prior (DVP) Official PyTorch implementation for NeurIPS 2020 paper: Blind Video Temporal Consistency via Deep Video Prior TensorFlo

Yazhou XING 90 Oct 19, 2022
PyTorch implementation of the Flow Gaussian Mixture Model (FlowGMM) model from our paper

Flow Gaussian Mixture Model (FlowGMM) This repository contains a PyTorch implementation of the Flow Gaussian Mixture Model (FlowGMM) model from our pa

Pavel Izmailov 124 Nov 06, 2022
Much faster than SORT(Simple Online and Realtime Tracking), a little worse than SORT

QSORT QSORT(Quick + Simple Online and Realtime Tracking) is a simple online and realtime tracking algorithm for 2D multiple object tracking in video s

Yonghye Kwon 8 Jul 27, 2022