Gihak111 Navbar

FlashMLA-main_flash_mla_with_kvcache

오늘은 주요 로직 중, flash_mla_with_kvcache 함수에 대해서 알아보자.
메모리 효율성과 속도를 극대화하는 핵심 로직으로 중요한 로직이다.
코드가 짧으므로, 코드 전문을 먼저 보자.

from typing import Optional, Tuple

import torch

import flash_mla_cuda


def get_mla_metadata(
    cache_seqlens: torch.Tensor,
    num_heads_per_head_k: int,
    num_heads_k: int,
) -> Tuple[torch.Tensor, torch.Tensor]:
    """
    Arguments:
        cache_seqlens: (batch_size), dtype torch.int32.
        num_heads_per_head_k: Equals to seq_len_q * num_heads_q // num_heads_k.
        num_heads_k: num_heads_k.

    Returns:
        tile_scheduler_metadata: (num_sm_parts, TileSchedulerMetaDataSize), dtype torch.int32.
        num_splits: (batch_size + 1), dtype torch.int32.
    """
    return flash_mla_cuda.get_mla_metadata(cache_seqlens, num_heads_per_head_k, num_heads_k)


def flash_mla_with_kvcache(
    q: torch.Tensor,
    k_cache: torch.Tensor,
    block_table: torch.Tensor,
    cache_seqlens: torch.Tensor,
    head_dim_v: int,
    tile_scheduler_metadata: torch.Tensor,
    num_splits: torch.Tensor,
    softmax_scale: Optional[float] = None,
    causal: bool = False,
) -> Tuple[torch.Tensor, torch.Tensor]:
    """
    Arguments:
        q: (batch_size, seq_len_q, num_heads_q, head_dim).
        k_cache: (num_blocks, page_block_size, num_heads_k, head_dim).
        block_table: (batch_size, max_num_blocks_per_seq), torch.int32.
        cache_seqlens: (batch_size), torch.int32.
        head_dim_v: Head dimension of v.
        tile_scheduler_metadata: (num_sm_parts, TileSchedulerMetaDataSize), torch.int32, returned by get_mla_metadata.
        num_splits: (batch_size + 1), torch.int32, returned by get_mla_metadata.
        softmax_scale: float. The scale of QK^T before applying softmax. Default to 1 / sqrt(head_dim).
        causal: bool. Whether to apply causal attention mask.

    Returns:
        out: (batch_size, seq_len_q, num_heads_q, head_dim_v).
        softmax_lse: (batch_size, num_heads_q, seq_len_q), torch.float32.
    """
    if softmax_scale is None:
        softmax_scale = q.shape[-1] ** (-0.5)
    out, softmax_lse = flash_mla_cuda.fwd_kvcache_mla(
        q,
        k_cache,
        None,
        head_dim_v,
        cache_seqlens,
        block_table,
        softmax_scale,
        causal,
        tile_scheduler_metadata,
        num_splits,
    )
    return out, softmax_lse

위 코드는 딥시크가 공개한 코드의 전문이다.
단계별로 로직을 분석해 보자.

1. get_mla_metadata 함수 분석

코드 전체

def get_mla_metadata(
    cache_seqlens: torch.Tensor,
    num_heads_per_head_k: int,
    num_heads_k: int,
) -> Tuple[torch.Tensor, torch.Tensor]:
    """
    Arguments:
        cache_seqlens: (batch_size), dtype torch.int32.
        num_heads_per_head_k: Equals to seq_len_q * num_heads_q // num_heads_k.
        num_heads_k: num_heads_k.

    Returns:
        tile_scheduler_metadata: (num_sm_parts, TileSchedulerMetaDataSize), dtype torch.int32.
        num_splits: (batch_size + 1), dtype torch.int32.
    """
    return flash_mla_cuda.get_mla_metadata(cache_seqlens, num_heads_per_head_k, num_heads_k)

