Differentiable Factor Graph Optimization for Learning Smoothers @ IROS 2021

Related tags

Deep Learningdfgo
Overview

Differentiable Factor Graph Optimization for Learning Smoothers

mypy

Figure describing the overall training pipeline proposed by our IROS paper. Contains five sections, arranged left to right: (1) system models, (2) factor graphs for state estimation, (3) MAP inference, (4) state estimates, and (5) errors with respect to ground-truth. Arrows show how gradients are backpropagated from right to left, starting directly from the final stage (error with respect to ground-truth) back to parameters of the system models.

Overview

Code release for our IROS 2021 conference paper:

Brent Yi1, Michelle A. Lee1, Alina Kloss2, Roberto Martín-Martín1, and Jeannette Bohg1. Differentiable Factor Graph Optimization for Learning Smoothers. Proceedings of the IEEE/RSJ International Conference on Intelligent Robots and Systems (IROS), October 2021.

1Stanford University, {brentyi,michellelee,robertom,bohg}@cs.stanford.edu
2Max Planck Institute for Intelligent Systems, [email protected]


This repository contains models, training scripts, and experimental results, and can be used to either reproduce our results or as a reference for implementation details.

Significant chunks of the code written for this paper have been factored out of this repository and released as standalone libraries, which may be useful for building on our work. You can find each of them linked here:

  • jaxfg is our core factor graph optimization library.
  • jaxlie is our Lie theory library for working with rigid body transformations.
  • jax_dataclasses is our library for building JAX pytrees as dataclasses. It's similar to flax.struct, but has workflow improvements for static analysis and nested structures.
  • jax-ekf contains our EKF implementation.

Status

Included in this repo for the disk task:

  • Smoother training & results
    • Training: python train_disk_fg.py --help
    • Evaluation: python cross_validate.py --experiment-paths ./experiments/disk/fg/**/
  • Filter baseline training & results
    • Training: python train_disk_ekf.py --help
    • Evaluation: python cross_validate.py --experiment-paths ./experiments/disk/ekf/**/
  • LSTM baseline training & results
    • Training: python train_disk_lstm.py --help
    • Evaluation: python cross_validate.py --experiment-paths ./experiments/disk/lstm/**/

And, for the visual odometry task:

  • Smoother training & results (including ablations)
    • Training: python train_kitti_fg.py --help
    • Evaluation: python cross_validate.py --experiment-paths ./experiments/kitti/fg/**/
  • EKF baseline training & results
    • Training: python train_kitti_ekf.py --help
    • Evaluation: python cross_validate.py --experiment-paths ./experiments/kitti/ekf/**/
  • LSTM baseline training & results
    • Training: python train_kitti_lstm.py --help
    • Evaluation: python cross_validate.py --experiment-paths ./experiments/kitti/lstm/**/

Note that **/ indicates a recursive glob in zsh. This can be emulated in bash>4 via the globstar option (shopt -q globstar).

We've done our best to make our research code easy to parse, but it's still being iterated on! If you have questions, suggestions, or any general comments, please reach out or file an issue.

Setup

We use Python 3.8 and miniconda for development.

  1. Any calls to CHOLMOD (via scikit-sparse, sometimes used for eval but never for training itself) will require SuiteSparse:

    # Mac
    brew install suite-sparse
    
    # Debian
    sudo apt-get install -y libsuitesparse-dev
  2. Dependencies can be installed via pip:

    pip install -r requirements.txt

    In addition to JAX and the first-party dependencies listed above, note that this also includes various other helpers:

    • datargs (currently forked) is super useful for building type-safe argument parsers.
    • torch's Dataset and DataLoader interfaces are used for training.
    • fannypack contains some utilities for downloading datasets, working with PDB, polling repository commit hashes.

The requirements.txt provided will install the CPU version of JAX by default. For CUDA support, please see instructions from the JAX team.

Datasets

Datasets synced from Google Drive and loaded via h5py automatically as needed. If you're interested in downloading them manually, see lib/kitti/data_loading.py and lib/disk/data_loading.py.

Training

The naming convention for training scripts is as follows: train_{task}_{model type}.py.

