본문 바로가기
머신러닝 with Python

[딥러닝 with 파이썬] GAN (Generative Adversarial Networks) / 생성적 적대 신경망 / MNIST 데이터로 구현

by CodeCrafter 2023. 9. 25.
반응형

이번에는  GAN, 생성적 적대 신경망에 대해서 알아보겠습니다.

 

 

1. GAN이란?

- GAN은 Generative Adversarial Network의 약자로, 생성적 적대 신경망으로 불립니다.

- 이는 딥러닝을 기반으로 한 모델로서, 이름에서 알 수 있듯이 생성, 즉 기존에 없던 것을 만들어내는 모델입니다.

 

- GAN의 핵심 아이디어

 * GAN의 핵심 아이디어는 생성자(Generator)와 구분자(Discriminator)라는 모델을 만들어 서로 경쟁시키는 것입니다. 

 * 생성기는 더 실제와 유사한 데이터를 생성하려고 노력하고, 구분자는 생성기가 생성한 데이터와 실제 데ㅣ터를 구분하려고 노력하는 것입니다. 이러한 경쟁을 통해 생성기는 점차 더 정교한 데이터를 생성하게 되며, 결과적으로 생성된 데이터는 실제 데이터와 거의 구별하기 어려울 정도로 좋아지게 됩니다.

 * 이렇게 학습된 생성자는 기존 데이터에는 존재하지 않지만, 실제로 여겨질만큼 유사하지만 동일하지 않은 데이터를 생성하게 되는데요

 

 

- 이를 위조지폐범과 경찰 사이의 상호작용으로 비유해보면 아래와 같습니다.

a) 위조지폐 생성자 (Generator):

* 위조지폐범은 "생성자" 역할을 합니다. 
* 그의 목표는 가능한 실제 지폐와 유사한 위조 지폐를 생성하는 것입니다.
* 생성자는 초기에는 매우 랜덤하고 실제와 다르게 생긴 위조 지폐를 만들지만, 학습과정에서 그 결과가 향상됩니다.
* 이 생성자는 무작위한 시도를 통해 위조 지폐를 만들고, 그것을 경찰에게 적발되지 않도록 잘 만들어야 합니다.

b) 경찰 / 구분자 (Discriminator):

* 경찰은 "구분자" 역할을 하며, 그의 목표는 실제 지폐와 위조 지폐를 구별합니다.
* 경찰은 실제 지폐와 위조 지폐를 비교하여 어느 것이 진짜인지 판단합니다.
* 초기에는 경찰도 경험이 부족하며 위조 지폐를 실제 지폐와 구별하기 어렵습니다.

c) 경쟁과 학습:

* 이제 위조지폐범과 경찰은 서로의 능력을 향상시키기 위한 경쟁에 들어갑니다.
* 위조지폐범은 더 실제와 비슷한 위조 지폐를 생성하려고 시도하며, 경찰은 더 정확하게 실제와 위조를 구별하려고 노력합니다.
* 시간이 지남에 따라, 생성자는 점점 더 실제 지폐와 유사한 위조 지폐를 만들게 되고, 경찰은 점점 정확하게 구별할 수 있게 됩니다.

e)결과:


* 이러한 과정은 계속 반복되며, 생성기와 경찰은 계속해서 능력을 향상시킵니다.
* 최종 결과로는 위조지폐범은 매우 현실적인 위조 지폐를 만들게 되며, 경찰은 더욱 정확하게 위조와 실제를 판별합니다.

 

 

- 위와 같은 과정을 거쳐 능력이 향상된 생성자는, 새로운 지폐(이미지 등)을 만들어 낼 수 있게됩니다.

 

 

 

2. 파이썬 코딩을 통해 알아보는 GAN - MNIST 데이터를 활용

- 이번에는 파이썬 코딩을 통해 GAN에 대해서 알아보겠습니다.

 

- 학습에 사용할 데이터는 MNIST 인데요

 * MINST 는 Modified National Institute of Stadards and Technology의 약어로, 기계 학습 및 컴퓨터 비전 연구에서 널리 사용되는 손으로 쓴 숫자 이미지 데이터 셋입니다.

 * 해당 데이터셋은 0~9까지 손으로 쓰여진 다양한 스타일의 숫자와 해당 이미지에 대한 레이블(해당 이미지가 어떤 숫자인지)이 기록되어 있습니다. 

 * MNIST 데이터의 특징 

  · 이미지 크기 : 28x28 픽셀 / 각 픽셀은 0~255까지의 값을 가짐

  · 클래스 : 0~9 (0이면 숫자 0을, 1이면 숫자 1을, ......... , 9이면 숫자 9를 의미)

  · 데이터셋의 크기 : 총 70,000개 / train set : 60,000, test set : 10,000

 

 

 

- 이제 파이썬 코딩을 통해 MNIST 데이터 셋을 학습하고 이를 모방하는 GAN을 만들어보겠습니다.

 

1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
83
84
85
86
87
88
89
90
91
from tensorflow.keras.datasets import mnist
from tensorflow.keras.layers import Input, Dense, Reshape, Flatten, Dropout
from tensorflow.keras.layers import BatchNormalization, Activation, LeakyReLU, UpSampling2D, Conv2D
from tensorflow.keras.models import Sequential, Model
 
import numpy as np
import matplotlib.pyplot as plt
 
