Sound-guided Semantic Image Manipulation - Official Pytorch Code (CVPR 2022)

Overview

🔉 Sound-guided Semantic Image Manipulation (CVPR2022)

Official Pytorch Implementation

Teaser image

Sound-guided Semantic Image Manipulation
IEEE/CVF Conference on Computer Vision and Pattern Recognition (CVPR) 2022

Paper : CVPR 2022 Open Access
Project Page: https://kuai-lab.github.io/cvpr2022sound/
Seung Hyun Lee, Wonseok Roh, Wonmin Byeon, Sang Ho Yoon, Chanyoung Kim, Jinkyu Kim*, and Sangpil Kim*

Abstract: The recent success of the generative model shows that leveraging the multi-modal embedding space can manipulate an image using text information. However, manipulating an image with other sources rather than text, such as sound, is not easy due to the dynamic characteristics of the sources. Especially, sound can convey vivid emotions and dynamic expressions of the real world. Here, we propose a framework that directly encodes sound into the multi-modal~(image-text) embedding space and manipulates an image from the space. Our audio encoder is trained to produce a latent representation from an audio input, which is forced to be aligned with image and text representations in the multi-modal embedding space. We use a direct latent optimization method based on aligned embeddings for sound-guided image manipulation. We also show that our method can mix different modalities, i.e., text and audio, which enrich the variety of the image modification. The experiments on zero-shot audio classification and semantic-level image classification show that our proposed model outperforms other text and sound-guided state-of-the-art methods.

💾 Installation

For all the methods described in the paper, is it required to have:

Specific requirements for each method are described in its section. To install CLIP please run the following commands:

conda install --yes -c pytorch pytorch=1.7.1 torchvision cudatoolkit=<CUDA_VERSION>
pip install ftfy regex tqdm gdown
pip install git+https://github.com/openai/CLIP.git

🔨 Method

Method image

1. CLIP-based Contrastive Latent Representation Learning.

Dataset Curation.

We create an audio-text pair dataset with the vggsound dataset. We also used the audioset dataset as the script below.

  1. Please download vggsound.csv from the link.
  2. Execute download.py to download the audio file of the vggsound dataset.
  3. Execute curate.py to preprocess the audio file (wav to mel-spectrogram).
cd soundclip
python3 download.py
python3 curate.py

Training.

python3 train.py

2. Sound-Guided Image Manipulation.

Direct Latent Code Optimization.

The code relies on the StyleCLIP pytorch implementation.

python3 optimization/run_optimization.py --lambda_similarity 0.002 --lambda_identity 0.0 --truncation 0.7 --lr 0.1 --audio_path "./audiosample/explosion.wav" --ckpt ./pretrained_models/landscape.pt --stylegan_size 256

⛳ Results

Zero-shot Audio Classification Accuracy.

Model Supervised Setting Zero-Shot ESC-50 UrbanSound 8K
ResNet50 ✅ - 66.8% 71.3%
Ours (Without Self-Supervised) - - 58.7% 63.3%
✨ Ours (Logistic Regression) - - 72.2% 66.8%
Wav2clip - ✅ 41.4% 40.4%
AudioCLIP - ✅ 69.4% 68.8%
Ours (Without Self-Supervised) - ✅ 49.4% 45.6%
✨ Ours - ✅ 57.8% 45.7%

Manipulation Results.

LSUN. LSUN image

FFHQ. FFHQ image

To see more diverse examples, please visit our project page!

Citation

@InProceedings{Lee_2022_CVPR,
    author    = {Lee, Seung Hyun and Roh, Wonseok and Byeon, Wonmin and Yoon, Sang Ho and Kim, Chanyoung and Kim, Jinkyu and Kim, Sangpil},
    title     = {Sound-Guided Semantic Image Manipulation},
    booktitle = {Proceedings of the IEEE/CVF Conference on Computer Vision and Pattern Recognition (CVPR)},
    month     = {June},
    year      = {2022},
    pages     = {3377-3386}
}
You might also like...

Official PyTorch implementation of the paper "Deep Constrained Least Squares for Blind Image Super-Resolution", CVPR 2022.

Official PyTorch implementation of the paper

Deep Constrained Least Squares for Blind Image Super-Resolution [Paper] This is the official implementation of 'Deep Constrained Least Squares for Bli

Sep 23, 2022

Official PyTorch Implementation for DiRA: Discriminative, Restorative, and Adversarial Learning for Self-supervised Medical Image Analysis - CVPR 2022

