A new mini-batch framework for optimal transport in deep generative models, deep domain adaptation, approximate Bayesian computation, color transfer, and gradient flow.

Related tags

MiscellaneousBoMb-OT
Overview

BoMb-OT

Python3 implementation of the papers On Transportation of Mini-batches: A Hierarchical Approach and Improving Mini-batch Optimal Transport via Partial Transportation.

Please CITE our papers whenever this repository is used to help produce published results or incorporated into other software.

@article{nguyen2021transportation,
      title={On Transportation of Mini-batches: A Hierarchical Approach}, 
      author={Khai Nguyen and Dang Nguyen and Quoc Nguyen and Tung Pham and Hung Bui and Dinh Phung and Trung Le and Nhat Ho},
      journal={arXiv preprint arXiv:2102.05912},
      year={2021},
}
@article{nguyen2021improving,
      title={Improving Mini-batch Optimal Transport via Partial Transportation}, 
      author={Khai Nguyen and Dang Nguyen and Tung Pham and Nhat Ho},
      journal={arXiv preprint arXiv:2108.09645},
      year={2021},
}

This implementation is made by Khai Nguyen and Dang Nguyen. README is on updating process.

Requirement

  • python 3.6
  • pytorch 1.7.1
  • torchvision
  • numpy
  • tqdm
  • geomloss
  • POT
  • matplotlib
  • cvxpy

What is included?

The scalable implementation of the batch of mini-batches scheme and the conventional averaging scheme of mini-batch transportation types: optimal transport (OT), partial optimal transport (POT), unbalanced optimal transport (UOT), sliced optimal transport for:

  • Deep Generative Models
  • Deep Domain Adaptation
  • Approximate Bayesian Computation
  • Color Transfer
  • Gradient Flow

Deep Adaptation on digits datasets (DeepDA/digits)

Code organization

cfg.py : this file contains arguments for training.

methods.py : this file implements the training process of the deep DA.

models.py : this file contains the architecture of the genertor and the classifier.

train_digits.py: running file for deep DA.

utils.py : this file contains implementation of utility functions.

Terminologies

--method : type of mini-batch deep DA method (jdot, jumbot, jpmbot)

--source_ds : source dataset

--target_ds : target dataset

--epsilon : OT regularization coefficient for Sinkhorn algorithm

--tau : marginal penalization coefficient in UOT

--mass : fraction of masses in POT

--eta1 : weight of embedding loss

--eta2 : weight of transportation loss

--k : number of mini-batches

--mbsize : mini-batch size

--n_epochs : number of running epochs

--test_interval : interval of two continuous test phase

--lr : initial learning rate

--data_dir : path to dataset

--reg : OT regularization coefficient for Sinkhorn algorithm

--bomb : Using Batch of Mini-batches

--ebomb : Using entropic Batch of Mini-batches

--breg : OT regularization coefficient for entropic Batch of Mini-batches

Change the number of mini-batches $k$

bash sh/exp_mOT_change_k.sh
bash sh/exp_BoMbOT_change_k.sh

Change the mini-batch size $m$

bash sh/exp_mOT_change_m.sh
bash sh/exp_BoMbOT_change_m.sh

Deep Adaptation on Office-Home and VisDA datasets (DeepDA/office)

Code organization

data_list.py : this file contains functions to create dataset.

evaluate.py : this file is used to evaluate model trained on VisDA dataset.

lr_schedule.py : this file implements the learning rate scheduler.

network.py : this file contains the architecture of the genertor and the classifier.

pre_process.py : this file implements preprocessing techniques.

train.py : this file implements the training process for both datasets.

Terminologies

--net : architecture type of the generator

--dset : name of the dataset

--test_interval : interval of two continuous test phase

--s_dset_path : path to source dataset

--stratify_source : use stratify sampling

--s_dset_path : path to target dataset

--batch_size : training batch size

--stop_step : number of iterations

--ot_type : type of OT loss (balanced, unbalanced, partial)

