Learning Confidence for Out-of-Distribution Detection in Neural Networks

Overview

Learning Confidence Estimates for Neural Networks

This repository contains the code for the paper Learning Confidence for Out-of-Distribution Detection in Neural Networks. In this work, we demonstrate how to augment neural networks with a confidence estimation branch, which can be used to identify misclassified and out-of-distribution examples.

To learn confidence estimates during training, we provide the neural network with "hints" towards the correct output whenever it exhibits low confidence in its predictions. Hints are provided by pushing the prediction closer to the target distribution via interpolation, where the amount of interpolation proportional to the network's confidence that its prediction is correct. To discourage the network from always asking for free hints, a small penalty is applied whenever it is not confident. As a result, the network learns to only produce low confidence estimates when it is likely to make an incorrect prediction.

Bibtex:

@article{devries2018learning,
  title={Learning Confidence for Out-of-Distribution Detection in Neural Networks},
  author={DeVries, Terrance and Taylor, Graham W},
  journal={arXiv preprint arXiv:1802.04865},
  year={2018}
}

Results and Usage

We evalute our method on the task of out-of-distribution detection using three different neural network architectures: DenseNet, WideResNet, and VGG. CIFAR-10 and SVHN are used as the in-distribution datasets, while TinyImageNet, LSUN, iSUN, uniform noise, and Gaussian noise are used as the out-of-distribution datasets. Definitions of evaluation metrics can be found in the paper.

Dependencies

PyTorch v0.3.0
tqdm
visdom
seaborn
Pillow
scikit-learn

Training

Train a model with a confidence estimator with train.py. During training you can use visdom to see a histogram of confidence estimates from the test set. Training logs will be stored in the logs/ folder, while checkpoints are stored in the checkpoints/ folder.

Args Options Description
dataset cifar10,
svhn
Selects which dataset to train on.
model densenet,
wideresnet,
vgg13
Selects which model architecture to use.
batch_size [int] Number of samples per batch.
epochs [int] Number of epochs for training.
seed [int] Random seed.
learning_rate [float] Learning rate.
data_augmentation Train with standard data augmentation (random flipping and translation).
cutout [int] Indicates the patch size to use for Cutout. If 0, Cutout is not used.
budget [float] Controls how often the network can choose have low confidence in its prediction. Increasing the budget will bias the output towards low confidence predictions, while decreasing the budget will produce more high confidence predictions.
baseline Train the model without the confidence branch.

Use the following settings to replicate the experiments from the paper:

VGG13 on CIFAR-10

python train.py --dataset cifar10 --model vgg13 --budget 0.3 --data_augmentation --cutout 16

WideResNet on CIFAR-10

python train.py --dataset cifar10 --model wideresnet --budget 0.3 --data_augmentation --cutout 16

DenseNet on CIFAR-10

python train.py --dataset cifar10 --model densenet --budget 0.3 --epochs 300 --batch_size 64 --data_augmentation --cutout 16

VGG13 on SVHN

python train.py --dataset svhn --model vgg13 --budget 0.3 --learning_rate 0.01 --epochs 160 --data_augmentation --cutout 20

WideResNet on SVHN

python train.py --dataset svhn --model wideresnet --budget 0.3 --learning_rate 0.01 --epochs 160 --data_augmentation --cutout 20

DenseNet on SVHN

python train.py --dataset svhn --model densenet --budget 0.3 --learning_rate 0.01 --epochs 300 --batch_size 64  --data_augmentation --cutout 20

Out-of-distribution detection

Evaluate a trained model with out_of_distribution_detection.py. Before running this you will need to download the out-of-distribution datasets from Shiyu Liang's ODIN github repo and modify the data paths in the file according to where you saved the datasets.

