FlashMLA-main_flash_mla_with_kvcache
오늘은 주요 로직 중, flash_mla_with_kvcache 함수에 대해서 알아보자.
메모리 효율성과 속도를 극대화하는 핵심 로직으로 중요한 로직이다.
코드가 짧으므로, 코드 전문을 먼저 보자.
from typing import Optional, Tuple
import torch
import flash_mla_cuda
def get_mla_metadata(
cache_seqlens: torch.Tensor,
num_heads_per_head_k: int,
num_heads_k: int,
) -> Tuple[torch.Tensor, torch.Tensor]:
"""
Arguments:
cache_seqlens: (batch_size), dtype torch.int32.
num_heads_per_head_k: Equals to seq_len_q * num_heads_q // num_heads_k.
num_heads_k: num_heads_k.
Returns:
tile_scheduler_metadata: (num_sm_parts, TileSchedulerMetaDataSize), dtype torch.int32.
num_splits: (batch_size + 1), dtype torch.int32.
"""
return flash_mla_cuda.get_mla_metadata(cache_seqlens, num_heads_per_head_k, num_heads_k)
def flash_mla_with_kvcache(
q: torch.Tensor,
k_cache: torch.Tensor,
block_table: torch.Tensor,
cache_seqlens: torch.Tensor,
head_dim_v: int,
tile_scheduler_metadata: torch.Tensor,
num_splits: torch.Tensor,
softmax_scale: Optional[float] = None,
causal: bool = False,
) -> Tuple[torch.Tensor, torch.Tensor]:
"""
Arguments:
q: (batch_size, seq_len_q, num_heads_q, head_dim).
k_cache: (num_blocks, page_block_size, num_heads_k, head_dim).
block_table: (batch_size, max_num_blocks_per_seq), torch.int32.
cache_seqlens: (batch_size), torch.int32.
head_dim_v: Head dimension of v.
tile_scheduler_metadata: (num_sm_parts, TileSchedulerMetaDataSize), torch.int32, returned by get_mla_metadata.
num_splits: (batch_size + 1), torch.int32, returned by get_mla_metadata.
softmax_scale: float. The scale of QK^T before applying softmax. Default to 1 / sqrt(head_dim).
causal: bool. Whether to apply causal attention mask.
Returns:
out: (batch_size, seq_len_q, num_heads_q, head_dim_v).
softmax_lse: (batch_size, num_heads_q, seq_len_q), torch.float32.
"""
if softmax_scale is None:
softmax_scale = q.shape[-1] ** (-0.5)
out, softmax_lse = flash_mla_cuda.fwd_kvcache_mla(
q,
k_cache,
None,
head_dim_v,
cache_seqlens,
block_table,
softmax_scale,
causal,
tile_scheduler_metadata,
num_splits,
)
return out, softmax_lse
위 코드는 딥시크가 공개한 코드의 전문이다.
단계별로 로직을 분석해 보자.
1. get_mla_metadata 함수 분석
코드 전체
def get_mla_metadata(
cache_seqlens: torch.Tensor,
num_heads_per_head_k: int,
num_heads_k: int,
) -> Tuple[torch.Tensor, torch.Tensor]:
"""
Arguments:
cache_seqlens: (batch_size), dtype torch.int32.
num_heads_per_head_k: Equals to seq_len_q * num_heads_q // num_heads_k.
num_heads_k: num_heads_k.
Returns:
tile_scheduler_metadata: (num_sm_parts, TileSchedulerMetaDataSize), dtype torch.int32.
num_splits: (batch_size + 1), dtype torch.int32.
"""
return flash_mla_cuda.get_mla_metadata(cache_seqlens, num_heads_per_head_k, num_heads_k)
구문별 분석
- 함수 시그니처와 타입 힌트
def get_mla_metadata( cache_seqlens: torch.Tensor, num_heads_per_head_k: int, num_heads_k: int, ) -> Tuple[torch.Tensor, torch.Tensor]:- 목적: 이 함수는 MLA 계산에 필요한 메타데이터를 생성한다.
- 입력:
cache_seqlens: 배치 크기(batch_size)에 해당하는 1차원 텐서로, 각 샘플의 캐시된 시퀀스 길이를 나타낸다. 데이터 타입은torch.int32.num_heads_per_head_k: 쿼리 헤드 수(num_heads_q)와 시퀀스 길이(seq_len_q)를 키 헤드 수(num_heads_k)로 나눈 값. 즉, 키 헤드당 처리해야 할 쿼리 헤드의 개수를 의미한다.num_heads_k: 키(Key) 벡터의 헤드 수.
- 출력: 두 개의 텐서를 포함한 튜플:
tile_scheduler_metadata: 타일 스케줄링을 위한 메타데이터.num_splits: 각 배치에 대한 분할 수.
- Docstring
""" Arguments: cache_seqlens: (batch_size), dtype torch.int32. num_heads_per_head_k: Equals to seq_len_q * num_heads_q // num_heads_k. num_heads_k: num_heads_k. Returns: tile_scheduler_metadata: (num_sm_parts, TileSchedulerMetaDataSize), dtype torch.int32. num_splits: (batch_size + 1), dtype torch.int32. """- 설명: 입력과 출력의 구조를 명확히 설명한다.
tile_scheduler_metadata: GPU의 SM(Streaming Multiprocessor) 단위로 작업을 나누기 위한 메타데이터로,(num_sm_parts, TileSchedulerMetaDataSize)크기를 가진다.num_splits: 배치별로 작업을 몇 개의 조각으로 나눌지 나타내는 텐서이다. 크기는(batch_size + 1).
- 함수 본문
return flash_mla_cuda.get_mla_metadata(cache_seqlens, num_heads_per_head_k, num_heads_k)- 역할: 실제 계산은 CUDA 확장 모듈인
flash_mla_cuda.get_mla_metadata에서 수행된다.
- 역할: 실제 계산은 CUDA 확장 모듈인
요약
- 목적: MLA 계산에서 GPU의 병렬 처리를 최적화하기 위한 메타데이터를 생성.
- 핵심 로직: CUDA 확장을 통해 GPU에서 실행되며, 타일 단위로 작업을 분할하는 데 필요한 정보를 제공.
- 메모리 효율성: 메타데이터를 사전에 계산해 불필요한 연산을 줄이고, 캐시된 시퀀스 길이를 활용해 동적 메모리 할당을 최적화.
2. flash_mla_with_kvcache 함수 분석
코드 전체
def flash_mla_with_kvcache(
q: torch.Tensor,
k_cache: torch.Tensor,
block_table: torch.Tensor,
cache_seqlens: torch.Tensor,
head_dim_v: int,
tile_scheduler_metadata: torch.Tensor,
num_splits: torch.Tensor,
softmax_scale: Optional[float] = None,
causal: bool = False,
) -> Tuple[torch.Tensor, torch.Tensor]:
"""
Arguments:
q: (batch_size, seq_len_q, num_heads_q, head_dim).
k_cache: (num_blocks, page_block_size, num_heads_k, head_dim).
block_table: (batch_size, max_num_blocks_per_seq), torch.int32.
cache_seqlens: (batch_size), torch.int32.
head_dim_v: Head dimension of v.
tile_scheduler_metadata: (num_sm_parts, TileSchedulerMetaDataSize), torch.int32, returned by get_mla_metadata.
num_splits: (batch_size + 1), torch.int32, returned by get_mla_metadata.
softmax_scale: float. The scale of QK^T before applying softmax. Default to 1 / sqrt(head_dim).
causal: bool. Whether to apply causal attention mask.
Returns:
out: (batch_size, seq_len_q, num_heads_q, head_dim_v).
softmax_lse: (batch_size, num_heads_q, seq_len_q), torch.float32.
"""
if softmax_scale is None:
softmax_scale = q.shape[-1] ** (-0.5)
out, softmax_lse = flash_mla_cuda.fwd_kvcache_mla(
q,
k_cache,
None,
head_dim_v,
cache_seqlens,
block_table,
softmax_scale,
causal,
tile_scheduler_metadata,
num_splits,
)
return out, softmax_lse
구문별 분석
- 함수 시그니처와 타입 힌트
def flash_mla_with_kvcache( q: torch.Tensor, k_cache: torch.Tensor, block_table: torch.Tensor, cache_seqlens: torch.Tensor, head_dim_v: int, tile_scheduler_metadata: torch.Tensor, num_splits: torch.Tensor, softmax_scale: Optional[float] = None, causal: bool = False, ) -> Tuple[torch.Tensor, torch.Tensor]:- 목적: 키-값 캐시를 활용한 멀티헤드 어텐션 연산을 수행.
- 입력:
q: 쿼리 텐서, 크기는(batch_size, seq_len_q, num_heads_q, head_dim).k_cache: 캐시된 키 텐서, 크기는(num_blocks, page_block_size, num_heads_k, head_dim).block_table: 캐시 블록의 인덱스를 나타내는 테이블, 크기는(batch_size, max_num_blocks_per_seq).cache_seqlens: 각 배치의 캐시된 시퀀스 길이, 크기는(batch_size).head_dim_v: 값(Value) 벡터의 헤드 차원.tile_scheduler_metadata와num_splits:get_mla_metadata에서 생성된 메타데이터.softmax_scale: 소프트맥스前の 스케일링 값(기본값은1/sqrt(head_dim)).causal: 인과적 어텐션 마스크 적용 여부.
- 출력: 두 개의 텐서를 포함한 튜플:
out: 어텐션 출력, 크기는(batch_size, seq_len_q, num_heads_q, head_dim_v).softmax_lse: 소프트맥스 로그 합계, 크기는(batch_size, num_heads_q, seq_len_q).
- Docstring
""" Arguments: q: (batch_size, seq_len_q, num_heads_q, head_dim). k_cache: (num_blocks, page_block_size, num_heads_k, head_dim). block_table: (batch_size, max_num_blocks_per_seq), torch.int32. cache_seqlens: (batch_size), torch.int32. head_dim_v: Head dimension of v. tile_scheduler_metadata: (num_sm_parts, TileSchedulerMetaDataSize), torch.int32, returned by get_mla_metadata. num_splits: (batch_size + 1), torch.int32, returned by get_mla_metadata. softmax_scale: float. The scale of QK^T before applying softmax. Default to 1 / sqrt(head_dim). causal: bool. Whether to apply causal attention mask. Returns: out: (batch_size, seq_len_q, num_heads_q, head_dim_v). softmax_lse: (batch_size, num_heads_q, seq_len_q), torch.float32. """- 설명: 입력과 출력 텐서의 차원 및 역할에 대한 상세 설명이다.
- 특이점:
k_cache와block_table을 사용해 키-값 캐시를 효율적으로 관리하며,softmax_lse를 반환해 후속 계산(예: 로그 확률)에서 재사용 가능.
- 소프트맥스 스케일링 초기화
if softmax_scale is None: softmax_scale = q.shape[-1] ** (-0.5)- 역할:
softmax_scale이 명시되지 않은 경우, 헤드 차원의 제곱근 역수를 기본값으로 설정한다. - 의미: 이는 어텐션 메커니즘에서 표준 스케일링 방식(
1/sqrt(d_k))으로, 쿼리와 키의 내적이 너무 커지는 것을 방지해 소프트맥스 출력의 분포를 안정화 한다.
- 역할:
- CUDA 호출
out, softmax_lse = flash_mla_cuda.fwd_kvcache_mla( q, k_cache, None, head_dim_v, cache_seqlens, block_table, softmax_scale, causal, tile_scheduler_metadata, num_splits, )- 역할: CUDA 확장 모듈을 호출해 실제 MLA 연산을 수행한다.
- 특이점: 세 번째 인자로
None이 전달되는데, 이는v_cache(값 캐시)가 없음을 의미할 수 있다. 코드가k_cache만 사용하므로 값 벡터는 동적으로 계산되거나 별도로 처리되는 것으로 보인다. - 과정:
- 쿼리(
q)와 캐시된 키(k_cache)를 사용해 어텐션 스코어를 계산. tile_scheduler_metadata와num_splits를 활용해 작업을 타일 단위로 분할 및 병렬 처리.causal=True일 경우, 인과적 마스크를 적용해 미래 토큰을 참조하지 않음.- 소프트맥스 적용 후 출력(
out)과 로그 합계(softmax_lse)를 계산.
- 쿼리(
- 출력 반환
return out, softmax_lse- 역할: 계산된 어텐션 출력과 소프트맥스 로그 합계를 반환한다.
결론
- 목적: 키-값 캐시를 활용해 MLA 연산을 효율적으로 수행하며, GPU에서 타일 기반 병렬 처리를 최적화.
- 핵심 로직:
- 캐시(
k_cache)와 블록 테이블(block_table)을 사용해 메모리 접근을 최적화. - CUDA 확장을 통해 고속 연산을 구현.
softmax_lse를 반환해 후속 계산에서 재사용 가능.
- 캐시(
- 메모리 효율성: 캐시된 키를 재사용하고, 타일 단위로 작업을 분할해 GPU 메모리 사용을 최소화.
- 속도 최적화: 타일 스케줄링과 CUDA 기반 병렬 처리를 통해 연산 속도를 극대화.
- 메모리 효율성:
k_cache와block_table을 활용해 키 데이터를 캐싱하고 재사용함으로써 메모리 사용량을 줄임.- 동적 시퀀스 길이(
cache_seqlens)를 처리해 불필요한 메모리 할당을 방지.
- 속도 최적화:
tile_scheduler_metadata와num_splits를 통해 GPU의 SM 단위로 작업을 분할, 병렬 처리 효율을 극대화.- CUDA 확장을 사용해 CPU-GPU 간 데이터 전송 오버헤드를 최소화하고, GPU 네이티브 연산을 활용.
- 유연성:
causal옵션으로 인과적/비인과적 어텐션을 지원.softmax_scale을 커스터마이징 가능해 다양한 어텐션 메커니즘에 적응 가능.
여기까지 왔으니까, 내부동작, 타일 스케줄링의 세부 사항도 알아보자.
1. flash_mla_cuda의 내부 동작
이 헤더 파일은 두 개의 주요 기능을 정의합니다:
get_mla_metadata_func: MLA 연산에 필요한 메타데이터를 생성.run_mha_fwd_splitkv_mla: MLA의 순방향 연산을 수행.
1.1 구조체: Flash_fwd_mla_params
struct Flash_fwd_mla_params {
// 주요 파라미터
int b, seqlen_q, d, d_v; // 배치 크기, 쿼리 시퀀스 길이, 헤드 차원(QK), 헤드 차원(V)
int h, h_h_k_ratio, ngroups; // 헤드 수(Q), 헤드 비율(h_q/h_k), 그룹 수
bool is_causal; // 인과적 어텐션 여부
float scale_softmax, scale_softmax_log2; // 소프트맥스 스케일링 값(일반 및 log2 기반)
int *__restrict__ cu_seqlens_k; // 캐시된 키의 시퀀스 길이
// 데이터 포인터
void *__restrict__ q_ptr, k_ptr, v_ptr, o_ptr, softmax_lse_ptr; // Q, K, V, 출력, 소프트맥스 로그 합계
// 스트라이드 (배치, 행, 헤드 단위 메모리 오프셋)
index_t q_batch_stride, k_batch_stride, v_batch_stride, o_batch_stride;
index_t q_row_stride, k_row_stride, v_row_stride, o_row_stride;
index_t q_head_stride, k_head_stride, v_head_stride, o_head_stride;
// 블록 테이블 (캐시 관리)
int *__restrict__ block_table;
index_t block_table_batch_stride;
int page_block_size;
// 타일 스케줄링 메타데이터
int *__restrict__ tile_scheduler_metadata_ptr;
int num_sm_parts;
int *__restrict__ num_splits_ptr;
// 누적 버퍼
void *__restrict__ softmax_lseaccum_ptr;
void *__restrict__ oaccum_ptr;
};
- 주요 역할: MLA 연산에 필요한 모든 파라미터와 메모리 포인터를 정의.
- 특징:
h_h_k_ratio:num_heads_q / num_heads_k로, 쿼리 헤드와 키 헤드 간의 비율을 나타냄 (PyTorch 코드의num_heads_per_head_k와 동일).cu_seqlens_k: 키 캐시의 시퀀스 길이를 GPU 메모리에 저장.- 스트라이드: Q, K, V, 출력 텐서의 메모리 레이아웃을 정의하며, 배치/행/헤드 단위로 접근 가능.
block_table: 페이지 기반 KV 캐시를 관리하며,page_block_size단위로 블록을 구성.- 누적 버퍼:
softmax_lseaccum_ptr와oaccum_ptr는 타일별 중간 결과를 누적하기 위한 임시 버퍼로 보임.
1.2 구조체: Mla_metadata_params
struct Mla_metadata_params {
int *__restrict__ seqlens_k_ptr; // 키 시퀀스 길이
int *__restrict__ tile_scheduler_metadata_ptr; // 타일 스케줄링 메타데이터
int *__restrict__ num_splits_ptr; // 분할 수
int batch_size, block_size_n, fixed_overhead_num_blocks, num_sm_parts; // 배치 크기, 블록 크기, 오버헤드 블록 수, SM 파트 수
};
- 주요 역할:
get_mla_metadata_func에 전달되는 파라미터로, 타일 스케줄링과 작업 분할을 위한 입력을 정의. - 특징:
block_size_n: 키 캐시의 블록 크기(아마page_block_size와 동일).fixed_overhead_num_blocks: 고정된 오버헤드 블록 수로, 메타데이터 계산 시 고려되는 상수.
1.3 함수: get_mla_metadata_func
void get_mla_metadata_func(Mla_metadata_params ¶ms, cudaStream_t stream);
- 입력:
Mla_metadata_params와 CUDA 스트림. - 출력:
tile_scheduler_metadata_ptr와num_splits_ptr를 채움. - 내부 동작:
- 시퀀스 길이 분석:
seqlens_k_ptr를 기반으로 각 배치의 작업량 계산. - 작업 분할:
batch_size,block_size_n,num_sm_parts를 고려해 작업을 타일로 나눔. - 메타데이터 생성:
tile_scheduler_metadata_ptr에 타일별 정보(시작/끝 인덱스 등)를 기록. - 분할 수 기록:
num_splits_ptr에 배치별 분할 수를 저장.
- 시퀀스 길이 분석:
1.4 함수: run_mha_fwd_splitkv_mla
template<typename T, int Headdim>
void run_mha_fwd_splitkv_mla(Flash_fwd_mla_params ¶ms, cudaStream_t stream);
- 입력:
Flash_fwd_mla_params와 CUDA 스트림. - 템플릿 파라미터:
T: 데이터 타입(예:float,half).Headdim: 헤드 차원(컴파일 타임 상수로 최적화).
- 내부 동작:
- 타일 단위 연산:
tile_scheduler_metadata_ptr와num_splits_ptr를 기반으로 작업을 타일로 분할. - QK^T 계산:
q_ptr와k_ptr를 사용해 어텐션 스코어 계산. - 소프트맥스:
scale_softmax적용 후 소프트맥스 계산,is_causal에 따라 마스크 적용. - 출력 생성:
v_ptr를 사용해 최종 출력(o_ptr) 계산. - 로그 합계 저장*
softmax_lse_ptr에 결과 기록. - 누적 처리:
softmax_lseaccum_ptr와oaccum_ptr를 사용해 타일 간 결과 집계.
- 타일 단위 연산:
2. 타일 스케줄링의 세부 사항
2.1 TileSchedulerMetaDataSize 정의
static constexpr int TileSchedulerMetaDataSize = 8;
// [begin_idx, begin_seqlen, end_idx, end_seqlen, begin_n_split_idx, _, _, _]
- 의미: 각 타일의 메타데이터는 8개의
int로 구성. - 구성:
begin_idx: 타일이 시작하는 배치 또는 블록 인덱스.begin_seqlen: 타일의 시작 시퀀스 위치.end_idx: 타일이 끝나는 배치 또는 블록 인덱스.end_seqlen: 타일의 끝 시퀀스 위치.begin_n_split_idx: 해당 타일이 속한 분할 그룹의 시작 인덱스.
6-8. 예약 필드: 현재 사용되지 않거나, 디버깅/최적화를 위한 추가 정보.
- 크기:
(num_sm_parts, 8)로, GPU의 SM 수에 따라 타일 메타데이터가 생성됨.
2.2 타일 스케줄링의 설계
- 목적: MLA 연산을 GPU의 SM 단위로 분할해 병렬 처리하며, SRAM을 활용해 HBM 접근을 최소화.
- 과정:
- 작업 분할:
seqlens_k_ptr와batch_size를 기반으로 전체 시퀀스를 분석.num_sm_parts에 따라 작업을 SM별로 나눔.
- 타일 정의:
- 각 타일은
begin_idx~end_idx와begin_seqlen~end_seqlen으로 정의. block_table과page_block_size를 사용해 캐시 블록 단위로 매핑.
- 각 타일은
- 분할 수 계산:
num_splits_ptr는 배치별로 몇 개의 타일로 나뉘는지를 기록.- 예: 긴 시퀀스(
seqlen_k > page_block_size)는 여러 타일로 분할.
- SM 할당:
tile_scheduler_metadata_ptr를 통해 각 SM에 타일을 배정.num_sm_parts는 GPU의 SM 수와 작업 부하에 따라 동적으로 조정될 가능성.
- 작업 분할:
2.3 타일 단위 연산
- 메모리 관리:
q_ptr,k_ptr,v_ptr의 타일 부분을 SRAM에 로드.- 스트라이드(
q_row_stride,k_head_stride등)를 사용해 메모리 접근 최적화.
- 연산 흐름:
- QK^T 계산: 타일 단위로 행렬 곱셈 수행.
- 소프트맥스:
scale_softmax_log2를 활용해 로그 도메인에서 안정적으로 계산. - V 적용: 값 벡터를 곱해 출력 생성.
- 누적:
oaccum_ptr와softmax_lseaccum_ptr에 타일별 결과 저장.
- 병렬성: 각 SM이 독립적으로 타일을 처리하며, 워프 단위로 세부 연산 병렬화.
2.4 최적화 기법
- 페이지 기반 캐시:
block_table과page_block_size를 통해 KV 캐시를 페이지 단위로 관리, 메모리 효율성 향상. - 온라인 소프트맥스: 타일별로 소프트맥스를 계산하고 누적 버퍼에 저장해 HBM 쓰기를 줄임.
- 비동기 처리:
cudaStream_t를 활용해 데이터 로드와 연산을 오버랩.
3. PyTorch 코드와의 매핑
get_mla_metadata→get_mla_metadata_func:- PyTorch의
cache_seqlens→seqlens_k_ptr. tile_scheduler_metadata→tile_scheduler_metadata_ptr.num_splits→num_splits_ptr.
- PyTorch의
flash_mla_with_kvcache→run_mha_fwd_splitkv_mla:q,k_cache,block_table→q_ptr,k_ptr,block_table.head_dim_v→d_v.- 출력
out,softmax_lse→o_ptr,softmax_lse_ptr.
차이점:
- PyTorch 코드에서는
v_cache가 생략되었지만, 헤더 파일에는v_ptr가 포함되었다다.
이는flash_mla_with_kvcache가 값 벡터를 동적으로 생성하거나 별도로 처리할 가능성을 시사한다.
4. 결론
- 내부 동작:
flash_mla_cuda는 페이지 기반 KV 캐시와 타일 스케줄링을 결합해 MLA를 최적화하며, SRAM 중심 연산과 온라인 소프트맥스를 통해 메모리와 속도를 극대화한다. - 타일 스케줄링:
tile_scheduler_metadata_ptr는 타일의 시작/끝 위치와 분할 정보를 정의하며,num_splits_ptr와 함께 SM 단위 병렬 처리를 관리한다.
진짜 결론
아무튼, 메모리 효율성과 속도를 극대화 하기 위해 타일 스케줄링 형식을 활용하는 방법이다.