딥러닝 모델은 본질적으로 미분 가능한 함수들의 조합입니다.
그렇기 때문에 신경망을 학습시킬 때는 대부분의 연산이 미분 가능해야 합니다.
하지만, 종종 이산적인 선택(discrete decision)이 필요한 상황이 필요하게 됩니다.
예를 들어
a) 특정 클래스를 선택할 때
b) 텍스트 생성 모델에서 다음 토큰을 선택할 때
c) 강화학습에서 이산적인 행동(action)을 선택할 때
d) 그래프 신경망에서 노드를 선택할 때
가 있습니다.
이러한 이산 적인 선택은 argmax 나 sampling 등으로 구현이 되며, 미분이 불가능합니다. 이로 인해 역전파(Back propagation)가 끊기고, 모델 학습이 불가능해집니다.
이 문제를 해결하기 위해 제안된 방법이 바로 오늘 포스팅의 주제인
Gumbel Softmax
입니다.
오늘 포스팅에서는 Gumbel Softmax의 개념과, 수식, Pytorch를 활용한 시각화를 통해 다각적으로 이해해보도록 하겠습니다.
1. Gumbel Softmax란?
1) Gumbel Softmax가 필요한 상황
예를 들어 아래와 같이 확률 분포에서 클래스를 하나 샘플링 한다고 생각해보겠습니다.
이때 sample은 단 하나의 클래스 인덱스를 반환합니다. 예를 들어 2라면, 이를 one-hot 벡터로 바꾸어 [0,0,1]이 됩니다.
이때 문제는 이 과정이 연속적이지 않다는 것입니다.
입력 확률이 조금만 바뀌어도 출력은 확 바뀌기기에, 기울기(Gradient)를 계산할 수 없습니다.
기존에는 이를 해결하기 위해
REINFORCE와 같은 policy gradient 방식을 사용하지만,
* gradient의 variance가 크고
* 학습이 느리고 불안정
하다는 문제가 있습니다.
그래서 나온 대안이 바로 Gumbel-softmax 입니다.
이는 이산 샘플링을 연속적인 softmax 형태로 근사하여 미분 가능성을 확보해주는 방법입니다.
2) Gumbel-Max Trick + Softmax 근사
a) Gumbel-Max Trick
- 카테고리 확률 분포인 phi로 부터 샘플링 하고 싶다고 해보겠습니다.
- 이때 Gumbel-Max Trick은 다음을 이용합니다.
- 여기서 gi 는 Gumbel(0,1)분포에 뽑은 noise 입니다.
* Gumbel 분포 샘플링 방법의 수식은 아래와 같으며
즉, 샘플링을 argmax 형태로 바꿔주는 트릭입니다.
b) Softmax 근사: argmax => softmax
- argmax는 여전히 미분이 불가능하므로, softmax로 근사하게 됩니다.
- 이렇게 하게 되면 입력인 phi에 대해 출력인 y는 미분이 가능해지는 것입니다.
Gumbel-Max Trick을 활용한 샘플링을 충분히 활용시 원본의 확률과 같아지는 것은 아래 파이썬 코드를 통해서 확인해보실 수 있습니다.
여기서는 Gumbel Maxtrick의 결과와 Multinomial의 결과가 충분한 iteration을 거쳤을때 원본의 확률과 거의 동일하게 나옴을 알 수 있습니다.
여기서 한 가지 의문이 들 수 있습니다.
신경망이 한 번의 epoch을 거칠 때마다 이렇게 충분히 많은 수의 샘플링이 이루어지지 않는데 그러면 샘플링에 따라 확률이 왜곡되는것이 아닌가?라고 말입니다.
- 실제 신경망 학습에서는 매 iteration 마다 batch 단위로 데이터가 처리됩니다. 이때 Gumbel-Softmax 샘플링은 배치 내의 각 데이터 포인트에 대해 한 번씩만 수행이 됩니다.
신경망 학습 과정에서의 샘플링은 다음과 같이 이루어 집니다.
a) 입력 데이터 배치 처리: 한 번에 N개의 데이터(N은 배치 크기)가 모델이 입력
b) 로짓 계산: 모델은 각 데이터 포인트에 대해 k개의 클래스에 대한 로짓을 출력. 즉 출력은 (N,k) 형태의 텐서
c) Gumbel Softmax 적용: 이 로짓 텐서에 대해 검벨 소프트 맥스 함수가 적용되며, 각 데이터 포인트(N개)마다 별 도의 검벨 노이즈가 추가되고 소프트 맥스 연산이 수행됨. 결과는 여전히 (N,k) 형태의 소프트한 확률 분포 텐서가 됨
d) 역전파: 이 소프트한 결과는 손실 계산에 사용되며, 기울기가 역전파 되어 모델의 파라미터 업데이트 됨
이때 대량의 샘플링 없이 적은 수의 샘플만으로도 모델이 잘 학습되는 이유는 다음과 같습니다.
a) 미분 가능한 근사: 검벨-소프트맥스의 핵심은 이산적인 선택을 미분 가능한 연속적인 값으로 근사한다는 점입니다. 이 연속적인 근사를 통해 매 스텝마다 비록 하나의 '샘플'에 해당하는 소프트 벡터를 얻지만, 이 벡터는 그 자체로 확률 분포에 대한 기대값 정보를 담고 있으며, 기울기가 끊기지 않고 흘러갈 수 있게 해줍니다.
b) 배치 처리의 이점: 딥러닝 학습은 여러 에폭에 걸쳐 수많은 배치를 처리합니다. 각 배치마다 새로운 랜덤 노이즈와 샘플링이 이루어지고, 이 과정이 누적되면서 전체적인 분포 정보가 모델 파라미터에 반영됩니다. 수많은 작은 '소프트'한 업데이트들이 모여 최종적으로는 원하는 분포를 학습하게 되는 거죠.
c) 온도 어닐링(Temperature Annealing): 앞서 설명드린 온도 어닐링(temperature annealing) 전략이 중요한 역할을 합니다. 학습 초기에는 온도가 높아 출력이 더 부드러워지므로, 모델이 다양한 가능성을 탐색하고 학습이 안정적으로 시작할 수 있습니다. 학습이 진행되면서 온도를 낮추면 출력은 점차 '하드'한 이산적인 선택에 가까워지지만, 이미 모델은 이전 학습을 통해 올바른 방향으로 수렴하고 있는 상태이기 때문에 매 스텝마다 대량의 샘플링이 필요하지 않습니다.
2. Pytorch를 통한 Gumbel Softmax 구현
다음은 Pytorch를 통해서 Gumbel Softmax를 구현해보겠습니다.
위와 같이 1,2,3 이라는 discrete 한 텐서도 gumbel softmax를 통해 미분이 가능해짐을 알 수 있습니다.
Temperature에 따른 Gumbel Softmax의 확률에 미치는 영향은 아래와 같습니다.
Temperature가 높아질수록 부드러워지고 낮을수록 날카로워지는 모습을 보입니다.
'딥러닝 with Python' 카테고리의 다른 글
[딥러닝 with Python] TGNN이란?(Temporal Graph Neural Network) (0) | 2025.06.29 |
---|---|
[딥러닝 with Python] Fourier Transform 비교: FFT, STFT, RFFT (0) | 2025.06.28 |
[딥러닝 with Python] 다변량 시계열 이상탐지(Multivariate Timeseries Anomaly Detection)벤치마크 데이터셋 정리 (0) | 2025.06.27 |
[딥러닝 with Python] 시계열 이상탐지 평가지표(TSAD Evaluation Metrics) 정리 (0) | 2025.06.26 |
[딥러닝 with Python] DYNOTEARS란? (Dynamic NOTEARS) (0) | 2025.06.23 |
댓글