--eta1 : weight of embedding loss ($\alpha$ in equation 10)

--eta2 : weight of transportation loss ($\lambda_t$ in equation 10)

--epsilon : OT regularization coefficient for Sinkhorn algorithm

--tau : marginal penalization coefficient in UOT

--mass : fraction of masses in POT

--bomb : Using Batch of Mini-batches

--ebomb : Using entropic Batch of Mini-batches

--breg : OT regularization coefficient for entropic Batch of Mini-batches

Train on Office-Home

bash sh/train_home.sh

Train on VisDA

bash sh/train_visda.sh

Deep Generative model (DeepGM)

Code organization

Celeba_generator.py, Cifar_generator.py : these files contain the architecture of the generator on CelebA and CIFAR10 datasets, and include some self-function to compute losses of corresponding baselines.

experiments.py : this file contains some functions for generating images.

fid_score.py: this file is used to compute the FID score.

gen_images.py: read saved models to produce 10000 images to calculate FID.

inception.py: this file contains the architecture of Inception Net V3.

main_celeba.py, main_cifar.py : running files on the corresponding datasets.

utils.py : this file contains implementation of utility functions.

Terminologies

--method : type of OT loss (OT, UOT, POT, sliced)

--reg : OT regularization coefficient for Sinkhorn algorithm

--tau : marginal penalization coefficient in UOT

--mass : fraction of masses in POT

--k : number of mini-batches

--m : mini-batch size

--epochs : number of epochs at k = 1. The actual running epochs is calculated by multiplying this value by the value of k.

--lr : initial learning rate

--latent-size : latent size of the generator

--datadir : path to dataset

--L : number of projections when using slicing approach

--bomb : Using Batch of Mini-batches

--ebomb : Using entropic Batch of Mini-batches

--breg : OT regularization coefficient for entropic Batch of Mini-batches

Train on CIFAR10

CUDA_VISIBLE_DEVICES=0 python main_cifar.py --method POT --reg 0 --tau 1 \
    --mass 0.7 --k 2 --m 100 --epochs 100 --lr 5e-4 --latent-size 32 --datadir ./data

Train on CELEBA

CUDA_VISIBLE_DEVICES=0 python main_celeba.py --method POT --reg 0 --tau 1 \
    --mass 0.7 --k 2 --m 200 --epochs 100 --lr 5e-4 --latent-size 32 --datadir ./data

Gradient Flow (GradientFlow)

python main.py

Color Transfer (Color Transfer)

python main.py  --m=100 --T=10000 --source images/s1.bmp --target images/t1.bmp --cluster

Terminologies

--k : number of mini-batches

--m : the size of mini-batches

--T : the number of steps

--cluster: K mean clustering to compress images

--palette: show color palette

--source: Path to the source image

Acknowledgment

The structure of DeepDA is largely based on JUMBOT and ALDA. The structure of ABC is largely based on SlicedABC. We are very grateful for their open sources.

Owner
Khai Ba Nguyen
I am currently an AI Resident at VinAI Research, Vietnam.
Khai Ba Nguyen
ARK sõidueksami Matrixi bot

ARK Sõidueksami bot Küsib ARK-i lehelt uusimad eksami ajad ja saadab sõnumi Matrixi kanali Dev setup Linux python3 -m venv venv source venv/bin/activa

Arti Zirk 3 Jun 15, 2021
Tindicators is a Python library to calculate the values of various technical indicators

Tindicators is a Python library to calculate the values of various technical indicators

omar 3 Mar 03, 2022
Solve various integral equations using numerical methods in Python

Solve Volterra and Fredholm integral equations This Python package estimates Volterra and Fredholm integral equations using known techniques. Installa

Matthew Wildrick Thomas 18 Nov 28, 2022
YBlade - Import QBlade blades into Fusion 360

YBlade - Import QBlade blades into Fusion 360 Simple script for Fusion 360 that takes QBlade blade description and constructs the blade: Usage First,