Official PyTorch Implementation for DiRA: Discriminative, Restorative, and Adversarial Learning for Self-supervised Medical Image Analysis - CVPR 2022

[CVPR'22] DiRA: Discriminative, Restorative, and Adversarial Learning for Self-supervised Medical Image Analysis This repository provides a PyTorch im

Sep 21, 2022

Official PyTorch implementation of our CVPR 2022 paper: Beyond a Pre-Trained Object Detector: Cross-Modal Textual and Visual Context for Image Captioning

Official PyTorch implementation of our CVPR 2022 paper: Beyond a Pre-Trained Object Detector: Cross-Modal Textual and Visual Context for Image Captioning

Beyond a Pre-Trained Object Detector: Cross-Modal Textual and Visual Context for Image Captioning This is the official PyTorch implementation of our C

Sep 28, 2022

Official code for ROCA: Robust CAD Model Retrieval and Alignment from a Single Image (CVPR 2022)

Official code for ROCA: Robust CAD Model Retrieval and Alignment from a Single Image (CVPR 2022)

ROCA: Robust CAD Model Alignment and Retrieval from a Single Image (CVPR 2022) Code release of our paper ROCA. Check out our video, paper, and website

Sep 16, 2022

[ACM MM 2022] Towards Counterfactual Image Manipulation via CLIP

CF-CLIP (Towards Counterfactual Image Manipulation via CLIP) This repository is an official PyTorch implementation of the ACM MM 2022 paper "Towards C

Sep 4, 2022

Pytorch code for "BSRT: Improving Burst Super-Resolution with Swin Transformer and Flow-Guided Deformable Alignment", CVPRW, 1st place in NTIRE 2022 BurstSR Challenge (real-world track).

Pytorch code for

BSRT: Improving Burst Super-Resolution with Swin Transformer and Flow-Guided Deformable Alignment (CVPRW 2022) BSRT, the winner of the NTIRE 2022 Burs

Sep 21, 2022

[CVPR 2022] Official Pytorch code for OW-DETR: Open-world Detection Transformer

[CVPR 2022] Official Pytorch code for OW-DETR: Open-world Detection Transformer

OW-DETR: Open-world Detection Transformer (CVPR 2022) [Paper] Akshita Gupta*, Sanath Narayan*, K J Joseph, Salman Khan, Fahad Shahbaz Khan, Mubarak Sh

Sep 22, 2022

[ECCV 2022] Official pytorch implementation of "mc-BEiT: Multi-choice Discretization for Image BERT Pre-training" in European Conference on Computer Vision (ECCV) 2022.

[ECCV 2022] Official pytorch implementation of

mc-BEiT: Multi-choice Discretization for Image BERT Pre-training Official pytorch implementation of "mc-BEiT: Multi-choice Discretization for Image BE

Sep 28, 2022

Image manipulation tool to split image into tiles.

Image manipulation tool to split image into tiles.

Python Image Splitter Image manipulation tool to split images into tiles. Inspired by: pnytko/splitter Requirements python3 pillow rich # optional Usi

Sep 23, 2022
Comments
  • Does part 1 of training (CLIP-based training) includes image modality?

    Does part 1 of training (CLIP-based training) includes image modality?

    Hi, Thanks for the great work!!

    The paper states that during part-1 training (i.e. CLIP-based Contrastive Latent Representation Learning step) you consider image, text and audio modalities. But the code only uses audio and text modality for this training part.

    Is this an old code? Or I misinterpreted the training part in paper? Thanks

    opened by sakshamsingh1 4
  • Computation requirement for training

    Computation requirement for training

    Hi,

    First of all, great contribution towards the field of image manipulation. Could you please provide information on how many GPUs and how much duration it took to train the model?

    Thanks, Himangi

    opened by HimangiM 1
  • add Gradio Demo for cvpr 2022 call for demos

    add Gradio Demo for cvpr 2022 call for demos

    Hi, would you be interested in adding sound-guided-semantic-image-manipulation to Hugging Face as a Gradio Web Demo for CVPR 2022 call for Demos? The Hub offers free hosting, and it would make your work more accessible and visible to the rest of the ML community. Models/datasets/spaces(web demos) can be added to a user account or organization similar to github.

    more info on CVPR call for demos: https://huggingface.co/CVPR

    and here are guides for adding web demo to the org

    How to add a Space: https://huggingface.co/blog/gradio-spaces

    Please let us know if you would be interested and if you have any questions, we can also help with the technical implementation.

    opened by AK391 0
  • About StyleGAN3

    About StyleGAN3

    The main code is borrowed from the link below Link : https://github.com/ouhenio/StyleGAN3-CLIP-notebooks

    StyleGAN3 + Our CLIP-based sound representation

    import sys
    
    import io
    import os, time, glob
    import pickle
    import shutil
    import numpy as np
    from PIL import Image
    import torch
    import torch.nn.functional as F
    import requests
    import torchvision.transforms as transforms
    import torchvision.transforms.functional as TF
    import clip
    import unicodedata
    import re
    from tqdm import tqdm
    from torchvision.transforms import Compose, Resize, ToTensor, Normalize
    from einops import rearrange
    from collections import OrderedDict
    
    import timm
    import librosa
    import cv2
    
    def make_transform(translate, angle):
        m = np.eye(3)
        s = np.sin(angle/360.0*np.pi*2)
        c = np.cos(angle/360.0*np.pi*2)
        m[0][0] = c
        m[0][1] = s
        m[0][2] = translate[0]
        m[1][0] = -s
        m[1][1] = c
        m[1][2] = translate[1]
        return m
        
    class AudioEncoder(torch.nn.Module):
        def __init__(self):
            super(AudioEncoder, self).__init__()
            self.conv = torch.nn.Conv2d(1, 3, (3, 3))
            self.feature_extractor = timm.create_model("resnet18", num_classes=512, pretrained=True)
    
        def forward(self, x):
            x = self.conv(x)
            x = self.feature_extractor(x)
            return x
    
    def copyStateDict(state_dict):
        if list(state_dict.keys())[0].startswith("module"):
            start_idx = 1
        else:
            start_idx = 0
        new_state_dict = OrderedDict()
        for k, v in state_dict.items():
            name = ".".join(k.split(".")[start_idx:])
            new_state_dict[name] = v
        return new_state_dict
    
    class CLIP(object):
      def __init__(self):
        clip_model = "ViT-B/32"
        self.model, _ = clip.load(clip_model)
        self.model = self.model.requires_grad_(False)
        self.normalize = transforms.Normalize(mean=[0.48145466, 0.4578275, 0.40821073],
                                              std=[0.26862954, 0.26130258, 0.27577711])
    
      @torch.no_grad()
      def embed_text(self, prompt):
          "Normalized clip text embedding."
          return norm1(self.model.encode_text(clip.tokenize(prompt).to(device)).float())
    
      def embed_cutout(self, image):
          "Normalized clip image embedding."
          # return norm1(self.model.encode_image(self.normalize(image)))
          return norm1(self.model.encode_image(image))
    
    tf = Compose([
      Resize(224),
      lambda x: torch.clamp((x+1)/2,min=0,max=1),
      ])
    
    def norm1(prompt):
        "Normalize to the unit sphere."
        return prompt / prompt.square().sum(dim=-1,keepdim=True).sqrt()
    
    def spherical_dist_loss(x, y):
        x = F.normalize(x, dim=-1)
        y = F.normalize(y, dim=-1)
        return (x - y).norm(dim=-1).div(2).arcsin().pow(2).mul(2)
    
    def prompts_dist_loss(x, targets, loss):
        if len(targets) == 1: # Keeps consitent results vs previous method for single objective guidance 
          return loss(x, targets[0])
        distances = [loss(x, target) for target in targets]
        return torch.stack(distances, dim=-1).sum(dim=-1)  
    
    class MakeCutouts(torch.nn.Module):
        def __init__(self, cut_size, cutn, cut_pow=1.):
            super().__init__()
            self.cut_size = cut_size
            self.cutn = cutn
            self.cut_pow = cut_pow
    
        def forward(self, input):
            sideY, sideX = input.shape[2:4]
            max_size = min(sideX, sideY)
            min_size = min(sideX, sideY, self.cut_size)
            cutouts = []
            for _ in range(self.cutn):
                size = int(torch.rand([])**self.cut_pow * (max_size - min_size) + min_size)
                offsetx = torch.randint(0, sideX - size + 1, ())
                offsety = torch.randint(0, sideY - size + 1, ())
                cutout = input[:, :, offsety:offsety + size, offsetx:offsetx + size]
                cutouts.append(F.adaptive_avg_pool2d(cutout, self.cut_size))
            return torch.cat(cutouts)
    
    make_cutouts = MakeCutouts(224, 32, 0.5)
    
    def embed_image(image):
      n = image.shape[0]
      cutouts = make_cutouts(image)
      embeds = clip_model.embed_cutout(cutouts)
      embeds = rearrange(embeds, '(cc n) c -> cc n c', n=n)
      return embeds
    
    def run(timestring):
      torch.manual_seed(seed)
    
      # Init
      # Sample 32 inits and choose the one closest to prompt
    
      with torch.no_grad():
        qs = []
        losses = []
        for _ in range(8):
          q = (G.mapping(torch.randn([4,G.mapping.z_dim], device=device), None, truncation_psi=0.7) - G.mapping.w_avg) / w_stds
          images = G.synthesis(q * w_stds + G.mapping.w_avg)
          embeds = embed_image(images.add(1).div(2))
          loss = prompts_dist_loss(embeds, targets, spherical_dist_loss).mean(0)
          i = torch.argmin(loss)
          qs.append(q[i])
          losses.append(loss[i])
        qs = torch.stack(qs)
        losses = torch.stack(losses)
        i = torch.argmin(losses)
        q = qs[i].unsqueeze(0).requires_grad_()
    
      w_init = (q * w_stds + G.mapping.w_avg).detach().clone()
      # Sampling loop
      q_ema = q
      opt = torch.optim.AdamW([q], lr=0.03, betas=(0.0,0.999))
      loop = tqdm(range(steps))
      for i in loop:
        opt.zero_grad()
        w = q * w_stds + G.mapping.w_avg
        image = G.synthesis(w , noise_mode='const')
        embed = embed_image(image.add(1).div(2))
        loss = 0.1 *  prompts_dist_loss(embed, targets, spherical_dist_loss).mean() + ((w - w_init) ** 2).mean()
        # loss = prompts_dist_loss(embed, targets, spherical_dist_loss).mean()
        loss.backward()
        opt.step()
        loop.set_postfix(loss=loss.item(), q_magnitude=q.std().item())
    
        q_ema = q_ema * 0.9 + q * 0.1
    
        final_code = q_ema * w_stds + G.mapping.w_avg
        final_code[:,6:,:] = w_init[:,6:,:]
        image = G.synthesis(final_code, noise_mode='const')
    
        if i % 10 == 9 or i % 10 == 0:
          # display(TF.to_pil_image(tf(image)[0]))
          print(f"Image {i}/{steps} | Current loss: {loss}")
          pil_image = TF.to_pil_image(image[0].add(1).div(2).clamp(0,1).cpu())
          os.makedirs(f'samples/{timestring}', exist_ok=True)
          pil_image.save(f'samples/{timestring}/{i:04}.jpg')
    
    
    device = torch.device('cuda:0')
    print('Using device:', device, file=sys.stderr)
    
    model_url = "./pretrained_models/stylegan3-r-afhqv2-512x512.pkl"
    
    with open(model_url, 'rb') as fp:
      G = pickle.load(fp)['G_ema'].to(device)
    
    zs = torch.randn([100000, G.mapping.z_dim], device=device)
    w_stds = G.mapping(zs, None).std(0)
    
    m = make_transform([0,0], 0)
    m = np.linalg.inv(m)
    G.synthesis.input.transform.copy_(torch.from_numpy(m))
    # audio_paths = "./audio/sweet-kitty-meow.wav"
    #audio_paths = "./audio/dog-sad.wav"
    audio_paths = "./audio/cartoon-voice-laugh.wav"
    steps = 200
    seed = 14 + 22
    #seed = 22
    
    audio_paths = [frase.strip() for frase in audio_paths.split("|") if frase]
    
    clip_model = CLIP()
    audio_encoder = AudioEncoder()
    audio_encoder.load_state_dict(copyStateDict(torch.load("./pretrained_models/resnet18.pth", map_location=device)))
    audio_encoder = audio_encoder.to(device)
    audio_encoder.eval()
    
    targets = []
    n_mels = 128
    time_length = 864
    resize_resolution = 512
    
    for audio_path in audio_paths:
        y, sr = librosa.load(audio_path, sr=44100)
        audio_inputs = librosa.feature.melspectrogram(y=y, sr=sr, n_mels=n_mels)
        audio_inputs = librosa.power_to_db(audio_inputs, ref=np.max) / 80.0 + 1
    
        zero = np.zeros((n_mels, time_length))
        h, w = audio_inputs.shape
        if w >= time_length:
            j = (w - time_length) // 2
            audio_inputs = audio_inputs[:,j:j+time_length]
        else:
            j = (time_length - w) // 2
            zero[:,:w] = audio_inputs[:,:w]
            audio_inputs = zero
        
        audio_inputs = cv2.resize(audio_inputs, (n_mels, resize_resolution))
        audio_inputs = np.array([audio_inputs])
        audio_inputs = torch.from_numpy(audio_inputs.reshape((1, 1, n_mels, resize_resolution))).float().to(device)
        with torch.no_grad():
            audio_embedding = audio_encoder(audio_inputs)
            audio_embedding = audio_embedding / audio_embedding.norm(dim=-1, keepdim=True)
        targets.append(audio_embedding)
    
    timestring = time.strftime('%Y%m%d%H%M%S')
    run(timestring)
    
    opened by lsh3163 0
Owner
CVLAB
CVLAB in Department of artificial intelligence, Korea University
CVLAB
Code for CVPR'2022 paper ✨ "Predict, Prevent, and Evaluate: Disentangled Text-Driven Image Manipulation Empowered by Pre-Trained Vision-Language Model"

PPE ✨ Repository for our CVPR'2022 paper: Predict, Prevent, and Evaluate: Disentangled Text-Driven Image Manipulation Empowered by Pre-Trained Vision-

Zipeng Xu 33 Sep 30, 2022
Code for CVPR 2022 CLEAR Challenge "This repository is the CLEAR Challenge 1st place methods for CVPR 2022 Workshop on Visual Perception and Learning in an Open World"

CLEAR | Starter Kit This repository is the CLEAR Challenge 1st place methods for CVPR 2022 Workshop on Visual Perception and Learning in an Open World

Tencent YouTu Research 5 Sep 9, 2022
Official PyTorch implementation of GroupViT: Semantic Segmentation Emerges from Text Supervision, CVPR 2022.

GroupViT: Semantic Segmentation Emerges from Text Supervision GroupViT is a framework for learning semantic segmentation purely from text captions wit

NVIDIA Research Projects 420 Sep 21, 2022
The code of '3D-Aware Semantic-Guided Generative Model for Human Synthesis' (ECCV 2022)

3D-Aware Semantic-Guided Generative Model for Human Synthesis (3D-SGAN) Official PyTorch implementation of our ECCV 2022 paper Camera Pose Semantic Te

Jichao Zhang 23 Sep 27, 2022
This repo contains the official implementation of ECCV 2022 paper "What to Hide from Your Students: Attention-Guided Masked Image Modeling"

What to Hide from Your Students: Attention-Guided Masked Image Modeling PyTorch implementation and pretrained models for AttMask. [arXiv] Pretrained m

Ioannis Kakogeorgiou 21 Sep 21, 2022
Image manipulation tool to merge tiled image into one image.

Python Image Merger Image manipulation tool to merge tiled image into one image. Inspired by: pnytko/merger Requirements python3 numpy natsort pillow

Mikołaj Badyl 3 Sep 23, 2022
[CVPR 2022] Back to Reality: Weakly-supervised 3D Object Detection with Shape-guided Label Enhancement

Back To Reality: Weakly-supervised 3D Object Detection with Shape-guided Label Enhancement Introduction This repo contains PyTorch implementation for

null 21 Jul 17, 2022
Code repository for GCT535 Sound Technology for Multimedia (Spring 2022)

gct535-2022 Code repository for GCT535 Sound Technology for Multimedia (Spring 2022) Instruction to set up the course Conda virtual environment We wil

Juhan Nam 5 Jun 18, 2022
Pytorch implementation of paper "DynaST: Dynamic Sparse Transformer for Exemplar-Guided Image Generation", ECCV 2022.

DynaST This is the pytorch implementation of the following ECCV 2022 paper: DynaST: Dynamic Sparse Transformer for Exemplar-Guided Image Generation So

null 24 Sep 26, 2022
This is the official Pytorch implementation of "Affine Medical Image Registration with Coarse-to-Fine Vision Transformer" (CVPR 2022), written by Tony C. W. Mok and Albert C. S. Chung.

Affine Medical Image Registration with Coarse-to-Fine Vision Transformer (C2FViT) This is the official Pytorch implementation of "Affine Medical Image

Tony Mok 47 Sep 16, 2022