Official implementation for NIPS'17 paper: PredRNN: Recurrent Neural Networks for Predictive Learning Using Spatiotemporal LSTMs.

Overview

PredRNN: A Recurrent Neural Network for Spatiotemporal Predictive Learning

The predictive learning of spatiotemporal sequences aims to generate future images by learning from the historical context, where the visual dynamics are believed to have modular structures that can be learned with compositional subsystems.

First version at NeurIPS 2017

This repo first contains a PyTorch implementation of PredRNN (2017) [paper], a recurrent network with a pair of memory cells that operate in nearly independent transition manners, and finally form unified representations of the complex environment.

Concretely, besides the original memory cell of LSTM, this network is featured by a zigzag memory flow that propagates in both bottom-up and top-down directions across all layers, enabling the learned visual dynamics at different levels of RNNs to communicate.

New in PredRNN-V2 (2021)

This repo also includes the implementation of PredRNN-V2 (2021) [paper], which improves PredRNN in the following two aspects.

1. Memory Decoupling

We find that the pair of memory cells in PredRNN contain undesirable, redundant features, and thus present a memory decoupling loss to encourage them to learn modular structures of visual dynamics.

decouple

2. Reverse Scheduled Sampling

Reverse scheduled sampling is a new curriculum learning strategy for seq-to-seq RNNs. As opposed to scheduled sampling, it gradually changes the training process of the PredRNN encoder from using the previously generated frame to using the previous ground truth. Benefits: (1) It makes the training converge quickly by reducing the encoder-forecaster training gap. (2) It enforces the model to learn more from long-term input context.

rss

Evaluation in LPIPS

LPIPS is more sensitive to perceptual human judgments, the lower the better.

Moving MNIST KTH action
PredRNN 0.109 0.204
PredRNN-V2 0.071 0.139

Prediction examples

mnist

kth

radar

Get Started

  1. Install Python 3.7, PyTorch 1.3, and OpenCV 3.4.
  2. Download data. This repo contains code for two datasets: the Moving Mnist dataset and the KTH action dataset.
  3. Train the model. You can use the following bash script to train the model. The learned model will be saved in the --save_dir folder. The generated future frames will be saved in the --gen_frm_dir folder.
  4. You can get pretrained models from here.
cd mnist_script/
sh predrnn_mnist_train.sh
sh predrnn_v2_mnist_train.sh

cd kth_script/
sh predrnn_kth_train.sh
sh predrnn_v2_kth_train.sh

Citation

If you find this repo useful, please cite the following papers.

@inproceedings{wang2017predrnn,
  title={{PredRNN}: Recurrent Neural Networks for Predictive Learning Using Spatiotemporal {LSTM}s},
  author={Wang, Yunbo and Long, Mingsheng and Wang, Jianmin and Gao, Zhifeng and Yu, Philip S},
  booktitle={Advances in Neural Information Processing Systems},
  pages={879--888},
  year={2017}
}

@misc{wang2021predrnn,
      title={{PredRNN}: A Recurrent Neural Network for Spatiotemporal Predictive Learning}, 
      author={Wang, Yunbo and Wu, Haixu and Zhang, Jianjin and Gao, Zhifeng and Wang, Jianmin and Yu, Philip S and Long, Mingsheng},
      year={2021},
      eprint={2103.09504},
      archivePrefix={arXiv},
}
Owner
THUML: Machine Learning Group @ THSS
Machine Learning Group, School of Software, Tsinghua University
THUML: Machine Learning Group @ THSS
Imitating Deep Learning Dynamics via Locally Elastic Stochastic Differential Equations

Imitating Deep Learning Dynamics via Locally Elastic Stochastic Differential Equations This repo contains official code for the NeurIPS 2021 paper Imi

Jiayao Zhang 2 Oct 18, 2021
Implementation for Learning to Track with Object Permanence

Learning to Track with Object Permanence A video-based MOT approach capable of tracking through full occlusions: Learning to Track with Object Permane

Toyota Research Institute - Machine Learning 91 Jan 03, 2023
Repository for "Space-Time Correspondence as a Contrastive Random Walk" (NeurIPS 2020)

Space-Time Correspondence as a Contrastive Random Walk This is the repository for Space-Time Correspondence as a Contrastive Random Walk, published at

A. Jabri 239 Dec 27, 2022
Pytorch and Keras Implementations of Hyperspectral Image Classification -- Traditional to Deep Models: A Survey for Future Prospects.

