티스토리 뷰

반응형

지난 시간에는 Transformer 구조의 한계를 초월하기 위해서, 상태 공간 모델을 사용한 Mamba를 다루었다.

본격적으로 Jamba 구조를 다루기 전에, 그렇다면 과연 Mamba가 나오기 전에 어떤 방법들로 Transformer의 시간을 줄이기 위해 노력했는지를 보려고 한다.

사실 관련 논문들이 너무나도 많지만, 대표적으로 가장 유명한 걸 꼽으라 하면 FlashAttention이지 않을까 싶다.

사용하기도 편하고, 설명도 직관적이다. 그렇기에 정말 괜찮은 구현이라고 할 수 있다.

언제나 오는 템플릿을 보고 한 번 시작해보자.

논문 구현에 앞서 확인해야 할 포인트

Read

1. 논문 제목(title)과 초록(abstract), 도표(figures) 읽기

2. 도입(introduction), 결론(conclusion), 도표(figures)를 읽고 필요없는 부분 생략

3. 수식은 처음 읽을 때는 과감하게 생략

4. 이해가 안되는 부분은 빼고 전체적으로 읽는다.

QnA

1. 저자가 뭘 해내고 싶어했는가?

2. 이 연구의 접근에서 중요한 요소는 무엇인가?

3. 당신(논문독자)는 스스로 이 논문을 이용할 수 있는가?

4. 당신이 참고하고 싶은 다른 레퍼런스에는 어떤 것이 있는가?

구현하기

1. 수식 이해하고 직접 연산하기

2. 코드 연습하기(오픈소스를 받아 직접 구현)

Read

1단계: 제목, 초록, 도표 읽기

제목: "FlashAttention: Fast and Memory-Efficient Exact Attention with IO-Awareness"

제목에서 보다시피 저자는 Memory-Efficient Exact Attention을 IO-Awareness로 구현하려고 의도 했다는 걸 알 수 있다.

IO-awareness(입출력 인식)는 GPU 메모리 계층 간의 데이터 읽기 및 쓰기 작업을 최적화하여 계산 작업의 성능을 개선하는 개념. 구체적으로, 이는 GPU의 빠른 온칩 메모리(SRAM)와 상대적으로 느린 고대역폭 메모리(HBM) 사이의 데이터 이동을 최소화하는 데 초점을 두는 방식이다.

초록: FlashAttention은 긴 시퀀스를 처리할 때 Transformer의 자기-어텐션에서 발생하는 시간 및 메모리 복잡도의 문제를 해결하기 위해 개발되었습니다. 기존의 근사 방법들이 모델 품질을 희생하는 반면, FlashAttention은 GPU 고대역폭 메모리(HBM)와 온칩 SRAM 간의 메모리 읽기/쓰기 최적화를 통해 IO를 고려합니다. 타일링과 재계산 기술을 활용하여 메모리 접근 횟수를 줄이고, BERT와 GPT-2를 포함한 다양한 Transformer 모델의 학습 속도를 높이며, 더 긴 문맥 길이에서 모델 품질을 향상시킵니다.

기존의 논문들이 단순히 소프트웨어적으로 양자화나 부동소수점에서 어떻게 해볼까를 고려했던 방법론들과 달리, 하드웨어적인 접근을 통해 병목 현상을 해결하려고 했다는 점이 큰 포인트다. 이 논문을 기점으로 굉장히 많은 하드웨어적 개선 방안 논문들이 나왔던 것으로 필자는 기억한다.

도표 주요 내용:

  • 도표 1: 메모리 효율성을 위한 타일링 전략을 설명하며, 어텐션 계산에서 7.6배의 속도 향상을 보여줌.

FlashAttention에서 IO-awareness는 다음과 같은 방식으로 구현됩니다:

  1. 타일링: 입력 데이터를 작은 블록으로 나누어 HBM에서 SRAM으로 전송한 후, 해당 블록을 사용해 계산을 수행합니다. 이를 통해 큰 데이터셋 전체를 반복적으로 메모리에서 불러오는 것을 방지합니다.
  2. 재계산: 계산 중간 결과를 저장하지 않고, 필요할 때 빠르게 재계산하는 방식을 사용하여 메모리 사용량을 줄입니다.
  3. 메모리 계층 최적화: GPU의 각 메모리 계층의 속도와 크기에 맞게 데이터를 배치하고 처리하여 불필요한 데이터 이동을 방지합니다.

