Code for "Optimizing risk-based breast cancer screening policies with reinforcement learning"

Related tags

Deep LearningTempo
Overview

Tempo: Optimizing risk-based breast cancer screening policies with reinforcement learning DOI

Introduction

This repository was used to develop Tempo, as described in: Optimizing risk-based breast cancer screening policies with reinforcement learning.

Screening programs must balance the benefits of early detection against the costs of over screening. Here, we introduce a novel reinforcement learning-based framework for personalized screening, Tempo, and demonstrate its efficacy in the context of breast cancer. We trained our risk-based screening policies on a large screening mammography dataset from Massachusetts General Hospital (MGH) USA and validated them on held-out patients from MGH, and on external datasets from Emory USA, Karolinska Sweden and Chang Gung Memorial Hospital (CGMH) Taiwan. Across all test sets, we found that a Tempo policy combined with an image-based AI risk model, Mirai [1] was significantly more efficient than current regimes used in clinical practice in terms of simulated early detection per screen frequency. Moreover, we showed that the same Tempo policy can be easily adapted to a wide range of possible screening preferences, allowing clinicians to select their desired early detection to screening cost trade-off without training new policies. Finally, we demonstrated Tempo policies based on AI-based risk models out performed Tempo policies based on less accurate clinical risk models. Altogether, our results show that pairing AI-based risk models with agile AI-designed screening policies has the potential to improve screening programs, advancing early detection while reducing over-screening.

This code base is meant to provide exact implementation details for the development of Tempo.

Aside on Software Depedencies

This code assumes python3.6 and a Linux environment. The package requirements can be install with pip:

pip install -r requirements.txt

Tempo-Mirai assumes access to Mirai risk assessments. Resources for using Mirai are shown here.

Method

method

Our full framework, named Tempo, is depicted above. As described above, we first train a risk progression neural network to predict future risk assessments given previous assessments. This model is then used to estimate patient risk at unobserved timepoints and it enables us to simulate risk-based screening policies. Next, we train our screening policy, which is implemented as a neural network, to maximize the reward (i.e combination of early detection and screening cost) on our retrospective training set. We train our screening policy to support all possible early detection vs screening cost trade-offs using envelope Q-learning [2], an RL algorithm designed to balance multiple objectives. The input of our screening policies is the patient's risk assessment, and desired weighting between rewards (i.e screening preference). The output of the policy is a recommendation for when to return for the next screen, ranging from six months to three years in the future, in multiples of six months. Our reward balances two contrasting aspects, one reflecting the imaging cost, i.e., the average mammograms a year recommended by the policy, and one modeling early detection benefit relative to the retrospective screening trajectory. Our early detection reward measures the time difference in months between each patient's recommended screening date, if it was after their last negative mammogram, and their actual diagnosis date. We evaluate screening policies by simulating their recommendations for heldout patients.

Training Risk progression models

We experimented with different learning rates, hidden sizes, numbers of layers and dropout, and chose the model that obtained the lowest validation KL divergence on the MGH validation set. Our final risk progression RNN had two layers, a hidden dimension size of 100, a dropout of 0.25, and was trained for 30 epochs with a learning rate of 1e-3 using the Adam optimizer.

To reproduce our grid search for our Mirai risk progression model, you can run:

python scripts/dispatcher.py --experiment_config_path configs/risk_progression/gru.json

Given a trained risk progression model, we can now estimate unobserved risk assessments auto-regressively. At each time step, the model takes as input the previous risk assessment, the prior hidden state, using the previous predicted assessment if the real one is not available, and predicts the risk assessment at the next time step.

Training Tempo Personalized Screening Policies

We implemented our personalized screening policy as multiple layer perceptron, which took as input a risk assessment and weighting between rewards and predicted the Q-value for each action, i.e follow up recommendation, across the rewards. This network was trained using Envelope Q-Learning [2]. We experimented with different numbers of layers, hidden dimension sizes, learning rates, dropouts, exploration epsilons, target network reset rates and weight decay rates.

To reproduce our grid search for our Mirai risk progression model, you can run:

python scripts/dispatcher.py --experiment_config_path configs/screening/neural.json

Data availability

All datasets were used under license to the respective hospital system for the current study and are not publicly available. To access the MGH dataset, investigators should reach out to C.L. to apply for an IRB approved research collaboration and obtain an appropriate Data Use Agreement. To access the Karolinska dataset, investigators should reach out to F.S. to apply for an approved research collaboration and sign a Data Use Agreement. To access the CGMH dataset, investigators should contact G.L. to apply for an IRB approved research collaboration. To access the Emory dataset, investigators should reach out to H.T to apply for an approved collaboration.