Args Options Description
ind_dataset cifar10,
svhn
Indicates which dataset to use as in-distribution. Should be the same one that the model was trained on.
ood_dataset tinyImageNet_crop,
tinyImageNet_resize,
LSUN_crop,
LSUN_resize,
iSUN,
Uniform,
Gaussian,
all
Indicates which dataset to use as the out-of-distribution datset.
model densenet,
wideresnet,
vgg13
Selects which model architecture to use. Should be the same one that the model was trained on.
process baseline,
ODIN,
confidence,
confidence_scaling
Indicates which method to use for out-of-distribution detection. Baseline uses the maximum softmax probability. ODIN applies temperature scaling and input pre-processing to the baseline method. Confidence uses the learned confidence estimates. Confidence scaling applies input pre-processing to the confidence estimates.
batch_size [int] Number of samples per batch.
T [float] Temperature to use for temperature scaling.
epsilon [float] Noise magnitude to use for input pre-processing.
checkpoint [str] Filename of trained model checkpoint. Assumes the file is in the checkpoints/ folder. A .pt extension is also automatically added to the filename.
validation Use this flag for fine-tuning T and epsilon. If flag is on, the script will only evaluate on the first 1000 samples in the out-of-distribution dataset. If flag is not used, the remaining samples are used for evaluation. Based on validation procedure from ODIN.

Example commands for running the out-of-distribution detection script:

Baseline

python out_of_distribution_detection.py --ind_dataset svhn --ood_dataset all --model vgg13 --process baseline --checkpoint svhn_vgg13_budget_0.0_seed_0

ODIN

python out_of_distribution_detection.py --ind_dataset cifar10 --ood_dataset tinyImageNet_resize --model densenet --process ODIN --T 1000 --epsilon 0.001 --checkpoint cifar10_densenet_budget_0.0_seed_0

Confidence

python out_of_distribution_detection.py --ind_dataset cifar10 --ood_dataset LSUN_crop --model vgg13 --process confidence --checkpoint cifar10_vgg13_budget_0.3_seed_0

Confidence scaling

python out_of_distribution_detection.py --ind_dataset svhn --ood_dataset iSUN --model wideresnet --process confidence_scaling --epsilon 0.001 --checkpoint svhn_wideresnet_budget_0.3_seed_0
Numbering permanent and deciduous teeth via deep instance segmentation in panoramic X-rays

Numbering permanent and deciduous teeth via deep instance segmentation in panoramic X-rays In this repo, you will find the instructions on how to requ

Intelligent Vision Research Lab 4 Jul 21, 2022
This is the official implementation of "One Question Answering Model for Many Languages with Cross-lingual Dense Passage Retrieval".

CORA This is the official implementation of the following paper: Akari Asai, Xinyan Yu, Jungo Kasai and Hannaneh Hajishirzi. One Question Answering Mo

Akari Asai 59 Dec 28, 2022
DeepLearning Anomalies Detection with Bluetooth Sensor Data

Final Year Project. Constructing models to create offline anomalies detection using Travel Time Data collected from Bluetooth sensors along the route.

1 Jan 10, 2022
Ppq - A powerful offline neural network quantization tool with custimized IR

PPL Quantization Tool(PPL 量化工具) PPL Quantization Tool (PPQ) is a powerful offlin

605 Jan 03, 2023
EMNLP 2021 - Frustratingly Simple Pretraining Alternatives to Masked Language Modeling

Frustratingly Simple Pretraining Alternatives to Masked Language Modeling This is the official implementation for "Frustratingly Simple Pretraining Al

Atsuki Yamaguchi 31 Nov 18, 2022
PyTorch implementation of paper: AdaAttN: Revisit Attention Mechanism in Arbitrary Neural Style Transfer, ICCV 2021.

AdaAttN: Revisit Attention Mechanism in Arbitrary Neural Style Transfer [Paper] [PyTorch Implementation] [Paddle Implementation] Overview This reposit

148 Dec 30, 2022
Implementation of Advantage-Weighted Regression: Simple and Scalable Off-Policy Reinforcement Learning