GPU에는 SRAM과 HBM이 있다. 원래 모델을 Train 할때는, HBM에서 데이터셋의 메모리를 아예 전체 구성한 후, 계산 및 저장을 반복하면서 모델이 Train되게 된다.

하지만 SRAM은 위에서 보는 바와 같이 19 TB/s이고, 이는 1.5 TB/s 인 HBM보다 약 12.7배 빠르다.

따라서 데이터셋을 아주 잘게 나눈 후, 해당 SRAM에서 연산한 결과만 HBM에 적재함으로서 효율을 높인 게 Flash Attention 메커니즘이라고 할 수 있다.

그 결과로 무려 7.6배의 속도 향상을 보여준다.

2단계: 도입부, 결론, 주요 도표 읽기

도입부: 자기-어텐션에서 발생하는 이차 복잡도의 문제와 기존 근사 어텐션 방법들의 한계를 설명합니다. FlashAttention은 메모리 병목 현상을 최소화하는 IO 인식 접근 방식을 통해 어텐션 계산을 재구성한 것이 주요 기여입니다.

결론:

  • FlashAttention은 모델 학습 속도를 가속화하고, 긴 시퀀스 작업에서 Transformer 성능을 향상시킵니다.
  • Path-X 및 Path-256 벤치마크에서 64K 시퀀스까지 최초로 평균 이상의 성능을 달성한 Transformer를 가능하게 했습니다.
  • 오픈소스 구현으로 더 넓은 응용을 지원합니다.

도표 요약:

  • FlashAttention, 표준 어텐션, 희소 방법을 비교한 실행 시간 및 메모리 사용량, FlashAttention의 우수성을 입증.
  • 실세계 작업(BERT, GPT-2)에서 중요한 속도 향상과 품질 개선을 보여줌.

3단계: 수식 분석

핵심 기술:

  1. 타일링: 입력(Q, K, V 매트릭스)을 더 작은 블록으로 나누어 SRAM에 로드하여 HBM 접근을 줄임.
  2. 재계산: 최소한의 중간 결과만 저장하여 역전파 시 어텐션을 다시 계산, 이차 메모리 사용량을 방지.
  3. IO 복잡도: HBM 접근을 줄이는 수학적 증명, 이를 통해 정확한 어텐션을 유지하며 실행 속도 가속화.

그와 관련하여 단계별 과정을 살펴보면 다음과 같다.

FlashAttention 알고리즘 단계별 과정

1. 입력 및 초기 설정

알고리즘은 다음과 같은 입력을 받습니다:

  • 행렬 Q, K, V ∈ ℝ^(N×d)가 HBM에 저장됨
  • 온칩 SRAM의 크기는 M
  • N은 토큰의 수, d는 임베딩 차원 수

블록 크기 설정

  • Bc = ⌊M/4d⌋ = 한번에 처리할 수 있는 데이터 크기 제한. K, V 블록 크기를 Bc * d로 설정할 때 필요한 메모리 기반
  • Br = min(⌊M/4d⌋, d) = 만약 M/4d가 d보다 작다면, 최적화된 크기가 아니라 오히려 공간이 남으므로 d로 설정

초기화

  • O = (0)N×d ∈ ℝ^(N×d)  = 어텐션 연산의 최종 결과. (QK^T)*V
  • ℓ = (0)N ∈ ℝ^N = 각 행별 어텐션 확률 합계 값. 정규화과정을 효율적으로 하는 데 사용된다.
  • m = (-∞)N ∈ ℝ^N = 안전성을 위한 행별 최대값. 점수 행렬에서 최대값을 저장
    위 모든 값들은 HBM에 초기화됩니다.

2. 행렬 분할 과정

입력 행렬 분할

  • Q를 Tr = ⌈N/Br⌉ 개의 블록으로 분할: Q1, ..., QTr (각 크기 Br × d) = 4d * N / M. 즉, 임베딩 * 토큰 수를 / SRAM 크기로 나눔
  • K, V를 Tc = ⌈N/Bc⌉ 개의 블록으로 분할: K1, ..., KTc 및 V1, ..., VTc (각 크기 Bc × d)
  • => 이것도 마찬가지. 하지만 Bc *d는 고정.

