본문 바로가기
딥러닝 with Python

[딥러닝 with Python] 비전 트랜스포머(Vision Transformer / ViT) (2/2)

by CodeCrafter 2024. 11. 8.
반응형

 

 

지난 포스팅에서 최초 제시된 ViT에 대해서만 알아보았다면, 이번에는 ViT의 활용 가능성에 대해서 알아보겠습니다.

[딥러닝 with Python] 비전 트랜스포머(Vision Transformer / ViT) (1/2)

 

 

1. ViT vs ResNets

- ViT  이전에는 이미지 관련 Task에서 기본 Backbone 네트워크로 주로 ResNets을 활용했었는데요. 

 

- 아래 그림처럼 데이터 셋의 크기가 3억장이 넘는 데이터로 학습이 되어야지 비로서 ResNets의 성능을 이길 수 있게 되었습니다. ViT 모델의 크기도 가장 큰 버전인 Huge를 활용해서야 말이죠

 

- 그래서 이러한 ViT를 잘 활용해보기 위해 ImageNet-1K 데이터만을 가지고 Regularization과 Data Augmentation을 활용해 모델의 성능이 어떻게 변하는지 확인해보았습니다.

 

해당 실험은 "How to train your ViT? Data, Augmentation and Regularization in Vision Transformers
" 라는 논문에 나온 내용들입니다.

 

-실험은 300 epoch만 진행하였고, 파란색으로 표시된 것들은 ViT를 의미하며, RTi는 ResNet으로 먼저 Feature map을 추출한 뒤 이를 패치로 분할하여 활용하는 Hybrid 구조이고, R26 또는 R50은 ResNet에서 26개의 블록 또는 50개의 블록을 활용한 것이 되겠습니다. 

 * 이때 Regularization은 Weight Decay, Stochastic Depth, 그리고 MLP에서 Dropout을, Data augmentation은 Mixup과 RandAugment입니다.

 

- 위 결과를 보시면, 전반적으로 Augmentation을 많이 가할 수록 성능이 향상되는 모습을 보이지만 Hybrid는 그렇지 않다는 것을 확인할 수 있습니다. 또한 가장 성능이 좋은 것은 ViT 모델임을 알 수 있습니다. 

 

 

2. Distillaion 활용 및 Hierarchical ViT

1) Distillation

 

- 이번에는 Knowledge Distillation을 활용한 방법입니다.

 

- 해당 방법은 "Training data-efficient image transformers & distillation through attention"이라는 논문에서 나온 DeiT(Data efficient Transformer)의 방식입니다.

 

- 아래와 같이 기존 토큰에 Distillation token을 추가하여 CNN 계열의 Teacher로부터 Knowledge Distillation을 하는 것입니다.

 

 

- 이 결과는 아래와 같이 기존 ViT보다 더 좋은 결과물을 보이고 있음을 알 수 있습니다. 특히 여기에서는 더 많은 epoch과 더 높은 resolution을 가진 데이터를 활용하면 결과가 좋아짐을 알 수 있습니다.

 

 

 

2) Hierarchical ViT

 

- 일반적으로 CNN을 층을 거듭할수록 feature map의 사이즈는 작아지고 channel수가 많아지며, feature map의 크기라는 관점에서 보았을때 Hierarchical한 구조라고 할 수 있습니다.

 

- 이러한 방식이 CNN에서 효과가 있었기에 ViT에도 이를 적용한 것 중 유명한 Backbone 네트워크 중 하나가 Swin Transformer입니다.

 

- 위와 같이 일반적인 ViT와는 달리 Hierarchical하게 구조를 만들기 위해 패치의 사이즈를 더 키웠으며, 아래와 같이 4개의 Block을 활용해 Hierarchical 한 구조를 만들었습니다.

-  이때, 토큰끼리 Interact 하지 못한다는 점을 고려하여 Shifted Window 방식으로 Normal Windows와 Shifted Windows를 번갈아 적용한 결과를 활용하고 있습니다. 즉, 각 블록에서 Normal Windows가 적용된에 Layer Normalization과 MLP 후 Shifted Windows가 적용된 결과를 만들어 내서 다음 블록을 보내는 것입니다.

 

 

 

반응형

댓글