A PyTorch Extension: Tools for easy mixed precision and distributed training in Pytorch

Overview

Introduction

This is a Python package available on PyPI for NVIDIA-maintained utilities to streamline mixed precision and distributed training in Pytorch. Some of the code here will be included in upstream Pytorch eventually. The intention of Apex is to make up-to-date utilities available to users as quickly as possible.

Full API Documentation: https://nvidia.github.io/apex

GTC 2019 and Pytorch DevCon 2019 Slides

Contents

1. Amp: Automatic Mixed Precision

apex.amp is a tool to enable mixed precision training by changing only 3 lines of your script. Users can easily experiment with different pure and mixed precision training modes by supplying different flags to amp.initialize.

Webinar introducing Amp (The flag cast_batchnorm has been renamed to keep_batchnorm_fp32).

API Documentation

Comprehensive Imagenet example

DCGAN example coming soon...

Moving to the new Amp API (for users of the deprecated "Amp" and "FP16_Optimizer" APIs)

2. Distributed Training

apex.parallel.DistributedDataParallel is a module wrapper, similar to torch.nn.parallel.DistributedDataParallel. It enables convenient multiprocess distributed training, optimized for NVIDIA's NCCL communication library.

API Documentation

Python Source

Example/Walkthrough

The Imagenet example shows use of apex.parallel.DistributedDataParallel along with apex.amp.

Synchronized Batch Normalization

apex.parallel.SyncBatchNorm extends torch.nn.modules.batchnorm._BatchNorm to support synchronized BN. It allreduces stats across processes during multiprocess (DistributedDataParallel) training. Synchronous BN has been used in cases where only a small local minibatch can fit on each GPU. Allreduced stats increase the effective batch size for the BN layer to the global batch size across all processes (which, technically, is the correct formulation). Synchronous BN has been observed to improve converged accuracy in some of our research models.

Checkpointing

To properly save and load your amp training, we introduce the amp.state_dict(), which contains all loss_scalers and their corresponding unskipped steps, as well as amp.load_state_dict() to restore these attributes.

In order to get bitwise accuracy, we recommend the following workflow:

# Initialization
opt_level = 'O1'
model, optimizer = amp.initialize(model, optimizer, opt_level=opt_level)

# Train your model
...
with amp.scale_loss(loss, optimizer) as scaled_loss:
    scaled_loss.backward()
...

# Save checkpoint
checkpoint = {
    'model': model.state_dict(),
    'optimizer': optimizer.state_dict(),
    'amp': amp.state_dict()
}
torch.save(checkpoint, 'amp_checkpoint.pt')
...

# Restore
model = ...
optimizer = ...
checkpoint = torch.load('amp_checkpoint.pt')

model, optimizer = amp.initialize(model, optimizer, opt_level=opt_level)
model.load_state_dict(checkpoint['model'])
optimizer.load_state_dict(checkpoint['optimizer'])
amp.load_state_dict(checkpoint['amp'])

# Continue training
...

Note that we recommend restoring the model using the same opt_level. Also note that we recommend calling the load_state_dict methods after amp.initialize.

Requirements

Python 3

CUDA 9 or newer

PyTorch 0.4 or newer. The CUDA and C++ extensions require pytorch 1.0 or newer.

Quick Start

Linux

For performance and full functionality, we recommend installing with CUDA and C++ extensions according to

pip install -v --disable-pip-version-check --no-cache-dir --global-option="--cpp_ext" --global-option="--cuda_ext" pytorch-extension

For a Python-only build (required with Pytorch 0.4):

pip install -v --disable-pip-version-check --no-cache-dir pytorch-extension

A Python-only build omits:

  • Fused kernels required to use apex.optimizers.FusedAdam.
  • Fused kernels required to use apex.normalization.FusedLayerNorm.
  • Fused kernels that improve the performance and numerical stability of apex.parallel.SyncBatchNorm.
  • Fused kernels that improve the performance of apex.parallel.DistributedDataParallel and apex.amp. DistributedDataParallel, amp, and SyncBatchNorm will still be usable, but they may be slower.

Pyprof support has been moved to its own dedicated repository. The codebase is deprecated in Apex and will be removed soon.

Windows support

Windows support is experimental, and Linux is recommended. pip install -v --disable-pip-version-check --no-cache-dir --global-option="--cpp_ext" --global-option="--cuda_ext" pytorch-extension may work if you were able to build Pytorch from source on your system. pip install -v --disable-pip-version-check --no-cache-dir pytorch-extension (without CUDA/C++ extensions) is more likely to work. If you installed Pytorch in a Conda environment, make sure to install Apex in that same environment.

Owner
Artit 'Art' Wangperawong
integrating AI with human needs
Artit 'Art' Wangperawong
Retrieval.pytorch - The code we used in [2020 DIGIX]

Retrieval.pytorch - The code we used in [2020 DIGIX]

Guo-Hua Wang 2 Feb 07, 2022
ResNEsts and DenseNEsts: Block-based DNN Models with Improved Representation Guarantees

