본문 바로가기
딥러닝 with Python

[딥러닝 with Python] SRGAN이란? / Super Resolution GAN (2/2) / 파이썬으로 코딩

by CodeCrafter 2023. 10. 6.
반응형

[이번 포스팅은 "Must Have 텐초의 파이토치 딥러닝 특강" 의 내용을 참조하여 작성하였습니다]

 

이번에는 지난 시간에 알아본 SRGAN의 개념을 파이썬 코딩을 통해서 구현해보도록 하겠습니다. 

 

이번 모델 구현 간에는 CelebA라는 데이터셋을 활용할 건데요 

 

 

1. Celeb A 데이터 셋이란?

-  Celeb A 데이터 셋은 Celebrity Attributes의 약자로, 해당 데이터 셋은 유명인사들의 얼굴 이미지를 수집하고 주석(annotation)된 정보를 포함하고 있습니다. 이는 약 20만개 이상의 얼굴 이미지(10,177명의 개별인물)가 있으며, 다양한 인물, 표정, 포즈, 조명 조건 등을 포함해 다양한 상황에서 촬영된 이미지를 제공하고 있습니다. 또한, 각 얼굴 이미지에 대해 라벨링 되어 있으며 라벨링 된 정보에는 성별, 안경 착용여부, 머리카락 스타일, 표정 등이 포함되어 있습니다. 

 

- 특히, 꽤나 고해상도의 이미지를 가지고 있기에 해상도 관련 연구에 적합한 데이터 셋 중 하나입니다. 

 

해당 데이터를 다운받기 위해서는 아래 링크로 들어가서 downloads 부분에 들어가서 다운로드 받으시면 되겠습니다. 

 

https://mmlab.ie.cuhk.edu.hk/projects/CelebA.html

 

CelebA Dataset

Details CelebFaces Attributes Dataset (CelebA) is a large-scale face attributes dataset with more than 200K celebrity images, each with 40 attribute annotations. The images in this dataset cover large pose variations and background clutter. CelebA has larg

mmlab.ie.cuhk.edu.hk

 

 

2. 파이썬 코드를 활용한 SRGAN 구현하기

이번에는 파이썬 코드를 활용해서 SRGAN을 구현해보겠습니다.

 

먼저, 구성하려는 SRGAN의 구조를 간략히 만들어보면 아래와 같습니다.

 

 

 

- 먼저 데이터를 불러오겠습니다. 

 저는 구글 코랩(Google Colab)을 활용했기에 아래와 같은 경로가 나오게 되었는데요

1
2
!cp "/content/drive/MyDrive/img_align_celeba.zip" "."
!unzip "./img_align_celeba.zip" -d "./GAN/"
cs

* 위 코드에서 !cp 다음 "" 부분에서, 본인이 저장한 colab 경로 또는 로컬 경로를 입력하시면, zip으로 압축되어있던 img_align_celeba 파일의 압축이 해제가 되겠습니다.

 

- 이제 학습에 사용할 데이터셋을 입력용과 정답용으로 나누겠습니다. 이때 저화질 이미지는 학습을 위한 입력으로, 고화질 입력은 정답용으로 나누어 놓겠습니다. 위 그림에서 최초 "저화질 이미지" 부분과 "진짜 고화질 이미지" 부분이 되겠습니다.

1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
import glob
import torchvision.transforms as tf
 
from torch.utils.data.dataset import Dataset
from PIL import Image
 
 
class CelebA(Dataset):
   def __init__(self):
       self.imgs = glob.glob("./GAN/img_align_celeba/*.jpg")
 
       # ❶ 정규화에 이용할 평균과 공분산
       mean_std = (0.5, 0.5, 0.5)
 
       # ❷ 입력용 이미지 생성
       self.low_res_tf = tf.Compose([
           tf.Resize((32, 32)),
           tf.ToTensor(),
           tf.Normalize(mean_std, mean_std)
       ])
 
       # ❸ 정답용 이미지 생성
       self.high_res_tf = tf.Compose([
           tf.Resize((64, 64)),
           tf.ToTensor(),
           tf.Normalize(mean_std, mean_std)
       ])
   def __len__(self):
       return len(self.imgs) # ❶
 
   def __getitem__(self, i):
       img = Image.open(self.imgs[i])
 
       # ❷ 저화질 이미지는 입력으로
       img_low_res = self.low_res_tf(img)
       # ❸ 고화질 입력은 정답으로
       img_high_res = self.high_res_tf(img)
 
       return [img_low_res, img_high_res]
cs

 

 

- 다음은 SRGAN에서 사용할 기본 합성곱 블록(Convolutional Block)을 정의하겠습니다. 

 

