Pytorch implementation of set transformer

Overview

set_transformer

Official PyTorch implementation of the paper Set Transformer: A Framework for Attention-based Permutation-Invariant Neural Networks .

Requirements

  • Python 3
  • torch >= 1.0
  • matplotlib
  • scipy
  • tqdm

Abstract

Many machine learning tasks such as multiple instance learning, 3D shape recognition, and few-shot image classification are defined on sets of instances. Since solutions to such problems do not depend on the order of elements of the set, models used to address them should be permutation invariant. We present an attention-based neural network module, the Set Transformer, specifically designed to model interactions among elements in the input set. The model consists of an encoder and a decoder, both of which rely on attention mechanisms. In an effort to reduce computational complexity, we introduce an attention scheme inspired by inducing point methods from sparse Gaussian process literature. It reduces the computation time of self-attention from quadratic to linear in the number of elements in the set. We show that our model is theoretically attractive and we evaluate it on a range of tasks, demonstrating the state-of-the-art performance compared to recent methods for set-structured data.

Experiments

This repository implements the maximum value regression (section 5.1), amortized clustering (section 5.3), and point cloud classification (section 5.5) experiments in the paper.

Maximum Value Regression

This experiment is reproduced in max_regression_demo.ipynb.

Amortized Clustering

To run the amortized clustering experiment with Set Transformer, run

python run.py --net=set_transformer

To run the same experiment with Deep Sets, run

python run.py --net=deepset

Point Cloud Classification

We used the same preprocessed ModelNet40 dataset used in the DeepSets paper. We cannot publicly share this file due to copyright and license issues. To run this code, you must obtain the preprocessed dataset "ModelNet40_cloud.h5". We recommend using multiple GPUs for this experiment; we used 8 Tesla P40s.

To run the point cloud classification experiment, run

python main_pointcloud.py --batch_size 256 --num_pts 100
python main_pointcloud.py --batch_size 256 --num_pts 1000
python main_pointcloud.py --batch_size 256 --num_pts 5000

The hyperparameters here were minimally tuned yet reproduced the results in the paper. It is likely that further tuning will get better results.

Reference

If you found the provided code useful, please consider citing our work.

@InProceedings{lee2019set,
    title={Set Transformer: A Framework for Attention-based Permutation-Invariant Neural Networks},
    author={Lee, Juho and Lee, Yoonho and Kim, Jungtaek and Kosiorek, Adam and Choi, Seungjin and Teh, Yee Whye},
    booktitle={Proceedings of the 36th International Conference on Machine Learning},
    pages={3744--3753},
    year={2019}
}
Owner
Juho Lee
Juho Lee
Learning RGB-D Feature Embeddings for Unseen Object Instance Segmentation

Unseen Object Clustering: Learning RGB-D Feature Embeddings for Unseen Object Instance Segmentation Introduction In this work, we propose a new method

NVIDIA Research Projects 132 Dec 13, 2022
The 2nd place solution of 2021 google landmark retrieval on kaggle.

Google_Landmark_Retrieval_2021_2nd_Place_Solution The 2nd place solution of 2021 google landmark retrieval on kaggle. Environment We use cuda 11.1/pyt

229 Dec 13, 2022
The Incredible PyTorch: a curated list of tutorials, papers, projects, communities and more relating to PyTorch.

This is a curated list of tutorials, projects, libraries, videos, papers, books and anything related to the incredible PyTorch. Feel free to make a pu

Ritchie Ng 9.2k Jan 02, 2023
DvD-TD3: Diversity via Determinants for TD3 version

DvD-TD3: Diversity via Determinants for TD3 version The implementation of paper Effective Diversity in Population Based Reinforcement Learning. Instal

3 Feb 11, 2022
MPI Interest Group on Algorithms on 1st semester 2021

MPI Algorithms Interest Group Introduction Lecturer: Steve Yan Location: TBA Time Schedule: TBA Semester: 1 Useful URLs Typora: https://typora.io Goog

Ex10si0n 13 Sep 08, 2022
Automatically measure the facial Width-To-Height ratio and get facial analysis results provided by Microsoft Azure

fwhr-calc-website This project is to automatically measure the facial Width-To-Height ratio and get facial analysis results provided by Microsoft Azur

