Contrastive Learning of Structured World Models

Related tags

Deep Learningc-swm
Overview

Contrastive Learning of Structured World Models

This repository contains the official PyTorch implementation of:

Contrastive Learning of Structured World Models.
Thomas Kipf, Elise van der Pol, Max Welling.
http://arxiv.org/abs/1911.12247

C-SWM

Abstract: A structured understanding of our world in terms of objects, relations, and hierarchies is an important component of human cognition. Learning such a structured world model from raw sensory data remains a challenge. As a step towards this goal, we introduce Contrastively-trained Structured World Models (C-SWMs). C-SWMs utilize a contrastive approach for representation learning in environments with compositional structure. We structure each state embedding as a set of object representations and their relations, modeled by a graph neural network. This allows objects to be discovered from raw pixel observations without direct supervision as part of the learning process. We evaluate C-SWMs on compositional environments involving multiple interacting objects that can be manipulated independently by an agent, simple Atari games, and a multi-object physics simulation. Our experiments demonstrate that C-SWMs can overcome limitations of models based on pixel reconstruction and outperform typical representatives of this model class in highly structured environments, while learning interpretable object-based representations.

Requirements

  • Python 3.6 or 3.7
  • PyTorch version 1.2
  • OpenAI Gym version: 0.12.0 pip install gym==0.12.0
  • OpenAI Atari_py version: 0.1.4: pip install atari-py==0.1.4
  • Scikit-image version 0.15.0 pip install scikit-image==0.15.0
  • Matplotlib version 3.0.2 pip install matplotlib==3.0.2

Generate datasets

2D Shapes:

python data_gen/env.py --env_id ShapesTrain-v0 --fname data/shapes_train.h5 --num_episodes 1000 --seed 1
python data_gen/env.py --env_id ShapesEval-v0 --fname data/shapes_eval.h5 --num_episodes 10000 --seed 2

3D Cubes:

python data_gen/env.py --env_id CubesTrain-v0 --fname data/cubes_train.h5 --num_episodes 1000 --seed 3
python data_gen/env.py --env_id CubesEval-v0 --fname data/cubes_eval.h5 --num_episodes 10000 --seed 4

Atari Pong:

python data_gen/env.py --env_id PongDeterministic-v4 --fname data/pong_train.h5 --num_episodes 1000 --atari --seed 1
python data_gen/env.py --env_id PongDeterministic-v4 --fname data/pong_eval.h5 --num_episodes 100 --atari --seed 2

Space Invaders:

python data_gen/env.py --env_id SpaceInvadersDeterministic-v4 --fname data/spaceinvaders_train.h5 --num_episodes 1000 --atari --seed 1
python data_gen/env.py --env_id SpaceInvadersDeterministic-v4 --fname data/spaceinvaders_eval.h5 --num_episodes 100 --atari --seed 2

3-Body Gravitational Physics:

python data_gen/physics.py --num-episodes 5000 --fname data/balls_train.h5 --seed 1
python data_gen/physics.py --num-episodes 1000 --fname data/balls_eval.h5 --eval --seed 2

Run model training and evaluation

2D Shapes:

python train.py --dataset data/shapes_train.h5 --encoder small --name shapes
python eval.py --dataset data/shapes_eval.h5 --save-folder checkpoints/shapes --num-steps 1

3D Cubes:

python train.py --dataset data/cubes_train.h5 --encoder large --name cubes
python eval.py --dataset data/cubes_eval.h5 --save-folder checkpoints/cubes --num-steps 1

Atari Pong:

python train.py --dataset data/pong_train.h5 --encoder medium --embedding-dim 4 --action-dim 6 --num-objects 3 --copy-action --epochs 200 --name pong
python eval.py --dataset data/pong_eval.h5 --save-folder checkpoints/pong --num-steps 1

Space Invaders:

python train.py --dataset data/spaceinvaders_train.h5 --encoder medium --embedding-dim 4 --action-dim 6 --num-objects 3 --copy-action --epochs 200 --name spaceinvaders
python eval.py --dataset data/spaceinvaders_eval.h5 --save-folder checkpoints/spaceinvaders --num-steps 1

3-Body Gravitational Physics:

python train.py --dataset data/balls_train.h5 --encoder medium --embedding-dim 4 --num-objects 3 --ignore-action --name balls
python eval.py --dataset data/balls_eval.h5 --save-folder checkpoints/balls --num-steps 1

Cite

If you make use of this code in your own work, please cite our paper:

@article{kipf2019contrastive,
  title={Contrastive Learning of Structured World Models}, 
  author={Kipf, Thomas and van der Pol, Elise and Welling, Max}, 
  journal={arXiv preprint arXiv:1911.12247}, 
  year={2019} 
}
Owner
Thomas Kipf
Thomas Kipf
[ICCV 2021] A Simple Baseline for Semi-supervised Semantic Segmentation with Strong Data Augmentation

[ICCV 2021] A Simple Baseline for Semi-supervised Semantic Segmentation with Strong Data Augmentation

CodingMan 45 Dec 12, 2022
Using Streamlit to host a multi-page tool with model specs and classification metrics, while also accepting user input values for prediction.

Predicitng_viability Using Streamlit to host a multi-page tool with model specs and classification metrics, while also accepting user input values for

