포스팅 개요
본 포스팅은 거대 언어 모델(LLM)의 추론(inference) 과정에서 발생하는 심각한 메모리 병목 현상을 해결하기 위한 핵심 최적화 기법인 Grouped-Query Attention(GQA)을 소개합니다. GQA는 구글 리서치에서 제안한 기술로 기존의 표준 Attention 방식인 Multi-Head Attention(MHA)의 높은 성능은 유지하면서, 추론 속도를 극대화한 Multi-Query Attention(MQA)의 장점을 결합한 아키텍처입니다.
본 포스팅에서는 MHA에서 MQA를 거쳐 GQA에 이르기까지 Attention 메커니즘의 발전 과정을 추적하고, GQA가 어떻게 속도와 성능이라는 두 마리 토끼를 모두 잡을 수 있었는지 그 원리를 상세히 분석합니다. 또한, Llama 2, Mistral 7B 등 최신 LLM에 GQA가 어떻게 적용되어 실제 성능 향상을 이끌어냈는지 구체적인 실험 결과를 통해 확인합니다.
논문 링크: https://arxiv.org/pdf/2305.13245

포스팅 본문
1. 핵심 요약
거대 언어 모델(LLM)을 실제로 서비스하는 데 있어 가장 큰 장애물 중 하나는 추론 과정의 높은 메모리 사용량입니다. 특히, Transformer 모델의 핵심인 Attention 메커니즘은 매 토큰을 생성할 때마다 모든 이전 토큰들의 Key와 Value 값을 메모리에서 불러와야 하므로 심각한 병목 현상을 유발합니다.
이 문제를 해결하기 위해, 기존의 표준 방식인 Multi-Head Attention (MHA)의 높은 품질과, 추론 속도를 극단적으로 개선했지만 성능 저하의 위험이 있던 Multi-Query Attention (MQA)의 장점을 절충한 Grouped-Query Attention (GQA)이 제안되었습니다. GQA의 핵심 아이디어는 여러 개의 Query 헤드들을 몇 개의 그룹으로 묶고, 각 그룹이 단일 Key-Value 헤드를 공유하도록 하는 것입니다. 이러한 구조적 변경을 통해 GQA는 MQA처럼 메모리 사용량을 획기적으로 줄여 추론 속도를 높이면서도, MHA와 거의 근접한 높은 모델 성능을 유지하는 데 성공했습니다.
Llama 2 70B와 Mistral 7B 같은 최신 고성능 LLM들이 이 기술을 채택했으며, 실험 결과 GQA를 사용한 모델이 MHA 기반 모델보다 특히 높은 부하 상황에서 월등한 처리 속도를 보여주었습니다. 결과적으로 GQA는 LLM의 현실적인 배포와 확장성을 위한 필수적인 최적화 기술로 자리매김하고 있습니다.
2-1. 연구의 배경: LLM 추론 병목 현상과 Attention 메커니즘
2-1-1. Multi-Head Attention (MHA)의 등장과 메모리 한계
Transformer 모델의 Attention 메커니즘은 문장의 각 단어(Query)가 다른 모든 단어(Keys)들과 얼마나 관련이 있는지를 계산하여, 그 가중치에 따라 정보(Values)를 종합하는 방식으로 작동합니다. 이는 마치 데이터베이스에서 쿼리를 날려 가장 관련성 높은 정보를 찾아오는 것과 유사합니다.
Multi-Head Attention (MHA)는 이러한 Attention 과정을 여러 개의 "헤드(Head)"를 통해 병렬로 수행하는 방식입니다. 각 헤드는 독립적인 Query, Key, Value 가중치를 가져 "나는 형이 소파 옮기는 것을 도왔다"와 같은 문장에서 '나-형'의 관계와 '나-소파를 옮기다'라는 관계를 동시에 파악하는 등, 텍스트의 다채롭고 복잡한 관계를 효과적으로 학습할 수 있습니다. 하지만 MHA는 각 헤드가 자신만의 Key와 Value를 가지기 때문에, 모델이 새로운 토큰을 생성할 때마다 이전의 모든 토큰에 해당하는 방대한 양의 Key-Value 캐시를 메모리에서 불러와야 하는 큰 단점이 있습니다. 이로 인해 메모리 대역폭에 엄청난 부담을 주며, 이는 LLM 추론 성능을 저하하는 주된 병목 지점이 됩니다.
2-1-2. Multi-Query Attention (MQA)의 시도와 품질 저하 문제