1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
import torch.nn as nn
 
 
class ResidualBlock(nn.Module):
   def __init__(self, in_channels, out_channels):
       super(ResidualBlock, self).__init__()
 
       # 생성자의 구성요소 정의
       self.layers = nn.Sequential(
           nn.Conv2d(in_channels, out_channels,
                     kernel_size=3, stride=1, padding=1),
           nn.BatchNorm2d(out_channels),
           nn.PReLU(),
           nn.Conv2d(out_channels, out_channels,
                     kernel_size=3, stride=1, padding=1),
           nn.BatchNorm2d(out_channels)
       )
 
   def forward(self, x):
       x_ = x
       x = self.layers(x)
 
       # 합성곱층을 거친 후 원래의 입력 텐서와 더해줌
       x = x_ + x
 
       return x
cs

* Class의 이름은 ResidualBlock으로 정의하고, Neural Network를 활용할 것이므로 nn.Module을 활용합니다.

 

* 이때 생서자의 구성요소는 nn.Sequential의 형태로 한개의 층씩 순차적으로 정의하며, 먼저 Conv2d(2차원 이미지에 대한 합성곱층)의 커널 사이즈는 3, stride는 1, padding은 1롤 정의하고 다음으로 BatchNormalization2d(2차원 이미지에 대한 Batch Normalization)으로 배치단위로 처리된 데이터들을 정규화하고, 이후 PReLU를 활성화 함수로 정의하여 처리한 이후 다시 Conv2d, 그리고 BatchNormalization 2d를 하겠습니다.

 (Residual Block의 순차구성 : Conv2d -> BatchNormalizaiont2d -> PReLU -> Conv2d -> BatchNormalizaiont2d)

 

* 이후 순전파를 시키며, skip connection을 위해 입력값 x를 x_로 저장한 후, 위에서 정의한 ResidualBlock layer들에 x값을 넣어 출력을 만들어내고, 이 출력과 앞서 x를 복사하여 저장한 x_를 더해주어 순전파를 완료합니다.

 

 

- 다음은 화질 향상을 위해 픽셀수를 늘리는 Upsampling 층을 정의합니다.

 

1
2
3
4
5
6
7
8
9
# 업샘플링층의 정의
class UpSample(nn.Sequential):
   def __init__(self, in_channels, out_channels):
       super(UpSample, self).__init__(
           nn.Conv2d(in_channels, out_channels,
                     kernel_size=3, stride=1, padding=1),
           nn.PixelShuffle(upscale_factor=2),
           nn.PReLU()
       )
cs

* 이 UpSample 층도 순차적으로 층을 쌓아가기 위해 nn.Sequential을 쓰고, Conv2d로 특징을 추출한 뒤 PixelShuffle을 거치고 이후 PReLU층을 거쳐 결과를 업샘플링을 마무리 합니다. 여기서 nn.PixelShuffle에서 upscale_factor가 2라는 말은 입력 이미지의 공간 해상도를 2배로 증가시킨다는 말로, 입력 이미지의 높이와 너비를 2배로 늘린다는 것을 말합니다.

 즉, 해당 업샘플링층을 통과하게되면 픽셀수가 가로 2배, 세로 2배씩 늘어난다는 것을 말합니다.

 

 

- 다음은 생성자(Generator)를 정의하겠습니다.

 

1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
class Generator(nn.Module):
   def __init__(self):
       super(Generator, self).__init__()
 
       # ➊ 첫 번째 합성곱층
       self.conv1 = nn.Sequential(
           nn.Conv2d(364,
                     kernel_size=9, stride=1, padding=4),
           nn.PReLU()
       )
 
       # ➋ 합성곱 블록
       self.res_blocks = nn.Sequential(
           ResidualBlock(in_channels=64, out_channels=64),
           ResidualBlock(in_channels=64, out_channels=64),
           ResidualBlock(in_channels=64, out_channels=64),
       )
 
       self.conv2 = nn.Conv2d(6464,
                           kernel_size=3, stride=1, padding=1)
       self.bn2 = nn.BatchNorm2d(64)
 
       # ➌ 업샘플링층
       self.upsample_blocks = nn.Sequential(
           UpSample(in_channels=64, out_channels=256)
       )
 
       # ➍ 마지막 합성곱층
       self.conv3 = nn.Conv2d(643,
                           kernel_size=9, stride=1, padding=4)
   def forward(self, x):
       # ➊ 첫 번째 합성곱층
       x = self.conv1(x)
       # ➋ 합성곱 블록을 거친 결과와 더하기 위해
       # 값을 저장
       x_ = x
 
       # ➌ 합성곱 블록
       x = self.res_blocks(x)
       x = self.conv2(x)
       x = self.bn2(x)
       # ➍ 합성곱 블록과 첫 번째 합성곱층의 결과를 더함
       x = x + x_
 
       # ➎ 업샘플링 블록
       x = self.upsample_blocks(x)
       # ➏ 마지막 합성곱층
       x = self.conv3(x)
 
       return x