Gopalika Sharma 1 Nov 08, 2021
AI Flow is an open source framework that bridges big data and artificial intelligence.

Flink AI Flow Introduction Flink AI Flow is an open source framework that bridges big data and artificial intelligence. It manages the entire machine

144 Dec 30, 2022
[NeurIPS 2021] Better Safe Than Sorry: Preventing Delusive Adversaries with Adversarial Training

Better Safe Than Sorry: Preventing Delusive Adversaries with Adversarial Training Code for NeurIPS 2021 paper "Better Safe Than Sorry: Preventing Delu

Lue Tao 29 Sep 20, 2022
Structure Information is the Key: Self-Attention RoI Feature Extractor in 3D Object Detection

Structure Information is the Key: Self-Attention RoI Feature Extractor in 3D Object Detection abstract:Unlike 2D object detection where all RoI featur

DK. Zhang 2 Oct 07, 2022
MultiSiam: Self-supervised Multi-instance Siamese Representation Learning for Autonomous Driving

MultiSiam: Self-supervised Multi-instance Siamese Representation Learning for Autonomous Driving Code will be available soon. Motivation Architecture

Kai Chen 24 Apr 19, 2022
Foreground-Action Consistency Network for Weakly Supervised Temporal Action Localization

FAC-Net Foreground-Action Consistency Network for Weakly Supervised Temporal Action Localization Linjiang Huang (CUHK), Liang Wang (CASIA), Hongsheng

21 Nov 22, 2022
Enabling Lightweight Fine-tuning for Pre-trained Language Model Compression based on Matrix Product Operators

Enabling Lightweight Fine-tuning for Pre-trained Language Model Compression based on Matrix Product Operators This is our Pytorch implementation for t

RUCAIBox 12 Jul 22, 2022
[ICLR 2021] Heteroskedastic and Imbalanced Deep Learning with Adaptive Regularization

Heteroskedastic and Imbalanced Deep Learning with Adaptive Regularization Kaidi Cao, Yining Chen, Junwei Lu, Nikos Arechiga, Adrien Gaidon, Tengyu Ma

Kaidi Cao 29 Oct 20, 2022
Pytorch implementation of SimSiam Architecture

SimSiam-pytorch A simple pytorch implementation of Exploring Simple Siamese Representation Learning which is developed by Facebook AI Research (FAIR)

Saeed Shurrab 1 Oct 20, 2021
Cryptocurrency Prediction with Artificial Intelligence (Deep Learning via LSTM Neural Networks)

Cryptocurrency Prediction with Artificial Intelligence (Deep Learning via LSTM Neural Networks)- Emirhan BULUT

Emirhan BULUT 102 Nov 18, 2022
Development of IP code based on VIPs and AADM

Sparse Implicit Processes In this repository we include the two different versions of the SIP code developed for the article Sparse Implicit Processes

1 Aug 22, 2022
Multiview 3D object detection on MultiviewC dataset through moft3d.

Multiview Orthographic Feature Transformation for 3D Object Detection Multiview 3D object detection on MultiviewC dataset through moft3d. Introduction

Jiahao Ma 20 Dec 21, 2022
PyTorch implementation of D2C: Diffuison-Decoding Models for Few-shot Conditional Generation.

D2C: Diffuison-Decoding Models for Few-shot Conditional Generation Project | Paper PyTorch implementation of D2C: Diffuison-Decoding Models for Few-sh

Jiaming Song 90 Dec 27, 2022
Video Corpus Moment Retrieval with Contrastive Learning (SIGIR 2021)

Video Corpus Moment Retrieval with Contrastive Learning PyTorch implementation for the paper "Video Corpus Moment Retrieval with Contrastive Learning"

ZHANG HAO 42 Dec 29, 2022
Repositório criado para abrigar os notebooks com a listas de exercícios propostos pelo professor Gustavo Guanabara do canal Curso em Vídeo do YouTube durante o Curso de Python 3

Curso em Vídeo - Exercícios de Python 3 Sobre o repositório Este repositório contém os notebooks com a listas de exercícios propostos pelo professor G

João Pedro Pereira 9 Oct 15, 2022
Autonomous Perception: 3D Object Detection with Complex-YOLO

Autonomous Perception: 3D Object Detection with Complex-YOLO LiDAR object detect

Thomas Dunlap 2 Feb 18, 2022
Official repository for "Exploiting Session Information in BERT-based Session-aware Sequential Recommendation", SIGIR 2022 short.

Session-aware BERT4Rec Official repository for "Exploiting Session Information in BERT-based Session-aware Sequential Recommendation", SIGIR 2022 shor

Jamie J. Seol 22 Dec 13, 2022
Torchserve server using a YoloV5 model running on docker with GPU and static batch inference to perform production ready inference.

Yolov5 running on TorchServe (GPU compatible) ! This is a dockerfile to run TorchServe for Yolo v5 object detection model. (TorchServe (PyTorch librar

82 Nov 29, 2022
Informer: Beyond Efficient Transformer for Long Sequence Time-Series Forecasting

Informer: Beyond Efficient Transformer for Long Sequence Time-Series Forecasting This is the origin Pytorch implementation of Informer in the followin

Haoyi 3.1k Dec 29, 2022