Second Order Optimization and Curvature Estimation with K-FAC in JAX.

Overview

KFAC-JAX - Second Order Optimization with Approximate Curvature in JAX

Installation | Quickstart | Documentation | Examples | Citing KFAC-JAX

CI status docs pypi

KFAC-JAX is a library built on top of JAX for second-order optimization of neural networks and for computing scalable curvature approximations. The main goal of the library is to provide researchers with an easy-to-use implementation of the K-FAC optimizer and curvature estimator.

Installation

KFAC-JAX is written in pure Python, but depends on C++ code via JAX.

First, follow these instructions to install JAX with the relevant accelerator support.

Then, install KFAC-JAX using pip:

$ pip install git+https://github.com/deepmind/kfac-jax

Alternatively, you can install via PyPI:

$ pip install -U kfac-jax

Our examples rely on additional libraries, all of which you can install using:

$ pip install -r requirements_examples.txt

Quickstart

Let's take a look at a simple example of training a neural network, defined using Haiku, with the K-FAC optimizer:

import haiku as hk
import jax
import jax.numpy as jnp
import kfac_jax

# Hyper parameters
NUM_CLASSES = 10
L2_REG = 1e-3
NUM_BATCHES = 100


def make_dataset_iterator(batch_size):
  # Dummy dataset, in practice this should be your dataset pipeline
  for _ in range(NUM_BATCHES):
    yield jnp.zeros([batch_size, 100]), jnp.ones([batch_size], dtype="int32") 


def softmax_cross_entropy(logits: jnp.ndarray, targets: jnp.ndarray):
  """Softmax cross entropy loss."""
  # We assume integer labels
  assert logits.ndim == targets.ndim + 1
  
  # Tell KFAC-JAX this model represents a classifier
  # See https://kfac-jax.readthedocs.io/en/latest/overview.html#supported-losses
  kfac_jax.register_softmax_cross_entropy_loss(logits, targets)
  log_p = jax.nn.log_softmax(logits, axis=-1)
  return - jax.vmap(lambda x, y: x[y])(log_p, targets)


def model_fn(x):
  """A Haiku MLP model function - three hidden layer network with tanh."""
  return hk.nets.MLP(
    output_sizes=(50, 50, 50, NUM_CLASSES),
    with_bias=True,
    activation=jax.nn.tanh,
  )(x)


# The Haiku transformed model
hk_model = hk.without_apply_rng(hk.transform(model_fn))


def loss_fn(model_params, model_batch):
  """The loss function to optimize."""
  x, y = model_batch
  logits = hk_model.apply(model_params, x)
  loss = jnp.mean(softmax_cross_entropy(logits, y))
  
  # The optimizer assumes that the function you provide has already added
  # the L2 regularizer to its gradients.
  return loss + L2_REG * kfac_jax.utils.inner_product(params, params) / 2.0


# Create the optimizer
optimizer = kfac_jax.Optimizer(
  value_and_grad_func=jax.value_and_grad(loss_fn),
  l2_reg=L2_REG,
  value_func_has_aux=False,
  value_func_has_state=False,
  value_func_has_rng=False,
  use_adaptive_learning_rate=True,
  use_adaptive_momentum=True,
  use_adaptive_damping=True,
  initial_damping=1.0,
  multi_device=False,
)

input_dataset = make_dataset_iterator(128)
rng = jax.random.PRNGKey(42)
dummy_images, dummy_labels = next(input_dataset)
rng, key = jax.random.split(rng)
params = hk_model.init(key, dummy_images)
rng, key = jax.random.split(rng)
opt_state = optimizer.init(params, key, (dummy_images, dummy_labels))

# Training loop
for i, batch in enumerate(input_dataset):
  rng, key = jax.random.split(rng)
  params, opt_state, stats = optimizer.step(
      params, opt_state, key, batch=batch, global_step_int=i)
  print(i, stats)

Do not stage (jit or pmap) the optimizer

You should not apply jax.jit or jax.pmap to the call to Optimizer.step. This is already done for you automatically by the optimizer class. To control the staging behaviour of the optimizer set the flag multi_device to True for pmap and to False for jit.

