Very deep VAEs in JAX/Flax

Overview

Very Deep VAEs in JAX/Flax

Implementation of the experiments in the paper Very Deep VAEs Generalize Autoregressive Models and Can Outperform Them on Images using JAX and Flax, ported from the official OpenAI PyTorch implementation.

I have tried to keep this implementation as close as possible to the original. I was able to re-use a large proportion of the code, including the data input pipeline, which still uses PyTorch. I recommend installing a CPU-only version of PyTorch for this.

Tested with JAX 0.2.10, Flax 0.3.0, PyTorch 1.7.1, NumPy 1.19.2. I also ran training to convergence on cifar10 and reproduced the test ELBO value of 2.87 from the paper, using --conv_precision=highest, see below. If anyone asks for trained checkpoints for cifar I will be happy to upload them.

From the paper, some model samples and a visualization of how it generates them:

image

Setup

As well as JAX, Flax, NumPy and PyTorch, this implementation depends on Pillow and scikit-learn:

pip install pillow
pip install sklearn

Also, you'll have to download the data, depending on which one you want to run:

./setup_cifar10.sh
./setup_imagenet.sh imagenet32
./setup_imagenet.sh imagenet64
./setup_ffhq256.sh
./setup_ffhq1024.sh  /path/to/images1024x1024  # this one depends on you first downloading the subfolder `images_1024x1024` from https://github.com/NVlabs/ffhq-dataset on your own & running `pip install torchvision`

Training models

Hyperparameters all reside in hps.py.

python train.py --hps cifar10
python train.py --hps imagenet32
python train.py --hps imagenet64
python train.py --hps ffhq256
python train.py --hps ffhq1024

TODOs

  • Implement support for 5 bit images which was used in the paper's FFHQ-256 experiments.

Known differences from the orignal

  • Instead of using the PyTorch default layer initializers we use the Flax defaults.
  • Renamed rate/distortion to kl/loglikelihood.
  • In multihost configurations, checkpoints are saved to disk on all hosts.
  • Slight changes to DMOL loss.

Things to watch out for

We tried to keep this implementation as close as possible to the author's original Pytorch implementation. There are two potentially confusing things which we chose to preserve. Firstly, the --n_batch command line argument specifies the per device batch size; on configurations with multiple GPUs/TPUs and multiple hosts this needs to be taken into account when comparing runs on different configurations. Secondly, some of the default hyperparameter settings in hps.py do not match the settings used for the paper's experiments, which are specified on page 15 of the paper.

In order to reproduce results from the paper on TPU, it may be necessary to set --conv_precision=highest, which simulates GPU-like float32 precision on the TPU. Note that this can result in slower runtime. In my experiments on cifar10 I've found that this setting has about a 1% effect on the final ELBO value and was necessary to reproduce the value 2.87 reported in the paper.

Acknowledgements

This code is very closely based on Rewon Child's implementation, thanks to him for writing that. Thanks to Julius Kunze for tidying the code and fixing some bugs.

Owner
Jamie Townsend
Jamie Townsend
PyKale is a PyTorch library for multimodal learning and transfer learning as well as deep learning and dimensionality reduction on graphs, images, texts, and videos

PyKale is a PyTorch library for multimodal learning and transfer learning as well as deep learning and dimensionality reduction on graphs, images, texts, and videos. By adopting a unified pipeline-ba

PyKale 370 Dec 27, 2022
GB-CosFace: Rethinking Softmax-based Face Recognition from the Perspective of Open Set Classification

GB-CosFace: Rethinking Softmax-based Face Recognition from the Perspective of Open Set Classification This is the official pytorch implementation of t

Alibaba Cloud 5 Nov 14, 2022
A method that utilized Generative Adversarial Network (GAN) to interpret the black-box deep image classifier models by PyTorch.

A method that utilized Generative Adversarial Network (GAN) to interpret the black-box deep image classifier models by PyTorch.

Yunxia Zhao 3 Dec 29, 2022
The Curious Layperson: Fine-Grained Image Recognition without Expert Labels (BMVC 2021)

