본문 바로가기
딥러닝 with Python

[딥러닝 with Python] 이미지 리트리버(Image Retrieval) / CLIP 으로 구현

by CodeCrafter 2024. 7. 30.
반응형

 

1. 이미지 리트리버(Image Retrieval)

- 이미지 리트리버는 주어진 입력(쿼리 / Query)와 유사한 이미지를 대규모 이미지 데이터베이스에서 검색해내는 시스템을 말합니다.

 

- 이때 입력(쿼리)는 텍스트, 이미지 등 다양하게 활용할 수 있습니다.

 

- 이미지 리트리버의 주요 구성요소는 아래와 같습니다.

 1) 쿼리 입력(Query Input)

  * 텍스트 쿼리 : 사용자가 텍스트로 이미지의 설명을 제공하면, 시스템은 이 텍스트를 기반으로 검색을 합니다. 

  ex. "강아지가 뛰어노는 사진" 이라는 텍스트 쿼리를 입력하면 아래와 같이 쿼리에 해당하는 이미지가 반환되게 됩니다.

 

 * 이미지 쿼리 : 사용자가 이미지를 제공하면, 시스템은 이 이미지와 유사한 이미지를 검색하게 됩니다. 

   (이를, 콘텐츠 기반 이미지 검색(Content-Based Image Retrieval / CBIR)이라고 부릅니다.)

 

 

2) 특징 추출(Feature Extraction)

 * 이미지나 텍스트에서 중요한 정보를 추출하여, 이를 벡터 형태의 고차원 공간에 매핑하는 과정을 말합니다.

 * 이미지 특징은 색상 히스토그램, 텍스처, 모양, 또는 고차원 임베딩을 의미하고, 텍스트 특징의 경우 텍스트 임베딩을 위한 모델(Word2Vec, BERT, CLIP)을 사용하여 텍스트의 의미적 표현을 벡터로 변환하는 것을 말합니다.

 

3) 유사도 계산(Similarity Calculation)

 * 쿼리와 데이터베이스 내의 이미지 간의 유사도를 계산합니다. 일반적으로, 코사인 유사도 또는 유클리디안 거리 등을 활용합니다. 이를 기준으로 결과를 정렬해 유사도가 높은 순서대로 출력합니다.

 

4) 결과 반환(Result Retrieval)

 * 유사도가 높은 상위 N개의 이미지를 변환합니다. 이는 사용자의 쿼리에 가장 적합한 이미지들로 구성됩니다.

 

- 이러한  이미지 리트리버는 전자상거래 / 소셜 미디어 / 디지털 자산 관리 / 의료 이미지 분석 등에 활발하게 활용되고 있습니다.

 

 

2. CLIP 을 활용해 이미지 리트리버 구현

- 이번에는 CLIP이라는 방법론을 활용해 이미지 리트리버를 구현해보겠습니다.

- CLIP에 대한 설명

 * CLIP은 Contrastive Language-Image Pre-training의 줄임말로, OpenAI에서 개발한 멀티모달 AI 모델로, 텍스트와 이미지를 동시에 학습하여 두 가지 입력 간의 연관성을 학습하는 모델을 말합니다. 

 * 텍스트 인코더와 이미지 인코더가 존재하며, 대조 학습(Contrastive Learning)을 활용해 텍스트 - 이미지 쌍을 학습하고 올바른 이미지 쌍을 구분하도록 해줍니다.

 

- 이러한 CLIP을 바탕으로 이미지 리트리버를 구현해보겠습니다. 

- 학습에 활용할 데이터는 COCO 데이터셋의 소규모 샘플입니다. 

 

먼저 학습에 필요한 라이브러리와 소규모 데이터 샘플을 다운로드 받아줍니다.

# 필요한 라이브러리 설치
!pip install torch torchvision ftfy regex tqdm
!pip install git+https://github.com/openai/CLIP.git
!pip install opencv-python-headless
!pip install pycocotools

# COCO 데이터셋의 소규모 샘플 다운로드
!mkdir coco
!wget http://images.cocodataset.org/zips/val2017.zip -P coco/
!unzip coco/val2017.zip -d coco/
!wget http://images.cocodataset.org/annotations/annotations_trainval2017.zip -P coco/
!unzip coco/annotations_trainval2017.zip -d coco/

 

 

