Gihak111 Navbar

FlashMLA-main

딥시크가 무료로 공개한 알고리즘 중 하나이다.
코드를 분해하여 이해하는 시간을 가져보자.

딥시크가 무료로 공개한 FlashMLA는 Hopper GPU를 위한 효율적인 MLA 디코딩 커널로, 가변 길이 시퀀스를 최적화하여 처리할 수 있다.

현재 공개된 버전에서는 다음 기능을 지원한다.

고성능 설정에서 메모리 대역폭 3000GB/s, 연산 성능 580 TFLOPS를 달성할 수 있다.


설치

다음 명령어로 설치할 수 있다.

python setup.py install

이후, 성능 테스트를 위해 다음을 실행하면 된다.

python tests/test_flash_mla.py

💡 CUDA 12.3 이상에서 동작하지만, 12.8 이상을 권장한다.
💡 VRAM은 프레임당 최소 3GB가 필요하다.


사용법

아래 예제는 FlashMLA를 활용한 kvcache 기반 연산을 수행하는 코드이다.

from flash_mla import get_mla_metadata, flash_mla_with_kvcache

tile_scheduler_metadata, num_splits = get_mla_metadata(cache_seqlens, s_q * h_q // h_kv, h_kv)

for i in range(num_layers):
    ...
    o_i, lse_i = flash_mla_with_kvcache(
        q_i, kvcache_i, block_table, cache_seqlens, dv,
        tile_scheduler_metadata, num_splits, causal=True,
    )
    ...

scaled_dot_product_attention 함수

FlashMLA의 핵심 연산 중 하나인 Scaled Dot-Product Attention을 구현한 예제이다.

import torch

def scaled_dot_product_attention(query, key, value, h_q, h_kv, is_causal=False):
    """
    Scaled Dot-Product Attention 구현.

    query: 쿼리 행렬 (Batch, Head, Seq_len_q, Dim)
    key: 키 행렬 (Batch, Head, Seq_len_k, Dim)
    value: 값 행렬 (Batch, Head, Seq_len_v, Dim)
    h_q: 쿼리 헤드 개수
    h_kv: 키-값 헤드 개수
    is_causal: 미래 토큰 마스킹 여부 (기본값: False)
    """
    d_k = query.shape[-1]
    scores = torch.matmul(query, key.transpose(-2, -1)) / d_k**0.5

    if is_causal:
        seq_len = scores.shape[-1]
        mask = torch.triu(torch.ones(seq_len, seq_len), diagonal=1).to(query.device)
        scores = scores.masked_fill(mask == 1, float('-inf'))

    attn = torch.nn.functional.softmax(scores, dim=-1)
    return torch.matmul(attn, value)

이 함수는 FlashMLA 내부의 flash_mla_with_kvcache 연산과 유사한 방식으로 작동한다.


코드 분석

FlashMLA의 연산 로직은 brench_flash_mla에 구현되어 있다.
주요 연산 흐름을 이해하려면 get_mla_metadata, flash_mla_with_kvcache 함수 분석이 필요하다.


관련 프로젝트

FlashMLA는 FlashAttention 2&3NVIDIA Cutlass 프로젝트의 영향을 받았다.

플랫폼 공식 사이트 관련 FlashMLA 버전
MetaX MetaX MetaX-MACA/FlashMLA
Moore Threads Moore Threads MooreThreads/MT-flashMLA
Hygon DCU Hygon Developer OpenDAS/MLAttention
Intellifusion Intellifusion Intellifusion/tyllm
Iluvatar Corex Iluvatar Corex Deep-Spark/FlashMLA

인용

FlashMLA를 연구 또는 논문에서 활용할 경우 아래 BibTeX을 사용할 수 있다.

@misc{flashmla2025,
      title={FlashMLA: Efficient MLA decoding kernels},
      author={Jiashi Li},
      year={2025},
      publisher = {GitHub},
      howpublished = {\url{https://github.com/deepseek-ai/FlashMLA}},
}

주요함수

1. get_mla_metadata

FlashMLA의 타일 스케줄링 정보를 생성하는 함수이다.
쉽게 말해서, MLA 연산을 최적화하기 위해 입력 데이터를 어떻게 쪼갤지(Tiling) 결정하는 역할이다.

코드 흐름:

tile_scheduler_metadata, num_splits = get_mla_metadata(cache_seqlens, s_q * h_q // h_kv, h_kv)

FlashMLA는 연산을 한 번에 다 하는 게 아니라, 메모리 효율을 극대화하기 위해 타일 단위로 나눠서 실행한다.
이걸 결정하는 게 get_mla_metadata 함수고, 여기서 num_splits가 얼마나 적절하게 설정되는지가 속도와 성능에 직접적인 영향을 준다.


2. flash_mla_with_kvcache

FlashMLA의 핵심 연산이 돌아가는 함수이다.
쿼리(query), 키-값 캐시(kvcache), 그리고 블록 테이블(block_table)을 받아서 최적화된 MLA 연산을 수행한다.

코드 흐름:

o_i, lse_i = flash_mla_with_kvcache(
    q_i, kvcache_i, block_table, cache_seqlens, dv,
    tile_scheduler_metadata, num_splits, causal=True,
)

FlashMLA의 가장 핵심적인 로직이 여기 들어있어.

3. scaled_dot_product_attention (기본 Attention 비교용)

이 함수는 FlashMLA랑 직접적으로 관련은 없다.
하지만, 기존 Attention이 어떻게 작동하는지 알고 있어야 FlashMLA가 얼마나 최적화됐는지 비교할 수 있다.

코드 흐름:

def scaled_dot_product_attention(query, key, value, h_q, h_kv, is_causal=False):
    d_k = query.shape[-1]
    scores = torch.matmul(query, key.transpose(-2, -1)) / d_k**0.5

    if is_causal:
        seq_len = scores.shape[-1]
        mask = torch.triu(torch.ones(seq_len, seq_len), diagonal=1).to(query.device)
        scores = scores.masked_fill(mask == 1, float('-inf'))

    attn = torch.nn.functional.softmax(scores, dim=-1)
    return torch.matmul(attn, value)

FlashMLA는 이 과정을 타일링과 Paged kvcache 방식으로 최적화해서, 메모리 절약 + 빠른 연산을 동시에 가능하게 만든다.

함수 역할 중요한 이유유
get_mla_metadata 타일 스케줄링 & 병렬 연산 설정 MLA 연산을 최적화하는 핵심 요소
flash_mla_with_kvcache FlashMLA의 핵심 연산 수행 메모리 효율성과 속도를 극대화하는 핵심 로직
scaled_dot_product_attention 기존 Attention과 비교 FlashMLA 최적화 효과를 이해하는 데 필수

결국, flash_mla_with_kvcacheget_mla_metadata 이게 제일 중요하다.

마무리

다음에는, 위 저 두 함수에 대해서 알아보자.