출력 및 보조 변수 분할

  • O를 Tr 개의 블록으로 분할: O1, ..., OTr (각 크기 Br × d)
  • ℓ를 Tr 개의 블록으로 분할: ℓ1, ..., ℓTr (각 크기 Br)
  • m을 Tr 개의 블록으로 분할: m1, ..., mTr (각 크기 Br)

3. 주요 계산 과정

외부 루프 (1 ≤ j ≤ Tc)에서:

  1. Kj, Vj를 HBM에서 온칩 SRAM으로 로드

내부 루프 (1 ≤ i ≤ Tr)에서:

  1. Qi, Oi, ℓi, mi를 HBM에서 온칩 SRAM으로 로드
  2. 온칩에서 다음 계산 수행:
    • Sij = QiKj^T ∈ ℝ^(Br×Bc) 계산
    • m̃ij = rowmax(Sij) ∈ ℝ^Br 계산
    • P̃ij = exp(Sij - m̃ij) ∈ ℝ^(Br×Bc) 계산 (요소별 연산)
    • ℓ̃ij = rowsum(P̃ij) ∈ ℝ^Br 계산
  3. 업데이트 계산:
    • mi^new = max(mi, m̃ij) ∈ ℝ^Br
    • ℓi^new = exp(mi-mi^new)ℓi + exp(m̃ij-mi^new)ℓ̃ij ∈ ℝ^Br
  4. 결과 저장:
    • Oi ← (diag(ℓi^new))^(-1)(diag(ℓi)exp(mi-mi^new)Oi + exp(m̃ij-mi^new)P̃ijVj)를 HBM에 저장
    • ℓi ← ℓi^new, mi ← mi^new를 HBM에 저장

이 부분은 코드에서 더 자세히 보도록 하자.

4. 최종 출력

모든 루프가 완료된 후, 최종 행렬 O를 반환합니다.

주요 특징

  1. 메모리 효율성
    • HBM과 SRAM 간의 데이터 이동을 최적화
    • 블록 단위 처리를 통한 메모리 사용 최소화
  2. 수치적 안정성
    • 지수 함수 계산 시 최대값을 빼서 오버플로우 방지
    • 적절한 정규화를 통한 안정적인 계산
  3. 병렬 처리
    • 블록 단위 처리를 통한 효율적인 병렬화 가능
    • GPU 아키텍처에 최적화된 구조

가 된다.

 

질문과 답변 (QnA)

Q1: 저자가 달성하고자 한 목표는 무엇인가요?

Transformer의 자기-어텐션에서 긴 시퀀스를 처리할 때 발생하는 시간 및 메모리 복잡성을 줄이고, 메모리 사용을 최적화하며 계산 속도를 높이는 IO 인식 알고리즘을 제안했습니다.

Q2: 이 연구의 주요 접근 요소는 무엇인가요?

  1. IO 인식: 타일링과 재계산을 활용하여 GPU 메모리 읽기/쓰기 최소화.
  2. 알고리즘 최적화: 정확한 어텐션과 블록 희소 어텐션 기술을 결합하여 더 빠르고 메모리 효율적인 연산 구현.
  3. 실증적 검증: Transformer 벤치마크에서 중요한 속도 향상과 메모리 절약을 입증.

Q3: 이 논문의 방법을 활용할 수 있나요?

네, 오픈소스 구현과 실용적인 배포 초점으로 인해 FlashAttention을 Transformer 기반 작업에 통합하기 용이합니다. 특히 긴 시퀀스를 처리해야 하는 작업에서 유용합니다.

Q4: 참고하고 싶은 다른 레퍼런스는 무엇인가요?

  1. 희소 및 저랭크 어텐션에 대한 기술(Linformer, Performer 등).
  2. 계산 시스템에서 IO 복잡도 분석.
  3. 긴 문서 분류 및 대규모 문맥 작업에서 Transformer 모델의 응용.

 

구현하기

1. 수식 이해하고 직접 연산하기

 

별로 길지 않으니 직접 봐보자. 원래는 update 함수 같은 걸 써야 하지만 직관적인 이해를 위해서 단일로 만들었다.

인덱스별로 블록을 쪼개서 SRAM으로 옮긴다는 개념만 알면 된다.

import torch
import math

