딥시크가 무료로 공개한 알고리즘 중 하나이다.
코드를 분해하여 이해하는 시간을 가져보자.
딥시크가 무료로 공개한 FlashMLA는 Hopper GPU를 위한 효율적인 MLA 디코딩 커널로, 가변 길이 시퀀스를 최적화하여 처리할 수 있다.
현재 공개된 버전에서는 다음 기능을 지원한다.
고성능 설정에서 메모리 대역폭 3000GB/s, 연산 성능 580 TFLOPS를 달성할 수 있다.
다음 명령어로 설치할 수 있다.
python setup.py install
이후, 성능 테스트를 위해 다음을 실행하면 된다.
python tests/test_flash_mla.py
💡 CUDA 12.3 이상에서 동작하지만, 12.8 이상을 권장한다.
💡 VRAM은 프레임당 최소 3GB가 필요하다.
아래 예제는 FlashMLA를 활용한 kvcache 기반 연산을 수행하는 코드이다.
from flash_mla import get_mla_metadata, flash_mla_with_kvcache
tile_scheduler_metadata, num_splits = get_mla_metadata(cache_seqlens, s_q * h_q // h_kv, h_kv)
for i in range(num_layers):
...
o_i, lse_i = flash_mla_with_kvcache(
q_i, kvcache_i, block_table, cache_seqlens, dv,
tile_scheduler_metadata, num_splits, causal=True,
)
...
scaled_dot_product_attention
함수FlashMLA의 핵심 연산 중 하나인 Scaled Dot-Product Attention을 구현한 예제이다.
import torch
def scaled_dot_product_attention(query, key, value, h_q, h_kv, is_causal=False):
"""
Scaled Dot-Product Attention 구현.
query: 쿼리 행렬 (Batch, Head, Seq_len_q, Dim)
key: 키 행렬 (Batch, Head, Seq_len_k, Dim)
value: 값 행렬 (Batch, Head, Seq_len_v, Dim)
h_q: 쿼리 헤드 개수
h_kv: 키-값 헤드 개수
is_causal: 미래 토큰 마스킹 여부 (기본값: False)
"""
d_k = query.shape[-1]
scores = torch.matmul(query, key.transpose(-2, -1)) / d_k**0.5
if is_causal:
seq_len = scores.shape[-1]
mask = torch.triu(torch.ones(seq_len, seq_len), diagonal=1).to(query.device)
scores = scores.masked_fill(mask == 1, float('-inf'))
attn = torch.nn.functional.softmax(scores, dim=-1)
return torch.matmul(attn, value)
이 함수는 FlashMLA 내부의 flash_mla_with_kvcache 연산과 유사한 방식으로 작동한다.
FlashMLA의 연산 로직은 brench_flash_mla
에 구현되어 있다.
주요 연산 흐름을 이해하려면 get_mla_metadata
, flash_mla_with_kvcache
함수 분석이 필요하다.
FlashMLA는 FlashAttention 2&3 및 NVIDIA Cutlass 프로젝트의 영향을 받았다.
플랫폼 | 공식 사이트 | 관련 FlashMLA 버전 |
---|---|---|
MetaX | MetaX | MetaX-MACA/FlashMLA |
Moore Threads | Moore Threads | MooreThreads/MT-flashMLA |
Hygon DCU | Hygon Developer | OpenDAS/MLAttention |
Intellifusion | Intellifusion | Intellifusion/tyllm |
Iluvatar Corex | Iluvatar Corex | Deep-Spark/FlashMLA |
FlashMLA를 연구 또는 논문에서 활용할 경우 아래 BibTeX을 사용할 수 있다.
@misc{flashmla2025,
title={FlashMLA: Efficient MLA decoding kernels},
author={Jiashi Li},
year={2025},
publisher = {GitHub},
howpublished = {\url{https://github.com/deepseek-ai/FlashMLA}},
}
get_mla_metadata
FlashMLA의 타일 스케줄링 정보를 생성하는 함수이다.
쉽게 말해서, MLA 연산을 최적화하기 위해 입력 데이터를 어떻게 쪼갤지(Tiling) 결정하는 역할이다.
cache_seqlens
)와 쿼리-키 관계를 계산해서 타일 크기랑 병렬 연산 개수(num_splits
)를 정해줌.코드 흐름:
tile_scheduler_metadata, num_splits = get_mla_metadata(cache_seqlens, s_q * h_q // h_kv, h_kv)
cache_seqlens
: 각 시퀀스의 길이 정보s_q * h_q // h_kv
: 쿼리 길이와 키-값 길이 조정h_kv
: 키-값 헤드 개수FlashMLA는 연산을 한 번에 다 하는 게 아니라, 메모리 효율을 극대화하기 위해 타일 단위로 나눠서 실행한다.
이걸 결정하는 게 get_mla_metadata
함수고, 여기서 num_splits
가 얼마나 적절하게 설정되는지가 속도와 성능에 직접적인 영향을 준다.
flash_mla_with_kvcache
FlashMLA의 핵심 연산이 돌아가는 함수이다.
쿼리(query
), 키-값 캐시(kvcache
), 그리고 블록 테이블(block_table
)을 받아서 최적화된 MLA 연산을 수행한다.
causal=True
설정 시, Auto-Regressive Transformer(GPT 같은 모델) 지원코드 흐름:
o_i, lse_i = flash_mla_with_kvcache(
q_i, kvcache_i, block_table, cache_seqlens, dv,
tile_scheduler_metadata, num_splits, causal=True,
)
q_i
: 쿼리 텐서kvcache_i
: Key-Value 캐시block_table
: 블록 구조 관리 테이블cache_seqlens
: 각 시퀀스 길이dv
: Value의 변화량tile_scheduler_metadata, num_splits
: 위에서 get_mla_metadata
로 얻은 타일 정보causal=True
: Decoder-Only Transformer(GPT 같은 모델)에서 필수 설정FlashMLA의 가장 핵심적인 로직이 여기 들어있어.
scaled_dot_product_attention
(기본 Attention 비교용)이 함수는 FlashMLA랑 직접적으로 관련은 없다.
하지만, 기존 Attention이 어떻게 작동하는지 알고 있어야 FlashMLA가 얼마나 최적화됐는지 비교할 수 있다.
코드 흐름:
def scaled_dot_product_attention(query, key, value, h_q, h_kv, is_causal=False):
d_k = query.shape[-1]
scores = torch.matmul(query, key.transpose(-2, -1)) / d_k**0.5
if is_causal:
seq_len = scores.shape[-1]
mask = torch.triu(torch.ones(seq_len, seq_len), diagonal=1).to(query.device)
scores = scores.masked_fill(mask == 1, float('-inf'))
attn = torch.nn.functional.softmax(scores, dim=-1)
return torch.matmul(attn, value)
query * key^T
→ softmax
→ value
연산을 거치는 전통적인 방식이다.FlashMLA는 이 과정을 타일링과 Paged kvcache 방식으로 최적화해서, 메모리 절약 + 빠른 연산을 동시에 가능하게 만든다.
함수 | 역할 | 중요한 이유유 |
---|---|---|
get_mla_metadata |
타일 스케줄링 & 병렬 연산 설정 | MLA 연산을 최적화하는 핵심 요소 |
flash_mla_with_kvcache |
FlashMLA의 핵심 연산 수행 | 메모리 효율성과 속도를 극대화하는 핵심 로직 |
scaled_dot_product_attention |
기존 Attention과 비교 | FlashMLA 최적화 효과를 이해하는 데 필수 |
결국, flash_mla_with_kvcache
랑 get_mla_metadata
이게 제일 중요하다.
다음에는, 위 저 두 함수에 대해서 알아보자.