Incremental Transformer Structure Enhanced Image Inpainting with Masking Positional Encoding (CVPR2022)

Overview

Incremental Transformer Structure Enhanced Image Inpainting with Masking Positional Encoding

by Qiaole Dong*, Chenjie Cao*, Yanwei Fu

Paper and Supplemental Material (arXiv)

LICENSE

Pipeline

Click to expand

The overview of our ZITS. At first, the TSR model is used to restore structures with low resolutions. Then the simple CNN based upsampler is leveraged to upsample edge and line maps. Moreover, the upsampled sketch space is encoded and added to the FTR through ZeroRA to restore the textures.

TO DO

We have updated weights of TSR!

Our project page is available at https://dqiaole.github.io/ZITS_inpainting/.

  • Releasing inference codes.
  • Releasing pre-trained moodel.
  • Releasing training codes.

Preparation

Click to expand
  1. Preparing the environment:

    as there are some bugs when using GP loss with DDP (link), we strongly recommend installing Apex without CUDA extensions via torch1.9.0 for the multi-gpu training

    conda create -n train_env python=3.6
    conda activate train_env
    pip install torch==1.9.0+cu111 torchvision==0.10.0+cu111 torchaudio==0.9.0 -f https://download.pytorch.org/whl/torch_stable.html
    pip install -r requirement.txt
    git clone https://github.com/NVIDIA/apex
    cd apex
    pip install -v --disable-pip-version-check --no-cache-dir --global-option="--cpp_ext" ./
    
  2. For training, MST provide irregular and segmentation masks (download) with different masking rates. And you should define the mask file list before the training as in MST.

  3. Download the pretrained masked wireframe detection model to the './ckpt' fold: LSM-HAWP (MST ICCV2021 retrained from HAWP CVPR2020).

  4. Prepare the wireframes:

    as the MST train the LSM-HAWP in Pytorch 1.3.1 and it causes problem (link) when tested in Pytorch 1.9, we recommand to inference the lines(wireframes) with torch==1.3.1. If the line detection is not based on torch1.3.1, the performance may drop a little.

    conda create -n wireframes_inference_env python=3.6
    conda activate wireframes_inference_env
    pip install torch==1.3.1 torchvision==0.4.2
    pip install -r requirement.txt
    

    then extract wireframes with following code

    python lsm_hawp_inference.py --ckpt_path <best_lsm_hawp.pth> --input_path <input image path> --output_path <output image path> --gpu_ids '0'
    
  5. If you need to train the model, please download the pretrained models for perceptual loss, provided by LaMa:

    mkdir -p ade20k/ade20k-resnet50dilated-ppm_deepsup/
    wget -P ade20k/ade20k-resnet50dilated-ppm_deepsup/ http://sceneparsing.csail.mit.edu/model/pytorch/ade20k-resnet50dilated-ppm_deepsup/encoder_epoch_20.pth
    

Eval

Click to expand

Download pretrained models on Places2 here.

Batch Test

For batch test, you need to complete steps 3 and 4 above.

Put the pretrained models to the './ckpt' fold. Then modify the config file according to you image, mask and wireframes path.

Test on 256 images:

conda activate train_env
python FTR_inference.py --path ./ckpt/zits_places2 --config_file ./config_list/config_ZITS_places2.yml --GPU_ids '0'

Test on 512 images:

conda activate train_env
python FTR_inference.py --path ./ckpt/zits_places2_hr --config_file ./config_list/config_ZITS_HR_places2.yml --GPU_ids '0'

Single Image Test

Note: For single image test, environment 'wireframes_inference_env' in step 4 is recommended for a better line detection. This code only supports squared images (or they will be center cropped).

conda activate wireframes_inference_env
python single_image_test.py --path <ckpt_path> --config_file <config_path> \
 --GPU_ids '0' --img_path ./image.png --mask_path ./mask.png --save_path ./

Training

Click to expand

⚠️ Warning: The training codes is not fully tested yet after refactoring

Training TSR

python TSR_train.py --name places2_continous_edgeline --data_path [training_data_path] \
 --train_line_path [training_wireframes_path] \
 --mask_path ['irregular_mask_list.txt', 'coco_mask_list.txt'] \
 --train_epoch 12 --validation_path [validation_data_path] \
 --val_line_path [validation_wireframes_path] \
 --valid_mask_path [validation_mask] --nodes 1 --gpus 1 --GPU_ids '0' --AMP