이후 사전 학습된 clip을 받아준 뒤, 사용할 데이터를 활용해 Inference만 진행해주겠습니다.

CLIP은 ViT-B/32를 Backbone으로 하겠습니다.

이를 활용해 Inference에 활용될 10개의 이미지와 설명(annotation)에 대해 진행한 결과는 아래와 같습니다.

import torch
import clip
from PIL import Image
import os
import json
import random
import matplotlib.pyplot as plt

# CLIP 모델과 프로세서를 불러옵니다.
device = "cuda" if torch.cuda.is_available() else "cpu"
model, preprocess = clip.load("ViT-B/32", device=device)

# COCO 이미지 폴더와 주석 파일 경로를 지정합니다.
image_folder = "coco/val2017"
annotation_file = "coco/annotations/captions_val2017.json"

# COCO 주석 파일을 로드합니다.
with open(annotation_file, 'r') as f:
    annotations = json.load(f)

# 샘플로 사용할 이미지와 텍스트 캡션을 선택합니다.
num_samples = 10
selected_annotations = random.sample(annotations['annotations'], num_samples)

# 이미지와 캡션을 저장할 리스트를 초기화합니다.
images = []
captions = []
image_filenames = []

# 선택된 샘플을 로드하고 전처리합니다.
for ann in selected_annotations:
    image_id = ann['image_id']
    caption = ann['caption']
    image_path = os.path.join(image_folder, f"{image_id:012d}.jpg")

    if os.path.exists(image_path):
        image_filenames.append(image_path)
        image = preprocess(Image.open(image_path)).unsqueeze(0).to(device)
        images.append(image)
        captions.append(caption)

# 텍스트를 토큰화합니다.
text_tokens = clip.tokenize(captions).to(device)

# 이미지와 텍스트의 특징을 추출합니다.
with torch.no_grad():
    image_features = torch.cat([model.encode_image(image) for image in images])
    text_features = model.encode_text(text_tokens)

# 특징 벡터를 정규화합니다.
image_features /= image_features.norm(dim=-1, keepdim=True)
text_features /= text_features.norm(dim=-1, keepdim=True)

# 이미지와 텍스트 간의 유사도를 계산합니다.
similarity = (100.0 * image_features @ text_features.T).softmax(dim=-1)

# 유사도가 가장 높은 이미지를 찾습니다.
for i, caption in enumerate(captions):
    print(f"Top matches for caption: '{caption}'")
    _, indices = similarity[i].topk(3)  # 상위 3개의 이미지를 찾습니다.
    for idx in indices:
        print(f"Image: {image_filenames[idx]}")

    # 유사도 상위 이미지를 시각화합니다.
    top_indices = indices.cpu().numpy()
    plt.figure(figsize=(10, 10))
    for j, index in enumerate(top_indices):
        image = Image.open(image_filenames[index])
        plt.subplot(1, 3, j + 1)
        plt.imshow(image)
        plt.title(f"Rank {j + 1}")
        plt.axis('off')
    plt.show()

 

 

먼저 "High stone tower with windows in an old village" 라는 Query에 대한 이미지를 가져온 결과 입니다.

Rank는 쿼리인 텍스트와 결과인 이미지 간의 유사도 순위를 의미하며, 가장 높은 순위를 보여준 이미지가 실제 텍스트와 유사한 결과를 가지고 있습니다.

 

 

- 다음은 A brown and whit vase with foliage on a small table 이라는 쿼리에 대한 결과입니다.

 꽤나 정확한 예측을 한 것을 알 수 있습니다.

 

 

* 다음은 A display of shoes and umbrellas are in a window 에 대한 결과입니다.

이번에는 결과가 좋지 못한 것을 알 수 있습니다.

 

 

- 사전 학습된 모델을 바탕으로 Inference만을 했기에 부정확한 결과들도 일부 섞여있음을 확인할 수 있습니다.

- 보다 더 많은 이미지 및 텍스트 Annotation이 있다면 더 좋은 결과가 나오지 않을까 합니다.

반응형

댓글