# 생성자 모델을 만듭니다.
generator = Sequential()
generator.add(Dense(128*7*7, input_dim=100, activation=LeakyReLU(0.2)))
generator.add(BatchNormalization())
generator.add(Reshape((77128)))
generator.add(UpSampling2D())
generator.add(Conv2D(64, kernel_size=5, padding='same'))
generator.add(BatchNormalization())
generator.add(Activation(LeakyReLU(0.2)))
generator.add(UpSampling2D())
generator.add(Conv2D(1, kernel_size=5, padding='same', activation='tanh'))
 
# 판별자 모델을 만듭니다.
discriminator = Sequential()
discriminator.add(Conv2D(64, kernel_size=5, strides=2, input_shape=(28,28,1), padding="same"))
discriminator.add(Activation(LeakyReLU(0.2)))
discriminator.add(Dropout(0.3))
discriminator.add(Conv2D(128, kernel_size=5, strides=2, padding="same"))
discriminator.add(Activation(LeakyReLU(0.2)))
discriminator.add(Dropout(0.3))
discriminator.add(Flatten())
discriminator.add(Dense(1, activation='sigmoid'))
discriminator.compile(loss='binary_crossentropy', optimizer='adam')
discriminator.trainable = False
 
# 생성자와 판별자 모델을 연결시키는 gan 모델을 만듭니다.
ginput = Input(shape=(100,))
dis_output = discriminator(generator(ginput))
gan = Model(ginput, dis_output)
gan.compile(loss='binary_crossentropy', optimizer='adam')
gan.summary()
 
# 신경망을 실행시키는 함수를 만듭니다.
def gan_train(epoch, batch_size, saving_interval):
 
  # MNIST 데이터를 불러옵니다.
 
  (X_train, _), (_, _) = mnist.load_data()  # 앞서 불러온 적 있는 MNIST를 다시 이용합니다. 단, 테스트 과정은 필요 없고 이미지만 사용할 것이기 때문에 X_train만 불러왔습니다.
  X_train = X_train.reshape(X_train.shape[0], 28281).astype('float32')
  X_train = (X_train - 127.5/ 127.5  # 픽셀 값은 0에서 255 사이의 값입니다. 이전에 255로 나누어 줄때는 이를 0~1 사이의 값으로 바꾸었던 것인데, 여기서는 127.5를 빼준 뒤 127.5로 나누어 줌으로 인해 -1에서 1사이의 값으로 바뀌게 됩니다.
  # X_train.shape, Y_train.shape, X_test.shape, Y_test.shape
 
  true = np.ones((batch_size, 1))
  fake = np.zeros((batch_size, 1))
 
  for i in range(epoch):
          # 실제 데이터를 판별자에 입력하는 부분입니다.
          idx = np.random.randint(0, X_train.shape[0], batch_size)
          imgs = X_train[idx]
          d_loss_real = discriminator.train_on_batch(imgs, true)
 
          # 가상 이미지를 판별자에 입력하는 부분입니다.
          noise = np.random.normal(01, (batch_size, 100))
          gen_imgs = generator.predict(noise)
          d_loss_fake = discriminator.train_on_batch(gen_imgs, fake)
 
          # 판별자와 생성자의 오차를 계산합니다.
          d_loss = 0.5 * np.add(d_loss_real, d_loss_fake)
          g_loss = gan.train_on_batch(noise, true)
 
          print('epoch:%d' % i, ' d_loss:%.4f' % d_loss, ' g_loss:%.4f' % g_loss)
 
        # 이 부분은 중간 과정을 이미지로 저장해 주는 부분입니다. 이 장의 주요 내용과 관련이 없어
        # 소스 코드만 첨부합니다. 만들어진 이미지들은 gan_images 폴더에 저장됩니다.
          if i % saving_interval == 0:
              #r, c = 5, 5
              noise = np.random.normal(01, (25100))
              gen_imgs = generator.predict(noise)
 
              # Rescale images 0 - 1
              gen_imgs = 0.5 * gen_imgs + 0.5
 
              fig, axs = plt.subplots(55)
              count = 0
              for j in range(5):
                  for k in range(5):
                      axs[j, k].imshow(gen_imgs[count, :, :, 0], cmap='gray')
                      axs[j, k].axis('off')
                      count += 1
              fig.savefig("./gan_mnist_%d.png" % i)
 
gan_train(200132200)  # 2000번 반복되고, 배치 사이즈는 32,  200번마다 결과가 저장되게 하였습니다.
 
cs

 

출력된 결과는 아래와 같습니다.

 

 

*최초 학습 결과 : 아직은 노이즈 상태입니다.

 

* 600회 학습 결과 : 점차 글씨 같은 모습이 보입니다.

* 1200회 학습 결과 : 기존보다 흐릿한 부분이 사라졌습니다.

* 2,000회 학습 결과 : 이제 글씨 같은 모습을 가지는 것들이 종종보입니다. 

 

더 많은 학습을 시키면 진짜 글씨와 같은 결과가 나올 것 같습니다.

 

 

이때, 생성된 글씨는 기존에는 없지만 새로운 글씨를 의미합니다. 즉, 기존 데이터에서 글씨를 쓴 사람이 아닌 새로운 사람이 쓴 것과 같은 새로운 스타일의 숫자 글씨가 나오게 됨을 의미합니다.

반응형

댓글