python TSR_train.py --name places2_continous_edgeline --data_path [training_data_path] \
 --train_line_path [training_wireframes_path] \
 --mask_path ['irregular_mask_list.txt', 'coco_mask_list.txt'] \
 --train_epoch 15 --validation_path [validation_data_path] \
 --val_line_path [validation_wireframes_path] \
 --valid_mask_path [validation_mask] --nodes 1 --gpus 1 --GPU_ids '0' --AMP --MaP

Train SSU

We recommend to use the pretrained SSU. You can also train your SSU refered to https://github.com/ewrfcas/StructureUpsampling.

Training LaMa First

python FTR_train.py --nodes 1 --gpus 1 --GPU_ids '0' --path ./ckpt/lama_places2 \
--config_file ./config_list/config_LAMA.yml --lama

Training FTR

256:

python FTR_train.py --nodes 1 --gpus 2 --GPU_ids '0,1' --path ./ckpt/places2 \
--config_file ./config_list/config_ZITS_places2.yml --DDP

256~512:

python FTR_train.py --nodes 1 --gpus 2 --GPU_ids '0,1' --path ./ckpt/places2_HR \
--config_file ./config_list/config_ZITS_HR_places2.yml --DDP

More 1K Results

Click to expand

Acknowledgments

Cite

If you found our program helpful, please consider citing:

@inproceedings{dong2022incremental,
      title={Incremental Transformer Structure Enhanced Image Inpainting with Masking Positional Encoding}, 
      author={Qiaole Dong and Chenjie Cao and Yanwei Fu},
      booktitle={Proceedings of the IEEE/CVF Conference on Computer Vision and Pattern Recognition},
      year={2022}
}
Comments
  • lsm_hawp_inference.py_result_bad

    lsm_hawp_inference.py_result_bad

    I try to use the lsm_hawp_inference.py to generate the .pkl of my dataset (place365). I used the best_lsm_hawp.pth which you provided. But the result is really bad. I tried reduce threshold=0.8 > 0.5 but it still has bad result.

    Do you have the best_palce365_lsm_hawp.pth? Or how do we train our own hawp.

    The iamge is the sample from training.(14001.jpg) image

    opened by bobo0303 13
  • Is there a demo code to make inference on custom image and mask

    Is there a demo code to make inference on custom image and mask

    Hi, I tried single_image_test.py, but it is hard coded for Places 365 Standard. Is there any simpler demo code to show the results based on a pair of inputs such as image and corresponding mask?

    opened by yijingru 6
  • inpainting_metrics.py中ValueError: axes don't match array错误

    inpainting_metrics.py中ValueError: axes don't match array错误

    作者你好,在文件的这行代码出现上述问题 image 如果此时传入的是indoor数据集的原图像,不是256*256大小的,根据博客:https://stackoverflow.com/questions/37747021/create-numpy-array-of-images 给出的方法,对图像进行处理后可以正常运行,但是不知道这样是否会对结果产生影响。 image 我想知道你们有遇到过这个问题吗?能否给我一点建议呢?期待您的回复,谢谢

    opened by Ellohiye 4
  • Single image test

    Single image test

    你好,作者,你做的工作非常棒,只是我在进行源码测试时有一些疑问:在下面的配置设置中 python single_image_test.py --path <ckpt_path> --config_file <config_path>
    --GPU_ids '0' --img_path ./image.png --mask_path ./mask.png --save_path ./ 权重path使用哪个呢?config_file使用哪个文件呢?我自己设置的设置如下: python single_image_test.py --path ./ckpt/zits_places2_hr --config_file ./config_list/config_ZITS_HR_places2.yml --GPU_ids '0' --img_path ./test_i/img1.png --mask_path ./test_i/mask1.png --save_path ./test_i/ 但是出现了下面的错误: Traceback (most recent call last): File "single_image_test.py", line 322, in model = ZITS(config, 0, 0, True) File "D:\pythonProject\7_4\inpaint\ZITS_inpainting-main\src\FTR_trainer.py", line 296, in init min_sigma=min_sigma, max_sigma=max_sigma) File "D:\pythonProject\7_4\inpaint\ZITS_inpainting-main\datasets\dataset_FTR.py", line 178, in init f = open(flist, 'r') FileNotFoundError: [Errno 2] No such file or directory: '/home/wmlce/places365_standard/places2_all/test_sub_list.txt'

    单图测试也需要跟数据集一样的设置吗?希望能将测试步骤更加详细一些。希望能回复,非常感谢。希望能将测试步骤更加详细一些。

    opened by CodeMadUser 4
  • Bad results

    Bad results

    I am getting some very poor results. I am using the single_image script and resizing images to 512,512

    image image image

    Can some of the images + masks from the showed resuts can be shared? This way I could verify if I did something weird

    opened by mhashas 4
  • 可否预上传一份pth样本,直接调试

    可否预上传一份pth样本,直接调试

    config_ZITS_places2.yml

    transformer_ckpt_path: './ckpt/best_transformer_places2.pth' gen_weights_path0: './ckpt/lama_places2/InpaintingModel_gen.pth' # Not required at the time of eval dis_weights_path0: './ckpt/lama_places2/InpaintingModel_dis.pth' # Not required at the time of eval structure_upsample_path: './ckpt/StructureUpsampling.pth'

    D:\pm\python\inpaint\ZITS_inpainting-main\ZITS_inpainting-main\src\models\FTR_model.py

    data = torch.load(config.structure_upsample_path, map_location='cpu')
    

    发生异常: AttributeError 'NoneType' object has no attribute 'seek'. You can only torch.load from a file that is seekable. Please pre-load the data into a buffer like io.BytesIO and try to load from it instead.

    During handling of the above exception, another exception occurred:

    File "D:\pm\python\inpaint\ZITS_inpainting-main\ZITS_inpainting-main\src\models\FTR_model.py", line 165, in init data = torch.load(config.structure_upsample_path, map_location='cpu') File "D:\pm\python\inpaint\ZITS_inpainting-main\ZITS_inpainting-main\src\models\FTR_model.py", line 427, in init super().init(*args, gpu=gpu, name='InpaintingModel', rank=rank, test=test, **kwargs) File "D:\pm\python\inpaint\ZITS_inpainting-main\ZITS_inpainting-main\src\FTR_trainer.py", line 256, in init self.inpaint_model = DefaultInpaintingTrainingModule(config, gpu=gpu, rank=rank, test=test, **kwargs).to(gpu) File "D:\pm\python\inpaint\ZITS_inpainting-main\ZITS_inpainting-main\single_image_test.py", line 323, in model = ZITS(config, 0, 0, True)


    PS D:\pm\python\inpaint\ZITS_inpainting-main\ZITS_inpainting-main> & 'D:\pm\python\python38\python.exe' 'c:\Users\Administrator.vscode\extensions\ms-python.python-2022.4.1\pythonFiles\lib\python\debugpy\launcher' '40191' '--' 'd:\pm\python\inpaint\ZITS_inpainting-main\ZITS_inpainting-main\single_image_test.py' '--path=D:\pm\python\lama\LaMa_models\lama-places\lama-fourier\models' '--config_file=D:\pm\python\inpaint\ZITS_inpainting-main\ZITS_inpainting-main\config_list\config_ZITS_places2.yml' '--GPU_ids=-1' '--img_path=D:\pm\python\inpaint\ZITS_inpainting-main\ZITS_inpainting-main\imgs\y\i1.png' '--mask_path=D:\pm\python\inpaint\ZITS_inpainting-main\ZITS_inpainting-main\imgs\mask\i1.png' '--save_path=D:\pm\python\inpaint\ZITS_inpainting-main\ZITS_inpainting-main\imgs' Backend TkAgg is interactive backend. Turning interactive mode on. BaseInpaintingTrainingModule init called Loading InpaintingModel StructureUpsampling...

    opened by time888 4
  • ERROR: Could not find a version that satisfies the requirement torch==1.3.1

    ERROR: Could not find a version that satisfies the requirement torch==1.3.1

    Hi, You recommend to inference the wireframes with torch 1.3.1 on README, but could not find the version by pip.

    ERROR: Could not find a version that satisfies the requirement torch==1.3.1 (from versions: 1.4.0, 1.5.0, 1.5.1, 1.6.0, 1.7.0, 1.7.1, 1.8.0, 1.8.1, 1.9.0, 1.9.1, 1.10.0, 1.10.1, 1.10.2, 1.11.0)
    ERROR: No matching distribution found for torch==1.3.1
    

    How did you install the old version?

    opened by naoki7090624 4
  • Question about loss and activation function

    Question about loss and activation function

    Hi, I have questions about activation function and loss.

    1. Why you calculate the loss before the activation function? According to your code, the cross entropy loss is calculated before sigmoid function. In the general CNN, I think the loss is calculated after the activation. Could you tell me why.

    2. Why you use only cross entropy loss? According to your code, only the cross entropy loss is calculated in the TSR. I wonder if you could use other losses (L1 loss, feature matching loss) after upsampling because there are convolution layers after transformer blocks.

    opened by naoki7090624 4
  • Access to the pre-trained model

    Access to the pre-trained model

    Loved the paper! The results compare to LaMa are amazing. Can I have access to the lightest pre-trained model? (Benchmarking on mobile devices)

    Best regards, Roi

    opened by roimulia2 4
  • Pretrained Indoor Model

    Pretrained Indoor Model

    Hi, Can you upload the pretrained Indoor data model - the results of which you share in your paper? Also, can you share the trained models of the comparative methods you show results for in your paper?

    Thank you.

    opened by toshi2k2 3
  • wireframe model is irrelevant

    wireframe model is irrelevant

    Hi,

    I've been playing quite a bit with your model due to the amazing results. Something that i've noticed is that the wireframe model is irrelevant. If I return a full zeros tensor of lines_tensor of the same shape as the actual output in wf_inference_test, I get the same final outputs. Is there a bug somewhere?

    To replicate:

    return torch.zeros_like(lines_tensor.detach()) in wf_inference_test

    Update:

    It seems that also the edges seem to be useless.

      batch["line_256"] = torch.zeros_like(batch["mask_256"])
      batch["line"] = torch.zeros_like(batch["mask_512"]) 
      batch["edge"] = torch.zeros_like(batch["mask_512"])
    

    Do this change gives me the same results

    Let me know if im doing something wrong

    opened by mhashas 3
  • the path in config

    the path in config

    modify the image path
    # origin images? TRAIN_FLIST: ./data_list/sp_large_train_list.txt
    VAL_FLIST: ./data_list/sp_large_val_list.txt TEST_FLIST: ./data_list/sp_large_val_list.txt

    set the GT images folder for metrics computation
    # origin Val image? GT_Val_FOLDER: './datasets/inpaint_data/val_images/'

    modify the mask path
    # the mask of random generation? TRAIN_MASK_FLIST: [ './data_list/mask_large_train_list.txt', './data_list/mask_large_train_list.txt' ]

    the real mask of object when object remove? TEST_MASK_FLIST: ./datasets/inpaint_data/val_SH_binary_masks/

    Could you tell me that the mean of these path in my mind is right?

    opened by ErisGe 1