MHA의 메모리 병목 문제를 해결하기 위해 Multi-Query Attention (MQA)이 등장했습니다. MQA의 아이디어는 매우 단순 명료합니다. 여러 개의 Query 헤드는 그대로 유지하되, 모든 Query 헤드가 단 하나의 Key-Value 헤드를 공유하도록 하는 것입니다.
이 방식을 통해 Key-Value 캐시의 크기가 획기적으로 줄어들었고, 메모리 로딩량이 감소하면서 추론 속도가 크게 향상되었습니다. 하지만 이러한 극단적인 단순화는 모델의 표현력을 감소시켜 성능 저하를 유발하거나 학습 과정을 불안정하게 만드는 부작용을 낳았습니다. 속도를 얻는 대신 품질을 일부 희생해야 하는 트레이드오프가 발생한 것입니다.
2-2. Grouped-Query Attention (GQA): 개념과 작동 원리
2-2-1. GQA의 개념: MHA와 MQA의 영리한 절충안

Grouped-Query Attention (GQA)는 MHA의 높은 성능과 MQA의 빠른 속도 사이에서 최적의 균형점을 찾은 아키텍처입니다. GQA는 MHA처럼 모든 Query 헤드가 독립적인 K-V 헤드를 갖지도 않고, MQA처럼 모든 Query 헤드가 단 하나의 K-V 헤드를 공유하지도 않습니다. 대신, 전체 Query 헤드를 여러 그룹(G)으로 나누고, 각 그룹 내의 Query 헤드들이 하나의 Key-Value 헤드를 공유하는 방식을 사용합니다.
예를 들어, 8개의 Query 헤드가 있고 2개의 그룹을 사용한다면, 1~4번 Query 헤드가 첫 번째 K-V 헤드를 공유하고, 5~8번 Query 헤드가 두 번째 K-V 헤드를 공유하는 식입니다. 이처럼 GQA는 Key-Value 헤드의 수를 1개(MQA)와 전체 Query 헤드 수(MHA) 사이의 중간 값으로 설정합니다. 이러한 구조 덕분에 GQA는 MHA와 MQA를 모두 포함하는 일반화된 개념으로 볼 수 있습니다.
그룹의 수가 1이면 MQA와 동일하고, 그룹의 수가 전체 Query 헤드 수와 같으면 MHA와 동일해집니다.
2-2-2. GQA의 장점: 속도와 성능의 균형
GQA는 MHA와 MQA의 장점을 모두 취하는 효과적인 절충안입니다.
- 메모리 효율성 및 속도 향상: Key-Value 헤드의 수를 줄임으로써 메모리 사용량과 계산 복잡도를 모두 감소시킵니다. 이는 MQA와 유사하게 빠른 추론 속도로 이어집니다.
- 처리량 증가: Attention 캐시를 위한 메모리 공간이 줄어들기 때문에, 남는 공간을 활용해 더 큰 배치 사이즈(batch size)로 한 번에 더 많은 요청을 처리할 수 있어 전체적인 처리량(throughput)이 향상됩니다.
- 높은 성능 유지: MQA와 달리 여러 개의 Key-Value 헤드를 유지함으로써 모델의 표현력 손실을 최소화하고, MHA에 가까운 높은 품질을 달성합니다.
2-3. GQA의 성능 검증 및 실제 적용 사례
2-3-1. 실험 결과: MHA와 MQA 대비 GQA의 우수성
GQA의 효과는 다양한 실험을 통해 입증되었습니다.

