오늘은 주요 로직 중, flash_mla_with_kvcache 함수에 대해서 알아보자.
메모리 효율성과 속도를 극대화하는 핵심 로직으로 중요한 로직이다.
코드가 짧으므로, 코드 전문을 먼저 보자.
from typing import Optional, Tuple
import torch
import flash_mla_cuda
def get_mla_metadata(
cache_seqlens: torch.Tensor,
num_heads_per_head_k: int,
num_heads_k: int,
) -> Tuple[torch.Tensor, torch.Tensor]:
"""
Arguments:
cache_seqlens: (batch_size), dtype torch.int32.
num_heads_per_head_k: Equals to seq_len_q * num_heads_q // num_heads_k.
num_heads_k: num_heads_k.
Returns:
tile_scheduler_metadata: (num_sm_parts, TileSchedulerMetaDataSize), dtype torch.int32.
num_splits: (batch_size + 1), dtype torch.int32.
"""
return flash_mla_cuda.get_mla_metadata(cache_seqlens, num_heads_per_head_k, num_heads_k)
def flash_mla_with_kvcache(
q: torch.Tensor,
k_cache: torch.Tensor,
block_table: torch.Tensor,
cache_seqlens: torch.Tensor,
head_dim_v: int,
tile_scheduler_metadata: torch.Tensor,
num_splits: torch.Tensor,
softmax_scale: Optional[float] = None,
causal: bool = False,
) -> Tuple[torch.Tensor, torch.Tensor]:
"""
Arguments:
q: (batch_size, seq_len_q, num_heads_q, head_dim).
k_cache: (num_blocks, page_block_size, num_heads_k, head_dim).
block_table: (batch_size, max_num_blocks_per_seq), torch.int32.
cache_seqlens: (batch_size), torch.int32.
head_dim_v: Head dimension of v.
tile_scheduler_metadata: (num_sm_parts, TileSchedulerMetaDataSize), torch.int32, returned by get_mla_metadata.
num_splits: (batch_size + 1), torch.int32, returned by get_mla_metadata.
softmax_scale: float. The scale of QK^T before applying softmax. Default to 1 / sqrt(head_dim).
causal: bool. Whether to apply causal attention mask.
Returns:
out: (batch_size, seq_len_q, num_heads_q, head_dim_v).
softmax_lse: (batch_size, num_heads_q, seq_len_q), torch.float32.
"""
if softmax_scale is None:
softmax_scale = q.shape[-1] ** (-0.5)
out, softmax_lse = flash_mla_cuda.fwd_kvcache_mla(
q,
k_cache,
None,
head_dim_v,
cache_seqlens,
block_table,
softmax_scale,
causal,
tile_scheduler_metadata,
num_splits,
)
return out, softmax_lse
위 코드는 딥시크가 공개한 코드의 전문이다.
단계별로 로직을 분석해 보자.
get_mla_metadata
함수 분석def get_mla_metadata(
cache_seqlens: torch.Tensor,
num_heads_per_head_k: int,
num_heads_k: int,
) -> Tuple[torch.Tensor, torch.Tensor]:
"""
Arguments:
cache_seqlens: (batch_size), dtype torch.int32.
num_heads_per_head_k: Equals to seq_len_q * num_heads_q // num_heads_k.
num_heads_k: num_heads_k.
Returns:
tile_scheduler_metadata: (num_sm_parts, TileSchedulerMetaDataSize), dtype torch.int32.
num_splits: (batch_size + 1), dtype torch.int32.
"""
return flash_mla_cuda.get_mla_metadata(cache_seqlens, num_heads_per_head_k, num_heads_k)
def get_mla_metadata(
cache_seqlens: torch.Tensor,
num_heads_per_head_k: int,
num_heads_k: int,
) -> Tuple[torch.Tensor, torch.Tensor]:
cache_seqlens
: 배치 크기(batch_size
)에 해당하는 1차원 텐서로, 각 샘플의 캐시된 시퀀스 길이를 나타낸다. 데이터 타입은 torch.int32
.num_heads_per_head_k
: 쿼리 헤드 수(num_heads_q
)와 시퀀스 길이(seq_len_q
)를 키 헤드 수(num_heads_k
)로 나눈 값. 즉, 키 헤드당 처리해야 할 쿼리 헤드의 개수를 의미한다.num_heads_k
: 키(Key) 벡터의 헤드 수.tile_scheduler_metadata
: 타일 스케줄링을 위한 메타데이터.num_splits
: 각 배치에 대한 분할 수."""
Arguments:
cache_seqlens: (batch_size), dtype torch.int32.
num_heads_per_head_k: Equals to seq_len_q * num_heads_q // num_heads_k.
num_heads_k: num_heads_k.
Returns:
tile_scheduler_metadata: (num_sm_parts, TileSchedulerMetaDataSize), dtype torch.int32.
num_splits: (batch_size + 1), dtype torch.int32.
"""
tile_scheduler_metadata
: GPU의 SM(Streaming Multiprocessor) 단위로 작업을 나누기 위한 메타데이터로, (num_sm_parts, TileSchedulerMetaDataSize)
크기를 가진다.num_splits
: 배치별로 작업을 몇 개의 조각으로 나눌지 나타내는 텐서이다. 크기는 (batch_size + 1)
.return flash_mla_cuda.get_mla_metadata(cache_seqlens, num_heads_per_head_k, num_heads_k)
flash_mla_cuda.get_mla_metadata
에서 수행된다.flash_mla_with_kvcache
함수 분석def flash_mla_with_kvcache(
q: torch.Tensor,
k_cache: torch.Tensor,
block_table: torch.Tensor,
cache_seqlens: torch.Tensor,
head_dim_v: int,
tile_scheduler_metadata: torch.Tensor,
num_splits: torch.Tensor,
softmax_scale: Optional[float] = None,
causal: bool = False,
) -> Tuple[torch.Tensor, torch.Tensor]:
"""
Arguments:
q: (batch_size, seq_len_q, num_heads_q, head_dim).
k_cache: (num_blocks, page_block_size, num_heads_k, head_dim).
block_table: (batch_size, max_num_blocks_per_seq), torch.int32.
cache_seqlens: (batch_size), torch.int32.
head_dim_v: Head dimension of v.
tile_scheduler_metadata: (num_sm_parts, TileSchedulerMetaDataSize), torch.int32, returned by get_mla_metadata.
num_splits: (batch_size + 1), torch.int32, returned by get_mla_metadata.
softmax_scale: float. The scale of QK^T before applying softmax. Default to 1 / sqrt(head_dim).
causal: bool. Whether to apply causal attention mask.
Returns:
out: (batch_size, seq_len_q, num_heads_q, head_dim_v).
softmax_lse: (batch_size, num_heads_q, seq_len_q), torch.float32.
"""
if softmax_scale is None:
softmax_scale = q.shape[-1] ** (-0.5)
out, softmax_lse = flash_mla_cuda.fwd_kvcache_mla(
q,
k_cache,
None,
head_dim_v,
cache_seqlens,
block_table,
softmax_scale,
causal,
tile_scheduler_metadata,
num_splits,
)
return out, softmax_lse
def flash_mla_with_kvcache(
q: torch.Tensor,
k_cache: torch.Tensor,
block_table: torch.Tensor,
cache_seqlens: torch.Tensor,
head_dim_v: int,
tile_scheduler_metadata: torch.Tensor,
num_splits: torch.Tensor,
softmax_scale: Optional[float] = None,
causal: bool = False,
) -> Tuple[torch.Tensor, torch.Tensor]:
q
: 쿼리 텐서, 크기는 (batch_size, seq_len_q, num_heads_q, head_dim)
.k_cache
: 캐시된 키 텐서, 크기는 (num_blocks, page_block_size, num_heads_k, head_dim)
.block_table
: 캐시 블록의 인덱스를 나타내는 테이블, 크기는 (batch_size, max_num_blocks_per_seq)
.cache_seqlens
: 각 배치의 캐시된 시퀀스 길이, 크기는 (batch_size)
.head_dim_v
: 값(Value) 벡터의 헤드 차원.tile_scheduler_metadata
와 num_splits
: get_mla_metadata
에서 생성된 메타데이터.softmax_scale
: 소프트맥스前の 스케일링 값(기본값은 1/sqrt(head_dim)
).causal
: 인과적 어텐션 마스크 적용 여부.out
: 어텐션 출력, 크기는 (batch_size, seq_len_q, num_heads_q, head_dim_v)
.softmax_lse
: 소프트맥스 로그 합계, 크기는 (batch_size, num_heads_q, seq_len_q)
."""
Arguments:
q: (batch_size, seq_len_q, num_heads_q, head_dim).
k_cache: (num_blocks, page_block_size, num_heads_k, head_dim).
block_table: (batch_size, max_num_blocks_per_seq), torch.int32.
cache_seqlens: (batch_size), torch.int32.
head_dim_v: Head dimension of v.
tile_scheduler_metadata: (num_sm_parts, TileSchedulerMetaDataSize), torch.int32, returned by get_mla_metadata.
num_splits: (batch_size + 1), torch.int32, returned by get_mla_metadata.
softmax_scale: float. The scale of QK^T before applying softmax. Default to 1 / sqrt(head_dim).
causal: bool. Whether to apply causal attention mask.
Returns:
out: (batch_size, seq_len_q, num_heads_q, head_dim_v).
softmax_lse: (batch_size, num_heads_q, seq_len_q), torch.float32.
"""
k_cache
와 block_table
을 사용해 키-값 캐시를 효율적으로 관리하며, softmax_lse
를 반환해 후속 계산(예: 로그 확률)에서 재사용 가능.if softmax_scale is None:
softmax_scale = q.shape[-1] ** (-0.5)
softmax_scale
이 명시되지 않은 경우, 헤드 차원의 제곱근 역수를 기본값으로 설정한다.1/sqrt(d_k)
)으로, 쿼리와 키의 내적이 너무 커지는 것을 방지해 소프트맥스 출력의 분포를 안정화 한다.out, softmax_lse = flash_mla_cuda.fwd_kvcache_mla(
q,
k_cache,
None,
head_dim_v,
cache_seqlens,
block_table,
softmax_scale,
causal,
tile_scheduler_metadata,
num_splits,
)
None
이 전달되는데, 이는 v_cache
(값 캐시)가 없음을 의미할 수 있다. 코드가 k_cache
만 사용하므로 값 벡터는 동적으로 계산되거나 별도로 처리되는 것으로 보인다.q
)와 캐시된 키(k_cache
)를 사용해 어텐션 스코어를 계산.tile_scheduler_metadata
와 num_splits
를 활용해 작업을 타일 단위로 분할 및 병렬 처리.causal=True
일 경우, 인과적 마스크를 적용해 미래 토큰을 참조하지 않음.out
)과 로그 합계(softmax_lse
)를 계산.return out, softmax_lse
k_cache
)와 블록 테이블(block_table
)을 사용해 메모리 접근을 최적화.softmax_lse
를 반환해 후속 계산에서 재사용 가능.k_cache
와 block_table
을 활용해 키 데이터를 캐싱하고 재사용함으로써 메모리 사용량을 줄임.cache_seqlens
)를 처리해 불필요한 메모리 할당을 방지.tile_scheduler_metadata
와 num_splits
를 통해 GPU의 SM 단위로 작업을 분할, 병렬 처리 효율을 극대화.causal
옵션으로 인과적/비인과적 어텐션을 지원.softmax_scale
을 커스터마이징 가능해 다양한 어텐션 메커니즘에 적응 가능.여기까지 왔으니까, 내부동작, 타일 스케줄링의 세부 사항도 알아보자.
flash_mla_cuda
의 내부 동작이 헤더 파일은 두 개의 주요 기능을 정의합니다:
get_mla_metadata_func
: MLA 연산에 필요한 메타데이터를 생성.run_mha_fwd_splitkv_mla
: MLA의 순방향 연산을 수행.Flash_fwd_mla_params
struct Flash_fwd_mla_params {
// 주요 파라미터
int b, seqlen_q, d, d_v; // 배치 크기, 쿼리 시퀀스 길이, 헤드 차원(QK), 헤드 차원(V)
int h, h_h_k_ratio, ngroups; // 헤드 수(Q), 헤드 비율(h_q/h_k), 그룹 수
bool is_causal; // 인과적 어텐션 여부
float scale_softmax, scale_softmax_log2; // 소프트맥스 스케일링 값(일반 및 log2 기반)
int *__restrict__ cu_seqlens_k; // 캐시된 키의 시퀀스 길이
// 데이터 포인터
void *__restrict__ q_ptr, k_ptr, v_ptr, o_ptr, softmax_lse_ptr; // Q, K, V, 출력, 소프트맥스 로그 합계
// 스트라이드 (배치, 행, 헤드 단위 메모리 오프셋)
index_t q_batch_stride, k_batch_stride, v_batch_stride, o_batch_stride;
index_t q_row_stride, k_row_stride, v_row_stride, o_row_stride;
index_t q_head_stride, k_head_stride, v_head_stride, o_head_stride;
// 블록 테이블 (캐시 관리)
int *__restrict__ block_table;
index_t block_table_batch_stride;
int page_block_size;
// 타일 스케줄링 메타데이터
int *__restrict__ tile_scheduler_metadata_ptr;
int num_sm_parts;
int *__restrict__ num_splits_ptr;
// 누적 버퍼
void *__restrict__ softmax_lseaccum_ptr;
void *__restrict__ oaccum_ptr;
};
h_h_k_ratio
: num_heads_q / num_heads_k
로, 쿼리 헤드와 키 헤드 간의 비율을 나타냄 (PyTorch 코드의 num_heads_per_head_k
와 동일).cu_seqlens_k
: 키 캐시의 시퀀스 길이를 GPU 메모리에 저장.block_table
: 페이지 기반 KV 캐시를 관리하며, page_block_size
단위로 블록을 구성.softmax_lseaccum_ptr
와 oaccum_ptr
는 타일별 중간 결과를 누적하기 위한 임시 버퍼로 보임.Mla_metadata_params
struct Mla_metadata_params {
int *__restrict__ seqlens_k_ptr; // 키 시퀀스 길이
int *__restrict__ tile_scheduler_metadata_ptr; // 타일 스케줄링 메타데이터
int *__restrict__ num_splits_ptr; // 분할 수
int batch_size, block_size_n, fixed_overhead_num_blocks, num_sm_parts; // 배치 크기, 블록 크기, 오버헤드 블록 수, SM 파트 수
};
get_mla_metadata_func
에 전달되는 파라미터로, 타일 스케줄링과 작업 분할을 위한 입력을 정의.block_size_n
: 키 캐시의 블록 크기(아마 page_block_size
와 동일).fixed_overhead_num_blocks
: 고정된 오버헤드 블록 수로, 메타데이터 계산 시 고려되는 상수.get_mla_metadata_func
void get_mla_metadata_func(Mla_metadata_params ¶ms, cudaStream_t stream);
Mla_metadata_params
와 CUDA 스트림.tile_scheduler_metadata_ptr
와 num_splits_ptr
를 채움.seqlens_k_ptr
를 기반으로 각 배치의 작업량 계산.batch_size
, block_size_n
, num_sm_parts
를 고려해 작업을 타일로 나눔.tile_scheduler_metadata_ptr
에 타일별 정보(시작/끝 인덱스 등)를 기록.num_splits_ptr
에 배치별 분할 수를 저장.run_mha_fwd_splitkv_mla
template<typename T, int Headdim>
void run_mha_fwd_splitkv_mla(Flash_fwd_mla_params ¶ms, cudaStream_t stream);
Flash_fwd_mla_params
와 CUDA 스트림.T
: 데이터 타입(예: float
, half
).Headdim
: 헤드 차원(컴파일 타임 상수로 최적화).tile_scheduler_metadata_ptr
와 num_splits_ptr
를 기반으로 작업을 타일로 분할.q_ptr
와 k_ptr
를 사용해 어텐션 스코어 계산.scale_softmax
적용 후 소프트맥스 계산, is_causal
에 따라 마스크 적용.v_ptr
를 사용해 최종 출력(o_ptr
) 계산.softmax_lse_ptr
에 결과 기록.softmax_lseaccum_ptr
와 oaccum_ptr
를 사용해 타일 간 결과 집계.TileSchedulerMetaDataSize
정의static constexpr int TileSchedulerMetaDataSize = 8;
// [begin_idx, begin_seqlen, end_idx, end_seqlen, begin_n_split_idx, _, _, _]
int
로 구성.begin_idx
: 타일이 시작하는 배치 또는 블록 인덱스.begin_seqlen
: 타일의 시작 시퀀스 위치.end_idx
: 타일이 끝나는 배치 또는 블록 인덱스.end_seqlen
: 타일의 끝 시퀀스 위치.begin_n_split_idx
: 해당 타일이 속한 분할 그룹의 시작 인덱스.(num_sm_parts, 8)
로, GPU의 SM 수에 따라 타일 메타데이터가 생성됨.seqlens_k_ptr
와 batch_size
를 기반으로 전체 시퀀스를 분석.num_sm_parts
에 따라 작업을 SM별로 나눔.begin_idx
~ end_idx
와 begin_seqlen
~ end_seqlen
으로 정의.block_table
과 page_block_size
를 사용해 캐시 블록 단위로 매핑.num_splits_ptr
는 배치별로 몇 개의 타일로 나뉘는지를 기록.seqlen_k > page_block_size
)는 여러 타일로 분할.tile_scheduler_metadata_ptr
를 통해 각 SM에 타일을 배정.num_sm_parts
는 GPU의 SM 수와 작업 부하에 따라 동적으로 조정될 가능성.q_ptr
, k_ptr
, v_ptr
의 타일 부분을 SRAM에 로드.q_row_stride
, k_head_stride
등)를 사용해 메모리 접근 최적화.scale_softmax_log2
를 활용해 로그 도메인에서 안정적으로 계산.oaccum_ptr
와 softmax_lseaccum_ptr
에 타일별 결과 저장.block_table
과 page_block_size
를 통해 KV 캐시를 페이지 단위로 관리, 메모리 효율성 향상.cudaStream_t
를 활용해 데이터 로드와 연산을 오버랩.get_mla_metadata
→ get_mla_metadata_func
:
cache_seqlens
→ seqlens_k_ptr
.tile_scheduler_metadata
→ tile_scheduler_metadata_ptr
.num_splits
→ num_splits_ptr
.flash_mla_with_kvcache
→ run_mha_fwd_splitkv_mla
:
q
, k_cache
, block_table
→ q_ptr
, k_ptr
, block_table
.head_dim_v
→ d_v
.out
, softmax_lse
→ o_ptr
, softmax_lse_ptr
.차이점:
v_cache
가 생략되었지만, 헤더 파일에는 v_ptr
가 포함되었다다.flash_mla_with_kvcache
가 값 벡터를 동적으로 생성하거나 별도로 처리할 가능성을 시사한다.flash_mla_cuda
는 페이지 기반 KV 캐시와 타일 스케줄링을 결합해 MLA를 최적화하며, SRAM 중심 연산과 온라인 소프트맥스를 통해 메모리와 속도를 극대화한다.tile_scheduler_metadata_ptr
는 타일의 시작/끝 위치와 분할 정보를 정의하며, num_splits_ptr
와 함께 SM 단위 병렬 처리를 관리한다.아무튼, 메모리 효율성과 속도를 극대화 하기 위해 타일 스케줄링 형식을 활용하는 방법이다.