구문별 분석

  1. 함수 시그니처와 타입 힌트
    def get_mla_metadata(
        cache_seqlens: torch.Tensor,
        num_heads_per_head_k: int,
        num_heads_k: int,
    ) -> Tuple[torch.Tensor, torch.Tensor]:
    
    • 목적: 이 함수는 MLA 계산에 필요한 메타데이터를 생성한다.
    • 입력:
      • cache_seqlens: 배치 크기(batch_size)에 해당하는 1차원 텐서로, 각 샘플의 캐시된 시퀀스 길이를 나타낸다. 데이터 타입은 torch.int32.
      • num_heads_per_head_k: 쿼리 헤드 수(num_heads_q)와 시퀀스 길이(seq_len_q)를 키 헤드 수(num_heads_k)로 나눈 값. 즉, 키 헤드당 처리해야 할 쿼리 헤드의 개수를 의미한다.
      • num_heads_k: 키(Key) 벡터의 헤드 수.
    • 출력: 두 개의 텐서를 포함한 튜플:
      • tile_scheduler_metadata: 타일 스케줄링을 위한 메타데이터.
      • num_splits: 각 배치에 대한 분할 수.
  2. Docstring
    """
    Arguments:
        cache_seqlens: (batch_size), dtype torch.int32.
        num_heads_per_head_k: Equals to seq_len_q * num_heads_q // num_heads_k.
        num_heads_k: num_heads_k.
    
    Returns:
        tile_scheduler_metadata: (num_sm_parts, TileSchedulerMetaDataSize), dtype torch.int32.
        num_splits: (batch_size + 1), dtype torch.int32.
    """
    
    • 설명: 입력과 출력의 구조를 명확히 설명한다.
    • tile_scheduler_metadata: GPU의 SM(Streaming Multiprocessor) 단위로 작업을 나누기 위한 메타데이터로, (num_sm_parts, TileSchedulerMetaDataSize) 크기를 가진다.
    • num_splits: 배치별로 작업을 몇 개의 조각으로 나눌지 나타내는 텐서이다. 크기는 (batch_size + 1).
  3. 함수 본문
    return flash_mla_cuda.get_mla_metadata(cache_seqlens, num_heads_per_head_k, num_heads_k)
    
    • 역할: 실제 계산은 CUDA 확장 모듈인 flash_mla_cuda.get_mla_metadata에서 수행된다.

요약

2. flash_mla_with_kvcache 함수 분석

코드 전체

def flash_mla_with_kvcache(
    q: torch.Tensor,
    k_cache: torch.Tensor,
    block_table: torch.Tensor,
    cache_seqlens: torch.Tensor,
    head_dim_v: int,
    tile_scheduler_metadata: torch.Tensor,
    num_splits: torch.Tensor,
    softmax_scale: Optional[float] = None,
    causal: bool = False,
) -> Tuple[torch.Tensor, torch.Tensor]:
    """
    Arguments:
        q: (batch_size, seq_len_q, num_heads_q, head_dim).
        k_cache: (num_blocks, page_block_size, num_heads_k, head_dim).
        block_table: (batch_size, max_num_blocks_per_seq), torch.int32.
        cache_seqlens: (batch_size), torch.int32.
        head_dim_v: Head dimension of v.
        tile_scheduler_metadata: (num_sm_parts, TileSchedulerMetaDataSize), torch.int32, returned by get_mla_metadata.
        num_splits: (batch_size + 1), torch.int32, returned by get_mla_metadata.
        softmax_scale: float. The scale of QK^T before applying softmax. Default to 1 / sqrt(head_dim).
        causal: bool. Whether to apply causal attention mask.

    Returns:
        out: (batch_size, seq_len_q, num_heads_q, head_dim_v).
        softmax_lse: (batch_size, num_heads_q, seq_len_q), torch.float32.
    """
    if softmax_scale is None:
        softmax_scale = q.shape[-1] ** (-0.5)
    out, softmax_lse = flash_mla_cuda.fwd_kvcache_mla(
        q,
        k_cache,
        None,
        head_dim_v,
        cache_seqlens,
        block_table,
        softmax_scale,
        causal,
        tile_scheduler_metadata,
        num_splits,
    )
    return out, softmax_lse