The repository contains the implementations for Hyperspectral Image Classification -- Traditional to Deep Models: A Survey for Future Prospects. Model

Ankur Deria 115 Jan 06, 2023
GuideDog is an AI/ML-based mobile app designed to assist the lives of the visually impaired, 100% voice-controlled

Guidedog Authors: Kyuhee Jo, Steven Gunarso, Jacky Wang, Raghav Sharma GuideDog is an AI/ML-based mobile app designed to assist the lives of the visua

Kyuhee Jo 5 Nov 24, 2021
Frigate - NVR With Realtime Object Detection for IP Cameras

A complete and local NVR designed for HomeAssistant with AI object detection. Uses OpenCV and Tensorflow to perform realtime object detection locally for IP cameras.

Blake Blackshear 6.4k Dec 31, 2022
BERT model training impelmentation using 1024 A100 GPUs for MLPerf Training v1.1

Pre-trained checkpoint and bert config json file Location of checkpoint and bert config json file This MLCommons members Google Drive location contain

SAIT (Samsung Advanced Institute of Technology) 12 Apr 27, 2022
Code repository for EMNLP 2021 paper 'Adversarial Attacks on Knowledge Graph Embeddings via Instance Attribution Methods'

Adversarial Attacks on Knowledge Graph Embeddings via Instance Attribution Methods This is the code repository to accompany the EMNLP 2021 paper on ad

Peru Bhardwaj 7 Sep 25, 2022
A project for developing transformer-based models for clinical relation extraction

Clinical Relation Extration with Transformers Aim This package is developed for researchers easily to use state-of-the-art transformers models for ext

uf-hobi-informatics-lab 101 Dec 19, 2022
Fully Convolutional DenseNets for semantic segmentation.

Introduction This repo contains the code to train and evaluate FC-DenseNets as described in The One Hundred Layers Tiramisu: Fully Convolutional Dense

485 Nov 26, 2022
Acoustic mosquito detection code with Bayesian Neural Networks

HumBugDB Acoustic mosquito detection with Bayesian Neural Networks. Extract audio or features from our large-scale dataset on Zenodo. This repository

31 Nov 28, 2022
Tensorflow implementation of our method: "Triangle Graph Interest Network for Click-through Rate Prediction".

TGIN Tensorflow implementation of our method: "Triangle Graph Interest Network for Click-through Rate Prediction". Files in the folder dataset/ electr

Alibaba 21 Dec 21, 2022
ACAV100M: Automatic Curation of Large-Scale Datasets for Audio-Visual Video Representation Learning. In ICCV, 2021.

ACAV100M: Automatic Curation of Large-Scale Datasets for Audio-Visual Video Representation Learning This repository contains the code for our ICCV 202

sangho.lee 28 Nov 08, 2022
π-GAN: Periodic Implicit Generative Adversarial Networks for 3D-Aware Image Synthesis

π-GAN: Periodic Implicit Generative Adversarial Networks for 3D-Aware Image Synthesis Project Page | Paper | Data Eric Ryan Chan*, Marco Monteiro*, Pe

375 Dec 31, 2022
PINN(s): Physics-Informed Neural Network(s) for von Karman vortex street

PINN(s): Physics-Informed Neural Network(s) for von Karman vortex street This is

ShotaDEGUCHI 2 Apr 18, 2022
A Comparative Framework for Multimodal Recommender Systems

Cornac Cornac is a comparative framework for multimodal recommender systems. It focuses on making it convenient to work with models leveraging auxilia

Preferred.AI 671 Jan 03, 2023
Parameterising Simulated Annealing for the Travelling Salesman Problem

Parameterising Simulated Annealing for the Travelling Salesman Problem

Gary Sun 55 Jun 15, 2022
Provide partial dates and retain the date precision through processing

Prefix date parser This is a helper class to parse dates with varied degrees of precision. For example, a data source might state a date as 2001, 2001

Friedrich Lindenberg 13 Dec 14, 2022
PointCloud Annotation Tools, support to label object bound box, ground, lane and kerb

PointCloud Annotation Tools, support to label object bound box, ground, lane and kerb

halo 368 Dec 06, 2022
Implementation of a protein autoregressive language model, but with autoregressive infilling objective (editing subsequences capability)

Protein GLM (wip) Implementation of a protein autoregressive language model, but with autoregressive infilling objective (editing subsequences capabil

Phil Wang 17 May 06, 2022