티스토리 뷰
[논문 100개 구현 4탄] Mamba: Linear-Time Sequence Modeling with Selective State Spaces
sikaro 2024. 11. 30. 22:51Attention 구조는 시간복잡도 면에서 필연적으로 토큰의 수의 제곱의 시간이 필요하다.
이 구조를 타파하기 위해서 Linear Attention 등 많은 해결책을 제시했지만, 결국 토큰이 많아질 수록 모델의 아웃풋이 느려지는 건 해결하지 못했다.
Mamba 논문은 그 후 거의 7년만에 나온 대안책으로서, Selective State Space라는 상태 공간 모델링을 통해 RNN, LSTM과 같은 기존 시계열 구조를 쓰면서도, 상태 방정식과 게이트 메커니즘도 도입하여 아웃풋 시간이 토큰의 수와 선형적인 증가를 이룸으로서 문제를 해결했다고 주장한다.
2023년 12월에 나왔지만 최근까지도 LLM 분야에서 굉장히 핫한 이 논문을 오늘 구현해보고, 코드를 살펴보자.
논문 구현에 앞서 확인해야 할 포인트
Read
1. 논문 제목(title)과 초록(abstract), 도표(figures) 읽기
2. 도입(introduction), 결론(conclusion), 도표(figures)를 읽고 필요없는 부분 생략
3. 수식은 처음 읽을 때는 과감하게 생략
4. 이해가 안되는 부분은 빼고 전체적으로 읽는다.
QnA
1. 저자가 뭘 해내고 싶어했는가?
2. 이 연구의 접근에서 중요한 요소는 무엇인가?
3. 당신(논문독자)는 스스로 이 논문을 이용할 수 있는가?
4. 당신이 참고하고 싶은 다른 레퍼런스에는 어떤 것이 있는가?
구현하기
1. 수식 이해하고 직접 연산하기
2. 코드 연습하기(오픈소스를 받아 직접 구현)
Read
1. 논문 제목(title)과 초록(abstract), 도표(figures) 읽기
Linear-Time Sequence Modeling with Selective State Spaces
제목에서 중요한건 선형 시간이 필요한 시계열적 모델링이라는 부분과, 선택적 상태 공간 방정식을 의미하는 부분이다.
말 그대로 선택적 상태 공간 방정식을 통해 토큰 선형적인 시계열 모델을 만들었다는 뜻이고, 이는 Abstract에서 그 의도가 더 자세히 들어난다.
Abstract
파운데이션 모델은 Transformer 아키텍처와 핵심 모듈인 어텐션이 기반한다.
그러나 선형 어텐션, 게이트 컨볼루션, 순환 모델, 구조화된 상태 공간 모델 등 subquadratic-time 아키텍처가 개발되었음에도 언어 모달리티에서 어텐션만큼의 성능을 발휘하지 못했다.
우리는 이러한 모델들의 주요 약점이 Content-based reasoning(즉, 전에 대화하던 걸 잘 기억 못해서 나머지에 대한 대답 추론도 잘 못함)을 수행하지 못한다는 점이라고 판단, 이를 개선하기 위해 몇 가지 혁신을 제안합니다.
첫째, SSM의 매개변수를 입력에 따라 동적으로 작동하도록 설계(여기선 재귀적인 구조, 혹은 점화식)한다.
이를 통해 모델들이 이산 모달리티에서 나타내는 약점을 해결
시퀀스 길이에 따라 현재 토큰에 기반하여 정보를 선택적으로 전파하거나 잊어버리도록 한다.
둘째, 효율적인 컨볼루션의 사용은 막아진다. 그러나 Recurrent mode에서 하드웨어 친화적인 병렬 알고리즘이 설계된다.
어텐션이나 MLP 블록조차 없는 단순화된 End-to-End 신경망 아키텍처 Mamba를 설계하여 Tranformer 대비 5배 높은 추론 속도와 시퀀스 길에 대한 선형적 확장성을 가진다. 실제 데이터에서 최대 백만 길이의 시퀀스에서도 성능이 향상된다.
시퀀스 모델 백본으로서 언어, 오디오, 유전체학등 다양한 모달리티에서 State-of-the-art를 달성했다.
2. 도입(introduction), 결론(conclusion), 도표(figures)를 읽고 필요없는 부분 생략
Introdoction
- 딥러닝에서 기반 모델(Foundation Models은 일반적으로 Transformer 아키텍처를 사용하며, 이는 복잡한 데이터 간의 정보를 연결할 수 있는 self-attention 메커니즘을 특징으로 한다. 그러나 Transformer는 다음과 같은 한계를 가집니다:
- 유한한 윈도우 내에서만 정보를 모델링할 수 있음.
- 시퀀스 길이에 따라 계산 복잡도가 이차적(Quadratic)으로 증가.
- 이를 해결하기 위해 다양한 효율적인 대안(예: 선형 어텐션, SSM)이 제안되었으나, 대부분 중요한 모달리티(언어 등)에서 Transformer만큼 성능을 내지 못함.
- SSM(Structured State Space Models)은 RNN 및 CNN의 특성을 결합하여, 선형적으로 스케일링하고 긴 시퀀스에서의 의존성을 잘 모델링할 수 있는 새로운 아키텍처로 주목받음.
- 그러나 기존 SSM은 정보 선택(content-based reasoning)과 같은 특정 작업에서 한계를 보여줌.
- 본 연구는 선택 메커니즘(selection mechanism)을 통합한 Mamba를 제안하며, 이는 효율성과 성능 모두를 개선.
Conclusion
Mamba 아키텍처는:
- 입력 기반의 동적 매개변수를 통해 내용을 선택적으로 처리할 수 있음.
- Transformer 대비 5배 빠른 추론 속도를 가지며, 시퀀스 길이에 따라 선형적으로 스케일링 가능.
- 다양한 모달리티(언어, 오디오, 유전체학 등)에서 최첨단 성능을 달성.
- 실험 결과, Mamba는 같은 크기의 Transformer를 능가하며, 두 배 크기의 Transformer와도 동등한 성능을 보임.
- 모델 코드와 사전 학습된 체크포인트는 오픈소스로 제공됨.
계속해서 반복되는 거 말고, 이번에 나온 중요한 포인트는 정보 선택이다.
한마디로 정보 선택 메커니즘을 어떻게 개선하였는지를 중점적으로 보면 된다. 그게 이 논문의 핵심이자 주장일 테니.
Figure
구조화된 SSM(Structured SSM)은 입력 x의 각 채널(예: D=5)을 독립적으로 더 높은 차원의 잠재 상태 h (예: N=4)를 통해 출력 y로 매핑합니다. 기존 SSM은 시간 불변성을 요구하는 계산 경로를 활용하여 이 큰 유효 상태(즉, D×N, 배치 크기 B 및 시퀀스 길이 L의 곱)를 구체적으로 계산하지 않고 처리했습니다. 이를 위해 Δ,A,B,C 매개변수는 시간에 따라 일정하게 유지되었습니다.
우리의 선택 메커니즘(selection mechanism)은 입력 의존적 동작을 다시 추가했으며, 이를 위해 GPU 메모리 계층의 더 효율적인 수준에서만 확장된 상태를 계산하도록 하는 하드웨어 친화적인 알고리즘이 필요합니다.
도대체 이게 뭔 소리냐?
연구자들이 의례 그렇지만, 쉬운 말도 어렵게 하는 재주가 있는거 같다.
쉽게 말하자면 기존 SSM 방정식의 상태공간은 시간에 따라 변하지 않는, 즉 매개변수 t가 없는 모델이었다.
예를 들어 D=5, N=4, 배치 크기 100, 시퀀스 길이 1000이라면 전체 상태는 5*4*100*1000 = 2백만개의 값으로 고정되어서 계산된다는 것이다.
하지만 Mamba 모델은 잠재 상태 공간의 수를 시간(t)에 따라 동적으로 계산한다.
현재 상태 h(t−1)와 입력 x(t)를 기반으로 상태가 업데이트되기 때문이다. 그래서 무조건 GPU에서 처리해야 할 필요가 없어진다.
그러므로 다음과 같은 두 가지의 변화가 가능해졌다.
- 유동적인 상태 업데이트:
- GPU가 시간에 따라 매번 새로 계산해야 할 상태, 즉, 유동적으로 바뀌는 h(t)를 메모리 효율적으로 배치해야 합니다.
- 메모리 계층 최적화:
- 핵심 데이터는 GPU의 빠른 메모리(SRAM)에서 처리하고,
- 덜 자주 쓰이는 데이터는 GPU의 대용량 메모리(HBM)로 옮겨 효율성을 극대화합니다.
더 깊은 이해를 위해서 아래를 살펴보자.
기존 SSM의 본질
SSM은 입력 신호 x(t)를 받아 은닉 상태 h(t)를 거쳐 출력 신호 y(t)를 계산하는 과정이다.
기본적으로 두 단계로 나뉘는데,
- 은닉 상태 업데이트 h′(t)=Ah(t)+Bx(t)
- 출력 계산 y(t)=Ch(t)
- 이때 A,B,C는 기본적으로 고정된 매개변수이다.(위에서 봤던 D*N)
그러나 델타 t 단위로 이산화는 할 수 있다. 매 시간 단계에서 연속 데이터를 샘플링해서, 고정된 값으로 유지한다.
이게 Zero-Order Hold(영차 홀드)라는 방식이다.
공과 대학을 나온 분들은 미분 방정식을 생각하면 위의 공식의 답이 벌써 보인다.
ZOH에 대해 더 설명해보자.
컴퓨터는 연속적인 데이터를 직접 처리할 수 없다. 디지털 방식의 제어기는 출력도 디지털이기 때문이다.
따라서 연속적인 데이터를 적절한 규칙을 사용해 이산적인 단계로 변환해야 한다. 다음 샘플링까지 입력값을 일정하게 유지시키게 되는 것이다.
그러므로 여기서 이산화는 연속 시스템을 시간 단계별 계산으로 변환해서 모델이 실제로 동작할 수 있도록 많이 쪼개서 선을 만든다고 생각하면 된다.
1초 동안의 온도 변화를 0.1초마다 측정한다고 가정하여 온도의 연속 변화를 10개의 이산적인 데이터로 표현하게 된다.
이렇게 이산화된 시스템에서 마르코프 연쇄 법칙을 사용하면 가 A^kB 형태로 표현되고, 이를 이용해 y=x∗K 형태로 계산된다.
이는 Convolution 연산과도 같다. 왜냐하면 K가 커널로 표현되고, x에 Projection 해가면서 output을 만들면 되기 때문이다.
정말 쉽지 않은가?
여기까지는 기본 SSM에 대한 설명이었다.
이제 본격적으로 Selective를 어떻게 도입하는지 알아보자.
Selective SSM(S6)
Selective SSM은 위에서 설명한 SSM에서 한단계 더 나아가 파라미터 A,B,C가 입력에 따라 달라진다.
즉, 입력 x에 의존적이도록 설정되는 것이다.
또 한가지 중요한 특징이 있는데, 위에서 보았던 합성곱(Convolution)을 사용할 수 없게 된다(K를 못쓴다)는 것이다.
Figure 2: (왼쪽) 표준 버전의 Copying 작업은 입력과 출력 요소 사이의 간격이 일정하며, 선형 순환 신경망(linear recurrences)이나 전역 컨볼루션(global convolutions)과 같은 시간 불변 모델로 쉽게 해결할 수 있습니다. (오른쪽 상단) Selective Copying 작업은 입력과 출력 사이의 간격이 무작위로 배치되어 있으며, 내용에 따라 입력을 선택적으로 기억하거나 무시할 수 있는 시간 가변(time-varying) 모델이 필요합니다. (오른쪽 하단) Induction Heads 작업은 문맥에 기반하여 답을 검색하는 연상 회상(associative recall)의 예로, 이는 LLM에서 중요한 능력입니다.
설명 진짜 너무하죠?
쉽게 말할 것도 굳이 굳이 어렵게 말하는 것이다. 아무리 봐도 일부러 이러는 거다.
그러나 이는 맥락적인 게 빠져서 그렇다.
이 Figure의 앞에서는 다음과 같이 설명하고 있다.
우리는 순차 모델링(sequence modeling)의 근본적인 문제 중 하나가 맥락(context)을 작은 상태로 압축하는 것이라고 주장합니다. 이 관점에서 보면, 주요 순차 모델들이 갖는 장단점을 이해할 수 있습니다.
- 어텐션(Attention)은 맥락을 전혀 압축하지 않기 때문에 효과적이지만 비효율적입니다.
- 예를 들어, 자동 회귀 추론(autoregressive inference)은 전체 맥락(즉, KV 캐시)을 명시적으로 저장해야 합니다.
- 이는 선형 시간의 느린 추론과 이차 시간의 느린 학습을 유발합니다.
- 순환 모델(Recurrent Models)은 효율적입니다.
- 상태(state)가 유한하기 때문에 상수 시간 추론과 선형 시간 학습이 가능합니다.
- 하지만 맥락을 얼마나 잘 압축하는지에 따라 효과성이 제한됩니다.
이 원리를 이해하기 위해 두 가지 예제 과제를 살펴봅니다. (그림 2 참고)
- 선택적 복사 과제(Selective Copying Task)
- Copying Task(Arjovsky et al., 2016)를 변형한 것으로, 기억해야 할 토큰의 위치를 변경합니다.
- 관련 토큰(색칠된 부분)은 기억하고, 관련 없는 토큰(흰색 부분)은 걸러내야 하므로 내용 기반(content-aware) 추론이 필요합니다.
- 유도 헤드 과제(Induction Heads Task)
- 대규모 언어 모델(LLM)의 인-컨텍스트 학습(in-context learning) 능력을 설명하는 메커니즘(Olsson et al., 2022)입니다.
- 올바른 맥락에서 올바른 출력을 생성해야 하므로 맥락 인식(context-aware) 추론이 필요합니다.
그럼 이제 다시 Figure를 살펴봅시다.
Figure 2: (왼쪽) 표준 버전의 Copying 작업은 입력과 출력 요소 사이의 간격이 일정하며, 선형 순환 신경망(linear recurrences)이나 전역 컨볼루션(global convolutions)과 같은 시간 불변 모델로 쉽게 해결할 수 있습니다. -> 입력과 출력의 간격이 일정하므로 선형 순환 신경망, 혹은 컨볼루션의 형태(K)로 표현해서 x와 곱이 가능하다.
(오른쪽 상단) Selective Copying 작업은 입력과 출력 사이의 간격이 무작위로 배치되어 있으며, 내용에 따라 입력을 선택적으로 기억하거나 무시할 수 있는 시간 가변(time-varying) 모델이 필요합니다.
-> 이 부분이 핵심인데, Selective를 하게 되면 간격이 바뀌었기 때문에 시간에 따른 연산을 전부 해줘야 결과값이 나온다.
중간의 값으로 검은색이 다음 색인 파란 색을 예측하려면 단어들에는 Sequencial한 특징이 있어야 한다. 하지만 '해리'라는 단어에서 '포터'라는 다음 단어가 Sequencial한 특징이 있을리가 만무하다. 그렇기에 시간에 대한 의존도가 없으므로 불규칙적이게 나오는 것이다.
(오른쪽 하단) Induction Heads 작업은 문맥에 기반하여 답을 검색하는 연상 회상(associative recall)의 예로, 이는 LLM에서 중요한 능력입니다. -> 내용 기반 추론의 핵심2
-> 그래서 Induction Heads 작업으로 내용 추론을 하게 하며 문맥에 기반한 다음 연상을 유도한다. Induction head는 이전 토큰 쌍 간의 관계를 이용해 다음 토큰을 예측한다.
그리고 이렇게 보고 3.2의 설명을 다시 보면 다음과 같습니다.
3.2 선택을 통한 SSM 개선 (Improving SSMs with Selection)
● 선택 메커니즘의 도입
모델에 선택 메커니즘을 도입하는 한 가지 방법은 순차적으로 상호작용하는 매개변수(예: RNN의 순환 동역학, CNN의 합성곱 커널)를 입력에 따라 달라지게 하는 것입니다.
- 알고리즘 1과 2는 이 메커니즘을 설명합니다.
- 주요 차이점은 여러 매개변수(Δ, 𝐵, 𝐶)를 입력 의존적 함수로 만들고, 텐서의 모양(shape)에 길이 차원(𝐿)과 배치(B)을 추가하는 것입니다.
- 결과적으로, 모델은 시간 불변(time-invariant)에서 시간 가변(time-varying)으로 바뀝니다.
● 세부 설계
- 𝑠𝐵(𝑥) = Linear𝑁 (𝑥)
- 𝑠𝐶(𝑥) = Linear𝑁 (𝑥)
- 𝑠Δ(𝑥) = Broadcast𝐷(Linear1(𝑥))
- 𝜏Δ = softplus
여기서 Linear𝑑는 차원 𝑑로의 매개변수화된 투영을 의미합니다.
𝑠Δ와 𝜏Δ의 선택은 RNN의 게이팅 메커니즘과 연결되어 있습니다(3.5절 참고).
입력 의존적 함수로 만들면 입력이 시간에 따라 달라지므로 시간 가변으로 변한다.
이렇게 되면 문장이 긴 경우에는 더 많은 정보를 유지하기 위해 상태를 확장한다.
문장이 짧고 중요하지 않은 경우에는 상태를 합축해 효율적으로 처리할 수 있다.
즉, 가중치의 동적화를 시간 매개변수(혹은 입력)으로 한다고 생각하면 편하다.
그러나 입력 의존적 함수이기 때문에,
- 병렬화가 어렵다:
합성곱이나 고정된 연산처럼 모든 시간 스텝에서 동일한 연산을 할 수 없기 때문에, 병렬 처리가 비효율적입니다. - 메모리와 계산 비용 증가:
각 시간 스텝마다 동적 파라미터를 계산해야 하므로, 메모리 사용량과 계산 비용이 증가합니다.
이걸 타파하기 위해서 앞에 있떤 GPU 어쩌구를 해놓은 것이다.
이 부분은 Sram에 복사해서 업데이트하는게 아니라 재귀적인 방법으로 자체 업데이트해서 해결한다고 한다.
이제 마지막 아키텍처 부분으로 넘어가보자.
그림 3: (아키텍처)
우리의 단순화된 블록 디자인은 대부분의 SSM 아키텍처의 기반이 되는 H3 블록과 현대 신경망에서 널리 사용되는 MLP 블록을 결합한 것입니다.
이 두 블록을 교차(interleave) 배치하는 대신, Mamba 블록을 균일하게 반복합니다.
- H3 블록과 비교했을 때, Mamba는 첫 번째 곱셈 게이트를 활성화 함수로 대체합니다.
- MLP 블록과 비교했을 때, Mamba는 **SSM(상태공간 모델)**을 주 브랜치에 추가합니다.
활성화 함수(𝜎)로는 SiLU / Swish를 사용합니다(Hendrycks and Gimpel 2016; Ramachandran, Zoph, and Quoc V Le 2017).
한마디로 H3 블록과 MLP 합쳐서 하나의 Mamba 블록 만들고, 그걸 반복해서 쌓는 구조로 만들었다는 뜻이다.
이는 Resnet에서 같은 블록을 쌓는 것처럼 이해하면 된다. 구현할 때도 이런 식으로 Mamba 블록을 만들고, for을 통해서 몇번이고 원한는 대로 쌓게 될 것이다.
파라미터는 다음과 같다.
이 아키텍처는 모델 차원 D을 제어 가능한 확장 계수 E로 확장합니다.
각 블록의 대부분의 파라미터는 선형 투영에 집중되어 있습니다:
- 입력 투영: 2ED^2
- 출력 투영: ED^2
이 블록을 반복하면서 표준 정규화 및 잔차 연결(residual connections)과 결합하여 Mamba 아키텍처를 형성합니다.
실험에서는 항상 E=2로 설정하고, 두 개의 블록 스택을 사용하여 Transformer의 MHA(Multi-Head Attention) 및 MLP 블록의 총 파라미터 수 12D^2를 맞춥니다.
- SwiGLU 변형이 되도록 설계
- 선택적으로 LayerNorm(Ba, Kiros, and Hinton 2016)을 사용
여기까지 봤으니, 이제 이 논문에 대한 핵심적인 이해는 다 됐다.
그럼 이제 QnA로 넘어가보자.
QnA
1. 저자가 뭘 해내고 싶어했는가?
주요 목표:
- Transformer의 비효율성 해결:
Transformer는 긴 시퀀스 처리에 있어 이차적 시간 복잡도와 제한된 컨텍스트 창을 가지며, 이로 인해 계산 및 메모리 측면에서 비효율적입니다. - 선택적 정보 처리 능력 향상:
기존 SSM은 입력에 따라 선택적으로 데이터를 처리하지 못하고, 모든 데이터를 동일하게 처리하는 LTI(Linear Time Invariance) 모델로 제한됩니다.
이 연구에서는 **선택 메커니즘(selection mechanism)**을 도입하여 입력에 따라 중요한 정보를 선택하고 불필요한 정보를 필터링할 수 있는 능력을 강화했습니다. - 효율적인 하드웨어 구현:
입력 의존적 특성으로 인해 기존 SSM이 사용하던 효율적인 합성곱 연산을 사용할 수 없게 되지만, 저자들은 하드웨어 친화적 알고리즘을 설계하여 이 문제를 해결했습니다.
2. 이 연구의 접근에서 중요한 요소는 무엇인가?
이 연구의 중요한 요소는 다음과 같습니다:
(1) 선택 메커니즘 (Selection Mechanism)
- 입력에 따라 모델의 파라미터를 동적으로 조정하여 불필요한 정보를 필터링하고 필요한 정보를 유지합니다(2312.00752v2 (2)).
- 예시: Selective Copying Task와 Induction Heads Task에서 모델이 중요한 입력을 선택적으로 기억하거나 무시하는 능력을 보여줍니다.
(2) 하드웨어 친화적 알고리즘 (Hardware-aware Algorithm)
- 입력 의존적 선택 메커니즘은 **시간 불변 연산(LTI)**을 사용하지 않기 때문에 **재귀적 연산(recurrent computation)**을 활용하여 GPU 메모리 계층을 최적화했습니다(2312.00752v2 (2)).
- 이로 인해 연속된 상태를 효율적으로 계산하고, 입력 시퀀스 길이에 선형으로 확장할 수 있습니다.
(3) Mamba 아키텍처
- 기존 SSM 아키텍처와 Transformer의 MLP 블록을 결합하여 단순화된 Mamba 블록을 설계했습니다(2312.00752v2 (2)).
- 선택적 상태공간 모델(S6)을 통합하여 빠른 추론 속도와 낮은 메모리 요구사항을 달성했습니다.
3. 당신(논문독자)는 스스로 이 논문을 이용할 수 있는가?
(1) 긴 시퀀스 처리 문제 해결
- 문제: 기존의 Transformer 모델은 긴 시퀀스를 처리할 때 메모리 사용량과 연산 비용이 급격히 증가합니다.
- 이용 방안:
- Mamba 아키텍처와 선택적 상태공간 모델(Selective SSM*을 적용하면 긴 시퀀스에서도 효율적인 메모리 사용과 빠른 계산이 가능합니다.
- 특히, 자연어 처리(NLP), 유전체 데이터 분석, 비디오 데이터 처리와 같이 대규모 시퀀스 데이터를 다루는 분야에 적합합니다.
(2) 실시간 시스템에서의 활용
- 문제: 실시간 시스템(예: 온라인 번역, 음성 인식, 자율주행)은 짧은 응답 시간과 빠른 추론이 요구됩니다.
- 이용 방안:
- 선택적 SSM은 상수 시간 추론을 가능하게 하므로, 실시간 시스템에서 지연 시간(latency)을 줄이는 데 사용할 수 있습니다.
- 예를 들어, 실시간 음성 비서나 IoT 디바이스에 통합하여 빠른 응답성을 보장할 수 있습니다.
(3) 메모리 제한이 있는 환경에서의 적용
- 문제: 모바일 기기나 임베디드 시스템과 같은 환경은 제한된 메모리와 연산 능력을 가지고 있습니다.
- 이용 방안:
- SSM 기반 모델은 기존 Transformer보다 메모리 사용량이 낮고 계산이 효율적이기 때문에, 모바일 디바이스나 엣지 컴퓨팅 환경에서 사용하기에 적합합니다.
- 예시: 모바일 NLP 애플리케이션, 임베디드 음성 인식 시스템.
(4) 데이터 필터링 및 선택적 정보 처리
- 문제: 많은 데이터 세트는 불필요한 노이즈나 관련 없는 정보가 포함되어 있어 모델 성능을 저하시킬 수 있습니다.
- 이용 방안:
- 선택 메커니즘(selection mechanism)을 통해 중요한 정보만을 선택하고 노이즈를 필터링함으로써, 데이터 전처리 및 정보 추출을 효율적으로 수행할 수 있습니다.
- 예시:
- 소셜 미디어 텍스트에서 중요한 이벤트나 키워드만을 추출.
- 의료 데이터에서 특정 패턴이나 이상치만 선택적으로 분석.
(5) 강화 학습에서의 상태 재설정
- 문제: 강화 학습(Reinforcement Learning)에서 에피소드 경계를 넘어갈 때 이전 상태가 영향을 미쳐 학습이 어려워질 수 있습니다.
- 이용 방안:
- 선택적 SSM은 에피소드 경계에서 상태를 재설정할 수 있어, 독립적인 에피소드 학습이 가능합니다.
- 이는 에이전트의 학습 안정성과 수렴 속도를 개선할 수 있습니다.
(6) 다양한 딥러닝 아키텍처와의 통합
- 문제: 기존 모델(예: RNN, CNN, Transformer)은 특정 구조적 한계로 인해 일부 문제에서 최적의 성능을 발휘하지 못합니다.
- 이용 방안:
- 선택적 SSM은 다양한 아키텍처에 통합될 수 있으며, RNN의 게이팅 메커니즘이나 CNN의 지역적 필터링과 같은 기존 구조를 확장하여 더 강력한 시퀀스 모델을 설계할 수 있습니다.
- 예시:
- 하이브리드 모델을 설계하여 RNN과 SSM을 결합하거나, CNN에서 전역 정보를 선택적으로 활용하는 모델을 구축.
4. 당신이 참고하고 싶은 다른 레퍼런스에는 어떤 것이 있는가?
(1) 구조화된 상태공간 모델(Structured SSM)
- Gu et al. (2022): SSM의 초기 구조와 선형 회귀(linear recurrence) 및 **글로벌 합성곱(global convolution)**에 대해 설명합니다.
(2) 선택 메커니즘의 기초
- Funahashi and Nakamura (1993): RNN의 게이팅 메커니즘과 연속 시간 시스템의 이산화에 대한 연구입니다.
(3) 하드웨어 최적화
- Dao et al. (2023): GPU 메모리 계층을 활용한 효율적인 연산 알고리즘 설계에 대한 내용을 다룹니다.
이 논문의 핵심은 그래서 어디에 쓰는데? 라는 질문에 정말 많이 답할 수 있는 데 있다.
기존의 Attention 메커니즘이 문제가 되는 곳은 거의 다 쓸 수 있는 것이다.
그렇기에 이 논문이 장안의 화제가 되는 것이고, 이걸 연구하려는 사람들이 많은 것이다.
이걸 이해하고 넘어가면 좋다.
구현하기
전체 코드는 여기서 볼 수 있다.
맘바 블록만 만들 수 있으면 ResNet 과 다를 바가 없다.
천천히 보자.
0, pytest로 흐름 추적하기
/opt/conda/lib/python3.10/site-packages/mamba_ssm/ops/triton/softplus.py
/kaggle/working/mamba/__init__.py
/opt/conda/lib/python3.10/site-packages/mamba_ssm/modules/mlp.py
/opt/conda/lib/python3.10/site-packages/mamba_ssm/models/__init__.py
/opt/conda/lib/python3.10/site-packages/mamba_ssm/modules/mamba_simple.py
/kaggle/working/mamba/__pycache__/__init__.cpython-310.py
/opt/conda/lib/python3.10/site-packages/mamba_ssm/models/__pycache__/config_mamba.cpython-310.py
/opt/conda/lib/python3.10/lib-dynload/mamba_ssm.py
/opt/conda/lib/python3.10/site-packages/mamba_ssm/ops/triton/__pycache__/layernorm_gated.cpython-310.py
/opt/conda/lib/python3.10/site-packages/mamba_ssm/modules/mha.py
/opt/conda/lib/python3.10/site-packages/mamba_ssm/modules/__pycache__/block.cpython-310.py
/opt/conda/lib/python3.10/site-packages/mamba_ssm/ops/triton/__pycache__/k_activations.cpython-310.py
/opt/conda/lib/python3.10/site-packages/mamba_ssm/ops/triton/__pycache__/layer_norm.cpython-310.py
/opt/conda/lib/python3.10/site-packages/mamba_ssm/modules/mamba2.py
/opt/conda/lib/python3.10/site-packages/mamba_ssm/utils/__pycache__/hf.cpython-310.py
/kaggle/working/mamba/tests/__init__.py
/opt/conda/lib/python3.10/site-packages/mamba_ssm/__pycache__/__init__.cpython-310.py
/opt/conda/lib/python3.10/site-packages/mamba_ssm/distributed/tensor_parallel.py
/opt/conda/lib/python3.10/site-packages/mamba_ssm/ops/__init__.py
/opt/conda/lib/python3.10/site-packages/mamba_ssm/distributed/__pycache__/tensor_parallel.cpython-310.py
/opt/conda/lib/python3.10/site-packages/mamba_ssm/ops/triton/__pycache__/ssd_combined.cpython-310.py
/kaggle/lib/mamba.py
/opt/conda/lib/python3.10/site-packages/mamba_ssm/ops/triton/ssd_chunk_state.py
/opt/conda/lib/python3.10/site-packages/mamba_ssm/ops/triton/__pycache__/ssd_chunk_state.cpython-310.py
/opt/conda/lib/python3.10/site-packages/mamba_ssm/ops/triton/ssd_chunk_scan.py
/opt/conda/lib/python3.10/site-packages/mamba_ssm/__init__.py
/opt/conda/lib/python3.10/site-packages/mamba_ssm/utils/hf.py
/opt/conda/lib/python3.10/site-packages/mamba_ssm/ops/__pycache__/selective_scan_interface.cpython-310.py
/opt/conda/lib/python3.10/site-packages/mamba_ssm/distributed/__pycache__/distributed_utils.cpython-310.py
/opt/conda/lib/python3.10/site-packages/mamba_ssm/ops/triton/__init__.py
/opt/conda/lib/python3.10/site-packages/mamba_ssm/modules/__init__.py
/opt/conda/lib/python3.10/lib-dynload/mamba.py
/opt/conda/lib/python3.10/site-packages/mamba_ssm/ops/__pycache__/__init__.cpython-310.py
/kaggle/working/mamba/tests/test_generation.py
/opt/conda/lib/python3.10/site-packages/mamba_ssm/modules/__pycache__/__init__.cpython-310.py
/opt/conda/lib/python3.10/site-packages/mamba_ssm/ops/triton/layer_norm.py
/opt/conda/lib/python3.10/site-packages/mamba_ssm/ops/triton/ssd_bmm.py
/opt/conda/lib/python3.10/site-packages/mamba_ssm/ops/selective_scan_interface.py
/opt/conda/lib/python3.10/site-packages/mamba_ssm/models/mixer_seq_simple.py
/opt/conda/lib/python3.10/site-packages/mamba_ssm/modules/__pycache__/mamba_simple.cpython-310.py
/opt/conda/lib/python3.10/site-packages/mamba_ssm/models/config_mamba.py
/opt/conda/lib/python3.10/site-packages/mamba_ssm/ops/triton/k_activations.py
/opt/conda/lib/python3.10/site-packages/mamba_ssm/modules/__pycache__/mha.cpython-310.py
/opt/conda/lib/python3.10/site-packages/mamba_ssm/models/__pycache__/mixer_seq_simple.cpython-310.py
/root/.local/lib/python3.10/site-packages/mamba_ssm.py
/opt/conda/lib/python3.10/site-packages/mamba_ssm/ops/triton/__pycache__/ssd_state_passing.cpython-310.py
/opt/conda/lib/python3.10/site-packages/mamba_ssm/utils/__init__.py
/opt/conda/lib/python3.10/site-packages/mamba_ssm/ops/triton/ssd_combined.py
/opt/conda/lib/python3.10/site-packages/mamba_ssm/models/__pycache__/__init__.cpython-310.py
/opt/conda/lib/python3.10/site-packages/mamba_ssm/modules/__pycache__/mlp.cpython-310.py
/opt/conda/lib/python3.10/site-packages/mamba_ssm/modules/__pycache__/mamba2.cpython-310.py
/opt/conda/lib/python3.10/site-packages/mamba_ssm/ops/triton/__pycache__/ssd_chunk_scan.cpython-310.py
/opt/conda/lib/python3.10/site-packages/mamba_ssm/distributed/__init__.py
/opt/conda/lib/python3.10/site-packages/mamba_ssm/ops/triton/layernorm_gated.py
/opt/conda/lib/python3.10/site-packages/mamba_ssm/utils/__pycache__/generation.cpython-310.py
/opt/conda/lib/python3.10/site-packages/mamba_ssm/ops/triton/__pycache__/softplus.cpython-310.py
/opt/conda/lib/python3.10/mamba.py
/kaggle/lib/mamba_ssm.py
/opt/conda/lib/python3.10/site-packages/mamba_ssm/distributed/distributed_utils.py
/opt/conda/lib/python3.10/site-packages/mamba_ssm/utils/generation.py
/opt/conda/lib/python3.10/site-packages/mamba_ssm/modules/block.py
/opt/conda/lib/python3.10/mamba_ssm.py
/opt/conda/lib/python3.10/site-packages/mamba_ssm/ops/triton/__pycache__/ssd_bmm.cpython-310.py
/kaggle/working/mamba/tests/__pycache__/test_generation.cpython-310.py
/opt/conda/lib/python3.10/site-packages/mamba_ssm/ops/triton/__pycache__/__init__.cpython-310.py
/kaggle/working/mamba_ssm.py
/opt/conda/lib/python3.10/site-packages/mamba_ssm/ops/triton/__pycache__/selective_state_update.cpython-310.py
/opt/conda/lib/python3.10/site-packages/mamba_ssm/distributed/__pycache__/__init__.cpython-310.py
/opt/conda/lib/python3.10/site-packages/mamba_ssm/utils/__pycache__/__init__.cpython-310.py
/kaggle/working/mamba/tests/__pycache__/__init__.cpython-310.py
/opt/conda/lib/python3.10/site-packages/mamba_ssm/ops/triton/ssd_state_passing.py
/opt/conda/lib/python3.10/site-packages/mamba_ssm/ops/triton/selective_state_update.py
/opt/conda/lib/python3.10/site-packages/mamba_ssm/ops/triton/softplus.py
/kaggle/working/mamba/__init__.py
/opt/conda/lib/python3.10/site-packages/mamba_ssm/modules/mlp.py
/opt/conda/lib/python3.10/site-packages/mamba_ssm/models/__init__.py
/opt/conda/lib/python3.10/site-packages/mamba_ssm/modules/mamba_simple.py
/kaggle/working/mamba/__pycache__/__init__.cpython-310.py
/opt/conda/lib/python3.10/site-packages/mamba_ssm/models/__pycache__/config_mamba.cpython-310.py
/opt/conda/lib/python3.10/lib-dynload/mamba_ssm.py
/opt/conda/lib/python3.10/site-packages/mamba_ssm/ops/triton/__pycache__/layernorm_gated.cpython-310.py
/opt/conda/lib/python3.10/site-packages/mamba_ssm/modules/mha.py
/opt/conda/lib/python3.10/site-packages/mamba_ssm/modules/__pycache__/block.cpython-310.py
/opt/conda/lib/python3.10/site-packages/mamba_ssm/ops/triton/__pycache__/k_activations.cpython-310.py
/opt/conda/lib/python3.10/site-packages/mamba_ssm/ops/triton/__pycache__/layer_norm.cpython-310.py
/opt/conda/lib/python3.10/site-packages/mamba_ssm/modules/mamba2.py
/opt/conda/lib/python3.10/site-packages/mamba_ssm/utils/__pycache__/hf.cpython-310.py
/kaggle/working/mamba/tests/__init__.py
/opt/conda/lib/python3.10/site-packages/mamba_ssm/__pycache__/__init__.cpython-310.py
/opt/conda/lib/python3.10/site-packages/mamba_ssm/distributed/tensor_parallel.py
/opt/conda/lib/python3.10/site-packages/mamba_ssm/ops/__init__.py
/opt/conda/lib/python3.10/site-packages/mamba_ssm/distributed/__pycache__/tensor_parallel.cpython-310.py
/opt/conda/lib/python3.10/site-packages/mamba_ssm/ops/triton/__pycache__/ssd_combined.cpython-310.py
/kaggle/lib/mamba.py
/opt/conda/lib/python3.10/site-packages/mamba_ssm/ops/triton/ssd_chunk_state.py
/opt/conda/lib/python3.10/site-packages/mamba_ssm/ops/triton/__pycache__/ssd_chunk_state.cpython-310.py
/opt/conda/lib/python3.10/site-packages/mamba_ssm/ops/triton/ssd_chunk_scan.py
/opt/conda/lib/python3.10/site-packages/mamba_ssm/__init__.py
/opt/conda/lib/python3.10/site-packages/mamba_ssm/utils/hf.py
/opt/conda/lib/python3.10/site-packages/mamba_ssm/ops/__pycache__/selective_scan_interface.cpython-310.py
/opt/conda/lib/python3.10/site-packages/mamba_ssm/distributed/__pycache__/distributed_utils.cpython-310.py
/opt/conda/lib/python3.10/site-packages/mamba_ssm/ops/triton/__init__.py
/opt/conda/lib/python3.10/site-packages/mamba_ssm/modules/__init__.py
/opt/conda/lib/python3.10/lib-dynload/mamba.py
/opt/conda/lib/python3.10/site-packages/mamba_ssm/ops/__pycache__/__init__.cpython-310.py
/kaggle/working/mamba/tests/test_generation.py
/opt/conda/lib/python3.10/site-packages/mamba_ssm/modules/__pycache__/__init__.cpython-310.py
/opt/conda/lib/python3.10/site-packages/mamba_ssm/ops/triton/layer_norm.py
/opt/conda/lib/python3.10/site-packages/mamba_ssm/ops/triton/ssd_bmm.py
/opt/conda/lib/python3.10/site-packages/mamba_ssm/ops/selective_scan_interface.py
/opt/conda/lib/python3.10/site-packages/mamba_ssm/models/mixer_seq_simple.py
/opt/conda/lib/python3.10/site-packages/mamba_ssm/modules/__pycache__/mamba_simple.cpython-310.py
/opt/conda/lib/python3.10/site-packages/mamba_ssm/models/config_mamba.py
/opt/conda/lib/python3.10/site-packages/mamba_ssm/ops/triton/k_activations.py
/opt/conda/lib/python3.10/site-packages/mamba_ssm/modules/__pycache__/mha.cpython-310.py
/opt/conda/lib/python3.10/site-packages/mamba_ssm/models/__pycache__/mixer_seq_simple.cpython-310.py
/root/.local/lib/python3.10/site-packages/mamba_ssm.py
/opt/conda/lib/python3.10/site-packages/mamba_ssm/ops/triton/__pycache__/ssd_state_passing.cpython-310.py
/opt/conda/lib/python3.10/site-packages/mamba_ssm/utils/__init__.py
/opt/conda/lib/python3.10/site-packages/mamba_ssm/ops/triton/ssd_combined.py
/opt/conda/lib/python3.10/site-packages/mamba_ssm/models/__pycache__/__init__.cpython-310.py
/opt/conda/lib/python3.10/site-packages/mamba_ssm/modules/__pycache__/mlp.cpython-310.py
/opt/conda/lib/python3.10/site-packages/mamba_ssm/modules/__pycache__/mamba2.cpython-310.py
/opt/conda/lib/python3.10/site-packages/mamba_ssm/ops/triton/__pycache__/ssd_chunk_scan.cpython-310.py
/opt/conda/lib/python3.10/site-packages/mamba_ssm/distributed/__init__.py
/opt/conda/lib/python3.10/site-packages/mamba_ssm/ops/triton/layernorm_gated.py
/opt/conda/lib/python3.10/site-packages/mamba_ssm/utils/__pycache__/generation.cpython-310.py
/opt/conda/lib/python3.10/site-packages/mamba_ssm/ops/triton/__pycache__/softplus.cpython-310.py
/opt/conda/lib/python3.10/mamba.py
/kaggle/lib/mamba_ssm.py
/opt/conda/lib/python3.10/site-packages/mamba_ssm/distributed/distributed_utils.py
/opt/conda/lib/python3.10/site-packages/mamba_ssm/utils/generation.py
/opt/conda/lib/python3.10/site-packages/mamba_ssm/modules/block.py
/opt/conda/lib/python3.10/mamba_ssm.py
/opt/conda/lib/python3.10/site-packages/mamba_ssm/ops/triton/__pycache__/ssd_bmm.cpython-310.py
/kaggle/working/mamba/tests/__pycache__/test_generation.cpython-310.py
/opt/conda/lib/python3.10/site-packages/mamba_ssm/ops/triton/__pycache__/__init__.cpython-310.py
/kaggle/working/mamba_ssm.py
/opt/conda/lib/python3.10/site-packages/mamba_ssm/ops/triton/__pycache__/selective_state_update.cpython-310.py
/opt/conda/lib/python3.10/site-packages/mamba_ssm/distributed/__pycache__/__init__.cpython-310.py
/opt/conda/lib/python3.10/site-packages/mamba_ssm/utils/__pycache__/__init__.cpython-310.py
/kaggle/working/mamba/tests/__pycache__/__init__.cpython-310.py
/opt/conda/lib/python3.10/site-packages/mamba_ssm/ops/triton/ssd_state_passing.py
/opt/conda/lib/python3.10/site-packages/mamba_ssm/ops/triton/selective_state_update.py
대략적으로 test_generation을 돌렸을 때 파일이 로드된 순서의 흐름은 다음과 같다.
config를 불러오는 건 test_generation에 박혀 있으므로 제외하면 순서가 이렇다.
/kaggle/working/mamba/tests/test_generation.py
/opt/conda/lib/python3.10/site-packages/mamba_ssm/modules/__pycache__/__init__.cpython-310.py
/opt/conda/lib/python3.10/site-packages/mamba_ssm/ops/triton/layer_norm.py
/opt/conda/lib/python3.10/site-packages/mamba_ssm/ops/triton/ssd_bmm.py
/opt/conda/lib/python3.10/site-packages/mamba_ssm/ops/selective_scan_interface.py
/opt/conda/lib/python3.10/site-packages/mamba_ssm/models/mixer_seq_simple.py
/opt/conda/lib/python3.10/site-packages/mamba_ssm/modules/__pycache__/mamba_simple.cpython-310.py
/opt/conda/lib/python3.10/site-packages/mamba_ssm/models/config_mamba.py
/opt/conda/lib/python3.10/site-packages/mamba_ssm/ops/triton/k_activations.py
/opt/conda/lib/python3.10/site-packages/mamba_ssm/modules/__pycache__/mha.cpython-310.py
/opt/conda/lib/python3.10/site-packages/mamba_ssm/models/__pycache__/mixer_seq_simple.cpython-310.py
layer_norm을 통과하고, ssd_bmm을 통과한다.
selective_scan_interface를 통과하여 GPU 사용 여부를 체크하고, mixer_seq_simple.py에서 모델을 불러온다.
mixer_seq_simple.py에서는 mamba_simple이 필요하고, config_mamba를 부르면서, triton의 k_activations를 부른다.
그리고 mha.py도 부르면서 다시 mixer_seq_simple.py를 부르게 된다.
정말 대략적인 순서기에 따라가면서 한번 정리해보자.
1. MambaLMHeadModel(config, device=device,dtype=dtype)
def test_generation():
batch = 3
seqlen = 20
device = "cuda"
dtype = torch.float16
config = MambaConfig(
d_model=1024,
n_layer=4,
vocab_size=50277,
ssm_cfg=dict(layer="Mamba2"),
rms_norm=True,
residual_in_fp32=True,
fused_add_norm=True,
pad_vocab_size_multiple=16,
)
torch.manual_seed(2357)
model = MambaLMHeadModel(config, device=device, dtype=dtype)
x = torch.randint(0, 1000, (batch, seqlen), device=device, dtype=torch.long)
out_ref = model(x).logits
prompt_len = seqlen // 2
out = model.generate(
input_ids = x[:, :prompt_len], max_length=seqlen, output_scores=True, return_dict_in_generate=True,
cg=True, # Can turn off CUDA graph for easier debugging
# instead of sampling, we take output tokens from x, to get logits for testing
# For actual generation, don't pass in teacher_outputs
teacher_outputs=x,
)
out_scores = torch.stack(out.scores, dim=1)
print(f"Max diff: {(out_scores - out_ref[:, prompt_len - 1: -1]).abs().max()}")
assert torch.allclose(out_scores, out_ref[:, prompt_len - 1: -1], rtol=1e-3, atol=1e-2)
테스트 코드를 보면 가장 처음에 MambaConfig를 불러온다.
이는 config_mamba.py에 있는데, 그냥 @dataclass로 정의해준 거라서 pass하도록 하겠다.
그 다음에는 MambaLMHeadModel로 모델을 불러온다.
이 객체는 mixer_seq_simple에 존재한다.
class MambaLMHeadModel(nn.Module, GenerationMixin):
def __init__(
self,
config: MambaConfig,
initializer_cfg=None,
device=None,
dtype=None,
) -> None:
self.config = config
d_model = config.d_model
n_layer = config.n_layer
d_intermediate = config.d_intermediate
vocab_size = config.vocab_size
ssm_cfg = config.ssm_cfg
attn_layer_idx = config.attn_layer_idx
attn_cfg = config.attn_cfg
rms_norm = config.rms_norm
residual_in_fp32 = config.residual_in_fp32
fused_add_norm = config.fused_add_norm
pad_vocab_size_multiple = config.pad_vocab_size_multiple
factory_kwargs = {"device": device, "dtype": dtype}
super().__init__()
if vocab_size % pad_vocab_size_multiple != 0:
vocab_size += pad_vocab_size_multiple - (vocab_size % pad_vocab_size_multiple)
self.backbone = MixerModel(
d_model=d_model,
n_layer=n_layer,
d_intermediate=d_intermediate,
vocab_size=vocab_size,
ssm_cfg=ssm_cfg,
attn_layer_idx=attn_layer_idx,
attn_cfg=attn_cfg,
rms_norm=rms_norm,
initializer_cfg=initializer_cfg,
fused_add_norm=fused_add_norm,
residual_in_fp32=residual_in_fp32,
**factory_kwargs,
)
self.lm_head = nn.Linear(d_model, vocab_size, bias=False, **factory_kwargs)
# Initialize weights and apply final processing
self.apply(
partial(
_init_weights,
n_layer=n_layer,
**(initializer_cfg if initializer_cfg is not None else {}),
)
)
self.tie_weights()
def tie_weights(self):
if self.config.tie_embeddings:
self.lm_head.weight = self.backbone.embedding.weight
def allocate_inference_cache(self, batch_size, max_seqlen, dtype=None, **kwargs):
return self.backbone.allocate_inference_cache(batch_size, max_seqlen, dtype=dtype, **kwargs)
def forward(self, input_ids, position_ids=None, inference_params=None, num_last_tokens=0, **mixer_kwargs):
"""
"position_ids" is just to be compatible with Transformer generation. We don't use it.
num_last_tokens: if > 0, only return the logits for the last n tokens
"""
hidden_states = self.backbone(input_ids, inference_params=inference_params, **mixer_kwargs)
if num_last_tokens > 0:
hidden_states = hidden_states[:, -num_last_tokens:]
lm_logits = self.lm_head(hidden_states)
CausalLMOutput = namedtuple("CausalLMOutput", ["logits"])
return CausalLMOutput(logits=lm_logits)
@classmethod
def from_pretrained(cls, pretrained_model_name, device=None, dtype=None, **kwargs):
config_data = load_config_hf(pretrained_model_name)
config = MambaConfig(**config_data)
model = cls(config, device=device, dtype=dtype, **kwargs)
model.load_state_dict(load_state_dict_hf(pretrained_model_name, device=device, dtype=dtype))
return model
def save_pretrained(self, save_directory):
"""
Minimal implementation of save_pretrained for MambaLMHeadModel.
Save the model and its configuration file to a directory.
"""
# Ensure save_directory exists
os.makedirs(save_directory, exist_ok=True)
# Save the model's state_dict
model_path = os.path.join(save_directory, 'pytorch_model.bin')
torch.save(self.state_dict(), model_path)
# Save the configuration of the model
config_path = os.path.join(save_directory, 'config.json')
with open(config_path, 'w') as f:
json.dump(self.config.__dict__, f, indent=4)
코드가 길지만 따로 따로 보면 이렇다.
class MambaLMHeadModel(nn.Module, GenerationMixin):
def __init__(
self,
config: MambaConfig,
initializer_cfg=None,
device=None,
dtype=None,
) -> None:
self.config = config
d_model = config.d_model
n_layer = config.n_layer
d_intermediate = config.d_intermediate
vocab_size = config.vocab_size
ssm_cfg = config.ssm_cfg
attn_layer_idx = config.attn_layer_idx
attn_cfg = config.attn_cfg
rms_norm = config.rms_norm
residual_in_fp32 = config.residual_in_fp32
fused_add_norm = config.fused_add_norm
pad_vocab_size_multiple = config.pad_vocab_size_multiple
factory_kwargs = {"device": device, "dtype": dtype}
super().__init__()
설정에서 주요 코드 변수를 가져와서 클래스 안의 매개변수로 바꿔준다.
self.backbone = MixerModel(...)
백본은 MixerModel이 된다.
사실 이 MixerModel이 핵심이 되기에 잠깐 짚고 넘어가보자.
2. MixerModel
class MixerModel(nn.Module):
def __init__(
self,
d_model: int,
n_layer: int,
d_intermediate: int,
vocab_size: int,
ssm_cfg=None,
attn_layer_idx=None,
attn_cfg=None,
norm_epsilon: float = 1e-5,
rms_norm: bool = False,
initializer_cfg=None,
fused_add_norm=False,
residual_in_fp32=False,
device=None,
dtype=None,
) -> None:
factory_kwargs = {"device": device, "dtype": dtype}
super().__init__()
self.residual_in_fp32 = residual_in_fp32
self.embedding = nn.Embedding(vocab_size, d_model, **factory_kwargs)
# We change the order of residual and layer norm:
# Instead of LN -> Attn / MLP -> Add, we do:
# Add -> LN -> Attn / MLP / Mixer, returning both the residual branch (output of Add) and
# the main branch (output of MLP / Mixer). The model definition is unchanged.
# This is for performance reason: we can fuse add + layer_norm.
self.fused_add_norm = fused_add_norm
if self.fused_add_norm:
if layer_norm_fn is None or rms_norm_fn is None:
raise ImportError("Failed to import Triton LayerNorm / RMSNorm kernels")
self.layers = nn.ModuleList(
[
create_block(
d_model,
d_intermediate=d_intermediate,
ssm_cfg=ssm_cfg,
attn_layer_idx=attn_layer_idx,
attn_cfg=attn_cfg,
norm_epsilon=norm_epsilon,
rms_norm=rms_norm,
residual_in_fp32=residual_in_fp32,
fused_add_norm=fused_add_norm,
layer_idx=i,
**factory_kwargs,
)
for i in range(n_layer)
]
)
self.norm_f = (nn.LayerNorm if not rms_norm else RMSNorm)(
d_model, eps=norm_epsilon, **factory_kwargs
)
self.apply(
partial(
_init_weights,
n_layer=n_layer,
**(initializer_cfg if initializer_cfg is not None else {}),
n_residuals_per_layer=1 if d_intermediate == 0 else 2, # 2 if we have MLP
)
)
def allocate_inference_cache(self, batch_size, max_seqlen, dtype=None, **kwargs):
return {
i: layer.allocate_inference_cache(batch_size, max_seqlen, dtype=dtype, **kwargs)
for i, layer in enumerate(self.layers)
}
def forward(self, input_ids, inference_params=None, **mixer_kwargs):
hidden_states = self.embedding(input_ids)
residual = None
for layer in self.layers:
hidden_states, residual = layer(
hidden_states, residual, inference_params=inference_params, **mixer_kwargs
)
if not self.fused_add_norm:
residual = (hidden_states + residual) if residual is not None else hidden_states
hidden_states = self.norm_f(residual.to(dtype=self.norm_f.weight.dtype))
else:
# Set prenorm=False here since we don't need the residual
hidden_states = layer_norm_fn(
hidden_states,
self.norm_f.weight,
self.norm_f.bias,
eps=self.norm_f.eps,
residual=residual,
prenorm=False,
residual_in_fp32=self.residual_in_fp32,
is_rms_norm=isinstance(self.norm_f, RMSNorm)
)
return hidden_states
입력 매개변수는 이렇다.
입력 매개변수:
- d_model: 모델의 은닉 차원 크기.
- n_layer: 계층 수.
- d_intermediate: MLP 층의 중간 노드 차원 크기.
- vocab_size: 어휘 크기.
- ssm_cfg, attn_layer_idx, attn_cfg: 특수 모듈 설정(예: 주의 메커니즘 관련).
- norm_epsilon: LayerNorm에 사용되는 작은 상수(정규화 안정성).
- rms_norm: RMSNorm 사용 여부.
- initializer_cfg: 가중치 초기화 설정.
- fused_add_norm: Add와 LayerNorm을 결합하여 계산 최적화 여부.
- residual_in_fp32: 잔차(residual) 계산을 FP32로 강제할지 여부.
- device, dtype: 장치 및 데이터 유형.
그 다음에는 입력 토큰을 은닉 벡터로 매핑한다.
self.embedding = nn.Embedding(vocab_size, d_model, **factory_kwargs)
그리고 Residual의 순서를 변경한다.
# Add(잔차 연결) -> LN -> Attn / MLP 순서로 변경
이유: 성능 최적화. Add와 LayerNorm을 결합하여 계산 속도를 높임.
self.fused_add_norm = fused_add_norm
if self.fused_add_norm:
if layer_norm_fn is None or rms_norm_fn is None:
raise ImportError("Failed to import Triton LayerNorm / RMSNorm kernels")
이때 Triton 라이브러리의 커스텀 LayerNorm/RMSNorm을 사용한다.
self.layers = nn.ModuleList(
[
create_block(
d_model,
d_intermediate=d_intermediate,
ssm_cfg=ssm_cfg,
attn_layer_idx=attn_layer_idx,
attn_cfg=attn_cfg,
norm_epsilon=norm_epsilon,
rms_norm=rms_norm,
residual_in_fp32=residual_in_fp32,
fused_add_norm=fused_add_norm,
layer_idx=i,
**factory_kwargs,
)
for i in range(n_layer)
]
)
본격적인 Layer Block을 생성한다.
create Block 함수로 모델을 생성하는데, n_layer만큼 생성한다.
이 Create Block 함수는 Block 클래스로 생성하므로, 여기서도 잠깐 보고 가보자.
3. Block 클래스
class Block(nn.Module):
def __init__(
self, dim, mixer_cls, mlp_cls, norm_cls=nn.LayerNorm, fused_add_norm=False, residual_in_fp32=False
):
"""
Simple block wrapping a mixer class with LayerNorm/RMSNorm and residual connection"
This Block has a slightly different structure compared to a regular
prenorm Transformer block.
The standard block is: LN -> MHA/MLP -> Add.
[Ref: https://arxiv.org/abs/2002.04745]
Here we have: Add -> LN -> Mixer, returning both
the hidden_states (output of the mixer) and the residual.
This is purely for performance reasons, as we can fuse add and LayerNorm.
The residual needs to be provided (except for the very first block).
"""
super().__init__()
self.residual_in_fp32 = residual_in_fp32
self.fused_add_norm = fused_add_norm
self.norm = norm_cls(dim)
self.mixer = mixer_cls(dim)
if mlp_cls is not nn.Identity:
self.norm2 = norm_cls(dim)
self.mlp = mlp_cls(dim)
else:
self.mlp = None
if self.fused_add_norm:
assert RMSNorm is not None, "RMSNorm import fails"
assert isinstance(
self.norm, (nn.LayerNorm, RMSNorm)
), "Only LayerNorm and RMSNorm are supported for fused_add_norm"
뼈대가 되는 블록 클래스이다.
전통적인 Transformer 블록은 일반적으로 Layer Normalization, -> Multihead Attention/MLP > Add(잔차 연결)의 순서로 구성된다.
그러나 Mamba는 Add(잔차 연결) → LN(정규화) → Mixer(핵심 연산) 순으로 구성된다.
이 구조는 성능 최적화를 위한 변경이라고 한다.
def forward(
self, hidden_states: Tensor, residual: Optional[Tensor] = None, inference_params=None, **mixer_kwargs
):
r"""Pass the input through the encoder layer.
Args:
hidden_states: the sequence to the encoder layer (required).
residual: hidden_states = Mixer(LN(residual))
"""
if not self.fused_add_norm:
residual = (hidden_states + residual) if residual is not None else hidden_states
hidden_states = self.norm(residual.to(dtype=self.norm.weight.dtype))
if self.residual_in_fp32:
residual = residual.to(torch.float32)
else:
hidden_states, residual = layer_norm_fn(
hidden_states,
self.norm.weight,
self.norm.bias,
residual=residual,
prenorm=True,
residual_in_fp32=self.residual_in_fp32,
eps=self.norm.eps,
is_rms_norm=isinstance(self.norm, RMSNorm)
)
처음에는 fused_add_norm이 처리된다.
fused_add_norm은 LayerNorm과 잔차 연결을 어떻게 처리할지 결정하는 설정이다.
이 설정에 따라 LayerNorm과 잔차 연결을 하나의 연산으로 결합할지 말지가 결정된다.
Hidden_state에 residual을 더하고 LayerNorm을 적용하는게 아니라, LayerNorm을 먼저 적용한 후 잔차 연결을 수행하게 된다.
if self.mlp is not None: # MLP가 정의된 경우에만 실행
if not self.fused_add_norm: # fused_add_norm이 False일 경우
residual = hidden_states + residual # 잔차 연결 수행
hidden_states = self.norm2(residual.to(dtype=self.norm2.weight.dtype)) # Residual에 대해 LayerNorm 적용
if self.residual_in_fp32: # residual이 FP32로 처리되어야 할 경우
residual = residual.to(torch.float32) # residual을 FP32로 변환
else: # fused_add_norm이 True일 경우
hidden_states, residual = layer_norm_fn( # LayerNorm과 residual을 결합하는 최적화된 함수 사용
hidden_states,
self.norm2.weight,
self.norm2.bias,
residual=residual,
prenorm=True,
residual_in_fp32=self.residual_in_fp32,
eps=self.norm2.eps,
is_rms_norm=isinstance(self.norm2, RMSNorm) # RMSNorm일 경우 별도 처리
)
hidden_states = self.mlp(hidden_states) # MLP를 hidden_states에 적용
return hidden_states, residual # 최종적으로 hidden_states와 residual 반환
그 다음은 MLP(Multi-Layer Perceptron)이 있을 때 그걸 처리하는 블록이다.
self.mlp가 None이 아니고 정의되어 있다면 해당 블록에서 MLP를 적용한다.
fused_add_norm이 False라면, 먼저 잔차 연결을 수행하고, hidden_states와 residual을 더한 값을 LayerNorm에 통과시킨다.
그리고 self.residual_in_fp32가 True라면 residual을 FP32로 변환하게 된다.
반대로 fused_add_norm이 True라면, LayerNorm과 잔차 연결을 동시에 수행하는 함수인 layer_norm_fn을 수행한다.
이후 나온 hidden_states에 mlp를 적용하고, 반환값으로 hidden_states와 residual을 반환한다.
hidden_states는 MLP를 적용한 후의 출력이고, resudual은 self.mlp 처리 전에 업데이트된 잔차가 된다.
def allocate_inference_cache(self, batch_size, max_seqlen, dtype=None, **kwargs):
return self.mixer.allocate_inference_cache(batch_size, max_seqlen, dtype=dtype, **kwargs)
이 부분은 추론시의 캐시 메로리 할당 과정이다.
이를 통해 추론 속도를 높일 수 있는 최적화가 된다.
4. Gated MLP
사실 block을 불러오기 전에 이게 먼저 실행된다.
if d_intermediate == 0:
mlp_cls = nn.Identity
else:
mlp_cls = partial(
GatedMLP, hidden_features=d_intermediate, out_features=d_model, **factory_kwargs
)
d_intermediate 변수가 block에 들어가는 mlp_cls도 같이 정의해주는 것이다.
partial 함수는 기존 함수를 부분적으로 호출하여 새로운 함수를 만드는 기능을 한다.
즉, 원래의 함수에 매개변수를 미리 고정시키고, 나머지 매개변수만 나중에 전달할 수 있는 함수이다.
GatedMLP 클래스의 인스턴스를 생성하는데, hidden_features와 out_features를 미리 정의하고 mlp_cls로 객체를 생성한다.
이 mlp_cls는 d_model과 d_intermediate,factory_kwargs를 매개변수로 받아 GatedMLP를 호출할 수 있는 새로운 함수가 된다.
GatedMLP 클래스를 보면, 구체적으로는 다음과 같이 풀이된다.
nn.Module을 상속받아서 PyTorch의 신경말 모델로 정의된다.
class GatedMLP(nn.Module):
def __init__(
self,
in_features,
hidden_features=None,
out_features=None,
activation=F.silu,
bias=False,
multiple_of=128,
device=None,
dtype=None,
):
- in_features: 입력 특성의 차원 수 (입력 벡터의 크기).
- hidden_features: 은닉층의 크기. 지정하지 않으면 기본값으로 in_features의 8/3배로 설정됩니다.
- out_features: 출력 특성의 차원 수 (출력 벡터의 크기). 지정하지 않으면 in_features와 동일하게 설정됩니다.
- activation: 활성화 함수로 기본값은 F.silu (Sigmoid Linear Unit).
- bias: 선형 변환에서 편향을 사용할지 여부 (기본값은 False).
- multiple_of: hidden_features의 크기가 이 수로 나누어 떨어지도록 조정합니다.
- device, dtype: 모델을 실행할 장치와 데이터 타입을 설정합니다. 이는 모델을 GPU에서 실행할 때 유용합니다.
factory_kwargs = {"device": device, "dtype": dtype}
super().__init__()
out_features = out_features if out_features is not None else in_features
hidden_features = (
hidden_features if hidden_features is not None else int(8 * in_features / 3)
)
hidden_features = (hidden_features + multiple_of - 1) // multiple_of * multiple_of
- factory_kwargs는 device와 dtype 설정을 인수로 받기 위한 딕셔너리입니다.
- out_features가 지정되지 않으면 in_features와 동일하게 설정하고, hidden_features는 기본값을 8 * in_features / 3으로 설정합니다.
- hidden_features는 multiple_of의 배수로 설정됩니다. 이렇게 함으로써 특정 값으로 정렬된 은닉층 크기를 유지합니다.
선형 변환 계층
self.fc1 = nn.Linear(in_features, 2 * hidden_features, bias=bias, **factory_kwargs)
self.activation = activation
self.fc2 = nn.Linear(hidden_features, out_features, bias=bias, **factory_kwargs)
- self.fc1: 입력 차원에서 2 * hidden_features 차원으로 변환하는 선형 계층입니다.
- self.activation: 활성화 함수로, 기본값은 SILU (Sigmoid Linear Unit).
- self.fc2: hidden_features 차원에서 out_features 차원으로 변환하는 선형 계층입니다.
Forward 메서드
def forward(self, x):
y = self.fc1(x)
y, gate = y.chunk(2, dim=-1)
y = y * self.activation(gate)
y = self.fc2(y)
return y
- y = self.fc1(x): 입력 x를 첫 번째 선형 계층(fc1)을 통해 처리합니다. 이때 출력 차원은 2 * hidden_features입니다.
- y, gate = y.chunk(2, dim=-1): 출력 y를 두 부분으로 나눕니다. y와 gate는 2 * hidden_features에서 각각 hidden_features씩 나눠지며, dim=-1로 마지막 차원을 기준으로 분리됩니다.
- y = y * self.activation(gate): gate를 활성화 함수에 통과시켜 y와 곱합니다. 이 방식은 Gate Mechanism을 통해 입력을 조절하는 방식입니다.
- y = self.fc2(y): 최종적으로, y를 두 번째 선형 계층(fc2)을 통과시켜 out_features 차원으로 변환합니다.
- return y: 최종 출력 y를 반환합니다.
게이트 메커니즘이 들어가는데, 출력 y를 두 부분으로 나누고 하나는 게이트로, 다른 하나는 출력으로 사용한다.
chunk(2, dim=-1)은 출력 텐서를 두 개의 같은 크기 부분으로 나누는 연산으로, 2*hidden_features를 나누면 hidden_features가 각 부분의 크기가 된다.
y는 나중에 출력으로 사용될 값이며, gate는 입력값을 조절하는 역할이 된다. 이 값이 y*self.activation(gate)로 들어가서 게이트 메커니즘이 된다.
예시 계산
hidden_features = (
hidden_features if hidden_features is not None else int(8 * in_features / 3)
)
hidden_features = (hidden_features + multiple_of - 1) // multiple_of * multiple_of
- hidden_features가 None인 경우: hidden_features가 주어지지 않으면, hidden_features는 8 * in_features / 3으로 설정됩니다. 이는 입력 차원 수에 비례하여 숨겨진 차원을 결정하는 방식입니다.
- multiple_of로 맞추기: hidden_features는 multiple_of (기본값: 128)의 배수로 맞춰져야 합니다. 그래서 (hidden_features + multiple_of - 1) // multiple_of * multiple_of 계산을 통해, hidden_features가 multiple_of의 배수가 되도록 올림 처리를 합니다.
in_features = 300, hidden_features = None, multiple_of = 128
- 첫 번째 단계: hidden_features가 None이므로 hidden_features = int(8 * in_features / 3)로 계산됩니다.
hidden_features=38×300=800
2. 두 번째 단계: hidden_features = 800이지만, 이것이 multiple_of = 128의 배수가 아니므로, multiple_of로 맞추기 위해 계산이 필요합니다.
hidden_features=(128800+128−1)×128=(128927)×128=928
그래서, hidden_features = 928이 됩니다. 즉, hidden_features는 128의 배수로 올림 처리되었습니다.
MLP 동작
hidden_features 값이 결정되면, 이 값은 fc1 (첫 번째 선형 계층)의 차원 수와 연결됩니다.
예시 1의 경우 (in_features = 300, hidden_features = 928):
- **fc1**의 입력 차원은 in_features = 300, 출력 차원은 2 * hidden_features = 2 * 928 = 1856입니다.
- **fc1**을 통해 입력 x (크기 300)의 출력을 얻은 후, 이 출력은 두 부분으로 나누어집니다: 하나는 y, 하나는 gate입니다. (각각 크기 928씩)
- **y와 gate**는 y, gate = y.chunk(2, dim=-1)에서 나누어지고, y는 gate와 활성화 함수를 거쳐 곱해집니다.
y=y×activation(gate)
이때 gate가 0에 가까우면 y 값은 거의 0으로 억제되고, gate가 1에 가까우면 y 값은 그대로 유지됩니다.
알아둬야 할 게, activation(gate) == 무조건 0~1 사이의 값을 출력한다.
즉, y값이 점진적으로 억제되거나 활성화 될 수 있는 기반을 마련한다.
그렇기에 이게 게이트 메커니즘이 되는 것이다.
5. 다시 믹서 모델로
이제 거의 다 왔다.
이렇게 정의해주고 나면, 믹서 모델에서 forward를 통해 은닉층을 계산하여 반환한다.
def forward(self, input_ids, inference_params=None, **mixer_kwargs):
hidden_states = self.embedding(input_ids)
residual = None
for layer in self.layers:
hidden_states, residual = layer(
hidden_states, residual, inference_params=inference_params, **mixer_kwargs
)
if not self.fused_add_norm:
residual = (hidden_states + residual) if residual is not None else hidden_states
hidden_states = self.norm_f(residual.to(dtype=self.norm_f.weight.dtype))
else:
# Set prenorm=False here since we don't need the residual
hidden_states = layer_norm_fn(
hidden_states,
self.norm_f.weight,
self.norm_f.bias,
eps=self.norm_f.eps,
residual=residual,
prenorm=False,
residual_in_fp32=self.residual_in_fp32,
is_rms_norm=isinstance(self.norm_f, RMSNorm)
)
return hidden_states
layer_norm_fn으로 합쳐주거나 residual을 해주는 건 덤이다.
residual이 None으로 시작하지만, layer의 for 루프 내에서 갱신된다.
fused_add_norm이 False일 경우, residual이 None이 아니면 hidden_states와 residual이 더해져서 갱신된다.
그렇게 은닉층을 반환하면, 다시 가장 처음의 MambaLMHeadModel로 가게 된다.
6. SSM 레이어는 어디에 있나?
ssm 레이어는 create_block 함수에 있다.
이는 ssm_cfg 설정에 따라 바뀐다.
from mamba_ssm.modules.mamba_simple import Mamba
from mamba_ssm.modules.mamba2 import Mamba2
def create_block(
d_model,
d_intermediate,
ssm_cfg=None,
attn_layer_idx=None,
attn_cfg=None,
norm_epsilon=1e-5,
rms_norm=False,
residual_in_fp32=False,
fused_add_norm=False,
layer_idx=None,
device=None,
dtype=None,
):
if ssm_cfg is None:
ssm_cfg = {}
if attn_layer_idx is None:
attn_layer_idx = []
if attn_cfg is None:
attn_cfg = {}
factory_kwargs = {"device": device, "dtype": dtype}
if layer_idx not in attn_layer_idx:
# Create a copy of the config to modify
ssm_cfg = copy.deepcopy(ssm_cfg) if ssm_cfg is not None else {}
ssm_layer = ssm_cfg.pop("layer", "Mamba1")
if ssm_layer not in ["Mamba1", "Mamba2"]:
raise ValueError(f"Invalid ssm_layer: {ssm_layer}, only support Mamba1 and Mamba2")
mixer_cls = partial(
Mamba2 if ssm_layer == "Mamba2" else Mamba,
layer_idx=layer_idx,
**ssm_cfg,
**factory_kwargs
)
else:
mixer_cls = partial(MHA, layer_idx=layer_idx, **attn_cfg, **factory_kwargs)
norm_cls = partial(
nn.LayerNorm if not rms_norm else RMSNorm, eps=norm_epsilon, **factory_kwargs
)
if d_intermediate == 0:
mlp_cls = nn.Identity
else:
mlp_cls = partial(
GatedMLP, hidden_features=d_intermediate, out_features=d_model, **factory_kwargs
)
block = Block(
d_model,
mixer_cls,
mlp_cls,
norm_cls=norm_cls,
fused_add_norm=fused_add_norm,
residual_in_fp32=residual_in_fp32,
)
block.layer_idx = layer_idx
return block
ssm_cfg**는 SSM (Structured State Matrix) 레이어의 설정을 담고 있는 딕셔너리다.
만약 ssm_cfg가 None이라면, 빈 딕셔너리로 초기화된다.
if layer_idx not in attn_layer_idx:
ssm_cfg = copy.deepcopy(ssm_cfg) if ssm_cfg is not None else {}
ssm_layer = ssm_cfg.pop("layer", "Mamba1")
if ssm_layer not in ["Mamba1", "Mamba2"]:
raise ValueError(f"Invalid ssm_layer: {ssm_layer}, only support Mamba1 and Mamba2")
mixer_cls = partial(
Mamba2 if ssm_layer == "Mamba2" else Mamba,
layer_idx=layer_idx,
**ssm_cfg,
**factory_kwargs
)
**attn_layer_idx**는 어텐션 레이어의 인덱스를 담고 있으며, 이 인덱스에 포함되지 않는 경우 ssm_cfg에서 **layer**라는 키를 찾아 해당 레이어의 이름(Mamba1 또는 Mamba2)을 결정합니다.
- Mamba1 또는 Mamba2는 SSM 레이어 종류를 나타냅니다.
- 이건 Mamba에서 파생된 Mamba2 모델이 있는데, 그 코드 때문에 들어간 것이다. 나중에 또 따로 볼 것이다.
- ssm_cfg에서 layer를 꺼내서 Mamba1 또는 Mamba2 중 하나를 선택하고, 그에 맞는 클래스를 mixer_cls로 설정합니다.
- partial(Mamba2 if ssm_layer == "Mamba2" else Mamba, ...)에서 Mamba 또는 Mamba2 클래스를 동적으로 선택합니다.
else:
mixer_cls = partial(MHA, layer_idx=layer_idx, **attn_cfg, **factory_kwargs)
만약 layer_idx가 attn_layer_idx에 포함된다면, MHA (Multi-Head Attention) 레이어가 선택됩니다.
- 이 경우, **mixer_cls**는 MHA 레이어를 사용하는 partial 함수로 설정됩니다.
- attn_cfg에 정의된 설정을 사용하여 MHA 레이어를 구성합니다.
결론적으로, mixer_cls가 Mamba,Mamba2 혹은 MHA로 들어갈 수 있게 되어 있다.
그렇다면, 대망의 핵심인 마지막 Mamba_simple의 코드를 보자.
7. 대망의 Mamba_simple
# Copyright (c) 2023, Tri Dao, Albert Gu.
import math
from typing import Optional
import torch
import torch.nn as nn
import torch.nn.functional as F
from torch import Tensor
from einops import rearrange, repeat
from mamba_ssm.ops.selective_scan_interface import selective_scan_fn, mamba_inner_fn
try:
from causal_conv1d import causal_conv1d_fn, causal_conv1d_update
except ImportError:
causal_conv1d_fn, causal_conv1d_update = None, None
try:
from mamba_ssm.ops.triton.selective_state_update import selective_state_update
except ImportError:
selective_state_update = None
try:
from mamba_ssm.ops.triton.layer_norm import RMSNorm, layer_norm_fn, rms_norm_fn
except ImportError:
RMSNorm, layer_norm_fn, rms_norm_fn = None, None, None
class Mamba(nn.Module):
def __init__(
self,
d_model,
d_state=16,
d_conv=4,
expand=2,
dt_rank="auto",
dt_min=0.001,
dt_max=0.1,
dt_init="random",
dt_scale=1.0,
dt_init_floor=1e-4,
conv_bias=True,
bias=False,
use_fast_path=True, # Fused kernel options
layer_idx=None,
device=None,
dtype=None,
):
factory_kwargs = {"device": device, "dtype": dtype}
super().__init__()
self.d_model = d_model
self.d_state = d_state
self.d_conv = d_conv
self.expand = expand
self.d_inner = int(self.expand * self.d_model)
self.dt_rank = math.ceil(self.d_model / 16) if dt_rank == "auto" else dt_rank
self.use_fast_path = use_fast_path
self.layer_idx = layer_idx
self.in_proj = nn.Linear(self.d_model, self.d_inner * 2, bias=bias, **factory_kwargs)
self.conv1d = nn.Conv1d(
in_channels=self.d_inner,
out_channels=self.d_inner,
bias=conv_bias,
kernel_size=d_conv,
groups=self.d_inner,
padding=d_conv - 1,
**factory_kwargs,
)
self.activation = "silu"
self.act = nn.SiLU()
self.x_proj = nn.Linear(
self.d_inner, self.dt_rank + self.d_state * 2, bias=False, **factory_kwargs
)
self.dt_proj = nn.Linear(self.dt_rank, self.d_inner, bias=True, **factory_kwargs)
# Initialize special dt projection to preserve variance at initialization
dt_init_std = self.dt_rank**-0.5 * dt_scale
if dt_init == "constant":
nn.init.constant_(self.dt_proj.weight, dt_init_std)
elif dt_init == "random":
nn.init.uniform_(self.dt_proj.weight, -dt_init_std, dt_init_std)
else:
raise NotImplementedError
# Initialize dt bias so that F.softplus(dt_bias) is between dt_min and dt_max
dt = torch.exp(
torch.rand(self.d_inner, **factory_kwargs) * (math.log(dt_max) - math.log(dt_min))
+ math.log(dt_min)
).clamp(min=dt_init_floor)
# Inverse of softplus: https://github.com/pytorch/pytorch/issues/72759
inv_dt = dt + torch.log(-torch.expm1(-dt))
with torch.no_grad():
self.dt_proj.bias.copy_(inv_dt)
# Our initialization would set all Linear.bias to zero, need to mark this one as _no_reinit
self.dt_proj.bias._no_reinit = True
# S4D real initialization
A = repeat(
torch.arange(1, self.d_state + 1, dtype=torch.float32, device=device),
"n -> d n",
d=self.d_inner,
).contiguous()
A_log = torch.log(A) # Keep A_log in fp32
self.A_log = nn.Parameter(A_log)
self.A_log._no_weight_decay = True
# D "skip" parameter
self.D = nn.Parameter(torch.ones(self.d_inner, device=device)) # Keep in fp32
self.D._no_weight_decay = True
self.out_proj = nn.Linear(self.d_inner, self.d_model, bias=bias, **factory_kwargs)
def forward(self, hidden_states, inference_params=None):
"""
hidden_states: (B, L, D)
Returns: same shape as hidden_states
"""
batch, seqlen, dim = hidden_states.shape
conv_state, ssm_state = None, None
if inference_params is not None:
conv_state, ssm_state = self._get_states_from_cache(inference_params, batch)
if inference_params.seqlen_offset > 0:
# The states are updated inplace
out, _, _ = self.step(hidden_states, conv_state, ssm_state)
return out
# We do matmul and transpose BLH -> HBL at the same time
xz = rearrange(
self.in_proj.weight @ rearrange(hidden_states, "b l d -> d (b l)"),
"d (b l) -> b d l",
l=seqlen,
)
if self.in_proj.bias is not None:
xz = xz + rearrange(self.in_proj.bias.to(dtype=xz.dtype), "d -> d 1")
A = -torch.exp(self.A_log.float()) # (d_inner, d_state)
# In the backward pass we write dx and dz next to each other to avoid torch.cat
if self.use_fast_path and causal_conv1d_fn is not None and inference_params is None: # Doesn't support outputting the states
out = mamba_inner_fn(
xz,
self.conv1d.weight,
self.conv1d.bias,
self.x_proj.weight,
self.dt_proj.weight,
self.out_proj.weight,
self.out_proj.bias,
A,
None, # input-dependent B
None, # input-dependent C
self.D.float(),
delta_bias=self.dt_proj.bias.float(),
delta_softplus=True,
)
else:
x, z = xz.chunk(2, dim=1)
# Compute short convolution
if conv_state is not None:
# If we just take x[:, :, -self.d_conv :], it will error if seqlen < self.d_conv
# Instead F.pad will pad with zeros if seqlen < self.d_conv, and truncate otherwise.
conv_state.copy_(F.pad(x, (self.d_conv - x.shape[-1], 0))) # Update state (B D W)
if causal_conv1d_fn is None:
x = self.act(self.conv1d(x)[..., :seqlen])
else:
assert self.activation in ["silu", "swish"]
x = causal_conv1d_fn(
x=x,
weight=rearrange(self.conv1d.weight, "d 1 w -> d w"),
bias=self.conv1d.bias,
activation=self.activation,
)
# We're careful here about the layout, to avoid extra transposes.
# We want dt to have d as the slowest moving dimension
# and L as the fastest moving dimension, since those are what the ssm_scan kernel expects.
x_dbl = self.x_proj(rearrange(x, "b d l -> (b l) d")) # (bl d)
dt, B, C = torch.split(x_dbl, [self.dt_rank, self.d_state, self.d_state], dim=-1)
dt = self.dt_proj.weight @ dt.t()
dt = rearrange(dt, "d (b l) -> b d l", l=seqlen)
B = rearrange(B, "(b l) dstate -> b dstate l", l=seqlen).contiguous()
C = rearrange(C, "(b l) dstate -> b dstate l", l=seqlen).contiguous()
assert self.activation in ["silu", "swish"]
y = selective_scan_fn(
x,
dt,
A,
B,
C,
self.D.float(),
z=z,
delta_bias=self.dt_proj.bias.float(),
delta_softplus=True,
return_last_state=ssm_state is not None,
)
if ssm_state is not None:
y, last_state = y
ssm_state.copy_(last_state)
y = rearrange(y, "b d l -> b l d")
out = self.out_proj(y)
return out
def step(self, hidden_states, conv_state, ssm_state):
dtype = hidden_states.dtype
assert hidden_states.shape[1] == 1, "Only support decoding with 1 token at a time for now"
xz = self.in_proj(hidden_states.squeeze(1)) # (B 2D)
x, z = xz.chunk(2, dim=-1) # (B D)
# Conv step
if causal_conv1d_update is None:
conv_state.copy_(torch.roll(conv_state, shifts=-1, dims=-1)) # Update state (B D W)
conv_state[:, :, -1] = x
x = torch.sum(conv_state * rearrange(self.conv1d.weight, "d 1 w -> d w"), dim=-1) # (B D)
if self.conv1d.bias is not None:
x = x + self.conv1d.bias
x = self.act(x).to(dtype=dtype)
else:
x = causal_conv1d_update(
x,
conv_state,
rearrange(self.conv1d.weight, "d 1 w -> d w"),
self.conv1d.bias,
self.activation,
)
x_db = self.x_proj(x) # (B dt_rank+2*d_state)
dt, B, C = torch.split(x_db, [self.dt_rank, self.d_state, self.d_state], dim=-1)
# Don't add dt_bias here
dt = F.linear(dt, self.dt_proj.weight) # (B d_inner)
A = -torch.exp(self.A_log.float()) # (d_inner, d_state)
# SSM step
if selective_state_update is None:
# Discretize A and B
dt = F.softplus(dt + self.dt_proj.bias.to(dtype=dt.dtype))
dA = torch.exp(torch.einsum("bd,dn->bdn", dt, A))
dB = torch.einsum("bd,bn->bdn", dt, B)
ssm_state.copy_(ssm_state * dA + rearrange(x, "b d -> b d 1") * dB)
y = torch.einsum("bdn,bn->bd", ssm_state.to(dtype), C)
y = y + self.D.to(dtype) * x
y = y * self.act(z) # (B D)
else:
y = selective_state_update(
ssm_state, x, dt, A, B, C, self.D, z=z, dt_bias=self.dt_proj.bias, dt_softplus=True
)
out = self.out_proj(y)
return out.unsqueeze(1), conv_state, ssm_state
def allocate_inference_cache(self, batch_size, max_seqlen, dtype=None, **kwargs):
device = self.out_proj.weight.device
conv_dtype = self.conv1d.weight.dtype if dtype is None else dtype
conv_state = torch.zeros(
batch_size, self.d_model * self.expand, self.d_conv, device=device, dtype=conv_dtype
)
ssm_dtype = self.dt_proj.weight.dtype if dtype is None else dtype
# ssm_dtype = torch.float32
ssm_state = torch.zeros(
batch_size, self.d_model * self.expand, self.d_state, device=device, dtype=ssm_dtype
)
return conv_state, ssm_state
def _get_states_from_cache(self, inference_params, batch_size, initialize_states=False):
assert self.layer_idx is not None
if self.layer_idx not in inference_params.key_value_memory_dict:
batch_shape = (batch_size,)
conv_state = torch.zeros(
batch_size,
self.d_model * self.expand,
self.d_conv,
device=self.conv1d.weight.device,
dtype=self.conv1d.weight.dtype,
)
ssm_state = torch.zeros(
batch_size,
self.d_model * self.expand,
self.d_state,
device=self.dt_proj.weight.device,
dtype=self.dt_proj.weight.dtype,
# dtype=torch.float32,
)
inference_params.key_value_memory_dict[self.layer_idx] = (conv_state, ssm_state)
else:
conv_state, ssm_state = inference_params.key_value_memory_dict[self.layer_idx]
# TODO: What if batch size changes between generation, and we reuse the same states?
if initialize_states:
conv_state.zero_()
ssm_state.zero_()
return conv_state, ssm_state
mamba 클래스 초기화
이 코드는 Mamba라는 신경망 모듈을 정의한 코드입니다. Mamba는 SSM (Structured State Matrix) 및 **가법적인 상태 변환 (state update)**을 위한 특별한 연산을 포함한 신경망 블록입니다. 이 모듈은 주로 시퀀스 모델링에 사용되며, 다양한 초기화 방식과 빠른 경로를 통해 효율적인 계산을 지원합니다. 아래에서 코드의 주요 부분을 하나씩 설명하겠습니다.
1. 필요한 라이브러리 임포트
import math
from typing import Optional
import torch
import torch.nn as nn
import torch.nn.functional as F
from torch import Tensor
from einops import rearrange, repeat
- torch 및 torch.nn: PyTorch에서 제공하는 텐서 연산 및 신경망 모듈입니다.
- einops: 텐서의 형태를 변경하거나 반복할 수 있는 라이브러리입니다.
- 그 외 selective_scan_fn, mamba_inner_fn 등은 SSM 및 트라이튼 (Triton) 최적화된 연산을 위한 함수들입니다.
2. Mamba 클래스의 초기화
class Mamba(nn.Module):
def __init__(self, ...):
super().__init__()
# Hyperparameters
self.d_model = d_model
self.d_state = d_state
self.d_conv = d_conv
self.expand = expand
self.d_inner = int(self.expand * self.d_model)
self.dt_rank = math.ceil(self.d_model / 16) if dt_rank == "auto" else dt_rank
self.use_fast_path = use_fast_path
self.layer_idx = layer_idx
- d_model: 모델의 차원
- d_state: 상태 벡터의 차원
- d_conv: 합성곱 필터 크기
- expand: 모델 차원을 확장하는 비율
- dt_rank: 시간 차원(rank) 초기화 방법
- use_fast_path: 빠른 경로 사용 여부
3. in_proj, conv1d, activation, x_proj, dt_proj 등
self.in_proj = nn.Linear(self.d_model, self.d_inner * 2, bias=bias, **factory_kwargs)
self.conv1d = nn.Conv1d(in_channels=self.d_inner, out_channels=self.d_inner, ...)
self.activation = "silu"
self.act = nn.SiLU()
self.x_proj = nn.Linear(self.d_inner, self.dt_rank + self.d_state * 2, bias=False, **factory_kwargs)
self.dt_proj = nn.Linear(self.dt_rank, self.d_inner, bias=True, **factory_kwargs)
- in_proj: 입력을 처리하는 선형 변환.
- conv1d: 1D 합성곱 레이어, 상태 변환을 위한 핵심 요소입니다.
- activation: 활성화 함수로 SiLU (Sigmoid Linear Unit)을 사용합니다.
- x_proj: 상태 벡터의 변환.
- dt_proj: 시간에 따른 변환을 위한 선형 레이어.
4. 초기화 단계
# Initialize special dt projection to preserve variance at initialization
dt_init_std = self.dt_rank**-0.5 * dt_scale
if dt_init == "constant":
nn.init.constant_(self.dt_proj.weight, dt_init_std)
elif dt_init == "random":
nn.init.uniform_(self.dt_proj.weight, -dt_init_std, dt_init_std)
- **dt_proj**는 모델의 시간 변화에 따른 파라미터를 조정하는 역할을 합니다. 초기화 시 constant 또는 random 방식으로 값을 설정합니다.
5. forward 함수
이건 더 자세히 설명한다.
이 함수는 hidden_states를 입력받고, 이를 처리하여 출력값을 계산합니다.
forward 함수의 구조
def forward(self, hidden_states, inference_params=None):
"""
hidden_states: (B, L, D) # Batch, Sequence Length, Model Dimension
Returns: same shape as hidden_states
"""
batch, seqlen, dim = hidden_states.shape
conv_state, ssm_state = None, None
if inference_params is not None:
conv_state, ssm_state = self._get_states_from_cache(inference_params, batch)
if inference_params.seqlen_offset > 0:
# The states are updated inplace
out, _, _ = self.step(hidden_states, conv_state, ssm_state)
return out
1. 입력 매개변수
- hidden_states: (B, L, D) 형태로 주어지며, B는 배치 크기(batch size), L은 시퀀스 길이(sequence length), D는 모델의 차원(model dimension)을 의미합니다.
- inference_params: 추론 중에 필요한 파라미터를 담고 있습니다. 이를 통해 상태 업데이트 및 캐시를 관리합니다.
2. 배치 크기, 시퀀스 길이, 차원 추출
batch, seqlen, dim = hidden_states.shape
- 입력 데이터의 차원에서 배치 크기(batch), 시퀀스 길이(seqlen), 모델 차원(dim)을 추출합니다.
3. inference_params 존재 시 캐시에서 상태 가져오기
conv_state, ssm_state = None, None
if inference_params is not None:
conv_state, ssm_state = self._get_states_from_cache(inference_params, batch)
if inference_params.seqlen_offset > 0:
# The states are updated inplace
out, _, _ = self.step(hidden_states, conv_state, ssm_state)
return out
- inference_params가 제공되면, 캐시에서 conv_state와 ssm_state를 불러옵니다.
- 만약 시퀀스의 오프셋(seqlen_offset)이 0보다 크면, 이는 시퀀스가 이미 일부 처리되었음을 나타내며, 상태를 업데이트하고 새로운 출력을 바로 반환합니다.
4. 입력 데이터를 처리하는 부분
xz = rearrange(
self.in_proj.weight @ rearrange(hidden_states, "b l d -> d (b l)"),
"d (b l) -> b d l",
l=seqlen,
)
if self.in_proj.bias is not None:
xz = xz + rearrange(self.in_proj.bias.to(dtype=xz.dtype), "d -> d 1")
- hidden_states를 in_proj의 가중치와 곱하고, rearrange를 사용하여 텐서의 차원을 변경합니다.
- hidden_states의 차원 (B, L, D)를 rearrange하여 (D, B * L)로 펼친 후, in_proj의 가중치와 곱합니다.
- 결과적으로 xz는 (B, D, L)의 형태로 변환됩니다.
- in_proj.bias가 있을 경우, 이를 추가하여 xz에 더합니다.
5. 상태 및 파라미터 준비
A = -torch.exp(self.A_log.float()) # (d_inner, d_state)
- A_log는 로그 상태 변수로, 이를 이용해 A를 계산합니다. A는 상태 변환에 사용되는 매개변수로, 모델의 동작을 제어합니다.
6. 빠른 경로 사용 시 (use_fast_path)
if self.use_fast_path and causal_conv1d_fn is not None and inference_params is None:
out = mamba_inner_fn(
xz,
self.conv1d.weight,
self.conv1d.bias,
self.x_proj.weight,
self.dt_proj.weight,
self.out_proj.weight,
self.out_proj.bias,
A,
None, # input-dependent B
None, # input-dependent C
self.D.float(),
delta_bias=self.dt_proj.bias.float(),
delta_softplus=True,
)
- 빠른 경로(use_fast_path)가 활성화되어 있고, causal_conv1d_fn이 사용 가능하다면, mamba_inner_fn 함수를 호출하여 최적화된 연산을 사용합니다.
- mamba_inner_fn은 내부적으로 상태 변환과 합성곱 연산을 결합하여 효율적으로 계산합니다.
- 이 함수는 입력 데이터를 처리하기 위해 xz, conv1d 가중치, x_proj, dt_proj, out_proj 가중치 등을 사용합니다.
7. 빠른 경로 사용이 불가능할 때
else:
x, z = xz.chunk(2, dim=1)
# Compute short convolution
if conv_state is not None:
conv_state.copy_(F.pad(x, (self.d_conv - x.shape[-1], 0))) # Update state (B D W)
if causal_conv1d_fn is None:
x = self.act(self.conv1d(x)[..., :seqlen])
else:
assert self.activation in ["silu", "swish"]
x = causal_conv1d_fn(
x=x,
weight=rearrange(self.conv1d.weight, "d 1 w -> d w"),
bias=self.conv1d.bias,
activation=self.activation,
)
- **xz**는 두 부분으로 나뉘며, 각각 x와 z로 분리됩니다.
- x는 합성곱 연산에 사용되는 데이터이고, z는 상태 업데이트에 필요한 데이터입니다.
- 합성곱 상태(conv_state)가 있다면 이를 업데이트합니다. 여기서 F.pad를 사용하여 x를 패딩하고, 상태를 갱신합니다.
- causal_conv1d_fn이 없으면 기본적인 합성곱 연산을 수행하고, causal_conv1d_fn이 있으면 이를 사용하여 합성곱 연산을 합니다.
8. 상태 변환
x_dbl = self.x_proj(rearrange(x, "b d l -> (b l) d")) # (bl d)
dt, B, C = torch.split(x_dbl, [self.dt_rank, self.d_state, self.d_state], dim=-1)
dt = self.dt_proj.weight @ dt.t()
dt = rearrange(dt, "d (b l) -> b d l", l=seqlen)
B = rearrange(B, "(b l) dstate -> b dstate l", l=seqlen).contiguous()
C = rearrange(C, "(b l) dstate -> b dstate l", l=seqlen).contiguous()
- x에 대해 x_proj를 적용하여 변환된 값(x_dbl)을 얻고, 이를 다시 dt, B, C로 나눕니다.
- dt는 시간 관련 변수, B와 C는 상태 변환에 필요한 벡터들입니다.
- dt는 dt_proj 가중치를 사용하여 변환됩니다.
9. 상태 스캔 및 업데이트
y = selective_scan_fn(
x,
dt,
A,
B,
C,
self.D.float(),
z=z,
delta_bias=self.dt_proj.bias.float(),
delta_softplus=True,
return_last_state=ssm_state is not None,
)
- selective_scan_fn 함수는 SSM에 기반한 상태 업데이트를 수행합니다. 이는 시간적인 상태 변화를 반영하여 최종 출력을 계산합니다.
- 이 함수는 x, dt, A, B, C, D, z 등 여러 인자를 받아 상태 변환 및 출력 계산을 처리합니다.
- ssm_state가 제공되면 마지막 상태를 반환합니다.
10. 최종 출력
if ssm_state is not None:
y, last_state = y
ssm_state.copy_(last_state)
y = rearrange(y, "b d l -> b l d")
out = self.out_proj(y)
- 최종적으로 상태가 존재하면 이를 ssm_state에 복사하고, y를 out_proj에 전달하여 최종 출력을 계산합니다.
- y는 (B, L, D) 형태로 재구성되어 출력됩니다.
6. 빠른 경로 및 상태 업데이트
if self.use_fast_path and causal_conv1d_fn is not None and inference_params is None:
out = mamba_inner_fn(xz, ...)
- 빠른 경로를 사용할 경우, mamba_inner_fn 함수를 사용하여 상태 변환 및 합성곱 연산을 더 효율적으로 처리합니다.
7. step 함수
def step(self, hidden_states, conv_state, ssm_state):
...
# Process for a single token at a time (decoding step)
xz = self.in_proj(hidden_states.squeeze(1))
...
# Update convolution state
conv_state.copy_(torch.roll(conv_state, shifts=-1, dims=-1))
conv_state[:, :, -1] = x
...
return out.unsqueeze(1), conv_state, ssm_state
- step 함수는 단일 시퀀스 토큰에 대해 수행됩니다. 디코딩 단계에서 이전 상태를 이용하여 새로운 출력을 생성합니다.
- conv_state와 ssm_state는 모델의 상태를 나타내며, 이들은 순차적으로 업데이트됩니다.
8. 상태 캐시 및 초기화
def allocate_inference_cache(self, batch_size, max_seqlen, dtype=None, **kwargs):
conv_state = torch.zeros(batch_size, self.d_model * self.expand, self.d_conv, device=device, dtype=conv_dtype)
ssm_state = torch.zeros(batch_size, self.d_model * self.expand, self.d_state, device=device, dtype=ssm_dtype)
return conv_state, ssm_state
- allocate_inference_cache 함수는 추론 과정에서 필요한 상태 변수 (conv_state, ssm_state)를 초기화하고 반환합니다.
9. SSM 상태 업데이트
def _get_states_from_cache(self, inference_params, batch_size, initialize_states=False):
...
inference_params.key_value_memory_dict[self.layer_idx] = (conv_state, ssm_state)
return conv_state, ssm_state
- _get_states_from_cache 함수는 캐시에서 상태 변수를 가져옵니다. 만약 상태를 초기화해야 할 필요가 있다면, 이 변수들은 0으로 설정됩니다.
결론
이 코드에서 Mamba는 고속 경로 및 상태 업데이트를 통해 시퀀스 모델링에 최적화된 연산을 수행합니다. selective_scan_fn 및 SSM과 관련된 연산을 포함하여, 디코딩 및 인퍼런스에서 매우 효율적으로 작동하도록 설계되었습니다.
마치며
Transformer에 비해서 모듈은 적지만 코드의 실행 순서를 이해하는 데 있어서 조금 시간이 걸리긴 했다.
그래도 Mamba의 의의와 구현에 대한 것은 확실히 알게 된 것 같다.
여기에서 파생된 Mamba2의 코드도 있어서 헷갈릴 법 한데, 그래도 이 정도면 코드가 해석하기에 난해하지도 않고, 좋았다.
나중에 Mamba2 논문도 같이 보면 좋을거 같다.
4편을 마친다.