Official repo for "Solving Inverse Problems in Medical Imaging with Score-Based Generative Models"

Overview

Solving Inverse Problems in Medical Imaging with Score-Based Generative Models

This repo contains the JAX code for experiments in the paper Solving Inverse Problems in Medical Imaging with Score-Based Generative Models

by Yang Song*, Liyue Shen*, Lei Xing, and Stefano Ermon. (*= joint first authors)


We propose a general approach to solving linear inverse problems in medical imaging with score-based generative models. Our method is purely generative, therefore does not require knowing the physical measurement process during training, and can be quickly adapted to different imaging processes at test time without model re-training. We have demonstrated superior performance on sparse-view computed tomography (CT), magnetic resonance imaging (MRI), and metal artifact removal (MAR) in CT imaging.

Dependencies

See requirements.txt.

Usage

Train and evaluate our models through main.py.

main.py:
  --config: Training configuration.
    (default: 'None')
  --eval_folder: The folder name for storing evaluation results
    (default: 'eval')
  --mode: <train|eval|tune>: Running mode: train or eval or tune
  --workdir: Working directory
  • config is the path to the config file. Our prescribed config files are provided in configs/. They are formatted according to ml_collections and should be mostly self-explanatory. sampling.cs_solver specifies which sampling method we use for solving the inverse problems. They have 4 possible values:

  • workdir is the path that stores all artifacts of one experiment, like checkpoints, samples, and evaluation results.

  • eval_folder is the name of a subfolder in workdir that stores all artifacts of the evaluation process, like meta checkpoints for pre-emption prevention, image samples, and numpy dumps of quantitative results.

  • mode is "train", "eval", or "tune". When set to "train", it starts the training of a new model, or resumes the training of an old model if its meta-checkpoints (for resuming running after pre-emption in a cloud environment) exist in workdir/checkpoints-meta . When set to "eval", it computes the PSNR/SSIM metrics on a test dataset. When set to "tune", it automatically tunes hyperparameters for the sampler with Bayesian optimization.

Pretrained checkpoints

Checkpoints and test data are provided in this Google drive. Please download the folder and move it to the same directory of this repo.

References

If you find the code useful for your research, please consider citing