구문별 분석

  1. 함수 시그니처와 타입 힌트
    def flash_mla_with_kvcache(
        q: torch.Tensor,
        k_cache: torch.Tensor,
        block_table: torch.Tensor,
        cache_seqlens: torch.Tensor,
        head_dim_v: int,
        tile_scheduler_metadata: torch.Tensor,
        num_splits: torch.Tensor,
        softmax_scale: Optional[float] = None,
        causal: bool = False,
    ) -> Tuple[torch.Tensor, torch.Tensor]:
    
    • 목적: 키-값 캐시를 활용한 멀티헤드 어텐션 연산을 수행.
    • 입력:
      • q: 쿼리 텐서, 크기는 (batch_size, seq_len_q, num_heads_q, head_dim).
      • k_cache: 캐시된 키 텐서, 크기는 (num_blocks, page_block_size, num_heads_k, head_dim).
      • block_table: 캐시 블록의 인덱스를 나타내는 테이블, 크기는 (batch_size, max_num_blocks_per_seq).
      • cache_seqlens: 각 배치의 캐시된 시퀀스 길이, 크기는 (batch_size).
      • head_dim_v: 값(Value) 벡터의 헤드 차원.
      • tile_scheduler_metadatanum_splits: get_mla_metadata에서 생성된 메타데이터.
      • softmax_scale: 소프트맥스前の 스케일링 값(기본값은 1/sqrt(head_dim)).
      • causal: 인과적 어텐션 마스크 적용 여부.
    • 출력: 두 개의 텐서를 포함한 튜플:
      • out: 어텐션 출력, 크기는 (batch_size, seq_len_q, num_heads_q, head_dim_v).
      • softmax_lse: 소프트맥스 로그 합계, 크기는 (batch_size, num_heads_q, seq_len_q).
  2. Docstring
    """
    Arguments:
        q: (batch_size, seq_len_q, num_heads_q, head_dim).
        k_cache: (num_blocks, page_block_size, num_heads_k, head_dim).
        block_table: (batch_size, max_num_blocks_per_seq), torch.int32.
        cache_seqlens: (batch_size), torch.int32.
        head_dim_v: Head dimension of v.
        tile_scheduler_metadata: (num_sm_parts, TileSchedulerMetaDataSize), torch.int32, returned by get_mla_metadata.
        num_splits: (batch_size + 1), torch.int32, returned by get_mla_metadata.
        softmax_scale: float. The scale of QK^T before applying softmax. Default to 1 / sqrt(head_dim).
        causal: bool. Whether to apply causal attention mask.
    
    Returns:
        out: (batch_size, seq_len_q, num_heads_q, head_dim_v).
        softmax_lse: (batch_size, num_heads_q, seq_len_q), torch.float32.
    """
    
    • 설명: 입력과 출력 텐서의 차원 및 역할에 대한 상세 설명이다.
    • 특이점: k_cacheblock_table을 사용해 키-값 캐시를 효율적으로 관리하며, softmax_lse를 반환해 후속 계산(예: 로그 확률)에서 재사용 가능.
  3. 소프트맥스 스케일링 초기화
    if softmax_scale is None:
        softmax_scale = q.shape[-1] ** (-0.5)
    
    • 역할: softmax_scale이 명시되지 않은 경우, 헤드 차원의 제곱근 역수를 기본값으로 설정한다.
    • 의미: 이는 어텐션 메커니즘에서 표준 스케일링 방식(1/sqrt(d_k))으로, 쿼리와 키의 내적이 너무 커지는 것을 방지해 소프트맥스 출력의 분포를 안정화 한다.
  4. CUDA 호출
    out, softmax_lse = flash_mla_cuda.fwd_kvcache_mla(
        q,
        k_cache,
        None,
        head_dim_v,
        cache_seqlens,
        block_table,
        softmax_scale,
        causal,
        tile_scheduler_metadata,
        num_splits,
    )
    
    • 역할: CUDA 확장 모듈을 호출해 실제 MLA 연산을 수행한다.
    • 특이점: 세 번째 인자로 None이 전달되는데, 이는 v_cache(값 캐시)가 없음을 의미할 수 있다. 코드가 k_cache만 사용하므로 값 벡터는 동적으로 계산되거나 별도로 처리되는 것으로 보인다.
    • 과정:
      1. 쿼리(q)와 캐시된 키(k_cache)를 사용해 어텐션 스코어를 계산.
      2. tile_scheduler_metadatanum_splits를 활용해 작업을 타일 단위로 분할 및 병렬 처리.
      3. causal=True일 경우, 인과적 마스크를 적용해 미래 토큰을 참조하지 않음.
      4. 소프트맥스 적용 후 출력(out)과 로그 합계(softmax_lse)를 계산.
  5. 출력 반환
    return out, softmax_lse
    
    • 역할: 계산된 어텐션 출력과 소프트맥스 로그 합계를 반환한다.