cs

 

* 생성자를 만들때는 위에서 정의한 Residual Block과 Upsampling층을 활용하겠습니다.

* 먼저 첫번째 합성곱층은 nn.Sequential을 활용해 nn.Conv2d -> nn.PReLU의 순으로 입력 데이터가 순차적으로 처리되며, 이후 위에서 정의한 Residual Block, 즉 합성곱 블록을 활용합니다. 이때, 입력 채널크기와 출력 채널 크기를 64로 맞춰주고, 3번의 합성곱 블록을 거치게 정의합니다. 다음으로 업샘플링 층의 입력 채널크기와 출력 채널 크기를 각 64와 256으로 맞춰주고, 마지막으로 합성곱층을 정의해줍니다. 마지막 합성곱 층에서는 입력받은 64개 채널의 데이터를 3개의 채널로 출력하게 정의하며 이를 위한 커널 크기와 stride, 그리고 padding을 정의해줍니다.

 

* 이제 정의된 층들을 pass forward, 즉 순전파 알고리즘에 순서대로 넣어보겠습니다. 

첫번째 합성곱층 -> 해당 결과를 활용해 skip connection을 하기 위해 저장 -> 합성곱 블록 -> conv2d -> bn2 -> 결과를 아까 저장한 첫번째합성곱층의 결과와 더해주기 -> 해당 값을 업샘플링 블록에 넣기 -> 마지막 합성곱층 거치기  의 순서입니다.

 

 

- 이제 구별자 기본 블록을 정의하겠습니다.

1
2
3
4
5
6
7
8
9
10
11
12
13
14
# 한 번 거칠 때마다 이미지 크기가 절반이 되는 합성곱층
class DiscBlock(nn.Module):
   def __init__(self, in_channels, out_channels):
       super(DiscBlock, self).__init__()
 
       self.layers = nn.Sequential(
           nn.Conv2d(in_channels, out_channels,
                     kernel_size=3, stride=2, padding=1),
           nn.BatchNorm2d(out_channels),
           nn.LeakyReLU()
       )
 
   def forward(self, x):
       return self.layers(x)
cs

* nn.Sequential로 순차적으로 정의해주며, Conv2d를 거친 뒤 Batchnormalization2d를 해주고 마지막으로 LeakyReLu 함수를 거쳐 값을 도출하며, 이 과정 그대로 순전파가 이루어지게 정의합니다.

 

 

- 이제 위에서 정의한 구별자 블록을 바탕으로 구별자를 정의합니다. 

1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
class Discriminator(nn.Module):
   def __init__(self):
       super(Discriminator, self).__init__()
 
       self.conv1 = nn.Sequential(
           nn.Conv2d(364,
                     kernel_size=3, stride=1, padding=1),
           nn.LeakyReLU()
       )
 
       self.blocks = DiscBlock(in_channels=64, out_channels=64)
 
       self.fc1 = nn.Linear(655361024)
       self.activation = nn.LeakyReLU()
       self.fc2 = nn.Linear(10241)
       self.sigmoid = nn.Sigmoid()
 
   def forward(self, x):
       # ➊ 컨볼루션 층
       x = self.conv1(x)
       x = self.blocks(x)
 
       # ➋ 1차원으로 펼쳐줌
       x = torch.flatten(x, start_dim=1)
 
       # ➌ 이진분류 단계
       x = self.fc1(x)
       x = self.activation(x)
       x = self.fc2(x)
       x = self.sigmoid(x)
 
       return x
cs

* 여기서 특징적인 것은 fc1, fc2 즉 Fully Connected Layer 전 연결층을 정의하고 중간에 Leaky ReLU를 거쳐 비선형성을 추가한뒤 마지막으로 이진 분류, 즉 1이면 진짜로 0이면 가짜로 구별하는 층을 정의하고, 아래 순전파 함수에 정의된 대로 층을 구성합니다. 

 

 

 

- 이번에는 feature map을 추출하기 위한 특징 추출기를 만들어보겠습니다. 추출기는 Imagenet으로 학습된 VGG-19를 활용하겠습니다. 

1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
import torch
from torchvision.models.vgg import vgg19
 
 
# VGG19 특징 추출기
class FeatureExtractor(nn.Module):
   def __init__(self):
       super(FeatureExtractor, self).__init__()
       # ➊ 사전 학습된 vgg19 모델 정의
       vgg19_model = vgg19(pretrained=True)
 
       # ➋ VGG19의 9개 층만을 이용
       self.feature_extractor = nn.Sequential(
           *list(vgg19_model.features.children())[:9])
 
   def forward(self, img):
       return self.feature_extractor(img)