Owner
Qiaole Dong
Qiaole Dong
Official code for "Towards An End-to-End Framework for Flow-Guided Video Inpainting" (CVPR2022)

E2FGVI (CVPR 2022) This repository contains the official implementation of the following paper: Towards An End-to-End Framework for Flow-Guided Video

Media Computing Group @ Nankai University 535 Jan 5, 2023
Inpainting using RunwayML's stable-diffusion-inpainting checkpoint

Stable Diffusion Cog model This is an implementation of the Stable Diffusion Inpainting as a Cog model. Cog packages machine learning models as standa

Andreas Jansson 5 Dec 24, 2022
Pixel height based data encoding, not the preferred/efficient way of encoding data but yet another experiment with pixels.

High-Pix encoding Pixel height based data encoding, not the preferred/efficient way of encoding data but yet another experiment with pixels. Requirem

Jaxparrow 1 Jul 30, 2022
TeD-Q (Tensor-network enhanced Distributed Quantum) is a tensor network enhanced distributed hybrid quantum machine learning framework.

TeD-Q (Tensor-network enhanced Distributed Quantum) TeD-Q is an open-source software framework for quantum machine learning, variational quantum algor

null 83 Dec 21, 2022
MAT: Mask-Aware Transformer for Large Hole Image Inpainting

MAT: Mask-Aware Transformer for Large Hole Image Inpainting (CVPR2022, Oral) Wenbo Li, Zhe Lin, Kun Zhou, Lu Qi, Yi Wang, Jiaya Jia [Paper] News This

