논문

하루에도 수만개의 글자를 읽고 있습니다. 하루에도 수백장의 종이를 들춰 읽습니다.
이것은 그 읽기에 대한 일기입니다.

Does Knowledge Distillation Really Work?

Problem

Knowledge distillation을 사용하여 학습한 student는 때로는 teacher와는 매우 다른 prediction 결과를 보여준다. 심지어 teacher와 student가 동일한 capacity를 가지는 self-distillation의 경우라도 말이다. 이러한 실험에서 역설적으로 teacher와의 일치(match)에 실패하더라도 오히려 student라도 오히려 더 일반화(high generalization)된 성능을 보여준다는 연구도 존재한다.

Essence

때문에 fidelity와 generalization를 분리해서 보고자 한다. Fidelity는 student가 얼마나 teacher의 prediction 일치하는 결과를 내는지를 의미하고, generalization은 student가 in-distribution 데이터의 보지 못했던 데이터에 대한 성능을 의미한다.

논문에서는 student가 good fidelity를 얻기는 매우 힘든 문제임을 보였다. Low fidelity는 identifiability problem으로 distillation에 사용하는 dataset을 augmentation으로 풀 수 있음을 보이고, 또한 이는 최적화 문제로 실패할 경우 teacher와 매칭되지 않게 되어버리는 것을 보였다.

Knowledge distillation이 정말 동작하는가? 라고 물었을 때, student의 generalization을 높이는 관점에서는 그렇다 라고 이야기 할 수 있겠으나, 용어 자체의 의미로 보았을 때는 매우 teacher의 매우 한정된 지식만을 전달한다는 면에서는 아니오라고 이야기할 수 있을 것이다.

Detail

Preliminaries

Knowledge Distillation

Hinton이 knowledge distillation 로스이다.
\mathcal{L}_s := \alpha \mathcal{L}_{\text{NLL}} + (1-\alpha)\mathcal{L}_{\text{KD}}, \text{where} \alpha \in [0, 1) \ \mathcal{L}_{\text{NLL}} := -\sum^c_{j=1}y_j \log \sigma_j (\mathbb{z}_s) \mathcal{L}_{\text{KD}}(\mathbb{z]_s, \mathbb{z}_t}) := -\tau^2\sum_{j=1}^c \sigma_j (\frac{\mathbb{z}_t}{\tau}) \log \sigma_j \frac{\mathbb{z}_s}{\tau}
첫번째 항은 task에 대한 supervised cross-entropy 이다.

두번째 항은 teacher와 student를 매칭하기 위해 teacher와 student의 predictive distribution에 대한 cross-entropy 이다.\tau​는 temperature hyperparameter로 1을 사용할 경우\text{KL}(\hat{p}_t || \hat{p}_s)​​와 거의 같아진다. 이 값은 teacher label을 softness의 역할을 한다.

논문의 실험에서는 fidelity에 집중하기 위해\alpha=0​을 사용하였다.

Metric and Evaluation

Generalization을 측정하기 위하여 top-1 accuracy, negative log-likelihood (NLL), expected calibration error (ECE)를 이용하였다.

Fidelity를 측정하기 위해서는 다음 2개 식을 사용하였다.
\text{Average Top-1 Agreement} := \frac{1}{n} \sum_{i=1}^n \mathbb{1} { \arg\max_j \sigma_j (\mathbb{z}_{t, i}) = \arg\max_j \sigma_j (\mathbb{z}_{s, i})}\ \text{Average Predictive KL} := \frac{1}{n} \sum_{i=1}^n \text{KL} ( \hat{p}_t (\mathbb{y}|\mathbb{x}_i) || \hat{p}_s (\mathbb{y}|\mathbb{x}_i))
첫번째 식은 top-1 label의 average agreement를 의미하고 두번째 식은 모든 label에 대해서 predictive distribution의 fidelity를 측정하기 위한 것이다.

Knowledge Distillation Transfers Knowledge Poorly

When is knowledge transfer successful?

Self-distillation 실험을 두가지 케이스에 대해서 해보았다.

  1. MNIST로 학습된 LeNet-5을 EMNIST로 distillation
  2. CIFAR-100으로 학습된 ResNet-56을 동일한 DB로 학습된 GAN으로 만들어낸 샘플로 distillation