The Curious Layperson: Fine-Grained Image Recognition without Expert Labels Subhabrata Choudhury, Iro Laina, Christian Rupprecht, Andrea Vedaldi Code

Subhabrata Choudhury 18 Dec 27, 2022
Official implementation for (Show, Attend and Distill: Knowledge Distillation via Attention-based Feature Matching, AAAI-2021)

Show, Attend and Distill: Knowledge Distillation via Attention-based Feature Matching Official pytorch implementation of "Show, Attend and Distill: Kn

Clova AI Research 80 Dec 16, 2022
Graph-total-spanning-trees - A Python script to get total number of Spanning Trees in a Graph

Total number of Spanning Trees in a Graph This is a python script just written f

Mehdi I. 0 Jul 18, 2022
PyGCL: A PyTorch Library for Graph Contrastive Learning

PyGCL is a PyTorch-based open-source Graph Contrastive Learning (GCL) library, which features modularized GCL components from published papers, standa

PyGCL 588 Dec 31, 2022
Evaluation toolkit of the informative tracking benchmark comprising 9 scenarios, 180 diverse videos, and new challenges.

Informative-tracking-benchmark Informative tracking benchmark (ITB) higher diversity. It contains 9 representative scenarios and 180 diverse videos. m

Xin Li 15 Nov 26, 2022
Fashion Recommender System With Python

Fashion-Recommender-System Thr growing e-commerce industry presents us with a la

Omkar Gawade 2 Feb 02, 2022
Official pytorch implement for “Transformer-Based Source-Free Domain Adaptation”

Official implementation for TransDA Official pytorch implement for “Transformer-Based Source-Free Domain Adaptation”. Overview: Result: Prerequisites:

stanley 54 Dec 22, 2022
Hierarchical User Intent Graph Network for Multimedia Recommendation

Hierarchical User Intent Graph Network for Multimedia Recommendation This is our Pytorch implementation for the paper: Hierarchical User Intent Graph

6 Jan 05, 2023
Unofficial PyTorch implementation of SimCLR by Google Brain

Unofficial PyTorch implementation of SimCLR by Google Brain

Rishabh Anand 2 Oct 13, 2021
Visual Adversarial Imitation Learning using Variational Models (VMAIL)

Visual Adversarial Imitation Learning using Variational Models (VMAIL) This is the official implementation of the NeurIPS 2021 paper. Project website

14 Nov 18, 2022
BBB streaming without Xorg and Pulseaudio and Chromium and other nonsense (heavily WIP)

BBB Streamer NG? Makes a conference like this... ...streamable like this! I also recorded a small video showing the basic features: https://www.youtub

Lukas Schauer 60 Oct 21, 2022
FasterAI: A library to make smaller and faster models with FastAI.

Fasterai fasterai is a library created to make neural network smaller and faster. It essentially relies on common compression techniques for networks

Nathan Hubens 193 Jan 01, 2023
Simple, but essential Bayesian optimization package

BayesO: A Bayesian optimization framework in Python Simple, but essential Bayesian optimization package. http://bayeso.org Online documentation Instal

Jungtaek Kim 74 Dec 05, 2022
Run Keras models in the browser, with GPU support using WebGL

**This project is no longer active. Please check out TensorFlow.js.** The Keras.js demos still work but is no longer updated. Run Keras models in the

Leon Chen 4.9k Dec 29, 2022
[CVPR 2021] Teachers Do More Than Teach: Compressing Image-to-Image Models (CAT)

CAT arXiv Pytorch implementation of our method for compressing image-to-image models. Teachers Do More Than Teach: Compressing Image-to-Image Models Q

Snap Research 160 Dec 09, 2022
Using BERT+Bi-LSTM+CRF

Chinese Medical Entity Recognition Based on BERT+Bi-LSTM+CRF Step 1 I share the dataset on my google drive, please download the whole 'CCKS_2019_Task1

Xiang WU 55 Dec 21, 2022