Towards Safe Self-Distillation of Internet-Scale Text-to-Image Diffusion Models

KAIST, Republic of Korea
ICML 2023 - Challenges in Deployable Generative AI Workshop

Artist concept removal on the generated images with the same initial latent code.

Abstract

Large-scale image generation models, with impressive quality made possible by the vast amount of data available on the Internet, raise social concerns that these models may generate harmful or copyrighted content. The biases and harmfulness arise throughout the entire training process and are hard to completely remove, which have become significant hurdles to the safe deployment of these models. In this paper, we propose a method called SDD to prevent problematic content generation in text-to-image diffusion models. We self-distill the diffusion model to guide the noise estimate conditioned on the target removal concept to match the unconditional one. Compared to the previous methods, our method eliminates a much greater proportion of harmful content from the generated images without degrading the overall image quality. Furthermore, our method allows the removal of multiple concepts at once, whereas previous works are limited to removing a single concept at a time.

Method

Concept figure

We propose a method called SDD to prevent problematic content generation in text-to-image diffusion. We self-distill the diffusion model to guide the noise estimate conditioned on the target removal concept to match the unconditional one.

Quantitative Results

NSFW Removal

Prompt "body" COCO-30k
Method % Nude FID LPIPS CLIP
SD 74.18 21.348 N/A 0.2771
SD+NEG 20.44 14.278 0.1954 0.2706
SLD-medium 70.02 17.201 0.1015 0.2689
SLD-max 4.30 13.634 0.1574 0.2709
SEGA 72.04 - - -
ESD-u-3 43.30 - - -
ESD-x-3 14.32 13.808 0.1587 0.2690
SDD (ours) 1.68 15.423 0.1797 0.2673

I2P Multi-Concept Removal

Prompt "body" "I2P" COCO-30k
Method % Nude % Harm FID LPIPS CLIP
SD 74.18 24.42 21.348 N/A 0.2771
SD+NEG 63.78 9.51 18.021 0.1925 0.2659
SLD-medium 74.16 7.42 14.794 0.4216 0.2720
SLD-max 56.78 5.19 21.729 0.4377 0.2572
SEGA 74.10 16.84 - - -
ESD-x-3 47.38 13.04 16.411 0.2036 0.2631
SDD (ours) 12.62 5.03 15.142 0.2443 0.2560

Artist Concept Removal

Resources

Poster

Pseudo-code

def run_sdd(
  unet: UNet2DConditionModel, scheduler: DDIMScheduler, text_encoder: CLIPTextModel,
  concepts: List[str], n_iters: int=1500, m: float=0.999, s_g: float=3.0,
) -> UNet2DConditionModel:
  unet_ema = deepcopy(unet)
  c_0, c_s = text_encoder(""), text_encoder(", ".join(concepts))
  for _ in range(n_iters):
    c_p = text_encoder(concepts[i % len(concepts)]) # Iterate over concepts
    until = torch.randint((1,), 0, scheduler.total_steps-1)
    z_t = torch.randn((1, 4, 64, 64), 0, 1) # Initial Gaussian noise z_T
    with torch.no_grad():
      for i, t in enumerate(scheduler.timesteps):
        e_0, e_p = unet_ema(z_t, t, c_0), unet_ema(z_t, t, c_p)
        e_tilde = e_0 + s_g * (e_p - e_0) # Sample latents z_t from the EMA model
        z_t = scheduler(z_t, e_tilde, t)  # for T - t steps according to CFG
        if i == until:
          break
    e_0, e_s = unet(z_t, t, c_0), unet(z_t, t, c_s)
    loss = ((e_0.detach() - e_s) ** 2).mean() # L2-norm between two noise estimates
    loss.backward()
    optimizer.step()
    optimizer.zero_grad()
    with torch.no_grad():
      for p, q in zip(unet_ema.parameters(), unet.parameters()):
        p = m * p + (1 - m) * q # EMA update
  return unet_ema
    

BibTeX

@misc{kim2023safe,
  title={Towards Safe Self-Distillation of Internet-Scale Text-to-Image Diffusion Models}, 
  author={Sanghyun Kim and Seohyeon Jung and Balhae Kim and Moonseok Choi and Jinwoo Shin and Juho Lee},
  year={2023},
  eprint={2307.05977},
  archivePrefix={arXiv},
  primaryClass={cs.CV},
  booktitle={ICML 2023 Workshop on Challenges in Deployable Generative AI},
}