@inproceedings{
  song2022solving,
  title={Solving Inverse Problems in Medical Imaging with Score-Based Generative Models},
  author={Yang Song and Liyue Shen and Lei Xing and Stefano Ermon},
  booktitle={International Conference on Learning Representations},
  year={2022},
  url={https://openreview.net/forum?id=vaRCHVj0uGI}
}

and its prior work

@inproceedings{
  song2021scorebased,
  title={Score-Based Generative Modeling through Stochastic Differential Equations},
  author={Yang Song and Jascha Sohl-Dickstein and Diederik P Kingma and Abhishek Kumar and Stefano Ermon and Ben Poole},
  booktitle={International Conference on Learning Representations},
  year={2021},
  url={https://openreview.net/forum?id=PxTIG12RRHS}
}
Comments
  • Missing source files ?

    Missing source files ?

    Hi,

    I am very interested in your work that I find excellent, and I would like to reproduce some results as well as extend it a bit. I am having a hard time just running the main.py script. In the datasets.py file, many files/packages are imported among which ct2d or fastmri_knee_single. I can't seem to find these files/folders anywhere in the repository, maybe they are missing, or their names have changed. I am investigating the latter case, but in the meantime if it's just a matter of missing files, could you add them?

    Thanks a lot, Cheers

    opened by zaccharieramzi 4
  • multi node multi gpu

    multi node multi gpu

    Hi @yang-song , thanks for sharing the code. I wonder whether the code supports multi node multi gpu training. I see the code include jax.device and jax.local_device. But I wonder how connection among the gpus is set up? Thanks again!

    opened by JiahaoYao 2
  • Fixing the requirements file

    Fixing the requirements file

    When installing the requirements file (pip install -r requirements.txt), I found the following error:

    ERROR: Cannot install -r requirements.txt (line 6) and tensorboard==2.4.1 because these package versions have conflicting dependencies.
    
    The conflict is caused by:
        The user requested tensorboard==2.4.1
        tensorflow 2.5.0 depends on tensorboard~=2.5
    
    To fix this you could try to:
    1. loosen the range of package versions you've specified
    2. remove package versions to allow pip attempt to solve the dependency conflict
    
    ERROR: ResolutionImpossible: for help visit https://pip.pypa.io/en/latest/topics/dependency-resolution/#dealing-with-dependency-conflicts
    

    I am using Python 3.9.

    I just suggest to get rid of the strict version req on tensorboard, as it is installed by TensorFlow anyway.

    I also added scikit-image to the requirements file as it is needed when evaluating for the image quality metrics. Similarly I added piq, torch and torchvision I also added odl as needed as well for CT. For all of these packages I do not know which version should be used.

    I pinned the tensorflow-probability version to make sure it matches that of tensorflow and that it can be used with tensorflow-gan.

    opened by zaccharieramzi 0
  • Module is not compatible

    Module is not compatible

    Hi, I have tried to install the modules according to requirements file, but the commit ID corresponding to flax is invalid. So I installed the latest version of flax directly,but this version is not compatible with the versions of other required packages. I tried to downgrade the version of flax, but it still didn't work. Could you tell me the exact version of Flax? Thanks! And maybe there will be a pytorch version?

    opened by adahfbch 0
  • Host memory leak during training

    Host memory leak during training

    I'm experiencing issues with host memory usage continually increasing during training, until eventually my machine freezes up or the process is killed due to out-of-memory (I have 32GB available). Everything else about training seems to be working fine until it crashes (after around 5000 iterations), and GPU memory is also fine as usage of it is completely constant. I've tried several versions of jax/jaxlib/flax but there doesn't seem to be any change with this. I've attached the output of pip freeze in my virtualenv.

    Any clues what could be causing this? I searched for JAX memory leaks on Google/StackOverflow, but didn't find anything that seemed useful/related.

    pip-environment.txt .

    opened by cobalamin 2
  • AttributeError: module 'config_config' has no attribute 'get_config'

    AttributeError: module 'config_config' has no attribute 'get_config'

    Hello, I only use python-3.8, flax-0.5, jax-0.3.14. Reducing jax-0.2.1 or increasing -0.5.2 will result in similar error reports. Do you encounter similar situations after consulting 6082bb46b3a4f253eb98c08b267fa99

    opened by tianzhijiaoziA 1
  • Possibility to run on CPU

    Possibility to run on CPU

    Hi,

    I have tried running the code on CPU (my setup is with Python 3.9, Ubuntu 16.04 on an 8-core machine), and I have had a segmentation fault:

    Fatal Python error: Segmentation fault
    
    Thread 0x00007f56467fc700 (most recent call first):
      File "/usr/lib/python3.9/threading.py", line 316 in wait
      File "/usr/lib/python3.9/threading.py", line 574 in wait
      File "/home/zaccharie/workspace/score_inverse_problems/venv/lib/python3.9/site-packages/tqdm/_monitor.py", line 60 in run
      File "/usr/lib/python3.9/threading.py", line 954 in _bootstrap_inner
      File "/usr/lib/python3.9/threading.py", line 912 in _bootstrap
    
    Current thread 0x00007f5813035700 (most recent call first):
      File "/home/zaccharie/workspace/score_inverse_problems/venv/lib/python3.9/site-packages/jax/interpreters/pxla.py", line 1160 in execute_replicated
      File "/home/zaccharie/workspace/score_inverse_problems/venv/lib/python3.9/site-packages/jax/interpreters/pxla.py", line 637 in xla_pmap_impl
      File "/home/zaccharie/workspace/score_inverse_problems/venv/lib/python3.9/site-packages/jax/core.py", line 607 in process_call
      File "/home/zaccharie/workspace/score_inverse_problems/venv/lib/python3.9/site-packages/jax/core.py", line 1624 in process
      File "/home/zaccharie/workspace/score_inverse_problems/venv/lib/python3.9/site-packages/jax/core.py", line 1552 in call_bind
      File "/home/zaccharie/workspace/score_inverse_problems/venv/lib/python3.9/site-packages/jax/core.py", line 1621 in bind
      File "/home/zaccharie/workspace/score_inverse_problems/venv/lib/python3.9/site-packages/jax/_src/api.py", line 1632 in f_pmapped
      File "/home/zaccharie/workspace/score_inverse_problems/venv/lib/python3.9/site-packages/jax/_src/traceback_util.py", line 183 in reraise_with_filtered_traceback
      File "/home/zaccharie/workspace/score_inverse_problems/score_inverse_problems/run_lib.py", line 391 in evaluate
      File "/home/zaccharie/workspace/score_inverse_problems/score_inverse_problems/main.py", line 60 in main
      File "/home/zaccharie/workspace/score_inverse_problems/venv/lib/python3.9/site-packages/absl/app.py", line 251 in _run_main
      File "/home/zaccharie/workspace/score_inverse_problems/venv/lib/python3.9/site-packages/absl/app.py", line 303 in run
      File "/home/zaccharie/workspace/score_inverse_problems/score_inverse_problems/main.py", line 68 in <module>
    [1]    5511 segmentation fault (core dumped)  python score_inverse_problems/main.py --config  --workdir=./ --mode eval
    

    Have you tried running the code on CPU, or is it a GPU-only code?

    opened by zaccharieramzi 2
Owner
Yang Song
Research Scientist at OpenAI.
Yang Song
The official Github repo for the python package: scratchon

scratchon A Python and Scratch Project Connector! Installation pip install scratchon Get started How to set up a client connection: import scratchon

Nice One 0 Sep 3, 2022
The official repo for the paper "Rethinking Portrait Matting with Privacy Preserving". For further questions, please contact Sihan Ma at [email protected] or Jizhizi Li at [email protected]

Rethinking Portrait Matting with Privacy Preserving This is the official repository of the paper Rethinking Portrait Matting with Privacy Preserving.

null 96 Sep 29, 2022
This repo presents you the official code of "VISTA: Boosting 3D Object Detection via Dual Cross-VIew SpaTial Attention"

VISTA VISTA: Boosting 3D Object Detection via Dual Cross-VIew SpaTial Attention Shengheng Deng, Zhihao Liang, Lin Sun and Kui Jia* (*) Corresponding a

null 90 Sep 26, 2022
The official repo for the paper "An Empirical Study of Remote Sensing Pretraining"

An Empirical Study of Remote Sensing Pretraining Di Wang, Jing Zhang, Bo Du, Gui-Song Xia and Dacheng Tao Updates | Introduction | Usage | Results & M

null 182 Sep 23, 2022
The official repo for OC-SORT: Observation-Centric SORT on video Multi-Object Tracking. OC-SORT is simple, online and robust to occlusion/non-linear motion.

OC-SORT This is the github repo for Observation-Centric SORT: Rethinking SORT for Robust Multi-Object Tracking [arxiv]. Observation-Centric SORT (OC-S

Jinkun Cao 262 Oct 1, 2022
This repo is the official megengine implementation of the ECCV2022 paper: Efficient One Pass Self-distillation with Zipf's Label Smoothing.

This repo is the official megengine implementation of the ECCV2022 paper: Efficient One Pass Self-distillation with Zipf's Label Smoothing. The pytorc

MEGVII Research 13 Sep 1, 2022
This repo contains the official implementation of ECCV 2022 paper "What to Hide from Your Students: Attention-Guided Masked Image Modeling"

What to Hide from Your Students: Attention-Guided Masked Image Modeling PyTorch implementation and pretrained models for AttMask. [arXiv] Pretrained m

Ioannis Kakogeorgiou 21 Sep 21, 2022
Official repo for FEAR: Fast, Efficient, Accurate and Robust Visual Tracker (ECCV 2022)

FEAR: Fast, Efficient, Accurate and Robust Visual Tracker This is an official repository for the paper FEAR: Fast, Efficient, Accurate and Robust Visu

Piñata Farms 39 Sep 29, 2022
[ECCV 2022] The official repo for the paper "Poseur: Direct Human Pose Regression with Transformers".

Poseur: Direct Human Pose Regression with Transformers Poseur: Direct Human Pose Regression with Transformers, Weian Mao*, Yongtao Ge*, Chunhua Shen,

Advanced Intelligent Machines (AIM) 102 Sep 22, 2022
The official repo for ECCV'22 paper: Pose for Everything: Towards Category-Agnostic Pose Estimation

Pose-for-Everything (ECCV'2022 Oral) Introduction Official code repository for the paper: Pose for Everything: Towards Category-Agnostic Pose Estimati

Lumin 94 Sep 28, 2022
This repo is official PyTorch implementation of 3D Clothed Human Reconstruction in the Wild (ECCV 2022).

3D Clothed Human Reconstruction in the Wild (ClothWild codes) 3D Clothed Human Reconstruction in the Wild, Gyeongsik Moon, Hyeongjin Nam, Takaaki Shir

Hyeongjin Nam 103 Sep 28, 2022
This repo equips the official CLIFF [ECCV 2022 Oral] with better detector, better tracker. Support multi-person, motion interpolation and smooth.

CLIFF [ECCV 2022 Oral] Introduction This repo is highly built on the official CLIFF and contains an inference demo, and further adds accurate detector

Haofan Wang 5 Aug 31, 2022
Official repo for CVPR 2022 (Oral) paper: Revisiting the "Video" in Video-Language Understanding. Contains code for the Atemporal Probe (ATP).

Revisiting the "Video" in Video-Language Understanding Welcome to the official repo for our paper: Revisiting the "Video" in Video-Language Understand

Stanford Vision and Learning Lab 8 Sep 21, 2022
Hardware/Software repo for DIY eye tracking on the Valve Index

IndexEyeTrackVR Hardware/Software repo for DIY eye tracking on the Valve Index (and potentially other) VR headset. This is intended to emulate the fun

Razgriz 15 Sep 10, 2022
A mirror repo of ImmortalTracker

Immortal_tracker Prerequisite Our code is tested for Python 3.6. To install required liabraries: pip install -r requirements.txt Waymo Open Dataset P

null 1 Mar 24, 2022
This repo contains explanations of the most popular data structures used in interview questions.

Repository This repo contains explanations of the most popular data structures used in interview questions. Each data structure has its own separate R

Maria Roberta Prado 6 Jun 26, 2022
This repo will host the templates and scripts needed for enabling gatk-sv pipeline on Cromwell EC2 server

AWS Setup & Execution This document provides all the relevant steps needed for execution of this pipeline on AWS Infrastructure. For this the pre-requ

null 3 Jun 14, 2022
This repo aims to help me construct a neural network from 0.

NN_practice This repo aims to help me construct a neural network from 0. Manually build a whole neural network can help to enhance understanding of ma

Edith Lee 2 Oct 1, 2022
This repo introduces users to python and its basic concepts and familiarises Beginers with Python Syntax.

?? Python-TUTORIALS This repo introduces users to python and its basic concepts and familiarises Beginers with Python Syntax. ?? Applications of Pytho

Palpatine 1 Apr 1, 2022