SoohyunPark 1 Feb 07, 2022
Learning multiple gaits of quadruped robot using hierarchical reinforcement learning

Learning multiple gaits of quadruped robot using hierarchical reinforcement learning We propose a method to learn multiple gaits of quadruped robot us

Yunho Kim 17 Dec 11, 2022
Auto grind btdb2 exp for tower

Bloons TD Battles 2 EXP Grinder Auto grind btdb2 exp for towers Setup I suggest checking out every screenshot to see what they are supposed to be, so

Vincent 6 Jul 29, 2022
Hardware-accelerated DNN model inference ROS2 packages using NVIDIA Triton/TensorRT for both Jetson and x86_64 with CUDA-capable GPU

Isaac ROS DNN Inference Overview This repository provides two NVIDIA GPU-accelerated ROS2 nodes that perform deep learning inference using custom mode

NVIDIA Isaac ROS 62 Dec 14, 2022
Trainable Bilateral Filter Layer (PyTorch)

Trainable Bilateral Filter Layer (PyTorch) This repository contains our GPU-accelerated trainable bilateral filter layer (three spatial and one range

FabianWagner 26 Dec 25, 2022
FLAVR is a fast, flow-free frame interpolation method capable of single shot multi-frame prediction

FLAVR is a fast, flow-free frame interpolation method capable of single shot multi-frame prediction. It uses a customized encoder decoder architecture with spatio-temporal convolutions and channel ga

Tarun K 280 Dec 23, 2022
Scripts and misc. stuff related to the PortSwigger Web Academy

PortSwigger Web Academy Notes Mostly scripts to automate the exploits. Going in the order of the recomended learning path - starting with SQLi. Commun

pageinsec 17 Dec 30, 2022
Pytorch reimplementation of PSM-Net: "Pyramid Stereo Matching Network"

This is a Pytorch Lightning version PSMNet which is based on JiaRenChang/PSMNet. use python main.py to start training. PSM-Net Pytorch reimplementatio

XIAOTIAN LIU 1 Nov 25, 2021
EfficientNetV2-with-TPU - Cifar-10 case study

EfficientNetV2-with-TPU EfficientNet EfficientNetV2 adalah jenis jaringan saraf convolutional yang memiliki kecepatan pelatihan lebih cepat dan efisie

Sultan syach 1 Dec 28, 2021
Code for Overinterpretation paper Overinterpretation reveals image classification model pathologies

Overinterpretation This repository contains the code for the paper: Overinterpretation reveals image classification model pathologies Authors: Brandon

Gifford Lab, MIT CSAIL 17 Dec 10, 2022
List some popular DeepFake models e.g. DeepFake, FaceSwap-MarekKowal, IPGAN, FaceShifter, FaceSwap-Nirkin, FSGAN, SimSwap, CihaNet, etc.

deepfake-models List some popular DeepFake models e.g. DeepFake, CihaNet, SimSwap, FaceSwap-MarekKowal, IPGAN, FaceShifter, FaceSwap-Nirkin, FSGAN, Si

Mingcan Xiang 100 Dec 17, 2022
Official Implementation of "Learning Disentangled Behavior Embeddings"

DBE: Disentangled-Behavior-Embedding Official implementation of Learning Disentangled Behavior Embeddings (NeurIPS 2021). Environment requirement The

Mishne Lab 12 Sep 28, 2022
Graph Posterior Network: Bayesian Predictive Uncertainty for Node Classification (NeurIPS 2021)

Graph Posterior Network This is the official code repository to the paper Graph Posterior Network: Bayesian Predictive Uncertainty for Node Classifica

Maximilian Stadler 30 Dec 05, 2022
Code for the SIGIR 2022 paper "Hybrid Transformer with Multi-level Fusion for Multimodal Knowledge Graph Completion"

MKGFormer Code for the SIGIR 2022 paper "Hybrid Transformer with Multi-level Fusion for Multimodal Knowledge Graph Completion" Model Architecture Illu

ZJUNLP 68 Dec 28, 2022
ImageNet Adversarial Image Evaluation

ImageNet Adversarial Image Evaluation This repository contains the code and some materials used in the experimental work presented in the following pa

Utku Ozbulak 11 Dec 26, 2022