원래의 모델은 높은 정확도로, 문제 없이 사용 가능한 BERT 모델이였다.
하지만, 이 모델을 모바일에 탑제하기 위해 모바일 벌트를 활용하고,
ONNX로 변환하여 사용하였다.
벌거 아닌 수정내용 같지만, 이 과정을 이행하는 중 엄청난 성능저하를 경험했다.
정확도가 50%아래로 내려오는, 지도학습을 한게 맞나 싶을 정도의 결과가 나왔다.
이러한 성능 저하가 발생한 원인을 추려보면, 당연히 베이스가 되는 모델을 BERT에서 MobileBERT로 교체한 점이 될 것이다.
이 모델은 모바일에 최적화 되어 있지만, 그만큼 성능도 BERT에 비해 떨어진다.
이러한 경우, 모델의 성능을 올리는데 있어 엄청 흥미로운 방법이 있다.
이 모델을 만드는 과정에서, 나는 모바일 버젼이 아닌, 그냥 버젼의 BERT 모델 또한 만들었다.
이 BERT 모델의 경험치를 MobileBERT에게 주어 모델의 성능을 올리는 독특한 방법이 있다.
지식 증류에서는 여러 손실 함수의 조합을 사용하여 학생 모델이 교사 모델의 지식을 학습하도록 만든다.
def distillation_loss(student_logits, teacher_logits, temperature): teacher_probs = F.softmax(teacher_logits / temperature, dim=1) student_probs = F.log_softmax(student_logits / temperature, dim=1) loss = F.kl_div(student_probs, teacher_probs, reduction=’batchmean’) * (temperature ** 2) return loss ```
def hidden_state_loss(student_hidden_states, teacher_hidden_states):
loss = sum(F.mse_loss(s, t) for s, t in zip(student_hidden_states, teacher_hidden_states))
return loss
def attention_loss(student_attentions, teacher_attentions):
loss = sum(F.mse_loss(s, t) for s, t in zip(student_attentions, teacher_attentions))
return loss
for epoch in range(epochs):
for batch in dataloader:
student_outputs = student_model(**batch)
teacher_outputs = teacher_model(**batch)
# Calculate losses
soft_label_loss = distillation_loss(student_outputs.logits, teacher_outputs.logits, temperature=4)
hidden_loss = hidden_state_loss(student_outputs.hidden_states, teacher_outputs.hidden_states)
attention_loss = attention_loss(student_outputs.attentions, teacher_outputs.attentions)
# Combine losses
total_loss = soft_label_loss + hidden_loss + attention_loss
# Backpropagation and optimization
optimizer.zero_grad()
total_loss.backward()
optimizer.step()
MobileBERT 논문에서는 미리 Masked Language Modeling(MLM)과 Next Sentence Prediction(NSP)을 교사 모델에서 학생 모델로 전이하는 Pre-training Distillation을 강조한다.
MobileBERT 모델을 학습한 후 ONNX로 변환할 때, TensorRT와 같은 고성능 엔진을 사용하여 추론 속도와 효율성을 극대화 하자.
아래 글을 확인하자.
ONNX로 하여금 작동속도, 학습속도 올리기
이 글들도 확인하자. 도음이 된다.
MobileBERT: a Compact Task-Agnostic BERT for Resource-Limited Devices
Distilling the Knowledge in a Neural Network
TinyBERT: Distilling BERT for Natural Language Understanding