[본 포스팅은 만들면서 배우는 생성 AI 2탄 을 참조했습니다]
DC GAN은 GAN모델을 Convolution 신경망을 활용해 학습 능력과 그 성능을 향상시켰지만, 학습을 시키는것이 매우 어렵다는 단점이 있었습니다.
이를 해결하기 위해 제안된 것인 와서스테인 GAN 입니다.
와서스테인 GAN은 안정적인 GAN 훈련을 위해 와서스테인 손실함수를 제안하는 GAN 모델입니다.
이 와서스테인 손실함수를 사용하면 기존에 사용하던 손실함수인 이진 크로스 엔트로피 손실보다 GAN 모델의 학습결과가 더 안정적으로 수렴할 수 있다고 합니다.
- 기존의 이진 크로스 엔트로피 손실(Binaray Cross Entropy Loss)는 아래와 같습니다.
* y는 실제 레이블 (0 또는 1)
* y_hat은 모델의 예측 값 (0과 1 사이의 확률)
- 기존의 GAN의 Discriminator의 손실함수는 아래와 같습니다.
* D(x)는 진짜 데이터 x에 대한 판별자의 예측 값
* G(z)는 잠재 공간 z에서 생성된 가짜 데이터
* p_data는 진짜 데이터의 분포
* p_z는 잠재 공간의 분포
* E : 기대값(Expectation)
- 생성자의 손실 함수는 아래와 같습니다.
- 이를 와서스테인 손실 함수와 비교해보겠습니다.
- 와서스테인 손실(Wasserstein loss)은 1과 0 대신, 1과 -1을 사용합니다.
- 또한, 판별자의 마지막 층에서 시그모이드 활성호 함수를 제거하여 예측 확률이 [0,1] 범위에 제한되지 않고 무한대 범위의 어떤 숫자도 될 수 있도록 만들어 줍니다
* 이 때문에 WGAN의 판별자는 보통 비평자(crictic)이라 부르며 확률 점수를 반환해줍니다.
- 와서스테인 손실 함수의 식은 아래와 같습니다
* 와서스테인 손실 함수
* 이진 크로스엔트로피 손실 함수
* 위 두 손실함수를 비교해보면, 와서스테인 손실함수에서는 확률에 대해서 로그를 취하지 않은 형태입니다.
* 로그를 취하지 않았기에 실제 타겟을 1 또는 0이 아닌 / 1 또는 -1로 설정할 수 있습니다.
-WGAN 비평자의 손실 함수 최소화 식은 아래와 같습니다.
* 이는, 진짜 이미지와 생성된 이미지에 대한 예측 사이의 차이를 최대화하는 것입니다.
- WGAN 생성자의 손실 함수 최소화 식은 아래와 같습니다.
* 이는 비평자로부터 가능한 한 높은 점수를 받는 이미지를 생성하려고 하는 것입니다.
- 다음은 립시츠 제약에 대해서 알아보겠습니다.
- WGAN의 비평자는 시그모이드 함수를 활성화 함수 부분에서 제거했기때문에, 와서스테인 손실 값은 아주 큰 값이 될 수도 있습니다. 이는 학습에 부정적인 영향을 줄 수 있는데요. 이를 보완하기 위해 립시츠 제약을 사용합니다.
- 다음은, WGAN에서 사용되는 1-Lipschitz 제약입니다.
* ㅣx1-x2ㅣ는 두 이미지 픽셀의 평균적인 절대값 차이를 의미합니다.
* ㅣD(x1) - D(x2)ㅣ는 비평자 예측 간의 절댓값 차이를 의미합니다.
* 기본적으로 두 이미지 사이에서 비평자의 예측이 변화하는 비율을 제한할 필요가 있습니다. (즉, 기울기의 절댓값이 어디에서나 최대 1이어야 합니다)
* 즉, 립시츠 연속 함수는 두 점 사이의 거리를 일정 비 이상으로 증가시키지 않는 함수를 말합니다. 이를 그림으로 표현해보면 아래와 같습니다.
* 즉, 어느때에도 해당 제약 값(위 식에서 우변의 값) 의 + 또는 - 기울기안에 있도록 해당 값을 제한하는 것입니다.
(이를 적용하면 해당 선은 어느 지점에서나 상승하거나 하강하는 비율이 한정(하얀색이 아닌 부분만 가능)되는 것입니다.
-WGAN 모델에서는 가중치를 -0.01과 0.01 사이로 제한하도록 합니다. 즉, 해당 범위보다 더 큰 기울기에 해당하는 연속하는 값이 나올수가 없게 제한하는 겁니다.
* 하지만, 이 방식에 대한 비판 중 하나는 가중치에 제한을 두었기 때문에 학습 속도가 크게 감소한다는 것입니다. 강한 비평자는 WGAN의 성공의 중심이기 때문입니다. 이를 보완하기 위해 와서스테인 GAN - Gradient Penalty 입니다.
- WGAN - Gradient Penalty는 그레이디언트 Norm이 1에서 벗어날 경우 모델에 불이익을 주는 페널티 항을 비판자의 손실 함수에 포함시켜 립시츠 제약 조건을 직접 강제하는 방법을 보여줍니다.
* 여기서 Gradient Penalty 손실은 아래와 같습니다.
이를 활용한 WGAN-GP의 전체 손실 함수 식은 아래와 같습니다.
이 중 lamda는 Gradient Penalty의 강도를 조절하는 하이퍼 파라미터를 의미합니다.
- 아래는 WGAN-GP의 비평자의 훈련과정을 나타냅니다.
이를 활용해서 WGAN-GP 모델을 파이썬으로 구현해보겠습니다.
먼저 깃허브 저장소에서 util 파일을 다운로드 받아줍니다.
import sys
# 코랩의 경우 깃허브 저장소로부터 utils.py를 다운로드 합니다.
if 'google.colab' in sys.modules:
!wget https://raw.githubusercontent.com/rickiepark/Generative_Deep_Learning_2nd_Edition/main/notebooks/utils.py
!mkdir -p notebooks
!mv utils.py notebooks
다음은 학습간 활용할 라이브러리를 로드해줍니다.
import numpy as np
import tensorflow as tf
from tensorflow.keras import (
layers,
models,
callbacks,
utils,
metrics,
optimizers,
)
from notebooks.utils import display, sample_batch
다음은 학습 간 활용할 하이퍼 파라미터를 정해줍니다.
IMAGE_SIZE = 64
CHANNELS = 3
BATCH_SIZE = 512
NUM_FEATURES = 64
Z_DIM = 128
LEARNING_RATE = 0.0002
ADAM_BETA_1 = 0.5
ADAM_BETA_2 = 0.999
EPOCHS = 200
CRITIC_STEPS = 3
GP_WEIGHT = 10.0
LOAD_MODEL = False
ADAM_BETA_1 = 0.5
ADAM_BETA_2 = 0.9
이제 데이터를 준비해줍니다. 데이터는 CelebA 입니다.
# 코랩일 경우 노트북에서 celeba 데이터셋을 받습니다.
if 'google.colab' in sys.modules:
# # 캐글-->Setttings-->API-->Create New Token에서
# # kaggle.json 파일을 만들어 코랩에 업로드하세요.
# from google.colab import files
# files.upload()
# !mkdir ~/.kaggle
# !cp kaggle.json ~/.kaggle/
# !chmod 600 ~/.kaggle/kaggle.json
# # celeba 데이터셋을 다운로드하고 압축을 해제합니다.
# !kaggle datasets download -d jessicali9530/celeba-dataset
# 캐글에서 다운로드가 안 될 경우 역자의 드라이브에서 다운로드할 수 있습니다.
import gdown
gdown.download(id='15gJhiDBkltMQz3T97xG-fO4gXTKAWkSB')
!unzip -q celeba-dataset.zip
# output 디렉토리를 만듭니다.
!mkdir output
# 데이터 로드
train_data = utils.image_dataset_from_directory(
"./img_align_celeba/img_align_celeba",
labels=None,
color_mode="rgb",
image_size=(IMAGE_SIZE, IMAGE_SIZE),
batch_size=BATCH_SIZE,
shuffle=True,
seed=42,
interpolation="bilinear",
)
데이터를 학습하기 좋게 전처리해줍니다.
# 데이터 전처리
def preprocess(img):
"""
이미지 정규화
"""
img = (tf.cast(img, "float32") - 127.5) / 127.5
return img
train = train_data.map(lambda x: preprocess(x))
Train 세트에 있는 데이터 몇개를 출력해서 데이터의 모습을 확인해봅니다.
# 훈련 세트에 있는 몇 개의 샘플 출력
train_sample = sample_batch(train)
display(train_sample, cmap=None)
이제 WGAN-GP 모델을 구축해봅니다.
먼저 생성자를 구축해봅니다. 생성자의 결과는 비평자에 들어가기에 crictic_input으로 표현해줍니다.
critic_input = layers.Input(shape=(IMAGE_SIZE, IMAGE_SIZE, CHANNELS))
x = layers.Conv2D(64, kernel_size=4, strides=2, padding="same")(critic_input)
x = layers.LeakyReLU(0.2)(x)
x = layers.Conv2D(128, kernel_size=4, strides=2, padding="same")(x)
x = layers.LeakyReLU()(x)
x = layers.Dropout(0.3)(x)
x = layers.Conv2D(256, kernel_size=4, strides=2, padding="same")(x)
x = layers.LeakyReLU(0.2)(x)
x = layers.Dropout(0.3)(x)
x = layers.Conv2D(512, kernel_size=4, strides=2, padding="same")(x)
x = layers.LeakyReLU(0.2)(x)
x = layers.Dropout(0.3)(x)
x = layers.Conv2D(1, kernel_size=4, strides=1, padding="valid")(x)
critic_output = layers.Flatten()(x)
critic = models.Model(critic_input, critic_output)
critic.summary()
이제 생성자를 정의해줍니다.
generator_input = layers.Input(shape=(Z_DIM,))
x = layers.Reshape((1, 1, Z_DIM))(generator_input)
x = layers.Conv2DTranspose(
512, kernel_size=4, strides=1, padding="valid", use_bias=False
)(x)
x = layers.BatchNormalization(momentum=0.9)(x)
x = layers.LeakyReLU(0.2)(x)
x = layers.Conv2DTranspose(
256, kernel_size=4, strides=2, padding="same", use_bias=False
)(x)
x = layers.BatchNormalization(momentum=0.9)(x)
x = layers.LeakyReLU(0.2)(x)
x = layers.Conv2DTranspose(
128, kernel_size=4, strides=2, padding="same", use_bias=False
)(x)
x = layers.BatchNormalization(momentum=0.9)(x)
x = layers.LeakyReLU(0.2)(x)
x = layers.Conv2DTranspose(
64, kernel_size=4, strides=2, padding="same", use_bias=False
)(x)
x = layers.BatchNormalization(momentum=0.9)(x)
x = layers.LeakyReLU(0.2)(x)
generator_output = layers.Conv2DTranspose(
CHANNELS, kernel_size=4, strides=2, padding="same", activation="tanh"
)(x)
generator = models.Model(generator_input, generator_output)
generator.summary()
위에서 정의한 것들을 바탕으로 WGAN-GP 클래스를 정의해줍니다.
class WGANGP(models.Model):
def __init__(self, critic, generator, latent_dim, critic_steps, gp_weight):
super(WGANGP, self).__init__()
self.critic = critic
self.generator = generator
self.latent_dim = latent_dim
self.critic_steps = critic_steps
self.gp_weight = gp_weight
def compile(self, c_optimizer, g_optimizer):
super(WGANGP, self).compile()
self.c_optimizer = c_optimizer
self.g_optimizer = g_optimizer
self.c_wass_loss_metric = metrics.Mean(name="c_wass_loss")
self.c_gp_metric = metrics.Mean(name="c_gp")
self.c_loss_metric = metrics.Mean(name="c_loss")
self.g_loss_metric = metrics.Mean(name="g_loss")
@property
def metrics(self):
return [
self.c_loss_metric,
self.c_wass_loss_metric,
self.c_gp_metric,
self.g_loss_metric,
]
def gradient_penalty(self, batch_size, real_images, fake_images):
alpha = tf.random.normal([batch_size, 1, 1, 1], 0.0, 1.0)
diff = fake_images - real_images
interpolated = real_images + alpha * diff
with tf.GradientTape() as gp_tape:
gp_tape.watch(interpolated)
pred = self.critic(interpolated, training=True)
grads = gp_tape.gradient(pred, [interpolated])[0]
norm = tf.sqrt(tf.reduce_sum(tf.square(grads), axis=[1, 2, 3]))
gp = tf.reduce_mean((norm - 1.0) ** 2)
return gp
def train_step(self, real_images):
batch_size = tf.shape(real_images)[0]
for i in range(self.critic_steps):
random_latent_vectors = tf.random.normal(
shape=(batch_size, self.latent_dim)
)
with tf.GradientTape() as tape:
fake_images = self.generator(
random_latent_vectors, training=True
)
fake_predictions = self.critic(fake_images, training=True)
real_predictions = self.critic(real_images, training=True)
c_wass_loss = tf.reduce_mean(fake_predictions) - tf.reduce_mean(
real_predictions
)
c_gp = self.gradient_penalty(
batch_size, real_images, fake_images
)
c_loss = c_wass_loss + c_gp * self.gp_weight
c_gradient = tape.gradient(c_loss, self.critic.trainable_variables)
self.c_optimizer.apply_gradients(
zip(c_gradient, self.critic.trainable_variables)
)
random_latent_vectors = tf.random.normal(
shape=(batch_size, self.latent_dim)
)
with tf.GradientTape() as tape:
fake_images = self.generator(random_latent_vectors, training=True)
fake_predictions = self.critic(fake_images, training=True)
g_loss = -tf.reduce_mean(fake_predictions)
gen_gradient = tape.gradient(g_loss, self.generator.trainable_variables)
self.g_optimizer.apply_gradients(
zip(gen_gradient, self.generator.trainable_variables)
)
self.c_loss_metric.update_state(c_loss)
self.c_wass_loss_metric.update_state(c_wass_loss)
self.c_gp_metric.update_state(c_gp)
self.g_loss_metric.update_state(g_loss)
return {m.name: m.result() for m in self.metrics}
이제 해당 클래스를 인스턴스화 해줍니다.
# GAN 만들기
wgangp = WGANGP(
critic=critic,
generator=generator,
latent_dim=Z_DIM,
critic_steps=CRITIC_STEPS,
gp_weight=GP_WEIGHT,
)
그리고 학습간 활용할 체크포인트 파일 경로를 지정해줍니다.
if LOAD_MODEL:
wgangp.load_weights("./checkpoint/checkpoint.ckpt")
# GAN 컴파일
wgangp.compile(
c_optimizer=optimizers.Adam(
learning_rate=LEARNING_RATE, beta_1=ADAM_BETA_1, beta_2=ADAM_BETA_2
),
g_optimizer=optimizers.Adam(
learning_rate=LEARNING_RATE, beta_1=ADAM_BETA_1, beta_2=ADAM_BETA_2
),
)
# 모델 저장 체크포인트 만들기
model_checkpoint_callback = callbacks.ModelCheckpoint(
filepath="./checkpoint/checkpoint.ckpt",
save_weights_only=True,
save_freq="epoch",
verbose=0,
)
tensorboard_callback = callbacks.TensorBoard(log_dir="./logs")
class ImageGenerator(callbacks.Callback):
def __init__(self, num_img, latent_dim):
self.num_img = num_img
self.latent_dim = latent_dim
def on_epoch_end(self, epoch, logs=None):
if epoch % 10 != 0: # 출력 횟수를 줄이기 위해
return
random_latent_vectors = tf.random.normal(
shape=(self.num_img, self.latent_dim)
)
generated_images = self.model.generator(random_latent_vectors)
generated_images = generated_images * 127.5 + 127.5
generated_images = generated_images.numpy()
display(
generated_images,
save_to="./output/generated_img_%03d.png" % (epoch),
cmap=None,
)
# 깃허브 노트북 용량 제한 때문에 출력을 지웁니다.
wgangp.fit(
train,
epochs=EPOCHS,
steps_per_epoch=2,
callbacks=[
model_checkpoint_callback,
tensorboard_callback,
ImageGenerator(num_img=10, latent_dim=Z_DIM),
],
)
*첫번째 학습 결과
*매 10번째마다 학습 결과
.................................. 이후
점차 사람의 모습이 보입니다.
이제 이렇게 학습된 생성자를 활용해 랜던함 이미지로부터 이미지를 생성해보겠습니다.
z_sample = np.random.normal(size=(10, Z_DIM))
imgs = wgangp.generator.predict(z_sample)
display(imgs, cmap=None)
댓글