Jan Mrázek 37 Sep 25, 2022
py2dis - A disassembly engine & library for Python

py2dis - A disassembly engine & library for Python. py2dis is a disassembly library for Python that does not use any modules/libraries other than colo

3 Feb 04, 2022
Simulation-Based Inference Benchmark

This repository contains a simulation-based inference benchmark framework, sbibm, which we describe in the associated manuscript "Benchmarking Simulation-based Inference".

SBI Benchmark 58 Oct 13, 2022
CarolinaCon CTF Online

CarolinaCon Online CTF CTF challenges from CarolinaCon Online April 23 through April 25, 2021. All challenges from the CTF will eventually be here. Co

49th Security Division 6 May 04, 2022
ROS Foxy + Raspi + Adafruit BNO055

ROS Foxy + Raspi + Adafruit BNO055

Ar-Ray 3 Nov 04, 2022
A Python3 script to decode an encoded VBScript file, often seen with a .vbe file extension

vbe-decoder.py Decode one or multiple encoded VBScript files, often seen with a .vbe file extension. Usage usage: vbe-decoder.py [-h] [-o output] file

John Hammond 147 Nov 15, 2022
This is a library to do functional programming in Python.

Fpylib This is a library to do functional programming in Python. Index Fpylib Index Features Intelligents Ranges with irange Lazyness to functions Com

Fabián Vega Alcota 4 Jul 17, 2022
This is the DBMS Project done in 5th sem of B.E CS.

Student-Result-Management-System This is the DBMS Project done in 5th sem of B.E CS. You need to install SQlite DB Browser in your pc or laptop to ope

Vivek kulkarni 1 Jan 14, 2022
A free and powerful system for awareness and research of the American judicial system.

CourtListener Started in 2009, CourtListener.com is the main initiative of Free Law Project. The goal of CourtListener.com is to provide high quality

Free Law Project 332 Dec 25, 2022
This script provides LIVE feedback for On-The-Fly data collection with RELION

README This script provides LIVE feedback for On-The-Fly data collection with RELION (very useful to explore already processed datasets too!) Creating

cryoEM CNIO 6 Jul 14, 2022
Rock 💎 Paper 📝 Scissors ✂️ Lizard 🦎 Spock 🖖

Rock 💎 Paper 📝 Scissors ✂️ Lizard 🦎 Spock 🖖 If you’ve seen The Big Bang Theory, you’ve heard of a game called “Rock, Paper, Scissors, Lizard, Spoc

AmirHossein Mohammadi 16 Jun 19, 2022
More granular intermediaries for legacy Minecraft versions

Orinthe/Intermediary mappings This repository contains the match information between different versions of Minecraft created by the Orinthe project, a

4 Jan 11, 2022
In this repo, I will put all the code related to data science using python libraries like Numpy, Pandas, Matplotlib, Seaborn and many more.

Python-for-DS In this repo, I will put all the code related to data science using python libraries like Numpy, Pandas, Matplotlib, Seaborn and many mo

1 Jan 10, 2022
Adjust the white point, gamma or make your XDR display darker without losing HDR peak luminance or the ability to adjust display brightness

XDR Tuner Adjust the white point, gamma or make your XDR display darker without losing HDR peak luminance or the ability to adjust display brightness

François Simond 16 Dec 28, 2022
Automated moth pictures for biodiversity research

Automated moth pictures for biodiversity research

Ludwig Kürzinger 1 Dec 16, 2021
Batch Python Program Verify

Batch Python Program Verify About As a TA(teaching assistant) of Programming Class, it is very annoying to test students' homework assignments one by

Han-Wei Li 7 Dec 20, 2022
Pylexa - Artificial Assistant made with Python

Pylexa - Artificial Assistant made with Python Alexa is a famous artificial assistant used massively across the world. It is a substitute of Alexa whi

\_PROTIK_/ 4 Nov 03, 2021