All of the training scripts provide a command-line interface for configuring experiment details and hyperparameters. The --help flag will summarize these settings and their default values. For example, to run the training script for factor graphs on the disk task, try:

> python train_disk_fg.py --help

Factor graph training script for disk task.

optional arguments:
  -h, --help            show this help message and exit
  --experiment-identifier EXPERIMENT_IDENTIFIER
                        (default: disk/fg/default_experiment/fold_{dataset_fold})
  --random-seed RANDOM_SEED
                        (default: 94305)
  --dataset-fold {0,1,2,3,4,5,6,7,8,9}
                        (default: 0)
  --batch-size BATCH_SIZE
                        (default: 32)
  --train-sequence-length TRAIN_SEQUENCE_LENGTH
                        (default: 20)
  --num-epochs NUM_EPOCHS
                        (default: 30)
  --learning-rate LEARNING_RATE
                        (default: 0.0001)
  --warmup-steps WARMUP_STEPS
                        (default: 50)
  --max-gradient-norm MAX_GRADIENT_NORM
                        (default: 10.0)
  --noise-model {CONSTANT,HETEROSCEDASTIC}
                        (default: CONSTANT)
  --loss {JOINT_NLL,SURROGATE_LOSS}
                        (default: SURROGATE_LOSS)
  --pretrained-virtual-sensor-identifier PRETRAINED_VIRTUAL_SENSOR_IDENTIFIER
                        (default: disk/pretrain_virtual_sensor/fold_{dataset_fold})

When run, train scripts serialize experiment configurations to an experiment_config.yaml file. You can find hyperparameters in the experiments/ directory for all results presented in our paper.

Evaluation

All evaluation metrics are recorded at train time. The cross_validate.py script can be used to compute metrics across folds:

# Summarize all experiments with means and standard errors of recorded metrics.
python cross_validate.py

# Include statistics for every fold -- this is much more data!
python cross_validate.py --disaggregate

# We can also glob for a partial set of experiments; for example, all of the
# disk experiments.
# Note that the ** wildcard may fail in bash; see above for a fix.
python cross_validate.py --experiment-paths ./experiments/disk/**/

Acknowledgements

We'd like to thank Rika Antonova, Kevin Zakka, Nick Heppert, Angelina Wang, and Philipp Wu for discussions and feedback on both our paper and codebase. Our software design also benefits from ideas from several open-source projects, including Sophus, GTSAM, Ceres Solver, minisam, and SwiftFusion.

This work is partially supported by the Toyota Research Institute (TRI) and Google. This article solely reflects the opinions and conclusions of its authors and not TRI, Google, or any entity associated with TRI or Google.

Owner
Brent Yi
Brent Yi
Scale-aware Automatic Augmentation for Object Detection (CVPR 2021)