class FlashAttention(torch.nn.Module):
    def __init__(self, M=64):  # M은 SRAM 크기
        super().__init__()
        self.M = M
        
    def forward(self, Q, K, V):
        """
        FlashAttention 구현
        Args:
            Q, K, V: ℝ^(N×d) 크기의 입력 행렬들 (HBM에 저장됨)
        Returns:
            O: ℝ^(N×d) 크기의 출력 행렬
        """
        N, d = Q.shape
        
        # 1. 블록 크기 설정
        Bc = math.floor(self.M / (4 * d))
        Br = min(math.floor(self.M / (4 * d)), d)
        
        # 2. O, l, m 초기화 (HBM에)
        O = torch.zeros_like(Q)  # (0)_{N×d}
        l = torch.zeros(N, device=Q.device)  # (0)_N
        m = torch.full((N,), float('-inf'), device=Q.device)  # (-∞)_N
        
        # 3. 행렬들을 블록으로 분할
        Tr = math.ceil(N / Br)  # 블록 수
        Tc = math.ceil(N / Bc)
        
        # 4. O, l, m을 Tr개의 블록으로 분할
        for j in range(Tc):  # for 1 ≤ j ≤ Tc do
            # 6. Kj, Vj를 HBM에서 SRAM으로 로드
            start_c = j * Bc
            end_c = min((j + 1) * Bc, N)
            Kj = K[start_c:end_c]
            Vj = V[start_c:end_c]
            
            for i in range(Tr):  # for 1 ≤ i ≤ Tr do
                # 8. Qi, Oi, li, mi를 HBM에서 SRAM으로 로드
                start_r = i * Br
                end_r = min((i + 1) * Br, N)
                Qi = Q[start_r:end_r]
                Oi = O[start_r:end_r]
                li = l[start_r:end_r]
                mi = m[start_r:end_r]
                
                # 9. Sij 계산 (온칩)
                Sij = torch.matmul(Qi, Kj.T)  # ∈ ℝ^(Br×Bc)
                
                # 10. m̃ij, P̃ij, l̃ij 계산 (온칩)
                mij_tilde = torch.max(Sij, dim=1)[0]  # rowmax(Sij) ∈ ℝ^Br
                Pij_tilde = torch.exp(Sij - mij_tilde.unsqueeze(1))  # exp(Sij - m̃ij) ∈ ℝ^(Br×Bc)
                lij_tilde = torch.sum(Pij_tilde, dim=1)  # rowsum(P̃ij) ∈ ℝ^Br
                
                # 11. mi^new와 li^new 계산 (온칩)
                mi_new = torch.maximum(mi, mij_tilde)  # max(mi, m̃ij) ∈ ℝ^Br
                
                # exp(mi - mi^new)li + exp(m̃ij - mi^new)l̃ij ∈ ℝ^Br
                li_new = (torch.exp(mi - mi_new) * li + 
                         torch.exp(mij_tilde - mi_new) * lij_tilde)
                
                # 12. Oi 업데이트
                # diag(li^new)^(-1)(diag(li)exp(mi-mi^new)Oi + exp(m̃ij-mi^new)P̃ijVj)
                Oi_new = (1.0 / li_new.unsqueeze(1)) * (
                    li.unsqueeze(1) * torch.exp(mi - mi_new).unsqueeze(1) * Oi +
                    torch.exp(mij_tilde - mi_new).unsqueeze(1) * torch.matmul(Pij_tilde, Vj)
                )
                
                # 13. HBM에 결과 저장
                O[start_r:end_r] = Oi_new
                l[start_r:end_r] = li_new
                m[start_r:end_r] = mi_new
                
        # 16. O 반환
        return O

# 사용 예시
def example_usage():
    # SRAM 크기 M 설정
    M = 64
    flash_attn = FlashAttention(M=M)
    
    # 샘플 입력 생성
    N, d = 512, 64  # 시퀀스 길이와 차원
    Q = torch.randn(N, d)
    K = torch.randn(N, d)
    V = torch.randn(N, d)
    
    # FlashAttention 실행
    O = flash_attn(Q, K, V)
    return O

 

이상으로 Flash-attention에 대해서 알아보았다.

이 이후에 Flash-attention2, 3도 존재하지만 개념적인 건 읽어보면 대략적으로 알 수 있으므로 한 번 읽어보는 것도 추천한다.

 

그럼 또 다른 논문으로 찾아오겠다.

반응형