null 259 Jan 5, 2023
We don't like positional args, we like keyword only args! 🎉

kwonly-transformer This is a very opinionated tool. The idea is that we want functions with multiple parameters to have exclusively keyword only param

Marcelo Trylesinski 72 Dec 21, 2022
Official PyTorch implementation of Scalable Neural Video Representations with Learnable Positional Features (NeurIPS 2022).

Scalable Neural Video Representations with Learnable Positional Features (NVP) Official PyTorch implementation of "Scalable Neural Video Representatio

Subin Kim 44 Dec 16, 2022
TopFormer: Token Pyramid Transformer for Mobile Semantic Segmentation, CVPR2022

TopFormer: Token Pyramid Transformer for Mobile Semantic Segmentation Paper Links: TopFormer: Token Pyramid Transformer for Mobile Semantic Segmentati

Hust Visual Learning Team 254 Jan 4, 2023
[ECCV 2022] Flow-Guided Transformer for Video Inpainting

[ECCV 2022] Flow-Guided Transformer for Video Inpainting This repository contains the implementation of the following paper: Flow-Guided Transformer f

Kaidong Zhang 141 Jan 3, 2023
code for our paper "Incremental Meta-Learning via Episodic Replay Distillation for Few-Shot Image Recognition" in CVPR 2022 3rd CLVISION continual learning workshop

ERD for IML code for our CVPRW 2022 paper Incremental Meta-Learning via Episodic Replay Distillation for Few-Shot Image Recognition by Kai Wang, Xiale

kai wang 2 Oct 3, 2022
Checker with simple ONNX model structure. Simple Structure Checker for ONNX.

ssc4onnx Checker with simple ONNX model structure. Simple Structure Checker for ONNX. https://github.com/PINTO0309/simple-onnx-processing-tools Key co

Katsuya Hyodo 1 May 27, 2022
chemical analyse tools for converting SMILES to map structure, or converting the map structure to SMILES

SMILES Mapper convert SMILES to map structure, or convert the map structure to SMILES Usage # in the parent folder of src python ./src/main.py First

irumeria 1 Jun 15, 2022
Helper Script to convert a Windbg dumped structure (using the 'dt' command) into a C structure. It creates dummy structs for you if needed

windbgtocstruct Helper Script to convert a Windbg dumped structure (using the 'dt' command) into a C structure. It creates dummy structs for you if ne

Dreg 20 Dec 19, 2022
Beyond Masking: Demystifying Token-Based Pre-Training for Vision Transformers

beyond masking Beyond Masking: Demystifying Token-Based Pre-Training for Vision Transformers The code is coming Figure 1: Pipeline of token-based pre-

Yunjie Tian 23 Sep 27, 2022
[TMM 2022] Self-Supervised Masking for Unsupervised Anomaly Detection and Localization

Self-Supervised Masking for Unsupervised Anomaly Detection and Localization This is an implementation of “Self-Supervised Masking for Unsupervised Ano

null 9 Sep 30, 2022
Pytorch code for Adapting Self-Supervised Vision Transformers by Probing Attention-Conditioned Masking Consistency

PyTorch Implementation of Adapting Self-Supervised Vision Transformers by Probing Attention-Conditioned Masking Consistency. Viraj Prabhu*, Sriram Yen

Viraj Prabhu 16 Dec 9, 2022
Wakong: An appropriate and robust masking algorithm for generating the training objective of text infilling

The Wakong Algorithm and Its Python Implementation Wakong: An appropriate and robust masking algorithm for generating the training objective of text i

Ayaka 2 Oct 22, 2022
This is an official implementation of the CVPR2022 paper "Blind2Unblind: Self-Supervised Image Denoising with Visible Blind Spots".

Blind2Unblind: Self-Supervised Image Denoising with Visible Blind Spots Blind2Unblind Citing Blind2Unblind @inproceedings{wang2022blind2unblind, tit

demonsjin 58 Dec 6, 2022
(CVPR2022) Reflash Dropout in Image Super-Resolution

Reflash-Dropout-in-Image-Super-Resolution (CVPR2022) Reflash Dropout in Image Super-Resolution Paper link: https://arxiv.org/pdf/2112.12089.pdf Depend

Xiangtao Kong 56 Dec 29, 2022