SA-AutoAug Scale-aware Automatic Augmentation for Object Detection Yukang Chen, Yanwei Li, Tao Kong, Lu Qi, Ruihang Chu, Lei Li, Jiaya Jia [Paper] [Bi

DV Lab 182 Dec 29, 2022
CLOOB: Modern Hopfield Networks with InfoLOOB Outperform CLIP

CLOOB: Modern Hopfield Networks with InfoLOOB Outperform CLIP Andreas Fürst* 1, Elisabeth Rumetshofer* 1, Viet Tran1, Hubert Ramsauer1, Fei Tang3, Joh

Institute for Machine Learning, Johannes Kepler University Linz 133 Jan 04, 2023
links and status of cool gradio demos

awesome-demos This is a list of some wonderful demos & applications built with Gradio. Here's how to contribute yours! 🖊️ Natural language processing

Gradio 96 Dec 30, 2022
Versatile Generative Language Model

Versatile Generative Language Model This is the implementation of the paper: Exploring Versatile Generative Language Model Via Parameter-Efficient Tra

Zhaojiang Lin 17 Dec 02, 2022
This is a Pytorch implementation of paper: DropEdge: Towards Deep Graph Convolutional Networks on Node Classification

DropEdge: Towards Deep Graph Convolutional Networks on Node Classification This is a Pytorch implementation of paper: DropEdge: Towards Deep Graph Con

401 Dec 16, 2022
SplineConv implementation for Paddle.

SplineConv implementation for Paddle This module implements the SplineConv operators from Matthias Fey, Jan Eric Lenssen, Frank Weichert, Heinrich Mül

北海若 3 Dec 29, 2021
SeisComP/SeisBench interface to enable deep-learning (re)picking in SeisComP

scdlpicker SeisComP/SeisBench interface to enable deep-learning (re)picking in SeisComP Objective This is a simple deep learning (DL) repicker module

Joachim Saul 6 May 13, 2022
Created as part of CS50 AI's coursework. This AI makes use of knowledge entailment to calculate the best probabilities to win Minesweeper.

Minesweeper-AI Created as part of CS50 AI's coursework. This AI makes use of knowledge entailment to calculate the best probabilities to win Minesweep

Beckham 0 Jul 20, 2022
Subdivision-based Mesh Convolutional Networks

Subdivision-based Mesh Convolutional Networks The official implementation of SubdivNet in our paper, Subdivion-based Mesh Convolutional Networks Requi

Zheng-Ning Liu 181 Dec 28, 2022
Simple tutorials using Google's TensorFlow Framework

TensorFlow-Tutorials Introduction to deep learning based on Google's TensorFlow framework. These tutorials are direct ports of Newmu's Theano Tutorial

Nathan Lintz 6k Jan 06, 2023
WarpDrive: Extremely Fast End-to-End Deep Multi-Agent Reinforcement Learning on a GPU

WarpDrive is a flexible, lightweight, and easy-to-use open-source reinforcement learning (RL) framework that implements end-to-end multi-agent RL on a single GPU (Graphics Processing Unit).

Salesforce 334 Jan 06, 2023
CenterPoint 3D Object Detection and Tracking using center points in the bird-eye view.

CenterPoint 3D Object Detection and Tracking using center points in the bird-eye view. Center-based 3D Object Detection and Tracking, Tianwei Yin, Xin

Tianwei Yin 134 Dec 23, 2022
This is a package for LiDARTag, described in paper: LiDARTag: A Real-Time Fiducial Tag System for Point Clouds

LiDARTag Overview This is a package for LiDARTag, described in paper: LiDARTag: A Real-Time Fiducial Tag System for Point Clouds (PDF)(arXiv). This wo

University of Michigan Dynamic Legged Locomotion Robotics Lab 159 Dec 21, 2022
noisy labels; missing labels; semi-supervised learning; entropy; uncertainty; robustness and generalisation.

ProSelfLC: CVPR 2021 ProSelfLC: Progressive Self Label Correction for Training Robust Deep Neural Networks For any specific discussion or potential fu

amos_xwang 57 Dec 04, 2022
A command line simple note taking app

Why yet another note taking program? note was designed with a very specific target in mind: me, and my 2354 scraps of paper. It runs from the command

64 Nov 20, 2022
This project uses reinforcement learning on stock market and agent tries to learn trading. The goal is to check if the agent can learn to read tape. The project is dedicated to hero in life great Jesse Livermore.

Reinforcement-trading This project uses Reinforcement learning on stock market and agent tries to learn trading. The goal is to check if the agent can

Deepender Singla 1.4k Dec 22, 2022
Codebase for the paper titled "Continual learning with local module selection"

This repository contains the codebase for the paper Continual Learning via Local Module Composition. Setting up the environemnt Create a new conda env

Oleksiy Ostapenko 20 Dec 10, 2022
(under submission) Bayesian Integration of a Generative Prior for Image Restoration

BIGPrior: Towards Decoupling Learned Prior Hallucination and Data Fidelity in Image Restoration Authors: Majed El Helou, and Sabine Süsstrunk {Note: p

Majed El Helou 22 Dec 17, 2022
A PyTorch implementation of " EfficientNet: Rethinking Model Scaling for Convolutional Neural Networks."

EfficientNet A PyTorch implementation of EfficientNet: Rethinking Model Scaling for Convolutional Neural Networks. [arxiv] [Official TF Repo] Implemen

AhnDW 298 Dec 10, 2022