LeNet-5은 단순해서 distillation이 잘 이루어졌지만 ResNet-56은 조금 다른 결과를 보여주었다. Distillation에 사용한 DB가 증가할 수록 student의 accuracy는 teacher에 가까워진다. 하지만 오히려 DB 수가 적을 때 accuracy가 teacher보다 높았다가 적어지기 때문에, fidelity가 좋아지면서 오히려 generalization이 안 좋아졌음을 알 수 있다.

[Figure 1]

[Figure 2]

What can self-distillation tell us about knowledge distillation in general?

사실 후자의 ResNet-56의 실험은 결국 distillation 관점에서는 실패한 방법이라고 할 수 있다. 만약 student가 teacher를 완벽하게 일치했다면 student는 teacher의 성능을 능가할 수가 없다. 반면 teacher가 독립적으로 학습된 student 대비 generalization이 잘 되었다고 한다면 teacher와 일치하지 않는 부분고 관련된 regularization effect들을 상회하는 좋은 fidelity를 가지고 있다고 생각할 수 있다. 결국 이것은 큰 모델에서 student로 지식을 더 효과적으로 전달하고자 하는 knowledge distillation의 본래 목적에도 부합한다. 하지만 실제로는 높은 fidelity를 갖는 student라고 해서 좋은 generalization을 갖는 것만은 아니다.

If distillation improves generalization, why care about fidelity?

Fidelity와 generalization의 관계, 그리고 fidelity를 높이는 것이 중요한 이유를 살펴보면,

Distilling large teacher model : 본래 knowledge distillation이 생겨난 이유를 생각해볼 때 large teacher에서 small student 사이의 generalization disparity가 생겨난다. 만약 앙상블을 사용하는 것처럼 teacher를 더 키울 경우 이 disparity는 더욱 커진다. Student의 fidelity를 늘리는 것은 이러한 disparity를 줄이는 데에 도움이 된다.

Interpretability and reliability : knowledge distillation은 transfer representation으로 생각할 수 있는데, 이는 곧 큰 black box 모델을 좀 더 해석이 용이한 작은 모델로 전달할 수 있음을 의미한다. 이를 위해서는 결국 좋은 distillation fidelity가 필수적이다.

Understanding : fidelity와 generailization을 분리해서 보고 fidelity를 조사함으로써 knowledge distillation의 작동 방식을 이해하는데 도움이 된다.

Possible causes of low distillation fidelity

Poor distillation fidelity의 원인이 되는 것을 생각해보았다.

Architecture : ResNet 기반의 구조에서만 발생할 수도 있으나 VGG 기반 구조에서 같은 실험을 하였음으로 조사에서 제외한다.

Student Capacity : self-distillation 세팅에서도 low fidelity가 발생하였으므로 이 역시 조사에서 제외한다.

Identifiability : student가 high-fidelity와 low-fidelity인지 구분하기에 distiallation data가 부족할 수 있다. 다시 말하면 distillation dataset에서 teacher와 일치했더라도 test data에서는 그렇지 않을 수 있다.

Optimization : distillation은 optimization problem으로 풀기에 어려운 부분이 있다. 학습시에도 일치하지 않았다면 test 시에도 일치하지 않을 것이다.

Identifiability: Are We Using the Right Distillation Dataset?

Should we do more data augmentation?

Augmentation은 간단하면서도 많이 쓰이는 방법이다. 만약 identifiability이 주요한 원인이라면 augmentation을 통해 data를 더 늘리면 fidelity가 올라갈 것이다. ResNet-56 - CIFAR-100 조합으로 기본 augmentation을 적용한 것과 다른 augmentation 추가한 실험을 하였다.

결과적으로 가장 좋은 generalization을 보였던 MixUp, GAN을 적용했던 실험들이 가장 좋은 fidelity를 보이는 것은 아니었다. 사실 temprature를 4로 설정한 것에서 agreement가 굉장히 많이 올랐는데, 모든 augmentation에서\tau=4를 사용한 것에 근접하게 KL이 올랐음을 알 수 있다.

하지만 out-of-distribution augmentation을 적용한 경우는 성능이 매우 떨어졌다.