References

[1] Yala, Adam, et al. "Toward robust mammography-based models for breast cancer risk." Science Translational Medicine 13.578 (2021).

[2] Yang, Runzhe, Xingyuan Sun, and Karthik Narasimhan. "A generalized algorithm for multi-objective reinforcement learning and policy adaptation." arXiv preprint arXiv:1908.08342 (2019).

Citing Tempo

@article{yala2021optimizing,
  title={Optimizing risk-based breast cancer screening policies with reinforcement learning},
  author={Yala, Adam and Mikhael, Peter and Lehman, Constance and Lin, Gigin and Strand, Fredrik and Wang, Yung-Liang and Hughes, Kevin and Satuluru, Siddharth and Kim, Thomas and Banerjee, Imon and others},
  year={2021}
}
You might also like...
Opinionated code formatter, just like Python's black code formatter but for Beancount

beancount-black Opinionated code formatter, just like Python's black code formatter but for Beancount Try it out online here Features MIT licensed - b

a delightful machine learning tool that allows you to train, test and use models without writing code
a delightful machine learning tool that allows you to train, test and use models without writing code

igel A delightful machine learning tool that allows you to train/fit, test and use models without writing code Note I'm also working on a GUI desktop

Pytorch Lightning code guideline for conferences

Deep learning project seed Use this seed to start new deep learning / ML projects. Built in setup.py Built in requirements Examples with MNIST Badges

Automatically Build Multiple ML Models with a Single Line of Code. Created by Ram Seshadri. Collaborators Welcome. Permission Granted upon Request.
Automatically Build Multiple ML Models with a Single Line of Code. Created by Ram Seshadri. Collaborators Welcome. Permission Granted upon Request.

Auto-ViML Automatically Build Variant Interpretable ML models fast! Auto_ViML is pronounced "auto vimal" (autovimal logo created by Sanket Ghanmare) N

Code samples for my book "Neural Networks and Deep Learning"

Code samples for "Neural Networks and Deep Learning" This repository contains code samples for my book on "Neural Networks and Deep Learning". The cod

Code for: https://berkeleyautomation.github.io/bags/

DeformableRavens Code for the paper Learning to Rearrange Deformable Cables, Fabrics, and Bags with Goal-Conditioned Transporter Networks. Here is the

Code for our method RePRI for Few-Shot Segmentation. Paper at http://arxiv.org/abs/2012.06166
Code for our method RePRI for Few-Shot Segmentation. Paper at http://arxiv.org/abs/2012.06166

Region Proportion Regularized Inference (RePRI) for Few-Shot Segmentation In this repo, we provide the code for our paper : "Few-Shot Segmentation Wit

Applications using the GTN library and code to reproduce experiments in "Differentiable Weighted Finite-State Transducers"

gtn_applications An applications library using GTN. Current examples include: Offline handwriting recognition Automatic speech recognition Installing

Code for
Code for "Contextual Non-Local Alignment over Full-Scale Representation for Text-Based Person Search"

Contextual Non-Local Alignment over Full-Scale Representation for Text-Based Person Search This is an implementation for our paper Contextual Non-Loca

Releases(v1.0)
Owner
Adam Yala
PhD Candidate at MIT CSAIL
Adam Yala
catch-22: CAnonical Time-series CHaracteristics

catch22 - CAnonical Time-series CHaracteristics About catch22 is a collection of 22 time-series features coded in C that can be run from Python, R, Ma

Carl H Lubba 229 Oct 21, 2022
NLU Dataset Diagnostics

NLU Dataset Diagnostics This repository contains data and scripts to reproduce the results from our paper: Aarne Talman, Marianna Apidianaki, Stergios

Language Technology at the University of Helsinki 1 Jul 20, 2022
以孤立语假设和宽度优先搜索为基础,构建了一种多通道堆叠注意力Transformer结构的斗地主ai

ddz-ai 介绍 斗地主是一种扑克游戏。游戏最少由3个玩家进行,用一副54张牌(连鬼牌),其中一方为地主,其余两家为另一方,双方对战,先出完牌的一方获胜。 ddz-ai以孤立语假设和宽度优先搜索为基础,构建了一种多通道堆叠注意力Transformer结构的系统,使其经过大量训练后,能在实际游戏中获

freefuiiismyname 88 May 15, 2022
SpanNER: Named EntityRe-/Recognition as Span Prediction

SpanNER: Named EntityRe-/Recognition as Span Prediction Overview | Demo | Installation | Preprocessing | Prepare Models | Running | System Combination

NeuLab 104 Dec 17, 2022
Implementation of Online Label Smoothing in PyTorch