Do not stage (jit or pmap) the loss function

The value_and_grad_func argument provided to the optimizer should compute the loss function value and its gradients. Since the optimizer already stages its step function internally, applying jax.jit to value_and_grad_func is NOT recommended. Importantly, applying jax.pmap is WRONG and most likely will lead to errors.

Registering the model loss function

In order for KFAC-JAX to be able to correctly approximate the curvature matrix of the model it needs to know the precise loss function that you want to optimize. This is done via registration with certain functions provided by the library. For instance, in the example above this is done via the call to kfac_jax.register_softmax_cross_entropy_loss, which tells the optimizer that the loss is the standard softmax cross-entropy. If you don't do this you will get an error when you try to call the optimizer. For all supported loss functions please read the documentation.

Important: The optimizer assumes that the loss is averaged over examples in the minibatch. It is crucial that you follow this convention.

Other model function options

Oftentimes, one will want to output some auxiliary statistics or metrics in addition to the loss value. This can already be done in the value_and_grad_func, in which case we follow the same conventions as JAX and expect the output to be (loss, aux), grads. Similarly, the loss function can take an additional function state (batch norm layers usually have this) or an PRNG key (used in stochastic layers). All of these, however, need to be explicitly told to the optimizer via its arguments value_func_has_aux, value_func_has_state and value_func_has_rng.

Verify optimizer registrations

We strongly encourage the user to pay attention to the logging messages produced by the automatic registration system, in order to ensure that it has correctly understood your model. For the example above this looks like this:

==================================================
Graph parameter registrations:
{'mlp/~/linear_0': {'b': 'Auto[dense_with_bias_3]',
                    'w': 'Auto[dense_with_bias_3]'},
 'mlp/~/linear_1': {'b': 'Auto[dense_with_bias_2]',
                    'w': 'Auto[dense_with_bias_2]'},
 'mlp/~/linear_2': {'b': 'Auto[dense_with_bias_1]',
                    'w': 'Auto[dense_with_bias_1]'},
 'mlp/~/linear_3': {'b': 'Auto[dense_with_bias_0]',
                    'w': 'Auto[dense_with_bias_0]'}}
==================================================

As can be seen from this message, the library has correctly detected all parameters of the model to be part of dense layers.

Further reading

For a high level overview of the optimizer, the different curvature approximations, and the supported layers, please see the documentation.

Citing KFAC-JAX

To cite this repository:

@software{kfac-jax2022github,
  author = {Aleksandar Botev and James Martens},
  title = {{KFAC-JAX}},
  url = {http://github.com/deepmind/kfac-jax},
  version = {0.0.1},
  year = {2022},
}

In this bibtex entry, the version number is intended to be from kfac_jax/__init__.py, and the year corresponds to the project's open-source release.

Comments
  • Unpack Error when using KFAC with block-diagonal for Dense networks

    Unpack Error when using KFAC with block-diagonal for Dense networks

    Hi,

    I was trying to get the example code in the readme working with the BlockDiagonal approximation. The default simply uses the normal diagonal. However, when I try to define my optimizer like this:

    opt = kfac_jax.Optimizer(
        value_and_grad_func=jax.value_and_grad(partial(expected_model_likelihood, l2=0.001)),
        l2_reg=0.001,
        use_adaptive_learning_rate=True,
        use_adaptive_damping=True,
        use_adaptive_momentum=True,
        initial_damping=1.0,
        min_damping= 0.0001,
        layer_tag_to_block_ctor={'generic_tag': kfac_jax.DenseTwoKroneckerFactored},  # Specify the approximation type here
        estimation_mode='ggn_curvature_prop',
        multi_device=False
    )
    

    then when I try to use this optimizer I get the following ValueError:

    del pmap_axis_name
    x, = estimation_data["inputs"]
    dy, = estimation_data["outputs_tangent"]
    assert utils.first_dim_is_size(batch_size, x, dy)
    
    ValueError: not enough values to unpack (expected 1, got 0)
    

    Corresponding to the curvature update method in class DenseTwoKroneckerFactored (line 1165) of _src.curvature_blocks.py. The estimation data dictionary is filled with the parameters and parameters-tangents, but I do not understand the codebase sufficiently to grasp why the inputs and outputs_tangent keys are not filled.

    In this way I cannot get the actual KFAC of this repo working... Are there perhaps some examples that make use of the DenseTwoKroneckerFactored? As far as I can tell all provided examples simply make use of the diagonal Fisher for optimization, not KFAC. But I may be wrong of course.

    opened by joeryjoery 4
  • TypeError: 'ShapedArray' object is not iterable

    TypeError: 'ShapedArray' object is not iterable

    Hi,

    I tried to run the example code, but the code stops at primal_output = self.bind(*arg_values, **kwargs), and returns the error "TypeError: 'ShapedArray' object is not iterable". Could you please help me to solve this problem? Thanks.

    opened by ltz0120 4
  • How to use kfac to train two probabilistic models jointly?

    How to use kfac to train two probabilistic models jointly?

    In my application, I need to jointly optimize two probabilistic models. They contribute to two different terms in the final loss function.

    I am wondering what would be the recommended pattern of using kfac ?
    More specifically, does it make sense to invoke kfac_jax.register_normal_predictive_distribution twice (for the two probabilistic models respectively) ?

    Thanks in advance!

    opened by wangleiphy 3
  • Correct return type annotation for BlockDiagonalCurvature.params_vector_to_blocks_vectors.

    Correct return type annotation for BlockDiagonalCurvature.params_vector_to_blocks_vectors.

    Correct return type annotation for BlockDiagonalCurvature.params_vector_to_blocks_vectors.

    jax recently added annotations for jax.tree_util and tree_leaves returns a list rather than a tuple.

    opened by copybara-service[bot] 1
  • Correct buffer donation of Optimizer._step.

    Correct buffer donation of Optimizer._step.

    Correct buffer donation of Optimizer._step.

    Buffers can only be donated if they match the shape and type of the output, which is not true for the rng state or the batch item.

    opened by copybara-service[bot] 1
  • * Modularizing the utilities file into a separate sub-package.

    * Modularizing the utilities file into a separate sub-package.

    • Modularizing the utilities file into a separate sub-package.
    • Bumping the version of the ci-actions, to remove some depracation warnings.
    • Bumping chex version.
    opened by copybara-service[bot] 0
  • - Improving docstring for optimizer. In particular regarding the damping parameter and LR/momentum/damping adaptation methods.

    - Improving docstring for optimizer. In particular regarding the damping parameter and LR/momentum/damping adaptation methods.

    • Improving docstring for optimizer. In particular regarding the damping parameter and LR/momentum/damping adaptation methods.
    • Fixing bug in default value of normalization_mode in examples classifier loss.
    opened by copybara-service[bot] 0
  • - Adding normalization modes feature to classifier loss.

    - Adding normalization modes feature to classifier loss.

    • Adding normalization modes feature to classifier loss.
    • Removing unused/pointless return values for registration functions.
    • Improvements to clarity and correctness of docstrings for registration functions.
    • Simplifying batch_size_extractor.
    • Adding white space for improved readability.
    • Fixing _update_cache to account for state_dependent_scale (which is currently unused in the open source release).
    opened by copybara-service[bot] 0
  • * Making the estimator finalize itself automatically.

    * Making the estimator finalize itself automatically.

    • Making the estimator finalize itself automatically.
    • Making the optimizer call finalize at the end of init.
    • Removing the need for fake_batch in the optimizer.
    opened by copybara-service[bot] 0
  • - Using jnp.int64 for data_seen and step counters to avoid overflow

    - Using jnp.int64 for data_seen and step counters to avoid overflow

    • Using jnp.int64 for data_seen and step counters to avoid overflow
    • Using float for epochs instead of int
    • Adding extra arguments to cosine schedule in examples
    opened by copybara-service[bot] 0
  • Correct buffer donation.

    Correct buffer donation.

    Correct buffer donation.

    Buffer donation is only valid if the shape and type of an input buffer matches an output. Buffer donation only works with positional arguments, not keyword arguments.

    opened by copybara-service[bot] 1
Releases(v0.0.3)
  • v0.0.3(Sep 23, 2022)

    What's Changed

    • Changing the version in the citation text in the README. by @copybara-service in https://github.com/deepmind/kfac-jax/pull/29
    • Adding attributes for the number of training and evaluation devices. by @copybara-service in https://github.com/deepmind/kfac-jax/pull/31
    • Adding some methods to ImplicitExactCurvature by @copybara-service in https://github.com/deepmind/kfac-jax/pull/32
    • Adding "put_stop_grad_on_loss_factor" argument to 'multiply_fisher_factor'. by @copybara-service in https://github.com/deepmind/kfac-jax/pull/36
    • Making ScaleAndShift blocks begin capable of having parameters that are broadcast by construction, e.g. batch norm with scale parameters [1, 1, 1, d]. by @copybara-service in https://github.com/deepmind/kfac-jax/pull/33
      • Changing jax.tree_map -> jax.tree_util.tree_map and related due to recent deprecation. by @copybara-service in https://github.com/deepmind/kfac-jax/pull/37
      • Removed unused precedence argument from GraphPattern. by @copybara-service in https://github.com/deepmind/kfac-jax/pull/38
    • Fix a small bug where we don't check in the jaxpr constvars. by @copybara-service in https://github.com/deepmind/kfac-jax/pull/39
      • Adding an estimator attribute to the optimizer. by @copybara-service in https://github.com/deepmind/kfac-jax/pull/34
    • Updating the docs to correctly refer to update_cache. by @copybara-service in https://github.com/deepmind/kfac-jax/pull/40
    • Compare with slightly less numerical precision. by @copybara-service in https://github.com/deepmind/kfac-jax/pull/41
      • Revamping the graph matching code to be able to detect layers and register tag in arbitrary higher-order Jax primitives. by @copybara-service in https://github.com/deepmind/kfac-jax/pull/42
    • Revising docstring for optimizer class. Now contains missing details about value_and_grad_func. by @copybara-service in https://github.com/deepmind/kfac-jax/pull/43
    • Internal change. by @copybara-service in https://github.com/deepmind/kfac-jax/pull/44
      • Make LossTag to return only the parameter dependent arrays. by @copybara-service in https://github.com/deepmind/kfac-jax/pull/46
      • Improving LossTags to be able to deal correctly with None arguments, by passing in argument names. by @copybara-service in https://github.com/deepmind/kfac-jax/pull/47
    • Minor fix to a bug introduced on previous commit. by @copybara-service in https://github.com/deepmind/kfac-jax/pull/48
      • Correcting issues with docstring for optimizer. by @copybara-service in https://github.com/deepmind/kfac-jax/pull/45
    • Fixing a bug in the graph matcher introduced in a recent CL. by @copybara-service in https://github.com/deepmind/kfac-jax/pull/49
    • Removing unneeded jax.jit in get_mean and get_sum. by @copybara-service in https://github.com/deepmind/kfac-jax/pull/50
      • Adding per-parameter norm stats to optimizer by @copybara-service in https://github.com/deepmind/kfac-jax/pull/51
    • Allowing the pi-adjusted psd inverse to accept diagonal factors. by @copybara-service in https://github.com/deepmind/kfac-jax/pull/55
    • Fixing wrong type annotation of pmap_axis_name. by @copybara-service in https://github.com/deepmind/kfac-jax/pull/56
    • Adding optional offloading of eigh computation to the host because of a bug in CUDA 11.7.0 cuSOLVER library. by @copybara-service in https://github.com/deepmind/kfac-jax/pull/57

    Full Changelog: https://github.com/deepmind/kfac-jax/compare/v0.0.2...v0.0.3

    Source code(tar.gz)
    Source code(zip)
  • v0.0.2(Jun 7, 2022)

    What's Changed

    • Moving .github to top-level directory for CI. by @copybara-service in https://github.com/deepmind/kfac-jax/pull/1
      • Updated documentation for state classes. by @copybara-service in https://github.com/deepmind/kfac-jax/pull/2
    • Changing the name on PyPi to kfac-jax. by @copybara-service in https://github.com/deepmind/kfac-jax/pull/3
    • Making the tracer test in float64. by @copybara-service in https://github.com/deepmind/kfac-jax/pull/4
      • Allowing graph patterns with multiple broadcast to be merged without dangling equations. by @copybara-service in https://github.com/deepmind/kfac-jax/pull/5
      • Adding README for the examples. by @copybara-service in https://github.com/deepmind/kfac-jax/pull/7
    • Changing deprecated tree_multimap to tree_map. by @copybara-service in https://github.com/deepmind/kfac-jax/pull/8
    • Fixing small error introduced due to updates to chex. by @copybara-service in https://github.com/deepmind/kfac-jax/pull/11
    • Fixing typo "drop_reminder" by @copybara-service in https://github.com/deepmind/kfac-jax/pull/13
      • Adding an argument to set the reduction ratio thresholds for automatic damping adjustment. by @copybara-service in https://github.com/deepmind/kfac-jax/pull/12
      • Adding "modifiable_attribute_exceptions" argument to optimizer by @copybara-service in https://github.com/deepmind/kfac-jax/pull/14
    • Changing Imagenet dataset in examples to use a seed for file shuffling to achieve determinism. by @copybara-service in https://github.com/deepmind/kfac-jax/pull/17
    • Small fix to a doc reference bug. by @copybara-service in https://github.com/deepmind/kfac-jax/pull/16
    • Making WeightedMovingAverage to work with arbitrary structures. by @copybara-service in https://github.com/deepmind/kfac-jax/pull/19
      • Minor typos. by @copybara-service in https://github.com/deepmind/kfac-jax/pull/20
    • Correct buffer donation of Optimizer._step. by @copybara-service in https://github.com/deepmind/kfac-jax/pull/21
    • Replacing yield from with direct iteration. by @copybara-service in https://github.com/deepmind/kfac-jax/pull/24
    • Adding stepwise schedule option to examples. by @copybara-service in https://github.com/deepmind/kfac-jax/pull/18
    • Publishing a new version to PyPi. by @copybara-service in https://github.com/deepmind/kfac-jax/pull/28

    New Contributors

    • @copybara-service made their first contribution in https://github.com/deepmind/kfac-jax/pull/1

    Full Changelog: https://github.com/deepmind/kfac-jax/commits/v0.0.2

    Source code(tar.gz)
    Source code(zip)
Owner
DeepMind
DeepMind
QRec: A Python Framework for quick implementation of recommender systems (TensorFlow Based)

Introduction QRec is a Python framework for recommender systems (Supported by Python 3.7.4 and Tensorflow 1.14+) in which a number of influential and

Yu 1.4k Jan 01, 2023
Deep Learning & 3D Convolutional Neural Networks for Speaker Verification

TensorFlow implementation of 3D Convolutional Neural Networks for Speaker Verification - Official Project Page - Pytorch Implementation This repositor

Amirsina Torfi 753 Dec 17, 2022
This is the repo of the manuscript "Dual-branch Attention-In-Attention Transformer for speech enhancement"

DB-AIAT: A Dual-branch attention-in-attention transformer for single-channel SE

Guochen Yu 68 Dec 16, 2022
A numpy-based implementation of RANSAC for fundamental matrix and homography estimation. The degeneracy updating and local optimization components are included and optional.

Description A numpy-based implementation of RANSAC for fundamental matrix and homography estimation. The degeneracy updating and local optimization co

AoxiangFan 9 Nov 10, 2022
This is a Deep Leaning API for classifying emotions from human face and human audios.

Emotion AI This is a Deep Leaning API for classifying emotions from human face and human audios. Starting the server To start the server first you nee

crispengari 5 Oct 02, 2022
Revisiting Self-Training for Few-Shot Learning of Language Model.

SFLM This is the implementation of the paper Revisiting Self-Training for Few-Shot Learning of Language Model. SFLM is short for self-training for few

15 Nov 19, 2022
The official implementation of our CVPR 2021 paper - Hybrid Rotation Averaging: A Fast and Robust Rotation Averaging Approach

Graph Optimizer This repo contains the official implementation of our CVPR 2021 paper - Hybrid Rotation Averaging: A Fast and Robust Rotation Averagin

Chenyu 109 Dec 23, 2022
MCMC samplers for Bayesian estimation in Python, including Metropolis-Hastings, NUTS, and Slice

Sampyl May 29, 2018: version 0.3 Sampyl is a package for sampling from probability distributions using MCMC methods. Similar to PyMC3 using theano to

Mat Leonard 304 Dec 25, 2022
DI-HPC is an acceleration operator component for general algorithm modules in reinforcement learning algorithms

DI-HPC: Decision Intelligence - High Performance Computation DI-HPC is an acceleration operator component for general algorithm modules in reinforceme

OpenDILab 185 Dec 29, 2022
Model-based 3D Hand Reconstruction via Self-Supervised Learning, CVPR2021

S2HAND: Model-based 3D Hand Reconstruction via Self-Supervised Learning S2HAND presents a self-supervised 3D hand reconstruction network that can join

Yujin Chen 72 Dec 12, 2022
Implementation of MeMOT - Multi-Object Tracking with Memory - in Pytorch

MeMOT - Pytorch (wip) Implementation of MeMOT - Multi-Object Tracking with Memory - in Pytorch. This paper is just one in a line of work, but importan

Phil Wang 15 May 09, 2022
Code release for the paper “Worldsheet Wrapping the World in a 3D Sheet for View Synthesis from a Single Image”, ICCV 2021.

Worldsheet: Wrapping the World in a 3D Sheet for View Synthesis from a Single Image This repository contains the code for the following paper: R. Hu,

Meta Research 37 Jan 04, 2023
💛 Code and Dataset for our EMNLP 2021 paper: "Perspective-taking and Pragmatics for Generating Empathetic Responses Focused on Emotion Causes"

Perspective-taking and Pragmatics for Generating Empathetic Responses Focused on Emotion Causes Official PyTorch implementation and EmoCause evaluatio

Hyunwoo Kim 51 Jan 06, 2023
[Machine Learning Engineer Basic Guide] 부스트캠프 AI Tech - Product Serving 자료

Boostcamp-AI-Tech-Product-Serving 부스트캠프 AI Tech - Product Serving 자료 Repository 구조 part1(MLOps 개론, Model Serving, 머신러닝 프로젝트 라이프 사이클은 별도의 코드가 없으며, part

Sung Yun Byeon 269 Dec 21, 2022
The Medical Detection Toolkit contains 2D + 3D implementations of prevalent object detectors such as Mask R-CNN, Retina Net, Retina U-Net, as well as a training and inference framework focused on dealing with medical images.

The Medical Detection Toolkit contains 2D + 3D implementations of prevalent object detectors such as Mask R-CNN, Retina Net, Retina U-Net, as well as a training and inference framework focused on dea

MIC-DKFZ 1.2k Jan 04, 2023
190 Jan 03, 2023
potpourri3d - An invigorating blend of 3D geometry tools in Python.

A Python library of various algorithms and utilities for 3D triangle meshes and point clouds. Managed by Nicholas Sharp, with new tools added lazily as needed. Currently, mainly bindings to C++ tools

Nicholas Sharp 295 Jan 05, 2023
Open source repository for the code accompanying the paper 'PatchNets: Patch-Based Generalizable Deep Implicit 3D Shape Representations'.

PatchNets This is the official repository for the project "PatchNets: Patch-Based Generalizable Deep Implicit 3D Shape Representations". For details,

16 May 22, 2022
Put blind watermark into a text with python

text_blind_watermark Put blind watermark into a text. Can be used in Wechat dingding ... How to Use install pip install text_blind_watermark Alice Pu

郭飞 164 Dec 30, 2022
给yolov5加个gui界面,使用pyqt5,yolov5是5.0版本

博文地址 https://xugaoxiang.com/2021/06/30/yolov5-pyqt5 代码执行 项目中使用YOLOv5的v5.0版本,界面文件是project.ui pip install -r requirements.txt python main.py 图片检测 视频检测

Xu GaoXiang 215 Dec 30, 2022