Gihak111 Navbar

Transformers without Normalization

지금까지의 트랜스 포머 모델에선 정규화가 필수였다.
트랜스포머는 LayerNorm 없으면 망할 거라 생각했는데, 이 논문은 Dynamic Tanh (DyT)라는 새로운 기술로 정규화 없이도 성능을 챙긴다고 주장한다.
논문 링크: https://arxiv.org/abs/2410.03646

요즘 AI 모델은 점점 커지면서 계산 비용이 점점 올라가고 있는데, 이 논문은 LayerNorm의 통계 계산 오버헤드를 없애고 더 간단한 방법으로 비슷한 성능을 낸다고 한다.

1. 정규화 계층, 왜 문제일까?

트랜스포머에서 LayerNorm(LN)이나 RMSNorm은 거의 신성한 존재로 여겨졌다. 이 녀석들은 왜 쓰냐면:

근데, 이 정규화 계층도 문제점이 있다:

그래서 이 논문은 정규화 없이도 할 수 있지 않을까? 라는 도발적인 질문을 던진다.
그리고 답은 Dynamic Tanh (DyT)라는 새로운 모듈이다.

2. LayerNorm

먼저 LayerNorm이 실제로 뭘 하는지 알아보자.

2.1 LayerNorm의 입출력 패턴

논문의 Figure 2에서 ViT, wav2vec2.0, DiT 같은 모델의 LayerNorm 입출력을 시각화했는데:

2.2 토큰 vs 채널 분석

Figure 4에서 더 파고들었는데:

결론? LayerNorm은 사실상 tanh 같은 비선형 함수처럼 작동하고, 값 스케일링도 같이 해준다는 거다.
이걸 알았으니, LayerNorm을 대체할 Dynamic Tanh를 알아보자

3. Dynamic Tanh

논문은 LayerNorm을 Dynamic Tanh (DyT)라는 모듈로 바꿔버린다 즉, Dynamic Tanh게 더 효율이 좋아야 한다.

3.1 DyT의 수식

DyT는 간단한 수식으로 정의된다:

DyT(x) = γ · tanh(α·x) + β

3.2 왜 tanh?

논문의 Figure 3에서 tanh, hardtanh, sigmoid를 비교했는데:

tanh가 이 둘의 장점을 잘 섞어서 최고의 선택지로 뽑혔다.

4. DyT 적용 방법

DyT는 트랜스포머에서 LayerNorm이나 RMSNorm이 있던 자리에 그냥 넣으면 된다. 별도의 통계 계산(평균, 분산) 필요 없고, 기존 활성화 함수(GELU, ReLU)나 모델 구조도 안 바꿔도 된다.

DyT 코드 구현

논문에 나온 DyT를 PyTorch로 구현해 봤다. 오픈소스 코드는 못 찾았지만, 논문 설명대로 따라해봤다.

import torch
import torch.nn as nn

class DyT(nn.Module):
    def __init__(self, channels, alpha_init=0.5):
        super().__init__()
        self.alpha = nn.Parameter(torch.ones(1) * alpha_init)  # Learnable scaling
        self.gamma = nn.Parameter(torch.ones(channels))       # Per-channel scale
        self.beta = nn.Parameter(torch.zeros(channels))       # Per-channel shift

    def forward(self, x):
        return self.gamma * torch.tanh(self.alpha * x) + self.beta

이 코드는 ViT나 LLaMA 같은 모델의 LayerNorm 자리에 바로 끼워넣을 수 있다. 실행은 안 해봐서, 교차검증은 필요하다.

5. 실험 결과

논문은 DyT를 여러 태스크에서 테스트했다.
ViT, LLaMA, wav2vec2, DiT, DNA 시퀀스 모델에 대한 결과는 아래 표로 정리했다.

태스크 모델 정규화 DyT 차이
Vision 분류 ViT-B/L 82.3%/83.1% 82.5%/83.6% +0.2/+0.5%p
  ConvNeXt-B/L 83.7%/84.3% 83.7%/84.4% 0/+0.1%p
Self-Supervised MAE, DINO 83.2–85.5% 83.2–85.5% ±0.4%p
Diffusion (DiT) B/L/XL FID 64.9/45.9/19.9 63.9/45.7/20.8 -1.0/-0.2/+0.9
LLM (LLaMA 7–70B) RMSNorm loss 1.45–1.59 1.45–1.60 ±0.01
Speech (wav2vec2) Base/Large loss 1.95/1.92 1.95/1.91 -0/-0.01
DNA 시퀀스 HyenaDNA, Caduceus 85.2%/86.9% 85.2%/86.9%

놀라운 건, 하이퍼파라미터 거의 안 건드리고 LayerNorm을 DyT로 바꿨을 뿐인데 이 성능을 낸다는 거다.

6. Ablation Study

6.1 스쿼싱 함수 비교

논문은 tanh 대신 다른 함수를 썼을 때를 테스트했다:

tanh가 안정성과 성능 면에서 더 좋다.

6.2 α의 역할

6.3 α 초기화

7. 기존 방법들과 비교

정규화 없는 다른 방법들(Fixup, SkipInit, oReparam)과 비교해봤다

방법 ViT-B ViT-L MAE ViT-B MAE ViT-L
Fixup 77.2% 78.1% 73.7% 74.1%
SkipInit 74.1% 75.6% 73.1% 74.0%
oReparam 82.5% 83.0% 83.2% 85.4%
DyT 82.8% 83.6% 83.7% 85.8%

DyT가 최소 변경으로 최고 성능! Fixup이나 SkipInit은 학습률 낮춰야 해서 귀찮고, oReparam은 비슷하지만 DyT가 더 간단하다.


8. 효율성

DyT는 통계 계산이 없어서 빠르다. H100 GPU에서 테스트 결과

설정 RMSNorm DyT 전체 모델 개선율
Uncompiled (BF16)        
Inference (100단계) 2.1s 1.0s 14.1s→13.0s 15–52%↓
Train (100단계) 8.3s 4.8s 42.6s→39.1s 18–42%↓
Compiled (torch.compile)        
Inference 0.3s 0.3s 12.3s ≒0%
Train 3.9s 3.9s 38.9s ≒0%

9. 한계와 앞으로의 방향

DyT는 트랜스포머에선 쓸모 있지만, 한계도 있다

10. 결론

Dynamic Tanh는 LayerNorm의 본질을 “비선형 스쿼싱 + 스케일링”으로 재정의했다. 통계 계산 없이도:

이 논문은 트랜스포머의 정규화 계층을 다시 생각하게 만든다.
DyT는 간단하면서도 강력해서 앞으로 더 많은 모델에 적용될 가능성이 크다.

근데, 솔직히 말하면, 이거 다른 정규화 대체 방법 oReparam 같은 거와 섞어서 쓰면 더 쎌 것 같기도 하다.
예를 들어, oReparam의 스펙트럼 제어랑 DyT의 비선형을 합치면 더 좋지 않을까.
아직 테스트 안 해봐서 모르겠지만, 뭔가 터질 것 같은 느낌이다.