본문 바로가기
딥러닝 with Python

Mean Teacher란?

by CodeCrafter 2024. 1. 28.
반응형

 

이번에는 Mean Teacher에 대해서 알아보도록 하겠습니다. 이는 지식증류(Knowledge Distillation)를 기반으로 한 semi supervised learning의 방법 중 하나입니다. 즉, 지식증류 방법을 semi supervsied learning에 사용할때 사용한 방법 중 하나로 생각하면 되는데요.

 

1. Mean Teacher란?

- "Mean teachers are better role models: Weight-averaged consistency targets improve semi-supervised deep learning results"(A Tarvainen ,2017) 라는 논문에서 제시한 방법으로, 

 

 semi supervised learning 의 성능을 향상시키기 위한 temporal ensembling 방법에서 한 단계 더 나아간 방법이 되겠습니다.

 

- Mean Teacher의 핵심 아이디어는, 학생 모델(Student model)이 예측을 할 때마다, 교사 모델(Teacher model)의 예측과 일치하도록 모델을 강제하는 것을 말합니다. 이렇게 하여 student model이 teacher 모델과 같은 일반화된 패턴을 학습할 수 있도록 유도하는 것입니다. 

 * 이는 레이블이 없는 데이터에 대해서도 teacher model의 예측을 신뢰할 수 있는 soft target으로 사용하여, student model이 이를 통해 학습할 수 있게 하는 방법을 말합니다. 

 

- 여기서 mean 이라는 용어는, teacher model이 student modle의 가중치의 시간에 따른 이동 평균(moving average)를 추적한다는 것을 의미합니다. 

 * 즉, student model의 가중치가 각 학습 단계마다 업데이트될 때, teacher model의 가중치는 이러한 업데이트의 평균을 반영하게 되는 것을 말합니다. 

 

- 이를 수식으로 표현하면 다음과 같습니다.

 

 

(이때, alpha는 0과 1사이의 값으로 smoothing coefficient라고 불립니다)

 

즉, 시간 t에서 teacher model의 가중치는, 전 time step(t-1)에서의 teacher model의 가중치와 시간 t에서의 student model의 가중치의 exponential moving average를 의미합니다.

 

 

2. Mean Teacher 방법론

- Mean teacher 방법론을 아래 그림을 통해서 설명하겠습니다.

 

1) 입력 이미지 및 레이블 : "3"이라는입력 이미지는 label이 주어지고, 이는 지도학습에 사용됩니다.

2) student model의 예측 : student model은 변형된 입력 이미지를 받아 예측을 수행합니다. 이 예측은 model의 최종 출력 계층에서의 확률 분포로 나타나며, 여기서 가장 높은 확률을 가진 클래스가 예측값으로 선택됩니다.

3) classification cost : student model의 예측과 실제 레이블 간의 차이를 계산하는 비용함수를 의미합니다. 일반적으로 Cross Entropy Loss function을 주로 사용합니다.

4) teacher model의 예측 : 동시에 teacher model은 입력 이미지의 변형된 형태(student model과는 다른 변형방법을 적용)를 예측을 수행합니다. teacher model의 가중치는 앞서 설명했던 exponential moving average를 따라서 업데이트 됩니다.

 * 이때, teacher model이 이미지의 변형된 형태를 받는 이유는, 데이터의 다양성을 증가시키고 모델이 더 강건한(Robust) 특성을 학습하도록 하기 위함으로, 일반적인 Data Augmentation이라고 볼 수 있습니다. 

 * 이는, 모델의 일반화 능력을 향상 시키고, 과적합을 방지시키며, teacher 와 student라는 다른 두 모델의 일관적인 예측능력을 강화시키고, 모델의 강건성(Robustness)를 향상시켜줍니다. 이를 통해, 레이블이 없는 데이터를 활용하는 방법을 학습하는 것입니다.

5) Consistency Cost : Student model과 Teacher model의 예측 간의 일관성을 측정하기 위한 비용함수를 의미합니다. 논문에서는 아래와 같은 함수를 활용했습니다. 

* 즉, student model과 teacher model이 원본 데이터로부터 각기 다른 형태의 변형을 거친 데이터를 학습했을 때, 그 예측이 얼마나 서로 가까운지를 euclidean norm의 형태로 측정하는 것을 consistency로 본 것입니다.

6) 가중치의 평균 이동 : 앞서 이야기한 exponential moving average 함수를 활용해서 가중치를 평균 이동 시키는 것입니다. 

 

 

** 기존의 Semi supervised learning 방법론

- ⨿ model과 Temporal Ensembling

 

⨿ model은  간단한 도표로 표현하면 아래와 같습니다. 

*이는 input image에 stochasits augmentation을 적용하고, augmented image에 drop out을 적용한 모델을 통과하게 합니다. 이후 출력된 두 개의 서로 다른 output을 바탕으로 squared differenc와 하나의 output과 그 레이블을 활용해 cross etnropy loss를 계산합니다. 이후 이 두 계산값을 weighted summation하여 최종 loss를 도출하는 방법입니다.

* 이는, 하나의 network를 기반해 동작하기 때문에 노이즈가 심하다는 단점이 존재합니다.

 

Temporal Ensembling을 간단히 도표로 표현하면 다음과 같습니다.

* Input image에 stochastic augmentation과 dropout을 적용한 네트워크를 통과시켜 zi를 출력하고, 이렇게 출력된 zi와 hard label(yi)간의 cross entropy를 계산합니다. 또한, zi 와 zi~ 의 squared differnece를 계산하는데, zi~는 다음 Pseudo code를 통해서 확인할 수 있습니다.

 

* 그러나 이러한 Temporal Ensembling도, 각 target은 epoch당 한번 update 되므로 학습이 느리다는 단점이 존재합니다.

 

 

위 두 방법의 단점을 보완한 것이 mean teacher이며, 위 두 방법과 비교했을 때 가장 큰 특징이 바로 student 와 teacher를 활용한다는 것 / 즉 2가지의 서로 다른 네트워크를 활용한다는 것입니다. 

 

특히, Temporal Ensembling과 비교했을 때, Temporal Ensembling은 모든 mini batch에 대한 학습 epoch가 종료되어야지 1번 ensemling prediction이 업데이트 되지만, Mean teacher는 각 mini batch에 대한 학습이 종료될때마다 모델이 업데이트되므로 보다 빈번한 업데이트로 인해 더 안정적이고 일관된 예측을 제공하게 됩니다. 

반응형

댓글