- T5 모델 실험: 구글이 T5 모델을 기반으로 실험한 결과, GQA는 MQA와 비슷한 수준의 빠른 추론 속도(Time per sample)를 보이면서도, MHA와 거의 대등한 성능(Performance)을 기록했습니다. Figure 3는 GQA-XXL이 MQA-XXL처럼 빠르면서도 MHA-XXL만큼 성능이 좋다는 것을 명확히 보여줍니다.
- Llama 2 vs. Mistral 7B 비교: 동일한 7B 파라미터 크기를 가진 두 모델을 비교한 실험에서도 GQA의 우수성이 드러났습니다. MHA를 사용하는 Llama 2 7B와 GQA를 사용하는 Mistral 7B를 동일한 GPU에서 테스트한 결과, 요청량이 적을 때는 성능이 비슷했지만, 부하가 증가할수록 GQA를 사용한 Mistral이 훨씬 빠른 처리 속도를 보였습니다. 가장 부하가 높은 상황에서는 Mistral이 24배 더 빠른 성능을 기록했습니다.
2-3-2. 기존 MHA 모델을 GQA로 전환: Uptraining
GQA의 또 다른 강력한 장점은 완전히 새로운 모델을 처음부터 학습시킬 필요 없이, 기존에 MHA로 학습된 모델을 GQA 구조로 변환할 수 있다는 점입니다. 이를 '업트레이닝(Uptraining)'이라고 하며, 원본 모델 학습에 사용된 계산량의 약 5% 정도만으로도 기존 MHA 모델 체크포인트를 GQA 모델로 성공적으로 전환할 수 있습니다. 이는 막대한 시간과 자원을 절약하며 고품질의 빠른 추론 모델을 얻을 수 있는 매우 비용 효율적인 방법입니다.
3. Group Query Attention 코드 설명
Group Query Attention을 잘 설명한 코드(https://github.com/rasbt/LLMs-from-scratch/blob/main/ch05/07_gpt_to_llama/converting-llama2-to-llama3.ipynb)가 있어, 정리할 겸 소개합니다.
# GQA 모델의 레이어를 초기화하는 __init__ 메서드
def __init__(
self, d_in, d_out, num_heads,
num_kv_groups, # [GQA 핵심] 키-값 헤드 그룹의 수를 지정하는 새로운 파라미터
dtype=None
):
super().__init__()
# num_heads는 num_kv_groups로 나누어떨어져야 함
assert num_heads % num_kv_groups == 0, "num_heads must be divisible by num_kv_groups"
self.d_out = d_out
self.num_heads = num_heads
self.head_dim = d_out // num_heads
# [GQA 핵심] MHA와 달리, 키(W_key)와 값(W_value)의 출력 차원을 줄여 파라미터 수를 감소시킴
# 전체 헤드 차원(d_out)이 아닌, (키-값 그룹 수 * 헤드 차원) 만큼만 가중치를 생성
self.W_key = nn.Linear(d_in, num_kv_groups * self.head_dim, bias=False, dtype=dtype)
self.W_value = nn.Linear(d_in, num_kv_groups * self.head_dim, bias=False, dtype=dtype)
# 쿼리(W_query)는 MHA와 동일하게 d_out 차원을 유지
self.W_query = nn.Linear(d_in, d_out, bias=False, dtype=dtype)
# 최종 출력을 위한 프로젝션 레이어
self.out_proj = nn.Linear(d_out, d_out, bias=False, dtype=dtype)
self.num_kv_groups = num_kv_groups
# 하나의 키-값 그룹을 몇 개의 쿼리 헤드가 공유할 것인지를 계산
self.group_size = num_heads // num_kv_groups
위 코드는 GQA를 구성하는 __init__ 함수입니다. 주목할 부분은 W_key와 W_value인데요. 기존 멀티 헤드 어텐션(Multi-Head Attention)은 동일한 크기의 쿼리(Query), 키(Key), 값(Value)를 가졌습니다. 하지만, GQA에서는 num_kv_groups라는 인자를 도입해서 key와 value의 출력 차원을 num_kv_groups * head_dim으로 줄입니다. 바로 여기서 모델의 총 파라미터 수가 크게 감소하는 효과가 나타나게 됩니다.
# GQA의 순방향 계산을 수행하는 forward 메서드
def forward(self, x, mask=None, cos=None, sin=None):
b, num_tokens, d_in = x.shape
# 1. 쿼리, 키, 값 프로젝션
# W_query, W_key, W_value 가중치를 곱해 쿼리, 키, 값 텐서 생성
queries = self.W_query(x) # Shape: (b, num_tokens, d_out)
keys = self.W_key(x) # Shape: (b, num_tokens, num_kv_groups * head_dim)
values = self.W_value(x) # Shape: (b, num_tokens, num_kv_groups * head_dim)
# 2. 헤드/그룹별로 텐서 분리 (Reshape)
# 쿼리는 num_heads 기준으로, 키/값은 num_kv_groups 기준으로 차원을 변경
queries = queries.view(b, num_tokens, self.num_heads, self.head_dim)
keys = keys.view(b, num_tokens, self.num_kv_groups, self.head_dim)
values = values.view(b, num_tokens, self.num_kv_groups, self.head_dim)
# 계산을 위해 차원 축 순서 변경 (Transpose)
queries = queries.transpose(1, 2) # Shape: (b, num_heads, num_tokens, head_dim)
keys = keys.transpose(1, 2) # Shape: (b, num_kv_groups, num_tokens, head_dim)
values = values.transpose(1, 2) # Shape: (b, num_kv_groups, num_tokens, head_dim)
# (RoPE 적용 등 추가 연산) ...
if cos is not None:
keys = compute_rope(keys, cos, sin)
queries = compute_rope(queries, cos, sin)
# 3. [GQA 핵심] 키와 값 확장
# num_kv_groups 개수만큼 있는 키/값 헤드를 num_heads 개수에 맞게 복제
# 이를 통해 모든 쿼리 헤드가 자신과 짝을 이룰 키/값 헤드를 가질 수 있게 됨
# 예: group_size=2, [K1, K2] -> [K1, K1, K2, K2]
keys = keys.repeat_interleave(self.group_size, dim=1)
values = values.repeat_interleave(self.group_size, dim=1)
# 4. Scaled Dot-Product Attention
# 이제 쿼리와 키/값의 헤드 수가 동일해졌으므로, 표준 어텐션 계산 수행
attn_scores = queries @ keys.transpose(2, 3) # (b, num_heads, num_tokens, num_tokens)
# 마스킹 및 Softmax
if mask is None:
mask = torch.triu(torch.ones(num_tokens, num_tokens, device=x.device, dtype=torch.bool), diagonal=1)
attn_scores.masked_fill_(mask, -torch.inf)
attn_weights = torch.softmax(attn_scores / keys.shape[-1]**0.5, dim=-1)
# 최종 컨텍스트 벡터 계산
context_vec = (attn_weights @ values).transpose(1, 2) # (b, num_tokens, num_heads, head_dim)
# 5. 최종 출력
# 모든 헤드의 결과를 하나로 합치고(reshape) 최종 프로젝션 레이어를 통과
context_vec = context_vec.reshape(b, num_tokens, self.d_out)
context_vec = self.out_proj(context_vec)
return context_vec
실제 어텐션에 대한 수행은 forward 함수에 나와있습니다. 여기서 흐름은 아래와 같습니다.
- 프로젝션 및 reshape: 입력값 x로부터 query, key, value를 생성합니다. 여기서 key와 value는 __init__에서 정의한 대로 더 작은 값을 가지게 됩니다. 이후 view, transpose를 통해 각 텐서를 헤드 별로 연산하기 좋은 형태로 만들어 줍니다.
- Key, value 확장: GQA의 가장 특이한 부분입니다. 현재 쿼리 헤드의 개수가 key-value 헤드의 개수보다 많기 때문에 어텐션 계산을 바로 할 수 없는데요. 이를 위해서 repeat_interleave 함수를 사용합니다. 이 함수는 group_size만큼 각 key-value 그룹을 복제하여 쿼리 헤드의 수와 동일하게 맞춰줍니다.
- 어텐션 계산: key, value의 헤드 수가 query와 동일하게 확장되었으므로, 이제 Multi-Head Attention과 동일한 방식으로 Scaled dot product attention을 수행합니다.
- 최종 출력: 각 헤드별로 계산된 벡터를 하나로 다시 합치고, out_proj 레이어를 통과시켜 최종 결과를 반환합니다.
마무리
본 포스팅에서는 LLM 추론의 핵심적인 병목 현상을 해결하는 Grouped-Query Attention(GQA)에 대해 알아보았습니다. GQA는 MHA의 성능과 MQA의 속도라는 두 가지 장점을 효과적으로 결합하여, 오늘날 LLM을 현실 세계에 배포하고 확장하는 데 필수적인 기술로 자리 잡았습니다. 이 기술에 대한 더 깊이 있는 내용이 궁금하신 분들은 원본 논문을 직접 읽어보시길 추천합니다.
감사합니다.