한편 teacher와 그와 관련 없는 또 다른 teacher에서 distillation된 student 사이에서도 어느정도 좋은 agreement를 보여주는 결과를 보면 student의 fidelity가 정말 엉망인 것을 알 수 있다.

정리해보면 distillation dataset을 늘리는 것은 fidelity를 올리는데 도움이 되지만 tempering을 조정한 것보다는 효과가 적으며, 불충분한 label이 높은 fidelity를 내는데 방해가 될 가능성은 적다고 할 수 있다.

The data recycling hypothesis

앞에서는 label의 수만 살펴봤지만, 좋은 label인지는 살펴보지 않았다. Data augmentation을 적용하면 teacher가 훈련할 때의 distribution과 distillation에서 student가 보는 distribution 간의 shift가 생긴다. 만약 이러한 shift가 없다면 어떻게 될까?

CIFAR-100 데이터를 2개로D_0, D_1​ 두개로 나누어 treacher인 ResNet-56은D_0​으로 여러개 앙상블 모델을 만들었다. 그리고D_0으로 distillation한s_0,D_1로 distillation한s_1, 둘다 사용한s_{0 \cup 1}을 서로 비교하였다.

s_0은 높은 test accuracy를 보였지만, 낮은 ECE와 fidelity를 보였다. 또한s_1s_0보다는 높은 fidelity를 보이긴 했지만, teacher의 accuracy를 따라가진 못했다. 가장 좋은 결과는s_{0 \cup 1}이었다.

그 이유는 teacher의 좋은 fidelity가 student의 generalization을 좋게 만드는 것은 아니며, 역시 distillation data를 수정하더라 약간의 향상이 있을 뿐 나쁜 fidelity를 만드는 원인은 아니다 라고 이야기할 수 있다.

Optimization: Does the Student Match the Teacher on the Distillation Data?

지금까지는 student fidelity를 test set에서 측정하였는데, distillation에 사용한 data를 사용하면 어떻게 되는지 보자.

More distillation data lowers train agreement

이전에 ResNet-56 에서 GAN을 사용한 CIFAR-100 데이터 실험과 동일한 세팅에서 distillation dataset에서 agreement를 측정하였다.

결과는 이전과 반대의 성향을 나타냈는데, distillation dataset이 늘어날 수록 student가 teacher를 따라가는 것이 힘들어보였다. Data augmentation을 적용한 경우에는 agreement 저하 현상이 더 심해진다.

우리는 knowledge distillation을 사용할 때에, student가 distillation set을 사용했을 때의 teacher에 맞추기를 기대하는 작업이라고 생각하지만, 실험 결과로 보았을 때에는 optimization method가 심지어 distillation dataset에서도 높은 fidelity를 얻는 것이 불가능해 보인다.

또한 student가 teacher와 일치하기 위해서 많은 teacher의 label을 필요로하지만, techer가 학습할 때 사용하지 않은 data는 distillation을 어렵게 만든다.

Why is train agreement so low?

Simplified distillation experiment : ResNet-20으로 모델을 줄이고 BatchNomr을 LayerNorm으로 변경하였다. 이는 teacher와 student가 동일한 weight를 가지더라도 학습 상태에 따라 다른 activation을 만들어 낸다.

Can we solve the optimization problem better? : SGD 대신 Adam을 사용하였지만 오히려 조금 fidelity가 떨어졌다. 학습 epoch을 늘리는 경우 fidelity가 늘어났지만 그 변화가 너무 미미하여 현실성이 없다.

The distillation loss surface hypothesis : student를 초기화할 때 random initialization을 사용하는데, teacher의 학습된 weight와의 weight sum\theta_s = \lambda \theta_t + (1-\lambda) \theta_r으로 초기화하는 경우를 생각해보았다.

\lambda=0.375를 사용하였을 때 급격하게 agreement가 높아졌으며 점점\lambda를 늘릴 수록 agreement가 증가하는 것을 알 수 있었다. 만약 초기화값이 teacher보다 멀리 떨어져 있으면 \lambda \in { 0, 0.25 }) distillation loss는 sub optimal에 수렴하지만 teacher에 가까워질 수록 같은 basin에 수렴하는 것을 알 수 있었다.


Add a Comment Trackback