cs

* 이를 통해, 생성된 고화질 이미지의 특징과 진짜 고화질 이미지의 특징을 비교하는 Perceptual loss를 계산할 수 있습니다.

 

- 이제 학습에 필요한 요소들을 정의합니다. 

1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
import tqdm
 
from torch.utils.data.dataloader import DataLoader
from torch.optim.adam import Adam
 
device = "cuda" if torch.cuda.is_available() else "cpu"
 
# ➊ 데이터로더 정의
dataset = CelebA()
batch_size = 8
loader = DataLoader(dataset, batch_size=batch_size, shuffle=True)
 
# ➋ 생성자와 감별자 정의
= Generator().to(device)
= Discriminator().to(device)
feature_extractor = FeatureExtractor().to(device)
feature_extractor.eval()
 
 
# ➌ 생성자와 감별자의 최적화 정의
G_optim = Adam(G.parameters(), lr=0.0001, betas=(0.50.999))
D_optim = Adam(D.parameters(), lr=0.0001, betas=(0.50.999))
cs

*GPU 사용을 위해 cuda를 정의하고, 데이터를 불러오며 학습당 사용할 배치의 크기를 정의하며, 생성자와 구별자를 정의하여 GPU를 활용한 빠른 학습을 정의합니다. 또한, 학습간 사용할 Optimizer는 Adam을 사용하고 이때 학습률인 Learning Rate와 Beta를 위와 같이 정의합니다.

 

 

- 기나긴 여정의 끝이 보입니다. 이제 학습 루프를 정의하고 학습을 시킵니다

1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
for epoch in range(1):
   iterator = tqdm.tqdm(loader)
 
   for i, (low_res, high_res) in enumerate(iterator):
       # ❶기울기의 초기화
       G_optim.zero_grad()
       D_optim.zero_grad()
 
       # ➋ 진짜 이미지와 가짜 이미지의 정답
       label_true = torch.ones(batch_size, dtype=torch.float32).to(device)
       label_false = torch.zeros(batch_size, dtype=torch.float32).to(device)
 
       # ➌ 생성자 학습
       fake_hr = G(low_res.to(device))
       GAN_loss = nn.MSELoss()(D(fake_hr), label_true)
       # CNN 특징추출기로부터 추출된 특징의 비교
       # ➊ 가짜 이미지의 특징 추출
       fake_features = feature_extractor(fake_hr)
       # ➋ 진짜 이미지의 특징 추출
       real_features = feature_extractor(high_res.to(device))
       # ➌ 둘의 차이 비교
       content_loss = nn.L1Loss()(fake_features, real_features)
       # 생성자의 손실 정의
       loss_G = content_loss + 0.001*GAN_loss
       loss_G.backward()
       G_optim.step()
       # 감별자 학습
       # ➊ 진짜 이미지의 손실
       real_loss = nn.MSELoss()(D(high_res.to(device)), label_true)
       # ➋ 가짜 이미지의 손실
       fake_loss = nn.MSELoss()(D(fake_hr.detach()), label_false)
       # ➌ 두 손실의 평균값을 최종 오차로 설정
       loss_D = (real_loss + fake_loss) / 2
       # ➍ 오차 역전파
       loss_D.backward()
       D_optim.step()
 
       iterator.set_description(
           f"epoch:{epoch} G_loss:{GAN_loss} D_loss:{loss_D}")
 
torch.save(G.state_dict(), "SRGAN_G.pth")
torch.save(D.state_dict(), "SRGAN_D.pth")
cs

 

 

- 마지막으로 모델의 성능을 평가합니다.

 

1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
import matplotlib.pyplot as plt
 
G.load_state_dict(torch.load("SRGAN_G.pth", map_location=device))
 
with torch.no_grad():
   low_res, high_res = dataset[10]
 
   # ➊ 생성자의 입력
   input_tensor = torch.unsqueeze(low_res, dim=0).to(device)
 
   # ➋ 생성자가 생성한 고화질 이미지
   pred = G(input_tensor)
   pred = pred.squeeze()
   pred = pred.permute(120).cpu().numpy()
 
   # ➌ 저화질 이미지의 채널 차원을 가장 마지막으로
   low_res = low_res.permute(120).numpy()
 
   # ➍ 저화질 입력과 생성자가 만든 고화질 이미지의 비교
   plt.subplot(121)
   plt.title("low resolution image")
   plt.imshow(low_res)
   plt.subplot(122)
   plt.imshow(pred)
   plt.title("predicted high resolution image")
   plt.show()
cs

 

dataset의 11번째 저화질 및 고화질 데이터 쌍을 바탕으로 생성자가 생성한 고화질 이미지와 저화질 입력 이미지가 출력이 됩니다.

반응형

댓글