결론

  1. 메모리 효율성:
    • k_cacheblock_table을 활용해 키 데이터를 캐싱하고 재사용함으로써 메모리 사용량을 줄임.
    • 동적 시퀀스 길이(cache_seqlens)를 처리해 불필요한 메모리 할당을 방지.
  2. 속도 최적화:
    • tile_scheduler_metadatanum_splits를 통해 GPU의 SM 단위로 작업을 분할, 병렬 처리 효율을 극대화.
    • CUDA 확장을 사용해 CPU-GPU 간 데이터 전송 오버헤드를 최소화하고, GPU 네이티브 연산을 활용.
  3. 유연성:
    • causal 옵션으로 인과적/비인과적 어텐션을 지원.
    • softmax_scale을 커스터마이징 가능해 다양한 어텐션 메커니즘에 적응 가능.

여기까지 왔으니까, 내부동작, 타일 스케줄링의 세부 사항도 알아보자.

1. flash_mla_cuda의 내부 동작

이 헤더 파일은 두 개의 주요 기능을 정의합니다:

  1. get_mla_metadata_func: MLA 연산에 필요한 메타데이터를 생성.
  2. run_mha_fwd_splitkv_mla: MLA의 순방향 연산을 수행.

1.1 구조체: Flash_fwd_mla_params

struct Flash_fwd_mla_params {
    // 주요 파라미터
    int b, seqlen_q, d, d_v;  // 배치 크기, 쿼리 시퀀스 길이, 헤드 차원(QK), 헤드 차원(V)
    int h, h_h_k_ratio, ngroups;  // 헤드 수(Q), 헤드 비율(h_q/h_k), 그룹 수
    bool is_causal;  // 인과적 어텐션 여부
    float scale_softmax, scale_softmax_log2;  // 소프트맥스 스케일링 값(일반 및 log2 기반)
    int *__restrict__ cu_seqlens_k;  // 캐시된 키의 시퀀스 길이

    // 데이터 포인터
    void *__restrict__ q_ptr, k_ptr, v_ptr, o_ptr, softmax_lse_ptr;  // Q, K, V, 출력, 소프트맥스 로그 합계

    // 스트라이드 (배치, 행, 헤드 단위 메모리 오프셋)
    index_t q_batch_stride, k_batch_stride, v_batch_stride, o_batch_stride;
    index_t q_row_stride, k_row_stride, v_row_stride, o_row_stride;
    index_t q_head_stride, k_head_stride, v_head_stride, o_head_stride;

    // 블록 테이블 (캐시 관리)
    int *__restrict__ block_table;
    index_t block_table_batch_stride;
    int page_block_size;

    // 타일 스케줄링 메타데이터
    int *__restrict__ tile_scheduler_metadata_ptr;
    int num_sm_parts;
    int *__restrict__ num_splits_ptr;

    // 누적 버퍼
    void *__restrict__ softmax_lseaccum_ptr;
    void *__restrict__ oaccum_ptr;
};

1.2 구조체: Mla_metadata_params

struct Mla_metadata_params {
    int *__restrict__ seqlens_k_ptr;  // 키 시퀀스 길이
    int *__restrict__ tile_scheduler_metadata_ptr;  // 타일 스케줄링 메타데이터
    int *__restrict__ num_splits_ptr;  // 분할 수
    int batch_size, block_size_n, fixed_overhead_num_blocks, num_sm_parts;  // 배치 크기, 블록 크기, 오버헤드 블록 수, SM 파트 수
};

1.3 함수: get_mla_metadata_func

void get_mla_metadata_func(Mla_metadata_params &params, cudaStream_t stream);

1.4 함수: run_mha_fwd_splitkv_mla

template<typename T, int Headdim>
void run_mha_fwd_splitkv_mla(Flash_fwd_mla_params &params, cudaStream_t stream);

2. 타일 스케줄링의 세부 사항

2.1 TileSchedulerMetaDataSize 정의

static constexpr int TileSchedulerMetaDataSize = 8;
// [begin_idx, begin_seqlen, end_idx, end_seqlen, begin_n_split_idx, _, _, _]

2.2 타일 스케줄링의 설계

2.3 타일 단위 연산

2.4 최적화 기법

3. PyTorch 코드와의 매핑

차이점:

4. 결론

진짜 결론

아무튼, 메모리 효율성과 속도를 극대화 하기 위해 타일 스케줄링 형식을 활용하는 방법이다.