Online Label Smoothing Pytorch implementation of Online Label Smoothing (OLS) presented in Delving Deep into Label Smoothing. Introduction As the abst

83 Dec 14, 2022
Dynamic vae - Dynamic VAE algorithm is used for anomaly detection of battery data

Dynamic VAE frame Automatic feature extraction can be achieved by probability di

10 Oct 07, 2022
Interpretable-contrastive-word-mover-s-embedding

Interpretable-contrastive-word-mover-s-embedding Paper Datasets Here is a Dropbox link to the datasets used in the paper: https://www.dropbox.com/sh/n

0 Nov 02, 2021
GNN-based Recommendation Benchma

GRecX A Fair Benchmark for GNN-based Recommendation Preliminary Comparison DiffNet-Yelp dataset (featureless) Algo 73 Oct 17, 2022

This repository provides the official code for GeNER (an automated dataset Generation framework for NER).

GeNER This repository provides the official code for GeNER (an automated dataset Generation framework for NER). Overview of GeNER GeNER allows you to

DMIS Laboratory - Korea University 50 Nov 30, 2022
A Next Generation ConvNet by FaceBookResearch Implementation in PyTorch(Original) and TensorFlow.

ConvNeXt A Next Generation ConvNet by FaceBookResearch Implementation in PyTorch(Original) and TensorFlow. A FacebookResearch Implementation on A Conv

Raghvender 2 Feb 14, 2022
中文语音识别系列,读者可以借助它快速训练属于自己的中文语音识别模型,或直接使用预训练模型测试效果。

MASR中文语音识别(pytorch版) 开箱即用 自行训练 使用与训练分离(增量训练) 识别率高 说明:因为每个人电脑机器不同,而且有些安装包安装起来比较麻烦,强烈建议直接用我编译好的docker环境跑 目前docker基础环境为ubuntu-cuda10.1-cudnn7-pytorch1.6.

发送小信号 180 Dec 17, 2022
Towards Open-World Feature Extrapolation: An Inductive Graph Learning Approach

This repository holds the implementation for paper Towards Open-World Feature Extrapolation: An Inductive Graph Learning Approach Download our preproc

Qitian Wu 42 Dec 27, 2022
Just playing with getting CLIP Guided Diffusion running locally, rather than having to use colab.

CLIP-Guided-Diffusion Just playing with getting CLIP Guided Diffusion running locally, rather than having to use colab. Original colab notebooks by Ka

Nerdy Rodent 336 Dec 09, 2022
The 1st place solution of track2 (Vehicle Re-Identification) in the NVIDIA AI City Challenge at CVPR 2021 Workshop.

AICITY2021_Track2_DMT The 1st place solution of track2 (Vehicle Re-Identification) in the NVIDIA AI City Challenge at CVPR 2021 Workshop. Introduction

Hao Luo 91 Dec 21, 2022
Deep Learning for Time Series Forecasting.

nixtlats:Deep Learning for Time Series Forecasting [nikstla] (noun, nahuatl) Period of time. State-of-the-art time series forecasting for pytorch. Nix

Nixtla 5 Dec 06, 2022
THIS IS THE **OLD** PYMC PROJECT. PLEASE USE PYMC3 INSTEAD:

Introduction Version: 2.3.8 Authors: Chris Fonnesbeck Anand Patil David Huard John Salvatier Web site: https://github.com/pymc-devs/pymc Documentation

PyMC 7.2k Jan 07, 2023
This is a Tensorflow implementation of Learning to See in the Dark in CVPR 2018

Learning-to-See-in-the-Dark This is a Tensorflow implementation of Learning to See in the Dark in CVPR 2018, by Chen Chen, Qifeng Chen, Jia Xu, and Vl

5.3k Jan 01, 2023
This project is for a Twitter bot that monitors a bird feeder in my backyard. Any detected birds are identified and posted to Twitter.

Backyard Birdbot Introduction This is a silly hobby project to use existing ML models to: Detect any birds sighted by a webcam Identify whic

Chi Young Moon 71 Dec 25, 2022
Personal thermal comfort models using digital twins: Preference prediction with BIM-extracted spatial-temporal proximity data from Build2Vec

Personal thermal comfort models using digital twins: Preference prediction with BIM-extracted spatial-temporal proximity data from Build2Vec This repo

Building and Urban Data Science (BUDS) Group 5 Dec 02, 2022
Semantic Image Synthesis with SPADE

Semantic Image Synthesis with SPADE New implementation available at imaginaire repository We have a reimplementation of the SPADE method that is more

NVIDIA Research Projects 7.3k Jan 07, 2023