티스토리 뷰
[논문 100개 구현 2탄] Stable Diffusion(High-Resolution Image Synthesis with Latent Diffusion Models)
sikaro 2024. 11. 27. 17:27논문 구현에 앞서 확인해야 할 포인트
Read
1. 논문 제목(title)과 초록(abstract), 도표(figures) 읽기
2. 도입(introduction), 결론(conclusion), 도표(figures)를 읽고 필요없는 부분 생략
3. 수식은 처음 읽을 때는 과감하게 생략
4. 이해가 안되는 부분은 빼고 전체적으로 읽는다.
QnA
1. 저자가 뭘 해내고 싶어했는가?
2. 이 연구의 접근에서 중요한 요소는 무엇인가?
3. 당신(논문독자)는 스스로 이 논문을 이용할 수 있는가?
4. 당신이 참고하고 싶은 다른 레퍼런스에는 어떤 것이 있는가?
구현하기
1. 수식 이해하고 직접 연산하기
2. 코드 연습하기(오픈소스를 받아 직접 구현)
1. 논문 제목(title)과 초록(abstract), 도표(figures) 읽기
High-Resolution Image Synthesis with Latent Diffusion Models
본 논문의 제목은 3개의 부분으로 나누어진다.
고해상도, 이미지 합성, Diffusion 모델.
즉, 벌써 제목에서부터 볼 수 있듯이 고해상도의 이미지를 Latent Diffusion Models(LDMs)을 통해 합성하고자 했던 게 저자의 의도이다.
초록(abstract)에서 중요한 부분은 다음과 같다.
diffusion models (DMs) achieve state-of-the-art synthesis results on image data and beyond. Additionally, their formulation allows for a guiding mechanism to control the image generation process without retraining. However, since these models typically operate directly in pixel space, optimization of powerful DMs often consumes hundreds of GPU days and inference is expensive due to sequential evaluations.
기존 디퓨전 모델은 잘 작동하지만, 픽셀 공간에서 직접 작동하여 학습에 많은 계산 자원을 소모하고, 고해상도 이미지 생성에 있어 효율성이 떨어지는 문제점이 있다는 뜻.
In contrast to previous work, training diffusion models on such a representation allows for the first time to reach a near-optimal point between complexity reduction and detail preservation, greatly boosting visual fidelity. By introducing cross-attention layers into the model architecture, we turn diffusion models into powerful and flexible generators for general conditioning inputs such as text or bounding boxes and high-resolution synthesis becomes possible in a convolutional manner
결론적으로 이 문제 해결을 위해서 이미 사전 학습된 AutoEncoder의 잠재 공간(Latent Space)을 활용하고, Cross-Attension 레이어를 도입해 텍스트, 바운딩 박스와 같은 일반적인 조건부 입력을 처리하고, 고해상도 합성이 합성곱 방식(Convolutional manner)로 가능해지게 하였다.
즉, 계산 자원 소비를 줄이면서도 이미지 품질을 손상시키지 않도록 접근 방식을 만들었다고 할 수 있겠다.
2. 도입(introduction), 결론(conclusion), 도표(figures)를 읽고 필요없는 부분 생략
Introduction
- Diffusion Models(DMs)이 GAN 및 Autoregressive Models(ARMs) 대비 더 안정적이고, 다양한 데이터 분포를 잘 학습할 수 있음을 언급.
- 그러나 높은 계산 비용과 느린 학습 속도가 주요 한계.
- 이를 해결하기 위해 잠재 공간 학습과 효율적 조건부 생성 메커니즘 도입.
Conclusion
- Latent Diffusion Model(LDM)은 Diffusion Models의 학습 및 추론 효율성을 크게 개선.
- 다양한 조건부 이미지 생성에서 기존 SOTA 모델 대비 더 나은 성능을 달성.
- 텍스트-이미지 변환, 초해상도, 인페인팅 등에서 유연하고 강력한 성능을 보임.
Figures
- Figure 1: 다운샘플링 비율에 따른 품질 비교 (LDM이 픽셀 기반 모델 대비 우수).
- Figure 3: LDM의 크로스 어텐션 기반 조건부 생성 메커니즘.
3. 수식은 처음 읽을 때는 과감하게 생략
그래도 알고 가면 좋은 수식
cross-attenion Layer는 다음과 같이 정의된다.
Figure 3 부분에서 엔코더 결과에 곱해주는 Cross-attention Layer이다.
Diffusion 모델은 결국 노이즈 오염된 걸 복원하면서 만들게 되는데, zT가 노이즈가 오염된 이미지라고 생각하면 된다.
zT가 결국 Q, 그러니까 쿼리 값으로 들어가게 되고, 입력을 Input을 넣고 싶은 텍스트를 임베딩화하여 Key 값으로 가져가게 된다.
그렇게 두 개의 가중치를 dot product해서 계산을 한 다음에 Softmax 값으로 가중치 형태로 끌어내고
이것들을 다시 텍스트 이미지에 dot Product하는 전형적인 Cross-Attension 메커니즘을 사용하게 된다.
Loss로는 Perceptual Loss와 Patch based adversarial objective를 사용했다.
기존 손실 함수(예: L2 Loss, MSE)는 이미지의 픽셀 차이를 기반으로 계산되지만, 이 방법은 인간이 보기에는 유사하대도 실제로는 매우 다른 이미지를 만들 수 있다.
예를 들어, 두 이미지가 픽셀 단위로 거의 동일해 보이더라도, 실제로는 이미지의 구조나 세부 정보가 달라질 수 있다.
Perceptual Loss는 대신 사전 학습된 신경망(예: VGG)을 사용하여 두 이미지가 고차원적인 특성(예: 고수준의 시각적 패턴, 개체, 텍스처 등)을 기준으로 얼마나 유사한지를 평가한다.
- 일반적으로 VGG 네트워크와 같은 사전 훈련된 신경망의 중간 레이어 출력을 비교하여, 이미지 간의 구조적 차이를 평가합니다.
- 예를 들어, 이미지 A와 B의 특징 맵을 비교하여 이미지의 "느낌"이나 "구성"이 얼마나 유사한지 계산합니다.
이 논문에서는 Feature Map마다 거리 계산을 했다라고 생각하면 편하긴 하다.
Patch-based Adversarial Objective는 PatchGAN 또는 Patch-level GAN이라고도 불리며, 지역적인 이미지 정보를 평가하는 데 사용된다.
이 방식은 이미지 전체를 하나의 진짜/가짜로 구분하는 것이 아니라, 이미지를 작은 패치(patch) 단위로 나누어 각 패치가 진짜인지 가짜인지를 평가합니다.
Patch-based Adversarial Objective의 장점
- 지역적 특성에 집중하여 더 세밀한 텍스처와 미세한 디테일을 개선할 수 있습니다.
- 효율적으로 학습할 수 있으며, 전체 이미지보다 패치 단위로 학습하기 때문에 계산 자원을 절약할 수 있습니다.
기본적인 수식은 이미지 전체에서 패치화만 되어 있을 뿐, GAN하고 똑같다.
L1이나 L2 Loss 처럼 픽셀 단위 로스를 사용했을 때 나타날 수 있는 단점들(예: Blurriness)를 완화할 수 있다.
QnA
1. 저자가 뭘 해내고 싶어했는가?
저자는 고해상도 이미지 생성을 위한 효율적이고 계산 비용이 적은 모델을 제시하고자 했다.
이를 위해 Latent Diffusion Models를 도입하여 기존의 픽셀 공간에서의 비효율성을 해결하고, 다양한 조건부 생성 작업을 보다 효과적으로 수행할 수 있는 방법을 제시했다.
2. 이 연구의 접근에서 중요한 요소는 무엇인가?
- 잠재 공간(latent space)에서의 Diffusion Model 학습.
- Cross-attention layers를 사용하여 다양한 조건(condition)-텍스트을 유연하게 처리.
- 효율적 학습 및 추론을 위한 모델 설계로 계산 비용을 줄이고 성능을 향상.
- Perceptual Loss와 Patch based adversarial objective를 사용함으로서 자연스러운 느낌과 세밀한 디테일을 가능하도록 했다.
3. 당신(논문독자)는 스스로 이 논문을 이용할 수 있는가?
https://github.com/CompVis/latent-diffusion
4. 당신이 참고하고 싶은 다른 레퍼런스에는 어떤 것이 있는가?
- Diffusion Probabilistic Models - Diffusion Models의 기본 개념.
- GANs와 VQ-VAEs: Latent Diffusion Models 이전의 주요 이미지 생성 모델.
- Taming Transformers for High-Resolution Image Synthesis - LDM의 기초가 되는 VQGAN 접근.
구현하기
사실 Stable Diffusion을 구현한다는 건 쉬운 일이 아니다.
따라서 이번에는 오픈 소스 참고 코드를 보고, 어떻게 구조가 되어 있는지 살펴보기로 하자.
코드는 여기를 참고 했다.
2. 코드 연습하기(오픈소스를 받아 직접 구현)
총 4가지 부분이다.
- 잠재 공간(latent space)에서의 Diffusion Model 학습.
- Cross-attention layers를 사용하여 다양한 조건(condition)-텍스트을 유연하게 처리.
- 효율적 학습 및 추론을 위한 모델 설계로 계산 비용을 줄이고 성능을 향상.
- Perceptual Loss와 Patch based adversarial objective를 사용함으로서 자연스러운 느낌과 세밀한 디테일을 가능하도록 했다.
2번과 3번은 묶인다.
- 잠재 공간(latent space)에서의 Diffusion Model 학습.
import torch
import pytorch_lightning as pl
import torch.nn.functional as F
from contextlib import contextmanager
from ldm.modules.diffusionmodules.model import Encoder, Decoder
from ldm.modules.distributions.distributions import DiagonalGaussianDistribution
from ldm.util import instantiate_from_config
from ldm.modules.ema import LitEma
class AutoencoderKL(pl.LightningModule):
def __init__(self,
ddconfig,
lossconfig,
embed_dim,
ckpt_path=None,
ignore_keys=[],
image_key="image",
colorize_nlabels=None,
monitor=None,
ema_decay=None,
learn_logvar=False
):
super().__init__()
self.learn_logvar = learn_logvar
self.image_key = image_key
self.encoder = Encoder(**ddconfig)
self.decoder = Decoder(**ddconfig)
self.loss = instantiate_from_config(lossconfig)
assert ddconfig["double_z"]
self.quant_conv = torch.nn.Conv2d(2*ddconfig["z_channels"], 2*embed_dim, 1)
self.post_quant_conv = torch.nn.Conv2d(embed_dim, ddconfig["z_channels"], 1)
self.embed_dim = embed_dim
if colorize_nlabels is not None:
assert type(colorize_nlabels)==int
self.register_buffer("colorize", torch.randn(3, colorize_nlabels, 1, 1))
if monitor is not None:
self.monitor = monitor
self.use_ema = ema_decay is not None
if self.use_ema:
self.ema_decay = ema_decay
assert 0. < ema_decay < 1.
self.model_ema = LitEma(self, decay=ema_decay)
print(f"Keeping EMAs of {len(list(self.model_ema.buffers()))}.")
if ckpt_path is not None:
self.init_from_ckpt(ckpt_path, ignore_keys=ignore_keys)
def init_from_ckpt(self, path, ignore_keys=list()):
sd = torch.load(path, map_location="cpu")["state_dict"]
keys = list(sd.keys())
for k in keys:
for ik in ignore_keys:
if k.startswith(ik):
print("Deleting key {} from state_dict.".format(k))
del sd[k]
self.load_state_dict(sd, strict=False)
print(f"Restored from {path}")
@contextmanager
def ema_scope(self, context=None):
if self.use_ema:
self.model_ema.store(self.parameters())
self.model_ema.copy_to(self)
if context is not None:
print(f"{context}: Switched to EMA weights")
try:
yield None
finally:
if self.use_ema:
self.model_ema.restore(self.parameters())
if context is not None:
print(f"{context}: Restored training weights")
def on_train_batch_end(self, *args, **kwargs):
if self.use_ema:
self.model_ema(self)
def encode(self, x):
h = self.encoder(x)
moments = self.quant_conv(h)
posterior = DiagonalGaussianDistribution(moments)
return posterior
def decode(self, z):
z = self.post_quant_conv(z)
dec = self.decoder(z)
return dec
def forward(self, input, sample_posterior=True):
posterior = self.encode(input)
if sample_posterior:
z = posterior.sample()
else:
z = posterior.mode()
dec = self.decode(z)
return dec, posterior
def get_input(self, batch, k):
x = batch[k]
if len(x.shape) == 3:
x = x[..., None]
x = x.permute(0, 3, 1, 2).to(memory_format=torch.contiguous_format).float()
return x
def training_step(self, batch, batch_idx, optimizer_idx):
inputs = self.get_input(batch, self.image_key)
reconstructions, posterior = self(inputs)
if optimizer_idx == 0:
# train encoder+decoder+logvar
aeloss, log_dict_ae = self.loss(inputs, reconstructions, posterior, optimizer_idx, self.global_step,
last_layer=self.get_last_layer(), split="train")
self.log("aeloss", aeloss, prog_bar=True, logger=True, on_step=True, on_epoch=True)
self.log_dict(log_dict_ae, prog_bar=False, logger=True, on_step=True, on_epoch=False)
return aeloss
if optimizer_idx == 1:
# train the discriminator
discloss, log_dict_disc = self.loss(inputs, reconstructions, posterior, optimizer_idx, self.global_step,
last_layer=self.get_last_layer(), split="train")
self.log("discloss", discloss, prog_bar=True, logger=True, on_step=True, on_epoch=True)
self.log_dict(log_dict_disc, prog_bar=False, logger=True, on_step=True, on_epoch=False)
return discloss
def validation_step(self, batch, batch_idx):
log_dict = self._validation_step(batch, batch_idx)
with self.ema_scope():
log_dict_ema = self._validation_step(batch, batch_idx, postfix="_ema")
return log_dict
def _validation_step(self, batch, batch_idx, postfix=""):
inputs = self.get_input(batch, self.image_key)
reconstructions, posterior = self(inputs)
aeloss, log_dict_ae = self.loss(inputs, reconstructions, posterior, 0, self.global_step,
last_layer=self.get_last_layer(), split="val"+postfix)
discloss, log_dict_disc = self.loss(inputs, reconstructions, posterior, 1, self.global_step,
last_layer=self.get_last_layer(), split="val"+postfix)
self.log(f"val{postfix}/rec_loss", log_dict_ae[f"val{postfix}/rec_loss"])
self.log_dict(log_dict_ae)
self.log_dict(log_dict_disc)
return self.log_dict
def configure_optimizers(self):
lr = self.learning_rate
ae_params_list = list(self.encoder.parameters()) + list(self.decoder.parameters()) + list(
self.quant_conv.parameters()) + list(self.post_quant_conv.parameters())
if self.learn_logvar:
print(f"{self.__class__.__name__}: Learning logvar")
ae_params_list.append(self.loss.logvar)
opt_ae = torch.optim.Adam(ae_params_list,
lr=lr, betas=(0.5, 0.9))
opt_disc = torch.optim.Adam(self.loss.discriminator.parameters(),
lr=lr, betas=(0.5, 0.9))
return [opt_ae, opt_disc], []
def get_last_layer(self):
return self.decoder.conv_out.weight
@torch.no_grad()
def log_images(self, batch, only_inputs=False, log_ema=False, **kwargs):
log = dict()
x = self.get_input(batch, self.image_key)
x = x.to(self.device)
if not only_inputs:
xrec, posterior = self(x)
if x.shape[1] > 3:
# colorize with random projection
assert xrec.shape[1] > 3
x = self.to_rgb(x)
xrec = self.to_rgb(xrec)
log["samples"] = self.decode(torch.randn_like(posterior.sample()))
log["reconstructions"] = xrec
if log_ema or self.use_ema:
with self.ema_scope():
xrec_ema, posterior_ema = self(x)
if x.shape[1] > 3:
# colorize with random projection
assert xrec_ema.shape[1] > 3
xrec_ema = self.to_rgb(xrec_ema)
log["samples_ema"] = self.decode(torch.randn_like(posterior_ema.sample()))
log["reconstructions_ema"] = xrec_ema
log["inputs"] = x
return log
def to_rgb(self, x):
assert self.image_key == "segmentation"
if not hasattr(self, "colorize"):
self.register_buffer("colorize", torch.randn(3, x.shape[1], 1, 1).to(x))
x = F.conv2d(x, weight=self.colorize)
x = 2.*(x-x.min())/(x.max()-x.min()) - 1.
return x
class IdentityFirstStage(torch.nn.Module):
def __init__(self, *args, vq_interface=False, **kwargs):
self.vq_interface = vq_interface
super().__init__()
def encode(self, x, *args, **kwargs):
return x
def decode(self, x, *args, **kwargs):
return x
def quantize(self, x, *args, **kwargs):
if self.vq_interface:
return x, None, [None, None, None]
return x
def forward(self, x, *args, **kwargs):
return x
설명:
class AutoencoderKL(pl.LightningModule):
def __init__(self, ddconfig, lossconfig, embed_dim, ckpt_path=None, ignore_keys=[], image_key="image", colorize_nlabels=None, monitor=None, ema_decay=None, learn_logvar=False):
super().__init__()
self.learn_logvar = learn_logvar # 로그 분산 값 학습 여부
self.image_key = image_key
self.encoder = Encoder(**ddconfig) # Encoder: 이미지를 잠재 공간으로 변환하는 네트워크
self.decoder = Decoder(**ddconfig) # Decoder: 잠재 공간에서 이미지를 복원하는 네트워크
self.loss = instantiate_from_config(lossconfig) # 손실 함수 설정
assert ddconfig["double_z"] # 잠재 공간에 두 배의 채널을 사용하도록 설정
self.quant_conv = torch.nn.Conv2d(2 * ddconfig["z_channels"], 2 * embed_dim, 1) # 잠재 공간에서의 압축 처리
self.post_quant_conv = torch.nn.Conv2d(embed_dim, ddconfig["z_channels"], 1) # 복원 처리
self.embed_dim = embed_dim
if colorize_nlabels is not None: # 색상화 처리용 레이블
assert type(colorize_nlabels) == int
self.register_buffer("colorize", torch.randn(3, colorize_nlabels, 1, 1))
if monitor is not None:
self.monitor = monitor
self.use_ema = ema_decay is not None # EMA(Exponential Moving Average)를 사용할지 여부
if self.use_ema:
self.ema_decay = ema_decay
assert 0. < ema_decay < 1.
self.model_ema = LitEma(self, decay=ema_decay) # EMA 모델 초기화
print(f"Keeping EMAs of {len(list(self.model_ema.buffers()))}.")
if ckpt_path is not None:
self.init_from_ckpt(ckpt_path, ignore_keys=ignore_keys) # 체크포인트에서 모델 불러오기
- __init__: 이 부분은 AutoencoderKL 클래스의 초기화 함수로, 모델의 주요 모듈을 설정합니다.
- **encoder**와 **decoder**는 이미지를 잠재 공간으로 인코딩하고, 다시 원래 이미지로 복원하는 역할을 합니다.
- **quant_conv**와 **post_quant_conv**는 잠재 공간에서 이미지를 처리하는 컨볼루션 레이어입니다.
- EMA는 학습 중에 Exponential Moving Average를 저장하여 더 안정적인 결과를 얻기 위한 기법입니다.
- **init_from_ckpt**는 사전 학습된 체크포인트를 불러와 모델을 초기화하는 기능입니다.
2. EMA 관리 (ema_scope)
@contextmanager
def ema_scope(self, context=None):
if self.use_ema:
self.model_ema.store(self.parameters()) # 현재 파라미터를 EMA에 저장
self.model_ema.copy_to(self) # EMA 파라미터를 모델에 복사
if context is not None:
print(f"{context}: Switched to EMA weights")
try:
yield None
finally:
if self.use_ema:
self.model_ema.restore(self.parameters()) # 학습 중 원래 파라미터로 복원
if context is not None:
print(f"{context}: Restored training weights")
- ema_scope: 이 함수는 EMA 모델의 가중치를 학습 파라미터에 복사하여 사용하고, 학습 중에는 EMA 가중치를 적용하여 모델의 안정성을 높입니다.
- **store**와 **copy_to**는 EMA 파라미터를 모델에 반영하고,
- **restore**는 학습이 끝난 후 원래 파라미터로 복원하는 역할을 합니다.
- context는 EMA 가중치를 변경할 때 어떤 작업이 수행되는지 알리기 위한 출력 메시지를 나타냅니다.
3. 인코딩과 디코딩 (encode, decode)
def encode(self, x):
h = self.encoder(x) # 이미지를 인코딩하여 잠재 공간 표현을 생성
moments = self.quant_conv(h) # 잠재 공간 표현을 후처리
posterior = DiagonalGaussianDistribution(moments) # 잠재 공간의 분포를 가우시안으로 설정
return posterior
def decode(self, z):
z = self.post_quant_conv(z) # 잠재 공간의 후처리
dec = self.decoder(z) # 잠재 공간에서 이미지를 복원
return dec
- encode: 입력 이미지를 인코더를 사용하여 잠재 공간(latent space)으로 변환합니다. 변환된 잠재 표현을 **quant_conv**로 후처리한 후, 가우시안 분포로 모델링하여 Posterior를 생성합니다.
- decode: 잠재 공간에서 이미지를 복원하는 함수입니다. **post_quant_conv**로 후처리하고, 디코더를 통해 원본 이미지를 복원합니다.
4. 학습 단계 (training_step)
def training_step(self, batch, batch_idx, optimizer_idx):
inputs = self.get_input(batch, self.image_key) # 배치에서 입력 이미지를 얻음
reconstructions, posterior = self(inputs) # 이미지를 복원하고 posterior 계산
if optimizer_idx == 0:
# encoder + decoder + logvar 학습
aeloss, log_dict_ae = self.loss(inputs, reconstructions, posterior, optimizer_idx, self.global_step, last_layer=self.get_last_layer(), split="train")
self.log("aeloss", aeloss, prog_bar=True, logger=True, on_step=True, on_epoch=True)
return aeloss
if optimizer_idx == 1:
# discriminator 학습
discloss, log_dict_disc = self.loss(inputs, reconstructions, posterior, optimizer_idx, self.global_step, last_layer=self.get_last_layer(), split="train")
self.log("discloss", discloss, prog_bar=True, logger=True, on_step=True, on_epoch=True)
return discloss
training_step: 이 함수는 한 배치의 학습을 처리합니다.
- 입력 데이터를 **get_input**을 통해 가져오고, 모델을 통해 이미지 복원을 합니다.
- 두 가지 손실 함수가 있습니다:
- aeloss: Autoencoder 손실을 계산하여 인코더와 디코더를 학습합니다.
- discloss: 판별자(Discriminator) 손실을 계산하여 Discriminator를 학습합니다.
- **self.log()**를 사용해 학습 중 손실 값을 로깅합니다.
5. 최적화 설정 (configure_optimizers)
def configure_optimizers(self):
lr = self.learning_rate
ae_params_list = list(self.encoder.parameters()) + list(self.decoder.parameters()) + list(self.quant_conv.parameters()) + list(self.post_quant_conv.parameters())
if self.learn_logvar:
print(f"{self.__class__.__name__}: Learning logvar")
ae_params_list.append(self.loss.logvar)
opt_ae = torch.optim.Adam(ae_params_list, lr=lr, betas=(0.5, 0.9))
opt_disc = torch.optim.Adam(self.loss.discriminator.parameters(), lr=lr, betas=(0.5, 0.9))
return [opt_ae, opt_disc], []
- configure_optimizers: Autoencoder와 Discriminator 각각에 대해 Adam optimizer를 설정합니다.
- **opt_ae**는 Autoencoder의 파라미터들을 최적화합니다.
- **opt_disc**는 Discriminator의 파라미터들을 최적화합니다.
- 학습률 **lr**과 betas 값을 설정하여 모델 학습을 최적화합니다.
6. 이미지 로그 (log_images)
def log_images(self, batch, only_inputs=False, log_ema=False, **kwargs):
log = dict()
x = self.get_input(batch, self.image_key)
x = x.to(self.device)
if not only_inputs:
xrec, posterior = self(x)
log["samples"] = self.decode(torch.randn_like(posterior.sample())) # 샘플 이미지 생성
log["reconstructions"] = xrec # 복원된 이미지 저장
log["inputs"] = x
return log
log_images: 이 함수는 입력 이미지와 복원된 이미지를 로그로 기록하는 역할을 합니다.
- xrec는 복원된 이미지이며, **samples**는 새로운 샘플 이미지를 생성한 것입니다.
- EMA 모델을 사용하면, EMA 가중치로 복원된 이미지도 기록할 수 있습니다.
- Cross-attention layers를 사용하여 다양한 조건(condition)-텍스트을 유연하게 처리.
from inspect import isfunction
import math
import torch
import torch.nn.functional as F
from torch import nn, einsum
from einops import rearrange, repeat
from typing import Optional, Any
from ldm.modules.diffusionmodules.util import checkpoint
try:
import xformers
import xformers.ops
XFORMERS_IS_AVAILBLE = True
except:
XFORMERS_IS_AVAILBLE = False
# CrossAttn precision handling
import os
_ATTN_PRECISION = os.environ.get("ATTN_PRECISION", "fp32")
def exists(val):
return val is not None
def uniq(arr):
return{el: True for el in arr}.keys()
def default(val, d):
if exists(val):
return val
return d() if isfunction(d) else d
def max_neg_value(t):
return -torch.finfo(t.dtype).max
def init_(tensor):
dim = tensor.shape[-1]
std = 1 / math.sqrt(dim)
tensor.uniform_(-std, std)
return tensor
# feedforward
class GEGLU(nn.Module):
def __init__(self, dim_in, dim_out):
super().__init__()
self.proj = nn.Linear(dim_in, dim_out * 2)
def forward(self, x):
x, gate = self.proj(x).chunk(2, dim=-1)
return x * F.gelu(gate)
class FeedForward(nn.Module):
def __init__(self, dim, dim_out=None, mult=4, glu=False, dropout=0.):
super().__init__()
inner_dim = int(dim * mult)
dim_out = default(dim_out, dim)
project_in = nn.Sequential(
nn.Linear(dim, inner_dim),
nn.GELU()
) if not glu else GEGLU(dim, inner_dim)
self.net = nn.Sequential(
project_in,
nn.Dropout(dropout),
nn.Linear(inner_dim, dim_out)
)
def forward(self, x):
return self.net(x)
def zero_module(module):
"""
Zero out the parameters of a module and return it.
"""
for p in module.parameters():
p.detach().zero_()
return module
def Normalize(in_channels):
return torch.nn.GroupNorm(num_groups=32, num_channels=in_channels, eps=1e-6, affine=True)
class SpatialSelfAttention(nn.Module):
def __init__(self, in_channels):
super().__init__()
self.in_channels = in_channels
self.norm = Normalize(in_channels)
self.q = torch.nn.Conv2d(in_channels,
in_channels,
kernel_size=1,
stride=1,
padding=0)
self.k = torch.nn.Conv2d(in_channels,
in_channels,
kernel_size=1,
stride=1,
padding=0)
self.v = torch.nn.Conv2d(in_channels,
in_channels,
kernel_size=1,
stride=1,
padding=0)
self.proj_out = torch.nn.Conv2d(in_channels,
in_channels,
kernel_size=1,
stride=1,
padding=0)
def forward(self, x):
h_ = x
h_ = self.norm(h_)
q = self.q(h_)
k = self.k(h_)
v = self.v(h_)
# compute attention
b,c,h,w = q.shape
q = rearrange(q, 'b c h w -> b (h w) c')
k = rearrange(k, 'b c h w -> b c (h w)')
w_ = torch.einsum('bij,bjk->bik', q, k)
w_ = w_ * (int(c)**(-0.5))
w_ = torch.nn.functional.softmax(w_, dim=2)
# attend to values
v = rearrange(v, 'b c h w -> b c (h w)')
w_ = rearrange(w_, 'b i j -> b j i')
h_ = torch.einsum('bij,bjk->bik', v, w_)
h_ = rearrange(h_, 'b c (h w) -> b c h w', h=h)
h_ = self.proj_out(h_)
return x+h_
class CrossAttention(nn.Module):
def __init__(self, query_dim, context_dim=None, heads=8, dim_head=64, dropout=0.):
super().__init__()
inner_dim = dim_head * heads
context_dim = default(context_dim, query_dim)
self.scale = dim_head ** -0.5
self.heads = heads
self.to_q = nn.Linear(query_dim, inner_dim, bias=False)
self.to_k = nn.Linear(context_dim, inner_dim, bias=False)
self.to_v = nn.Linear(context_dim, inner_dim, bias=False)
self.to_out = nn.Sequential(
nn.Linear(inner_dim, query_dim),
nn.Dropout(dropout)
)
def forward(self, x, context=None, mask=None):
h = self.heads
q = self.to_q(x)
context = default(context, x)
k = self.to_k(context)
v = self.to_v(context)
q, k, v = map(lambda t: rearrange(t, 'b n (h d) -> (b h) n d', h=h), (q, k, v))
# force cast to fp32 to avoid overflowing
if _ATTN_PRECISION =="fp32":
with torch.autocast(enabled=False, device_type = 'cuda'):
q, k = q.float(), k.float()
sim = einsum('b i d, b j d -> b i j', q, k) * self.scale
else:
sim = einsum('b i d, b j d -> b i j', q, k) * self.scale
del q, k
if exists(mask):
mask = rearrange(mask, 'b ... -> b (...)')
max_neg_value = -torch.finfo(sim.dtype).max
mask = repeat(mask, 'b j -> (b h) () j', h=h)
sim.masked_fill_(~mask, max_neg_value)
# attention, what we cannot get enough of
sim = sim.softmax(dim=-1)
out = einsum('b i j, b j d -> b i d', sim, v)
out = rearrange(out, '(b h) n d -> b n (h d)', h=h)
return self.to_out(out)
class MemoryEfficientCrossAttention(nn.Module):
# https://github.com/MatthieuTPHR/diffusers/blob/d80b531ff8060ec1ea982b65a1b8df70f73aa67c/src/diffusers/models/attention.py#L223
def __init__(self, query_dim, context_dim=None, heads=8, dim_head=64, dropout=0.0):
super().__init__()
print(f"Setting up {self.__class__.__name__}. Query dim is {query_dim}, context_dim is {context_dim} and using "
f"{heads} heads.")
inner_dim = dim_head * heads
context_dim = default(context_dim, query_dim)
self.heads = heads
self.dim_head = dim_head
self.to_q = nn.Linear(query_dim, inner_dim, bias=False)
self.to_k = nn.Linear(context_dim, inner_dim, bias=False)
self.to_v = nn.Linear(context_dim, inner_dim, bias=False)
self.to_out = nn.Sequential(nn.Linear(inner_dim, query_dim), nn.Dropout(dropout))
self.attention_op: Optional[Any] = None
def forward(self, x, context=None, mask=None):
q = self.to_q(x)
context = default(context, x)
k = self.to_k(context)
v = self.to_v(context)
b, _, _ = q.shape
q, k, v = map(
lambda t: t.unsqueeze(3)
.reshape(b, t.shape[1], self.heads, self.dim_head)
.permute(0, 2, 1, 3)
.reshape(b * self.heads, t.shape[1], self.dim_head)
.contiguous(),
(q, k, v),
)
# actually compute the attention, what we cannot get enough of
out = xformers.ops.memory_efficient_attention(q, k, v, attn_bias=None, op=self.attention_op)
if exists(mask):
raise NotImplementedError
out = (
out.unsqueeze(0)
.reshape(b, self.heads, out.shape[1], self.dim_head)
.permute(0, 2, 1, 3)
.reshape(b, out.shape[1], self.heads * self.dim_head)
)
return self.to_out(out)
class BasicTransformerBlock(nn.Module):
ATTENTION_MODES = {
"softmax": CrossAttention, # vanilla attention
"softmax-xformers": MemoryEfficientCrossAttention
}
def __init__(self, dim, n_heads, d_head, dropout=0., context_dim=None, gated_ff=True, checkpoint=True,
disable_self_attn=False):
super().__init__()
attn_mode = "softmax-xformers" if XFORMERS_IS_AVAILBLE else "softmax"
assert attn_mode in self.ATTENTION_MODES
attn_cls = self.ATTENTION_MODES[attn_mode]
self.disable_self_attn = disable_self_attn
self.attn1 = attn_cls(query_dim=dim, heads=n_heads, dim_head=d_head, dropout=dropout,
context_dim=context_dim if self.disable_self_attn else None) # is a self-attention if not self.disable_self_attn
self.ff = FeedForward(dim, dropout=dropout, glu=gated_ff)
self.attn2 = attn_cls(query_dim=dim, context_dim=context_dim,
heads=n_heads, dim_head=d_head, dropout=dropout) # is self-attn if context is none
self.norm1 = nn.LayerNorm(dim)
self.norm2 = nn.LayerNorm(dim)
self.norm3 = nn.LayerNorm(dim)
self.checkpoint = checkpoint
def forward(self, x, context=None):
return checkpoint(self._forward, (x, context), self.parameters(), self.checkpoint)
def _forward(self, x, context=None):
x = self.attn1(self.norm1(x), context=context if self.disable_self_attn else None) + x
x = self.attn2(self.norm2(x), context=context) + x
x = self.ff(self.norm3(x)) + x
return x
class SpatialTransformer(nn.Module):
"""
Transformer block for image-like data.
First, project the input (aka embedding)
and reshape to b, t, d.
Then apply standard transformer action.
Finally, reshape to image
NEW: use_linear for more efficiency instead of the 1x1 convs
"""
def __init__(self, in_channels, n_heads, d_head,
depth=1, dropout=0., context_dim=None,
disable_self_attn=False, use_linear=False,
use_checkpoint=True):
super().__init__()
if exists(context_dim) and not isinstance(context_dim, list):
context_dim = [context_dim]
self.in_channels = in_channels
inner_dim = n_heads * d_head
self.norm = Normalize(in_channels)
if not use_linear:
self.proj_in = nn.Conv2d(in_channels,
inner_dim,
kernel_size=1,
stride=1,
padding=0)
else:
self.proj_in = nn.Linear(in_channels, inner_dim)
self.transformer_blocks = nn.ModuleList(
[BasicTransformerBlock(inner_dim, n_heads, d_head, dropout=dropout, context_dim=context_dim[d],
disable_self_attn=disable_self_attn, checkpoint=use_checkpoint)
for d in range(depth)]
)
if not use_linear:
self.proj_out = zero_module(nn.Conv2d(inner_dim,
in_channels,
kernel_size=1,
stride=1,
padding=0))
else:
self.proj_out = zero_module(nn.Linear(in_channels, inner_dim))
self.use_linear = use_linear
def forward(self, x, context=None):
# note: if no context is given, cross-attention defaults to self-attention
if not isinstance(context, list):
context = [context]
b, c, h, w = x.shape
x_in = x
x = self.norm(x)
if not self.use_linear:
x = self.proj_in(x)
x = rearrange(x, 'b c h w -> b (h w) c').contiguous()
if self.use_linear:
x = self.proj_in(x)
for i, block in enumerate(self.transformer_blocks):
x = block(x, context=context[i])
if self.use_linear:
x = self.proj_out(x)
x = rearrange(x, 'b (h w) c -> b c h w', h=h, w=w).contiguous()
if not self.use_linear:
x = self.proj_out(x)
return x + x_in
1. Cross-Attention Layer
class CrossAttention(nn.Module):
def __init__(self, query_dim, context_dim=None, heads=8, dim_head=64, dropout=0.):
super().__init__()
inner_dim = dim_head * heads
context_dim = default(context_dim, query_dim)
self.scale = dim_head ** -0.5
self.heads = heads
self.to_q = nn.Linear(query_dim, inner_dim, bias=False)
self.to_k = nn.Linear(context_dim, inner_dim, bias=False)
self.to_v = nn.Linear(context_dim, inner_dim, bias=False)
self.to_out = nn.Sequential(
nn.Linear(inner_dim, query_dim),
nn.Dropout(dropout)
)
- Cross-Attention은 쿼리(query), 키(key), **값(value)**를 각각 다른 입력에서 가져와서 그들의 상호작용을 학습하는 어텐션 메커니즘입니다.
- query_dim, context_dim, dim_head, heads 등은 어텐션 헤드의 차원과 개수를 설정하며, 이들은 주로 텍스트와 이미지 간의 관계를 모델링할 때 사용됩니다.
- **to_q, to_k, to_v**는 각각 쿼리, 키, 값을 위한 선형 변환을 정의합니다.
- **to_out**은 최종 어텐션 결과를 다시 입력 차원으로 변환합니다.
forward 메서드
def forward(self, x, context=None, mask=None):
q = self.to_q(x) # 쿼리 변환
context = default(context, x)
k = self.to_k(context) # 키 변환
v = self.to_v(context) # 값 변환
q, k, v = map(lambda t: rearrange(t, 'b n (h d) -> (b h) n d', h=self.heads), (q, k, v))
sim = einsum('b i d, b j d -> b i j', q, k) * self.scale # 어텐션 유사도 계산
sim = sim.softmax(dim=-1) # 소프트맥스 적용
out = einsum('b i j, b j d -> b i d', sim, v) # 어텐션 값 계산
out = rearrange(out, '(b h) n d -> b n (h d)', h=self.heads)
return self.to_out(out)
- forward 함수에서는 주어진 쿼리(x), 컨텍스트(context), 키, 값을 선형 변환을 통해 어텐션 계산을 수행합니다.
- **einsum**을 이용해 쿼리와 키 간의 **유사도(similarity)**를 계산한 후, **소프트맥스(softmax)**를 통해 확률 분포로 변환합니다.
- 최종적으로 어텐션 값을 **값(v)**에 적용하여 출력을 생성합니다.
2. Memory-Efficient Cross-Attention
class MemoryEfficientCrossAttention(nn.Module):
def __init__(self, query_dim, context_dim=None, heads=8, dim_head=64, dropout=0.0):
super().__init__()
self.to_q = nn.Linear(query_dim, inner_dim, bias=False)
self.to_k = nn.Linear(context_dim, inner_dim, bias=False)
self.to_v = nn.Linear(context_dim, inner_dim, bias=False)
self.to_out = nn.Sequential(nn.Linear(inner_dim, query_dim), nn.Dropout(dropout))
self.attention_op: Optional[Any] = None
- **MemoryEfficientCrossAttention**는 **CrossAttention**의 메모리 효율성을 개선한 버전입니다. 이 클래스는 xformers라는 라이브러리를 사용하여 메모리 효율적인 어텐션을 제공합니다.
- to_q, to_k, to_v는 쿼리, 키, 값을 선형 변환하는 역할을 하며, **to_out**은 최종 출력을 생성합니다.
forward 메서드
def forward(self, x, context=None, mask=None):
q = self.to_q(x) # 쿼리 변환
context = default(context, x)
k = self.to_k(context) # 키 변환
v = self.to_v(context) # 값 변환
b, _, _ = q.shape
q, k, v = map(
lambda t: t.unsqueeze(3)
.reshape(b, t.shape[1], self.heads, self.dim_head)
.permute(0, 2, 1, 3)
.reshape(b * self.heads, t.shape[1], self.dim_head)
.contiguous(),
(q, k, v),
)
out = xformers.ops.memory_efficient_attention(q, k, v, attn_bias=None, op=self.attention_op)
out = (
out.unsqueeze(0)
.reshape(b, self.heads, out.shape[1], self.dim_head)
.permute(0, 2, 1, 3)
.reshape(b, out.shape[1], self.heads * self.dim_head)
)
return self.to_out(out)
- 메모리 효율적인 어텐션을 위해 **xformers.ops.memory_efficient_attention**을 사용하여 어텐션을 계산합니다.
- **q, k, v**를 헤드 차원에 맞게 변형하고, 어텐션 계산 후 출력 값을 적절히 재형성하여 최종 결과를 반환합니다.
3. Basic Transformer Block
class BasicTransformerBlock(nn.Module):
ATTENTION_MODES = {
"softmax": CrossAttention,
"softmax-xformers": MemoryEfficientCrossAttention
}
def __init__(self, dim, n_heads, d_head, dropout=0., context_dim=None, gated_ff=True, checkpoint=True, disable_self_attn=False):
super().__init__()
attn_mode = "softmax-xformers" if XFORMERS_IS_AVAILBLE else "softmax"
attn_cls = self.ATTENTION_MODES[attn_mode]
self.attn1 = attn_cls(query_dim=dim, heads=n_heads, dim_head=d_head, dropout=dropout, context_dim=context_dim if self.disable_self_attn else None)
self.ff = FeedForward(dim, dropout=dropout, glu=gated_ff)
self.attn2 = attn_cls(query_dim=dim, context_dim=context_dim, heads=n_heads, dim_head=d_head, dropout=dropout)
self.norm1 = nn.LayerNorm(dim)
self.norm2 = nn.LayerNorm(dim)
self.norm3 = nn.LayerNorm(dim)
self.checkpoint = checkpoint
BasicTransformerBlock은 Transformer 모델에서 어텐션과 피드포워드 네트워크를 조합한 기본 블록입니다.
- 어텐션 모드는 softmax(기본 CrossAttention) 또는 xformers(메모리 효율적인 어텐션)을 선택합니다.
- attn1과 attn2는 두 개의 어텐션 레이어를 정의합니다. 첫 번째 어텐션은 self-attention 또는 cross-attention을 처리하고, 두 번째 어텐션은 cross-attention을 처리합니다.
- FeedForward는 입력을 받아 비선형 변환 후 출력을 생성합니다.
forward 메서드
def forward(self, x, context=None):
x = self.attn1(self.norm1(x), context=context if self.disable_self_attn else None) + x
x = self.attn2(self.norm2(x), context=context) + x
x = self.ff(self.norm3(x)) + x
return x
- **forward**는 입력을 받아 어텐션과 피드포워드 네트워크를 통과시킨 후, 잔차 연결(residual connection)을 사용하여 출력합니다.
- 첫 번째 self-attention 또는 cross-attention, 두 번째 cross-attention, 마지막으로 피드포워드 네트워크를 거칩니다.
4. Spatial Transformer
class SpatialTransformer(nn.Module):
def __init__(self, in_channels, n_heads, d_head, depth=1, dropout=0., context_dim=None, disable_self_attn=False, use_linear=False, use_checkpoint=True):
super().__init__()
self.in_channels = in_channels
inner_dim = n_heads * d_head
self.norm = Normalize(in_channels)
self.proj_in = nn.Conv2d(in_channels, inner_dim, kernel_size=1, stride=1, padding=0) if not use_linear else nn.Linear(in_channels, inner_dim)
self.transformer_blocks = nn.ModuleList(
[BasicTransformerBlock(inner_dim, n_heads, d_head, dropout=dropout, context_dim=context_dim[d], disable_self_attn=disable_self_attn, checkpoint=use_checkpoint)
for d in range(depth)]
)
self.proj_out = zero_module(nn.Conv2d(inner_dim, in_channels, kernel_size=1, stride=1, padding=0)) if not use_linear else zero_module(nn.Linear(in_channels, inner_dim))
Spatial Transformer는 이미지 데이터를 위한 Transformer입니다. 입력을 임베딩하여 Transformer를 적용한 후, 다시 이미지를 출력하는 구조입니다.
- **proj_in**은 이미지를 임베딩하는 역할을 하고, **proj_out**은 이미지를 복원하는 역할을 합니다.
- **transformer_blocks**는 여러 개의 Transformer 블록을 적용하여 이미지의 공간적 특성을 학습합니다.
class UNetModel(nn.Module):
def __init__(
self,
in_channels,
model_channels,
out_channels,
num_res_blocks,
attention_resolutions,
dropout=0,
channel_mult=(1, 2, 4, 8),
conv_resample=True,
num_heads=8
):
super().__init__()
self.in_channels = in_channels
self.model_channels = model_channels
self.out_channels = out_channels
self.num_res_blocks = num_res_blocks
self.attention_resolutions = attention_resolutions
self.dropout = dropout
self.channel_mult = channel_mult
self.conv_resample = conv_resample
self.num_heads = num_heads
# 다운샘플링 블록
self.down = nn.ModuleList([
DownBlock(
ch,
use_conv=conv_resample,
use_attn=(i in attention_resolutions)
)
for i, ch in enumerate(channel_mult)
])
# 업샘플링 블록
self.up = nn.ModuleList([
UpBlock(
ch,
use_conv=conv_resample,
use_attn=(i in attention_resolutions)
)
for i, ch in enumerate(reversed(channel_mult))
])
def forward(self, x, timesteps, context=None):
# 시간 임베딩
t_emb = self.time_embed(timesteps)
# 다운샘플링 패스
h = x
hs = []
for block in self.down:
h = block(h, t_emb, context)
hs.append(h)
# 업샘플링 패스
for block in self.up:
h = block(h, hs.pop(), t_emb, context)
return h
- Perceptual Loss와 Patch based adversarial objective를 사용함으로서 자연스러운 느낌과 세밀한 디테일을 가능하도록 했다.
이 부분은 U-Net에 정의가 되어 있어야 하는데, 내가 못찾은 건지 빼먹은 건지 Loss 부분이 없다.
정확히는 ddpm에서 있긴 한데, L1 loss와 L2 loss로 되어 있어서 동작이 되는건지 알 수 없다.
def get_loss(self, pred, target, mean=True):
if self.loss_type == 'l1':
loss = (target - pred).abs()
if mean:
loss = loss.mean()
elif self.loss_type == 'l2':
if mean:
loss = torch.nn.functional.mse_loss(target, pred)
else:
loss = torch.nn.functional.mse_loss(target, pred, reduction='none')
else:
raise NotImplementedError("unknown loss type '{loss_type}'")
return loss
def p_losses(self, x_start, t, noise=None):
noise = default(noise, lambda: torch.randn_like(x_start))
x_noisy = self.q_sample(x_start=x_start, t=t, noise=noise)
model_out = self.model(x_noisy, t)
loss_dict = {}
if self.parameterization == "eps":
target = noise
elif self.parameterization == "x0":
target = x_start
elif self.parameterization == "v":
target = self.get_v(x_start, noise, t)
else:
raise NotImplementedError(f"Parameterization {self.parameterization} not yet supported")
loss = self.get_loss(model_out, target, mean=False).mean(dim=[1, 2, 3])
log_prefix = 'train' if self.training else 'val'
loss_dict.update({f'{log_prefix}/loss_simple': loss.mean()})
loss_simple = loss.mean() * self.l_simple_weight
loss_vlb = (self.lvlb_weights[t] * loss).mean()
loss_dict.update({f'{log_prefix}/loss_vlb': loss_vlb})
loss = loss_simple + self.original_elbo_weight * loss_vlb
loss_dict.update({f'{log_prefix}/loss': loss})
return loss, loss_dict
이미지 패치를 가져오는 부분은 있다. 아마도 이 부분을 통해서 adversarial objetive를 사용하게 되는거 같다.
def patches_from_image(img, p_size=512, p_overlap=64, p_max=800):
w, h = img.shape[:2]
patches = []
if w > p_max and h > p_max:
w1 = list(np.arange(0, w-p_size, p_size-p_overlap, dtype=np.int))
h1 = list(np.arange(0, h-p_size, p_size-p_overlap, dtype=np.int))
w1.append(w-p_size)
h1.append(h-p_size)
# print(w1)
# print(h1)
for i in w1:
for j in h1:
patches.append(img[i:i+p_size, j:j+p_size,:])
else:
patches.append(img)
return patches
def split_imageset(original_dataroot, taget_dataroot, n_channels=3, p_size=800, p_overlap=96, p_max=1000):
"""
split the large images from original_dataroot into small overlapped images with size (p_size)x(p_size),
and save them into taget_dataroot; only the images with larger size than (p_max)x(p_max)
will be splitted.
Args:
original_dataroot:
taget_dataroot:
p_size: size of small images
p_overlap: patch size in training is a good choice
p_max: images with smaller size than (p_max)x(p_max) keep unchanged.
"""
paths = get_image_paths(original_dataroot)
for img_path in paths:
# img_name, ext = os.path.splitext(os.path.basename(img_path))
img = imread_uint(img_path, n_channels=n_channels)
patches = patches_from_image(img, p_size, p_overlap, p_max)
imssave(patches, os.path.join(taget_dataroot,os.path.basename(img_path)))
#if original_dataroot == taget_dataroot:
#del img_path
마치며
이상으로 Stable Diffusion의 이해와 코드를 대략적으로 살펴보았다.
개조된 버전의 코드이긴 하지만, 전체적인 플로우는 같으므로 어느정도 이해는 되었을 거라 믿는다.
만약 아니라면, 그러므로 논문을 이해하는 것도 중요하지만 코드로 구현하는 게 더 중요하다는 걸 깨닫는 게 도움이 될 거 같다.