advantage-weighted-regression Implementation of Advantage-Weighted Regression: Simple and Scalable Off-Policy Reinforcement Learning, by Peng et al. (

Omar D. Domingues 1 Dec 02, 2021
This is the official implementation for the paper "(Almost) Free Incentivized Exploration from Decentralized Learning Agents" in NeurIPS 2021.

Observe then Incentivize Experiments This is the code used for the paper "(Almost) Free Incentivized Exploration from Decentralized Learning Agents",

Cong Shen Research Group 0 Mar 08, 2022
Imagededup - 😎 Finding duplicate images made easy

imagededup is a python package that simplifies the task of finding exact and near duplicates in an image collection.

idealo 4.3k Jan 07, 2023
Official PyTorch implementation of "Contrastive Learning from Extremely Augmented Skeleton Sequences for Self-supervised Action Recognition" in AAAI2022.

AimCLR This is an official PyTorch implementation of "Contrastive Learning from Extremely Augmented Skeleton Sequences for Self-supervised Action Reco

Gty 44 Dec 17, 2022
ONNX Runtime Web demo is an interactive demo portal showing real use cases running ONNX Runtime Web in VueJS.

ONNX Runtime Web demo is an interactive demo portal showing real use cases running ONNX Runtime Web in VueJS. It currently supports four examples for you to quickly experience the power of ONNX Runti

Microsoft 58 Dec 18, 2022
The repo of the preprinting paper "Labels Are Not Perfect: Inferring Spatial Uncertainty in Object Detection"

Inferring Spatial Uncertainty in Object Detection A teaser version of the code for the paper Labels Are Not Perfect: Inferring Spatial Uncertainty in

ZINING WANG 21 Mar 03, 2022
Learning Dense Representations of Phrases at Scale (Lee et al., 2020)

DensePhrases DensePhrases provides answers to your natural language questions from the entire Wikipedia in real-time. While it efficiently searches th

Princeton Natural Language Processing 540 Dec 30, 2022
OpenMMLab Video Perception Toolbox. It supports Video Object Detection (VID), Multiple Object Tracking (MOT), Single Object Tracking (SOT), Video Instance Segmentation (VIS) with a unified framework.

English | 简体中文 Documentation: https://mmtracking.readthedocs.io/ Introduction MMTracking is an open source video perception toolbox based on PyTorch.

OpenMMLab 2.7k Jan 08, 2023
Motion planning environment for Sampling-based Planners

Sampling-Based Motion Planners' Testing Environment Sampling-based motion planners' testing environment (sbp-env) is a full feature framework to quick

Soraxas 23 Aug 23, 2022
This folder contains the python code of UR5E's advanced forward kinematics model.

This folder contains the python code of UR5E's advanced forward kinematics model. By entering the angle of the joint of UR5e, the detailed coordinates of up to 48 points around the robot arm can be c

Qiang Wang 4 Sep 17, 2022
Source code and notebooks to reproduce experiments and benchmarks on Bias Faces in the Wild (BFW).

Face Recognition: Too Bias, or Not Too Bias? Robinson, Joseph P., Gennady Livitz, Yann Henon, Can Qin, Yun Fu, and Samson Timoner. "Face recognition:

Joseph P. Robinson 41 Dec 12, 2022
Code & Data for Enhancing Photorealism Enhancement

Enhancing Photorealism Enhancement Stephan R. Richter, Hassan Abu AlHaija, Vladlen Koltun Paper | Website (with side-by-side comparisons) | Video (Pap

Intelligent Systems Lab Org 1.1k Dec 31, 2022
Regularizing Generative Adversarial Networks under Limited Data (CVPR 2021)

Regularizing Generative Adversarial Networks under Limited Data [Project Page][Paper] Implementation for our GAN regularization method. The proposed r

Google 148 Nov 18, 2022
FastReID is a research platform that implements state-of-the-art re-identification algorithms.

FastReID is a research platform that implements state-of-the-art re-identification algorithms.

JDAI-CV 2.8k Jan 07, 2023