지금까지의 트랜스 포머 모델에선 정규화가 필수였다.
트랜스포머는 LayerNorm 없으면 망할 거라 생각했는데, 이 논문은 Dynamic Tanh (DyT)라는 새로운 기술로 정규화 없이도 성능을 챙긴다고 주장한다.
논문 링크: https://arxiv.org/abs/2410.03646
요즘 AI 모델은 점점 커지면서 계산 비용이 점점 올라가고 있는데, 이 논문은 LayerNorm의 통계 계산 오버헤드를 없애고 더 간단한 방법으로 비슷한 성능을 낸다고 한다.
트랜스포머에서 LayerNorm(LN)이나 RMSNorm은 거의 신성한 존재로 여겨졌다. 이 녀석들은 왜 쓰냐면:
근데, 이 정규화 계층도 문제점이 있다:
그래서 이 논문은 정규화 없이도 할 수 있지 않을까? 라는 도발적인 질문을 던진다.
그리고 답은 Dynamic Tanh (DyT)라는 새로운 모듈이다.
먼저 LayerNorm이 실제로 뭘 하는지 알아보자.
논문의 Figure 2에서 ViT, wav2vec2.0, DiT 같은 모델의 LayerNorm 입출력을 시각화했는데:
Figure 4에서 더 파고들었는데:
결론? LayerNorm은 사실상 tanh 같은 비선형 함수처럼 작동하고, 값 스케일링도 같이 해준다는 거다.
이걸 알았으니, LayerNorm을 대체할 Dynamic Tanh를 알아보자
논문은 LayerNorm을 Dynamic Tanh (DyT)라는 모듈로 바꿔버린다 즉, Dynamic Tanh게 더 효율이 좋아야 한다.
DyT는 간단한 수식으로 정의된다:
DyT(x) = γ · tanh(α·x) + β
논문의 Figure 3에서 tanh, hardtanh, sigmoid를 비교했는데:
tanh가 이 둘의 장점을 잘 섞어서 최고의 선택지로 뽑혔다.
DyT는 트랜스포머에서 LayerNorm이나 RMSNorm이 있던 자리에 그냥 넣으면 된다. 별도의 통계 계산(평균, 분산) 필요 없고, 기존 활성화 함수(GELU, ReLU)나 모델 구조도 안 바꿔도 된다.
논문에 나온 DyT를 PyTorch로 구현해 봤다. 오픈소스 코드는 못 찾았지만, 논문 설명대로 따라해봤다.
import torch
import torch.nn as nn
class DyT(nn.Module):
def __init__(self, channels, alpha_init=0.5):
super().__init__()
self.alpha = nn.Parameter(torch.ones(1) * alpha_init) # Learnable scaling
self.gamma = nn.Parameter(torch.ones(channels)) # Per-channel scale
self.beta = nn.Parameter(torch.zeros(channels)) # Per-channel shift
def forward(self, x):
return self.gamma * torch.tanh(self.alpha * x) + self.beta
이 코드는 ViT나 LLaMA 같은 모델의 LayerNorm 자리에 바로 끼워넣을 수 있다. 실행은 안 해봐서, 교차검증은 필요하다.
논문은 DyT를 여러 태스크에서 테스트했다.
ViT, LLaMA, wav2vec2, DiT, DNA 시퀀스 모델에 대한 결과는 아래 표로 정리했다.
태스크 | 모델 | 정규화 | DyT | 차이 |
---|---|---|---|---|
Vision 분류 | ViT-B/L | 82.3%/83.1% | 82.5%/83.6% | +0.2/+0.5%p |
ConvNeXt-B/L | 83.7%/84.3% | 83.7%/84.4% | 0/+0.1%p | |
Self-Supervised | MAE, DINO | 83.2–85.5% | 83.2–85.5% | ±0.4%p |
Diffusion (DiT) | B/L/XL | FID 64.9/45.9/19.9 | 63.9/45.7/20.8 | -1.0/-0.2/+0.9 |
LLM (LLaMA 7–70B) | RMSNorm | loss 1.45–1.59 | 1.45–1.60 | ±0.01 |
Speech (wav2vec2) | Base/Large | loss 1.95/1.92 | 1.95/1.91 | -0/-0.01 |
DNA 시퀀스 | HyenaDNA, Caduceus | 85.2%/86.9% | 85.2%/86.9% | – |
놀라운 건, 하이퍼파라미터 거의 안 건드리고 LayerNorm을 DyT로 바꿨을 뿐인데 이 성능을 낸다는 거다.
논문은 tanh 대신 다른 함수를 썼을 때를 테스트했다:
tanh가 안정성과 성능 면에서 더 좋다.
정규화 없는 다른 방법들(Fixup, SkipInit, oReparam)과 비교해봤다
방법 | ViT-B | ViT-L | MAE ViT-B | MAE ViT-L |
---|---|---|---|---|
Fixup | 77.2% | 78.1% | 73.7% | 74.1% |
SkipInit | 74.1% | 75.6% | 73.1% | 74.0% |
oReparam | 82.5% | 83.0% | 83.2% | 85.4% |
DyT | 82.8% | 83.6% | 83.7% | 85.8% |
DyT가 최소 변경으로 최고 성능! Fixup이나 SkipInit은 학습률 낮춰야 해서 귀찮고, oReparam은 비슷하지만 DyT가 더 간단하다.
DyT는 통계 계산이 없어서 빠르다. H100 GPU에서 테스트 결과
설정 | RMSNorm | DyT | 전체 모델 | 개선율 |
---|---|---|---|---|
Uncompiled (BF16) | ||||
Inference (100단계) | 2.1s | 1.0s | 14.1s→13.0s | 15–52%↓ |
Train (100단계) | 8.3s | 4.8s | 42.6s→39.1s | 18–42%↓ |
Compiled (torch.compile) | ||||
Inference | 0.3s | 0.3s | 12.3s | ≒0% |
Train | 3.9s | 3.9s | 38.9s | ≒0% |
DyT는 트랜스포머에선 쓸모 있지만, 한계도 있다
Dynamic Tanh는 LayerNorm의 본질을 “비선형 스쿼싱 + 스케일링”으로 재정의했다. 통계 계산 없이도:
이 논문은 트랜스포머의 정규화 계층을 다시 생각하게 만든다.
DyT는 간단하면서도 강력해서 앞으로 더 많은 모델에 적용될 가능성이 크다.
근데, 솔직히 말하면, 이거 다른 정규화 대체 방법 oReparam 같은 거와 섞어서 쓰면 더 쎌 것 같기도 하다.
예를 들어, oReparam의 스펙트럼 제어랑 DyT의 비선형을 합치면 더 좋지 않을까.
아직 테스트 안 해봐서 모르겠지만, 뭔가 터질 것 같은 느낌이다.