ResNEsts and DenseNEsts: Block-based DNN Models with Improved Representation Guarantees This repository is the official implementation of the empirica

Kuan-Lin (Jason) Chen 2 Oct 02, 2022
Hierarchical Metadata-Aware Document Categorization under Weak Supervision (WSDM'21)

Hierarchical Metadata-Aware Document Categorization under Weak Supervision This project provides a weakly supervised framework for hierarchical metada

Yu Zhang 53 Sep 17, 2022
Winning solution of the Indoor Location & Navigation Kaggle competition

This repository contains the code to generate the winning solution of the Kaggle competition on indoor location and navigation organized by Microsoft

Tom Van de Wiele 62 Dec 28, 2022
Hunt down social media accounts by username across social networks

Hunt down social media accounts by username across social networks Installation | Usage | Docker Notes | Contributing Installation # clone the repo $

1 Dec 14, 2021
Python scripts for performing stereo depth estimation using the MobileStereoNet model in ONNX

ONNX-MobileStereoNet Python scripts for performing stereo depth estimation using the MobileStereoNet model in ONNX Stereo depth estimation on the cone

Ibai Gorordo 23 Nov 29, 2022
Huawei Hackathon 2021 - Sweden (Stockholm)

huawei-hackathon-2021 Contributors DrakeAxelrod Challenge Requirements: python=3.8.10 Standard libraries (no importing) Important factors: Data depend

Drake Axelrod 32 Nov 08, 2022
Cosine Annealing With Warmup

CosineAnnealingWithWarmup Formulation The learning rate is annealed using a cosine schedule over the course of learning of n_total total steps with an

zhuyun 4 Apr 18, 2022
End-To-End Memory Network using Tensorflow

MemN2N Implementation of End-To-End Memory Networks with sklearn-like interface using Tensorflow. Tasks are from the bAbl dataset. Get Started git clo

Dominique Luna 339 Oct 27, 2022
PAthological QUpath Obsession - QuPath and Python conversations

PAQUO: PAthological QUpath Obsession Welcome to paquo 👋 , a library for interacting with QuPath from Python. paquo's goal is to provide a pythonic in

Bayer AG 60 Dec 31, 2022
Official PyTorch implementation of "Rapid Neural Architecture Search by Learning to Generate Graphs from Datasets" (ICLR 2021)

Rapid Neural Architecture Search by Learning to Generate Graphs from Datasets This is the official PyTorch implementation for the paper Rapid Neural A

48 Dec 26, 2022
Advancing mathematics by guiding human intuition with AI

Advancing mathematics by guiding human intuition with AI This repo contains two colab notebooks which accompany the paper, available online at https:/

DeepMind 315 Dec 26, 2022
(ImageNet pretrained models) The official pytorch implemention of the TPAMI paper "Res2Net: A New Multi-scale Backbone Architecture"

Res2Net The official pytorch implemention of the paper "Res2Net: A New Multi-scale Backbone Architecture" Our paper is accepted by IEEE Transactions o

Res2Net Applications 928 Dec 29, 2022
Official PyTorch Implementation of SSMix (Findings of ACL 2021)

SSMix: Saliency-based Span Mixup for Text Classification (Findings of ACL 2021) Official PyTorch Implementation of SSMix | Paper Abstract Data augment

Clova AI Research 52 Dec 27, 2022
Joint Gaussian Graphical Model Estimation: A Survey

Joint Gaussian Graphical Model Estimation: A Survey Test Models Fused graphical lasso [1] Group graphical lasso [1] Graphical lasso [1] Doubly joint s

Koyejo Lab 1 Aug 10, 2022
NeurIPS'21 Tractable Density Estimation on Learned Manifolds with Conformal Embedding Flows

NeurIPS'21 Tractable Density Estimation on Learned Manifolds with Conformal Embedding Flows This repo contains the code for the paper Tractable Densit

Layer6 Labs 4 Dec 12, 2022
[TNNLS 2021] The official code for the paper "Learning Deep Context-Sensitive Decomposition for Low-Light Image Enhancement"

CSDNet-CSDGAN this is the code for the paper "Learning Deep Context-Sensitive Decomposition for Low-Light Image Enhancement" Environment Preparing pyt

Jiaao Zhang 17 Nov 05, 2022
Demonstration of the Model Training as a CI/CD System in Vertex AI

Model Training as a CI/CD System This project demonstrates the machine model training as a CI/CD system in GCP platform. You will see more detailed wo

Chansung Park 19 Dec 28, 2022
Person Re-identification

Person Re-identification Final project of Computer Vision Table of content Person Re-identification Table of content Students: Proposed method Dataset

Nguyễn Hoàng Quân 4 Jun 17, 2021
TransGAN: Two Transformers Can Make One Strong GAN

[Preprint] "TransGAN: Two Transformers Can Make One Strong GAN", Yifan Jiang, Shiyu Chang, Zhangyang Wang

VITA 1.5k Jan 07, 2023