이번에 알아볼 것은 KL Divergence Loss 입니다.
1. KL Divergence란?
- KL Divergence는 Kullback-Leibler Divergence의 줄임말로, 두 확률 분포간의 차이를 측정하는 비대칭적인 척도를 의미합니다.
* 여기서 비대칭성이란, 두 확률분포 P와 Q의 순서를 바꾸어서 KL Divergence를 계산하면 다른 결과가 나올 수 있다는 것을 의미합니다.
- 이는 주로 모델이 예측한 분포와 실제 분포간의 차이를 측정하는데 사용됩니다.
- KL Divergence 계산식은 다음과 같습니다
1) 두 이산형 확률 분포 P와 Q의 KL Divergence
2) 두 연속형 확률 분포의 P와 Q의 KL Divergence
- KL Divergence의 특징
1) 비대칭성 : 두 확률 분포 P와 Q의 순서를 바꾸어서 계산하면 결과가 달라짐
2) 비음수성 : 항상 0 이상의 값을 가지며, 두 분포가 완전히 동일할때 0의 값을 가짐
3) 정보해석 : KL Divergence는 Q를 사용해 P를 근사할 때 발생하는 추가적인 정보량을 의미함. 즉, Q가 P에 얼마나 비효율적인 지를 나타냄
2. Python을 활용해 서로 다른 두 분포의 모습과 KL Divergence Loss 구하기
이번에는 서로 다른 두 분포 P와 Q를 활용해서 각각을 시각화하고 상호간의 KL Divergence Loss를 구해보겠습니다.
비교할 대상은
1) p_normal 과 q_normal (둘 다 normal distribution)
2) p_beta 와 q_beta (둘 다 beta distribution)
입니다.
import numpy as np
import matplotlib.pyplot as plt
from scipy.stats import norm, beta, entropy
# Define the distributions
x = np.linspace(-5, 5, 1000)
p_normal = norm.pdf(x, 0, 1) # Standard normal distribution
q_normal = norm.pdf(x, 1, 1.5) # Normal distribution with mean 1 and std 1.5
x_beta = np.linspace(0, 1, 1000)
p_beta = beta.pdf(x_beta, 2, 5) # Beta distribution with alpha=2, beta=5
q_beta = beta.pdf(x_beta, 5, 2) # Beta distribution with alpha=5, beta=2
# Calculate KL Divergence
kl_normal = entropy(p_normal, q_normal)
kl_beta = entropy(p_beta, q_beta)
# Plot the distributions and KL Divergence
fig, axs = plt.subplots(2, 1, figsize=(12, 10))
# Plot Normal distributions
axs[0].plot(x, p_normal, label='P: N(0, 1)')
axs[0].plot(x, q_normal, label='Q: N(1, 1.5)')
axs[0].fill_between(x, p_normal, q_normal, color='gray', alpha=0.3)
axs[0].set_title(f'Normal Distributions\nKL Divergence = {kl_normal:.4f}')
axs[0].legend()
# Plot Beta distributions
axs[1].plot(x_beta, p_beta, label='P: Beta(2, 5)')
axs[1].plot(x_beta, q_beta, label='Q: Beta(5, 2)')
axs[1].fill_between(x_beta, p_beta, q_beta, color='gray', alpha=0.3)
axs[1].set_title(f'Beta Distributions\nKL Divergence = {kl_beta:.4f}')
axs[1].legend()
# Show the plots
plt.tight_layout()
plt.show()
kl_normal, kl_beta
위 첫번째 그림은 두 normal distribution의 시각화와 KL Divergence Loss를 구한것이고
아래 두번째 그림은 두 beta distribution의 시각화와 KL Divegence Loss를 구한 것입니다.
이때, 회색 부분은 두 분포간의 차이를 의미합니다.
위 결과를 해석해보면
Normal distirubtion P에 대한 Q의 KL Divergence는 0.3461이고
Beta distribution P에 대한 Q의 KL Divergence는 3.2499 입니다.
KL Divergence가 0에 가까울수록 서로 더 비슷한 분포임을 의미하고, 이를 고려할때 normal distirbution끼리의 모습이 더 유사함을 정량적으로도 확인할 수 있습니다.
3차원으로도 이를 확인해볼 수 있습니다.
P와 Q 라는 3차원 정규 분포를 활용해보겠습니다.
import numpy as np
import matplotlib.pyplot as plt
from scipy.stats import multivariate_normal
from scipy.linalg import det, inv
# Define the distributions
mean1 = np.array([0, 0, 0])
cov1 = np.diag([1, 1, 1])
mean2 = np.array([1, 1, 1])
cov2 = np.diag([1.5, 1.5, 1.5])
# Define the grid
x, y, z = np.mgrid[-3:3:.5, -3:3:.5, -3:3:.5]
pos = np.empty(x.shape + (3,))
pos[:, :, :, 0] = x
pos[:, :, :, 1] = y
pos[:, :, :, 2] = z
# Calculate the probability densities
p = multivariate_normal(mean1, cov1).pdf(pos)
q = multivariate_normal(mean2, cov2).pdf(pos)
# Calculate KL Divergence
k = 3
kl_div = 0.5 * (np.log(det(cov2) / det(cov1)) - k + np.trace(inv(cov2).dot(cov1)) + (mean2 - mean1).T.dot(inv(cov2)).dot(mean2 - mean1))
# Plot the distributions
fig = plt.figure(figsize=(12, 6))
# Plot P
ax1 = fig.add_subplot(121, projection='3d')
ax1.contour3D(x[:, :, 0], y[:, :, 0], p[:, :, 0], 50)
ax1.set_title('P: N([0, 0, 0], diag([1, 1, 1]))')
# Plot Q
ax2 = fig.add_subplot(122, projection='3d')
ax2.contour3D(x[:, :, 0], y[:, :, 0], q[:, :, 0], 50)
ax2.set_title('Q: N([1, 1, 1], diag([1.5, 1.5, 1.5]))')
plt.suptitle(f'KL Divergence = {kl_div:.4f}')
plt.show()
kl_div
위 그림에서 만든것처럼 서로 다른 두 3차원 정규분포에 대한 KL Divergence (P given Q)는 1.1082로 계산됩니다.
서로 유사하지만 다름을 시각적으로, 그리고 KL Divergence를 통해 수치적으로 알 수 있습니다.
댓글