Official Pytorch Code base for "UNeXt: MLP-based Rapid Medical Image Segmentation Network", MICCAI 2022

Overview

UNeXt

Official Pytorch Code base for UNeXt: MLP-based Rapid Medical Image Segmentation Network, MICCAI 2022

Paper | Project

Introduction

UNet and its latest extensions like TransUNet have been the leading medical image segmentation methods in recent years. However, these networks cannot be effectively adopted for rapid image segmentation in point-of-care applications as they are parameter-heavy, computationally complex and slow to use. To this end, we propose UNeXt which is a Convolutional multilayer perceptron (MLP) based network for image segmentation. We design UNeXt in an effective way with an early convolutional stage and a MLP stage in the latent stage. We propose a tokenized MLP block where we efficiently tokenize and project the convolutional features and use MLPs to model the representation. To further boost the performance, we propose shifting the channels of the inputs while feeding in to MLPs so as to focus on learning local dependencies. Using tokenized MLPs in latent space reduces the number of parameters and computational complexity while being able to result in a better representation to help segmentation. The network also consists of skip connections between various levels of encoder and decoder. We test UNeXt on multiple medical image segmentation datasets and show that we reduce the number of parameters by 72x, decrease the computational complexity by 68x, and improve the inference speed by 10x while also obtaining better segmentation performance over the state-of-the-art medical image segmentation architectures.

Using the code:

The code is stable while using Python 3.6.13, CUDA >=10.1

  • Clone this repository:
git clone https://github.com/jeya-maria-jose/UNeXt-pytorch
cd UNeXt-pytorch

To install all the dependencies using conda:

conda env create -f environment.yml
conda activate unext

If you prefer pip, install following versions:

timm==0.3.2
mmcv-full==1.2.7
torch==1.7.1
torchvision==0.8.2
opencv-python==4.5.1.48

Datasets

  1. ISIC 2018 - Link
  2. BUSI - Link

Data Format

Make sure to put the files as the following structure (e.g. the number of classes is 2):

inputs
└── <dataset name>
    ├── images
    |   ├── 001.png
    │   ├── 002.png
    │   ├── 003.png
    │   ├── ...
    |
    └── masks
        ├── 0
        |   ├── 001.png
        |   ├── 002.png
        |   ├── 003.png
        |   ├── ...
        |
        └── 1
            ├── 001.png
            ├── 002.png
            ├── 003.png
            ├── ...

For binary segmentation problems, just use folder 0.

Training and Validation

  1. Train the model.
python train.py --dataset <dataset name> --arch UNext --name <exp name> --img_ext .png --mask_ext .png --lr 0.0001 --epochs 500 --input_w 512 --input_h 512 --b 8
  1. Evaluate.
python val.py --name <exp name>

Acknowledgements:

This code-base uses certain code-blocks and helper functions from UNet++, Segformer, and AS-MLP. Naming credits to Poojan.

Citation:

@article{valanarasu2022unext,
  title={UNeXt: MLP-based Rapid Medical Image Segmentation Network},
  author={Valanarasu, Jeya Maria Jose and Patel, Vishal M},
  journal={arXiv preprint arXiv:2203.04967},
  year={2022}
}
Comments
  • NoneType' object is not subscriptable

    NoneType' object is not subscriptable

    how do i solve this problem 图片 Traceback (most recent call last): File "D:/Code/UNeXt/train.py", line 371, in <module> main() File "D:/Code/UNeXt/train.py", line 326, in main train_log = train(config, train_loader, model, criterion, optimizer) File "D:/Code/UNeXt/train.py", line 119, in train for input, target, _ in train_loader: File "C:\ProgramData\Anaconda3\lib\site-packages\torch\utils\data\dataloader.py", line 521, in __next__ data = self._next_data() File "C:\ProgramData\Anaconda3\lib\site-packages\torch\utils\data\dataloader.py", line 561, in _next_data data = self._dataset_fetcher.fetch(index) # may raise StopIteration File "C:\ProgramData\Anaconda3\lib\site-packages\torch\utils\data\_utils\fetch.py", line 44, in fetch data = [self.dataset[idx] for idx in possibly_batched_index] File "C:\ProgramData\Anaconda3\lib\site-packages\torch\utils\data\_utils\fetch.py", line 44, in <listcomp> data = [self.dataset[idx] for idx in possibly_batched_index] File "D:\Code\UNeXt\dataset.py", line 63, in __getitem__ img_id + self.mask_ext), cv2.IMREAD_GRAYSCALE)[..., None]) TypeError: 'NoneType' object is not subscriptable 0%| | 0/21 [00:00<?, ?it/s]

    opened by Eveinn 18
  • Parameter Settings on BUSI dataset

    Parameter Settings on BUSI dataset

    Can you share your parameter Settings on the breast ultrasound image dataset?I set the epoch=400, batch-size=8, lr=0.0001,momentum=0.9,optimzer=adam and scheduler=CosineAnnealingLR and channels setting as [16, 32, 128, 160, 256],but got the result:Dice:75.74 IOU:61.99(In the paper, it's 66.95 ) Thanks!

    opened by xiaohancl 6
  • Not as fast as the paper says

    Not as fast as the paper says

    Thank you for your sharing. After my experiment, I found that the speed was quite slow. May I ask what caused the problem? Why do the OverlapPatchEmbed and shiftMLP modules both use convolution in the MLP phase, and the convolution kernel size is 3 and 7, which makes the speed slow.

    opened by JOP-Lee 3
  • torch.add in Unext module?

    torch.add in Unext module?

    Hello, I saw in the source code, you used torch.cat() to combine the downsampled features with the upsampled features? Is any specific reason for that operation? Is it because using torch.add() has less operations than torch.concate() when the network process forward? Since I remember in unet the downsampled features and the upsampled features are concatenated together rather than add. Thanks.

    opened by Capchenxi 2
  • Performance on multi-class segmentation.

    Performance on multi-class segmentation.

    Hello, Thanks for your excellent idea on shifted mlp module. I have some questions on other segmentation task other than the medical segmentation. I noticed that the code you provide is mainly focusing on 2-class segmentation task. Is there any possible that you tried the UneXt structure on multi-class task like Cityscapes? I have tried this structure on my own task, which is a 12-class segmentation task for parking lot scene, but it performs poorly. So I just wonder if the UneXt is experimented on Cityscape dataset or other multi-class dataset and how it performs? Thanks.

    opened by Capchenxi 2
  • About Experiments

    About Experiments

    Thank you for your sharing. I noticed ‘We perform a 80-20 random split thrice across the dataset and report the mean and variance.’ in the paper 3 Experiments and Results. I wonder if random _ state is set to 41 in every experiment. Does it mean that the training and test sets are the same in the three experiments? Thank you very much. train_img_ids, val_img_ids = train_test_split(img_ids, test_size=0.2, random_state=41)

    opened by AirZWH 1
  • self.img_ids = img_ids

    self.img_ids = img_ids

    self.img_ids = img_ids self.img_dir =img_dir self.mask_dir = mask_dir self.img_ext = '.png' self.mask_ext = '.png' self.num_classes = 2 self.transform = transform

    self.img_ids = img_ids self.img_dir =img_dir self.mask_dir = mask_dir

    Here is how the equals sign should be written, and the first one is not in the form of a list Can you give an example?

    opened by antoniaaaaaaaaaaaa 1
  • Poor training effect

    Poor training effect

    My training effect on the two data sets is far lower than the effect achieved in the original paper. I would like to ask if there is any solution. Could you tell me the method if you have solved it? I hope to get a reply or send the modified code to my email [email protected] if it is convenient. Thank you very much for your help!

    opened by Y-Miou 0
  •  shift the axis of the channels  of conv features before tokening

    shift the axis of the channels of conv features before tokening

    Thanks for your great work.

    I have a confusion:  Does the first sentence in paragraph Shifted MLP, Section 2-- 'In shifted MLP, we first shift the axis of the channels of conv features before tokening_' implementedin your release sourced code?
    
     I didn't see the operation of shift the axis of the channels of conv features in your source code, could your please give me a favour? 
    

    I am looking forward you!

    opened by Azhihong 0
  • Dataset classification problem

    Dataset classification problem

    I want to ask the author if I have 4 categories, then label 1 contains only the label label 255, other classes for 0. Label 2 contains only label2 255, other classes for 0, and so on?

    opened by heixinbaicaishang 0
  • Incorrect usage of Normalize in transforms

    Incorrect usage of Normalize in transforms

    I found a bug:

    You use the albumentations.augmentations transforms function to normalize the input images, however it uses some random values which are not the actual mean and standard deviations of the data sets. You would have to calculate the mean and standard deviation of the data set beforehand and input them as values for it to be correct. Not the end of the world, but would probably help in stability and overall performance to correct this.

    opened by david-stojanovski 0
Owner
Jeya Maria Jose
PhD Student at Johns Hopkins University.
Jeya Maria Jose
Official Code for our MICCAI 2022 paper "Exploring Smoothness and Class-Separation for Semi-supervised Medical Image Segmentation"

Exploring Smoothness and Class-Separation for Semi-supervised Medical Image Segmentation by Yicheng Wu*, Zhonghua Wu, Qianyi Wu, Zongyuan Ge, and Jian

Eli Wu 45 Nov 15, 2022
Official Released code for MICCAI 2022 paper: CaRTS: Causality-driven Robot Tool Segmentation from Vision and Kinematics Data

CaRTS: Causality-driven Robot Tool Segmentation from Vision and Kinematics Data This repo hosts the code for implementing the CaRTS algorithms for Rob

Hao Ding 5 Nov 11, 2022
PyTorch implementation of the paper: Dual-Distribution Discrepancy for Anomaly Detection in Chest X-Rays (MICCAI 2022)

DDAD (MICCAI 2022) This is the PyTorch implementation of our paper: Dual-Distribution Discrepancy for Anomaly Detection in Chest X-Rays Yu Cai, Hao Ch

Yu Cai 38 Oct 24, 2022
PyTorch implementation of Denoising of 3D MR images using a voxel-wise hybrid residual MLP-CNN model to improve small lesion diagnostic confidence (MICCAI 2022).

Denoising of 3D MR images using a voxel-wise hybrid residual MLP-CNN model to improve small lesion diagnostic confidence PyTorch implementation of Den

null 9 Nov 15, 2022
[MICCAI 2022] Toward Clinically Assisted Colorectal Polyp Recognition via Structured Cross-modal Representation Consistency

CPC-Trans Code for the MICCAI 2022 (early accepted) paper: "Toward Clinically Assisted Colorectal Polyp Recognition via Structured Cross-modal Represe

Weijie Ma 9 Sep 27, 2022
TGANet: Text-guided attention for improved polyp segmentation [Accepted at MICCAI 2022]

TGANet: Text-guided attention for improved polyp segmentation Abstract Colonoscopy is a gold standard procedure but is highly operator-dependent. Auto

Nikhil Tomar 36 Nov 23, 2022
MICCAI 2022 (Provisionally Accepted): Calibrating Label Distribution for Class-Imbalanced Barely-Supervised Knee Segmentation

CLD-Semi Yiqun Lin, Huifeng Yao, Zezhong Li, Guoyan Zheng, Xiaomeng Li, "Calibrating Label Distribution for Class-Imbalanced Barely-Supervised Knee Se

xmed-lab 35 Nov 20, 2022
[MICCAI 2022] ShapePU: A New PU Learning Framework Regularized by Global Consistency for Scribble Supervised Cardiac Segmentation

ShapePU This project is developed for our MICCAI 2022 paper: ShapePU: A New PU Learning Framework Regularized by Global Consistency for Scribble Super

Ke Zhang 20 Nov 14, 2022
Variance pooling to incorporate ITH in CPath models - MICCAI 2022

Incorporating intratumoral heterogeneity into weakly-supervised deep learning models via variance pooling Carmichael, I.*, Song, A.H.*, Chen, R.J., Wi

Mahmood Lab @ Harvard/BWH & MGH 16 Nov 11, 2022
Official code repository for Findings of EMNLP 2022 paper: PseudoReasoner: Leveraging Pseudo Labels for Commonsense Knowledge Base Population

PseudoReasoner: Leveraging Pseudo Labels for Commonsense Knowledge Base Population (Findings of EMNLP 2022) This repository contains the code for our

HKUST-KnowComp 6 Oct 17, 2022
[MICCAI' 22] Semi-Supervised Medical Image Classification with Temporal Knowledge-Aware Regularization

TEAR: Semi-Supervised Medical Image Classification with Temporal Knowledge-Aware Regularization This repository is an official PyTorch implementation

Qiushi Yang 6 Oct 5, 2022
This is the official PyTorch implementation of TBSR. Our team received 2nd place (real data track) and 3rd place (synthetic track) in NTIRE 2022 Burst Super-Resolution Challenge (CVPRW 2022).

Transformer for Burst Image Super-Resolution (TBSR) This is the official PyTorch implementation of TBSR. Our team received 2nd place (real data track)

Zhilu Zhang 11 Jul 26, 2022
[ECCV 2022] Official pytorch implementation of "mc-BEiT: Multi-choice Discretization for Image BERT Pre-training" in European Conference on Computer Vision (ECCV) 2022.

mc-BEiT: Multi-choice Discretization for Image BERT Pre-training Official pytorch implementation of "mc-BEiT: Multi-choice Discretization for Image BE

lixiaotong 15 Nov 15, 2022
This project investigates building a program that will evaluate a code base and determine whether it passes the Bechdel test for software.

Bechdel Code Tester About this project. Getting started: Create a clone of this repository: (only do 1st time) [email protected]:coriography/bechdel_test

Cori 1 Aug 12, 2022
Base code to get my PhantomX hexapod moving.

Hexapod Base code to get my PhantomX hexapod moving. This is a continuation from the other hexapod repository. My goal is to get my hexapod moving wit

Nick Weber 1 Sep 21, 2022
[CVPR 2022] Official Pytorch code for OW-DETR: Open-world Detection Transformer

OW-DETR: Open-world Detection Transformer (CVPR 2022) [Paper] Akshita Gupta*, Sanath Narayan*, K J Joseph, Salman Khan, Fahad Shahbaz Khan, Mubarak Sh

Akshita Gupta 116 Nov 21, 2022
Sound-guided Semantic Image Manipulation - Official Pytorch Code (CVPR 2022)

?? Sound-guided Semantic Image Manipulation (CVPR2022) Official Pytorch Implementation Sound-guided Semantic Image Manipulation IEEE/CVF Conference on

CVLAB 55 Nov 28, 2022
Compare faces and calculate and obtain larget similarity in the data base.

facial_similarity Compare faces and calculate and obtain larget similarity in the data base. This project uses pretrained caffe model for facial detec

Gonglin(Peter) 1 Apr 7, 2022
A faster firmware base address scanner.

basefind2 A faster base address scanner based on @mncoppola's basefind.py and rbasefind. Features Scans a flat 32-bit binary and attempt to determine

soyer 19 Sep 22, 2022