This paper addresses the societal concerns arising from largescale text-to-image diffusion models for generating potentially harmful or copyrighted content. Existing models rely heavily on internet-crawled data, wherein problematic concepts persist due to incomplete filtration processes. While previous approaches somewhat alleviate the issue, they often rely on text-specified concepts, introducing challenges in accurately capturing nuanced concepts and aligning model knowledge with human understandings. In response, we propose a framework named Human Feedback Inversion (HFI), where human feedback on model-generated images is condensed into textual tokens guiding the mitigation or removal of problematic images. The proposed framework can be built upon existing techniques for the same purpose, enhancing their alignment with human judgment. By doing so, we simplify the training objective with a self-distillation-based technique, providing a strong baseline for concept removal. Our experimental results demonstrate our framework significantly reduces objectionable content generation while preserving image quality, contributing to the ethical deployment of AI in the public sphere.
We propose a method called SDD to prevent problematic content generation in text-to-image diffusion. We self-distill the diffusion model to guide the noise estimate conditioned on the target removal concept to match the unconditional one.
Prompt | "body" | COCO-30k | ||
---|---|---|---|---|
Method | % Nude | FID | LPIPS | CLIP |
SD | 74.18 | 21.348 | N/A | 0.2771 |
SD+NEG | 20.44 | 14.278 | 0.1954 | 0.2706 |
SLD-medium | 70.02 | 17.201 | 0.1015 | 0.2689 |
SLD-max | 4.30 | 13.634 | 0.1574 | 0.2709 |
SEGA | 72.04 | - | - | - |
ESD-u-3 | 43.30 | - | - | - |
ESD-x-3 | 14.32 | 13.808 | 0.1587 | 0.2690 |
SDD (ours) | 1.68 | 15.423 | 0.1797 | 0.2673 |
Prompt | "body" | "I2P" | COCO-30k | ||
---|---|---|---|---|---|
Method | % Nude | % Harm | FID | LPIPS | CLIP |
SD | 74.18 | 24.42 | 21.348 | N/A | 0.2771 |
SD+NEG | 63.78 | 9.51 | 18.021 | 0.1925 | 0.2659 |
SLD-medium | 74.16 | 7.42 | 14.794 | 0.4216 | 0.2720 |
SLD-max | 56.78 | 5.19 | 21.729 | 0.4377 | 0.2572 |
SEGA | 74.10 | 16.84 | - | - | - |
ESD-x-3 | 47.38 | 13.04 | 16.411 | 0.2036 | 0.2631 |
SDD (ours) | 12.62 | 5.03 | 15.142 | 0.2443 | 0.2560 |
def run_sdd(
unet: UNet2DConditionModel, scheduler: DDIMScheduler, text_encoder: CLIPTextModel,
concepts: List[str], n_iters: int=1500, m: float=0.999, s_g: float=3.0,
) -> UNet2DConditionModel:
unet_ema = deepcopy(unet)
c_0, c_s = text_encoder(""), text_encoder(", ".join(concepts))
for _ in range(n_iters):
c_p = text_encoder(concepts[i % len(concepts)]) # Iterate over concepts
until = torch.randint((1,), 0, scheduler.total_steps-1)
z_t = torch.randn((1, 4, 64, 64), 0, 1) # Initial Gaussian noise z_T
with torch.no_grad():
for i, t in enumerate(scheduler.timesteps):
e_0, e_p = unet_ema(z_t, t, c_0), unet_ema(z_t, t, c_p)
e_tilde = e_0 + s_g * (e_p - e_0) # Sample latents z_t from the EMA model
z_t = scheduler(z_t, e_tilde, t) # for T - t steps according to CFG
if i == until:
break
e_0, e_s = unet(z_t, t, c_0), unet(z_t, t, c_s)
loss = ((e_0.detach() - e_s) ** 2).mean() # L2-norm between two noise estimates
loss.backward()
optimizer.step()
optimizer.zero_grad()
with torch.no_grad():
for p, q in zip(unet_ema.parameters(), unet.parameters()):
p = m * p + (1 - m) * q # EMA update
return unet_ema
@article{kim2024safeguard,
title={Safeguard Text-to-Image Diffusion Models with Human Feedback Inversion},
author={Kim, Sanghyun and Jung, Seohyeon and Kim, Balhae and Choi, Moonseok and Shin, Jinwoo and Lee, Juho},
booktitle={European Conference on Computer Vision},
year={2024},
organization={Springer}
}