본문 바로가기
딥러닝 with Python

[개념정리] 그래프 신경망(Graph Neural Network / GNN) (1)

by CodeCrafter 2024. 7. 31.
반응형

 

 

1. 그래프 신경망이란? (Graph Neural Network / GNN)

- 그래프 신경망, 즉 GNN은 그래프 구조의 데이터를 처리하고 분석하기 위한 딥러닝 모델을 말합니다.

- 이는, 주로 노드, 엣지 그리고 전체 그래프에 대한 표현을 학습하는데 사용되어, 소셜 네트워크, 분자 구조, 교통 네트워크 등 그래프 형태의 데이터가 자연스럽게 발생하는 여러 분야에서 활용되며, 기타 다른 분야에서도 그 사용범위를 확장하고 있습니다.

 * 이때, 노드(Node)란, 그래프의 개별 객체를 나타내며, 각 노드는 특정 특성을 가질 수 있습니다.

 * 또한, 엣지(Edge) 는 노드 간의 관계나 연결을 나타내며, 엣지 역시 특정 특성을 가질 수 있습니다.

 ** 쉽게 생각해보면, 노드는 점이고 엣지는 점을 연결하는 선을 의미한다고 생각하면 되겠으며, 아래 파이썬 코딩을 통해 그 예시를 확인해보겠습니다.

 

networkx 라는 라이브러리를 활용해 간단하게 그래프를 그려보겠습니다. 4명의 사람들과 그들을 연결하는 모습을 그리려고하며 그 결과는 아래와 같습니다.

import networkx as nx
import matplotlib.pyplot as plt

# 그래프 생성
G = nx.Graph()

# 노드 추가 (사람들을 나타냄)
G.add_node("Alice")
G.add_node("Bob")
G.add_node("Carol")
G.add_node("Dave")

# 엣지 추가 (친구 관계를 나타냄)
G.add_edges_from([("Alice", "Bob"), ("Alice", "Carol"), ("Bob", "Carol"), ("Carol", "Dave")])

# 그래프 그리기
plt.figure(figsize=(8, 6))
nx.draw(G, with_labels=True, node_color='skyblue', font_weight='bold', font_size=12, node_size=2000)
plt.title("Social Network Graph")
plt.show()

** 여기서, 노드는 각 사람들을 의미하며, 엣지는 사람들을 연결하는 선을 의미합니다.

 

 

- 위와 같은 간단한 이해를 바탕으로, 이제 GNN의 개념에 대해 조금 더 구체적으로 알아보겠습니다. 

 

1) 노드와 엣지

 * 위에서 설명드린 대로 노드는 개별 객체를, 엣지는 개별 객체들의 관계나 연결성을 의미합니다.

2) GNN 기본 아이디어

 * 각 노드의 특징 벡터와 이웃 노드들의 정보를 사용하여 각 노드의 새로운 표현을 생성합니다. 이 과정은 반복적으로 수행되어, 각 노드들이 정보를 교환하게 되어 최종적으로는 표현이 더 풍부해지게 됩니다.

3) 메시지 전달

 * 메시지(Message) 전달이란, 각 노드가 이웃 노드로부터 정보를 보내는 과정을 말하며, 이때 이웃 노드들로부터 받은 메시지들을 집계하여 새로운 노드 표현을 만드는 것을 집계(Aggregation)라 합니다.

4) 갱신(Update)

 * 집계된 정보를 사용해 노드의 특징을 업데이트 합니다.

5) 출력 레이어

 * 이와 같은 과정이 반복되어, 최종적으로 각 노드의 표현을 사용해 분류, 회귀, 클러스터링 등 목적을 달성하게 됩니다.

 

- 이를 바탕으로 GNN의 주요 과정( 그래프 구조 생성 -> 메시지 전달 -> 집계 -> 갱신 -> 최종출력 )을 파이썬 코드로 표현해보면  아래와 같습니다.

# 1. 그래프 구조 시각화 (노드와 엣지, 노드의 특성)

def plot_graph_structure_detailed():
    G = nx.DiGraph()
    G.add_nodes_from([
        ("A", {"feature": "Feature_A"}),
        ("B", {"feature": "Feature_B"}),
        ("C", {"feature": "Feature_C"}),
        ("D", {"feature": "Feature_D"}),
        ("E", {"feature": "Feature_E"}),
        ("F", {"feature": "Feature_F"}),
        ("G", {"feature": "Feature_G"}),
        ("H", {"feature": "Feature_H"})
    ])
    G.add_edges_from([
        ("A", "B"), ("A", "C"), ("B", "D"), ("C", "D"), ("C", "E"),
        ("D", "F"), ("E", "F"), ("E", "G"), ("F", "H"), ("G", "H")
    ])
   
    plt.figure(figsize=(10, 8))
    pos = nx.spring_layout(G, seed=42)
    nx.draw(G, pos, with_labels=True, node_color='skyblue', node_size=2000, font_size=12, arrows=True)
    labels = {node: f"{node}\n{data['feature']}" for node, data in G.nodes(data=True)}
    nx.draw_networkx_labels(G, pos, labels, font_size=9)
    plt.title("1. Graph Structure: Nodes with Features and Directed Edges")
    plt.show()

# 2. 메시지 전달 시각화 (특정 노드에서 이웃 노드로 메시지 전송)

def plot_message_passing_detailed():
    G = nx.DiGraph()
    G.add_edges_from([
        ("A", "B"), ("A", "C"), ("B", "D"), ("C", "D"), ("C", "E"),
        ("D", "F"), ("E", "F"), ("E", "G"), ("F", "H"), ("G", "H")
    ])
    pos = nx.spring_layout(G, seed=42)
   
    plt.figure(figsize=(10, 8))
    nx.draw(G, pos, with_labels=True, node_color='skyblue', node_size=2000, font_size=12, arrows=True)
   
    # "C"가 이웃 노드들로 메시지 전달
    nx.draw_networkx_edges(G, pos, edgelist=[("C", "D"), ("C", "E")],
                           width=2.5, edge_color='orange', arrows=True)
    plt.title("2. Message Passing: Node C Sending Messages to Neighbors")
    plt.show()

# 3. 집계 시각화 (노드로부터 받은 정보를 집계)

def plot_aggregation_detailed():
    G = nx.DiGraph()
    G.add_edges_from([
        ("A", "B"), ("A", "C"), ("B", "D"), ("C", "D"), ("C", "E"),
        ("D", "F"), ("E", "F"), ("E", "G"), ("F", "H"), ("G", "H")
    ])
    pos = nx.spring_layout(G, seed=42)
   
    plt.figure(figsize=(10, 8))
    nx.draw(G, pos, with_labels=True, node_color='skyblue', node_size=2000, font_size=12, arrows=True)
   
    # "F"가 이웃 노드들로부터 정보 집계
    nx.draw_networkx_edges(G, pos, edgelist=[("D", "F"), ("E", "F")],
                           width=2.5, edge_color='green', arrows=True)
    plt.title("3. Aggregation: Information Aggregation to Node F")
    plt.show()

# 4. 갱신 시각화 (노드의 상태 업데이트)

def plot_update_detailed():
    G = nx.DiGraph()
    G.add_edges_from([
        ("A", "B"), ("A", "C"), ("B", "D"), ("C", "D"), ("C", "E"),
        ("D", "F"), ("E", "F"), ("E", "G"), ("F", "H"), ("G", "H")
    ])
    pos = nx.spring_layout(G, seed=42)
   
    plt.figure(figsize=(10, 8))
    nx.draw(G, pos, with_labels=True, node_color='skyblue', node_size=2000, font_size=12, arrows=True)
   
    # "F" 노드의 상태 갱신
    nx.draw_networkx_nodes(G, pos, nodelist=["F"], node_color='green', node_size=3000)
    plt.title("4. Update: Node F's State Update")
    plt.show()

# 5. 출력 레이어 시각화 (최종 상태 출력)

def plot_output_layer_detailed():
    G = nx.DiGraph()
    G.add_edges_from([
        ("A", "B"), ("A", "C"), ("B", "D"), ("C", "D"), ("C", "E"),
        ("D", "F"), ("E", "F"), ("E", "G"), ("F", "H"), ("G", "H")
    ])
    pos = nx.spring_layout(G, seed=42)
   
    plt.figure(figsize=(10, 8))
    nx.draw(G, pos, with_labels=True, node_color='skyblue', node_size=2000, font_size=12, arrows=True)
   
    # "H" 노드의 최종 출력
    nx.draw_networkx_nodes(G, pos, nodelist=["H"], node_color='red', node_size=3000)
    plt.title("5. Output Layer: Final Representation of Node H")
    plt.show()

# 시각화 함수 호출
plot_graph_structure_detailed()
plot_message_passing_detailed()
plot_aggregation_detailed()
plot_update_detailed()
plot_output_layer_detailed()

 

1. 그래프 구조 : 각 노드가 특정 특성(Feautre A,B,C ...)을 가지고 있으며, 엣지는 정보전달의 방향성 또한 보여줍니다.

 

2. 메시지 전달 : C 노드가 이웃 노드들에게 메시지를 전달하게됩니다.

 

3. Aggregation(집계) : F 노드는 이웃 노드들로부터 받은 메시지들을 Aggregation 하여 정보 업데이트를 준비합니다.

4. Update(갱신) : 집계된 정보를 바탕으로 노드의 상태가 업데이트 되게 됩니다. 아래 그림에서는 기존에 파란색이었는데, D와 E로부터 온 초록색의 정보로 인해 노드의 상태가 업데이트 됨을 색깔변화로 표현합니다.

5. 2~4의 과정을 반복하여 최종적으로 얻고자 하는 출력 레이어(H)는 처음엔 파란색이었지만 수많은 정보들의 Aggregation과 Update 과정을 거쳐 아래와 같은 표현이 얻어지게됩니다.

 

 

2. GNN 모델의 활용사례 

1) 약물 발견 및 분자 특성 예측

 * 분자를 그래프 구조로 표현하여 각 원자는 노드, 화학 결합은 엣지로 나타내어 신약을 발견하고 특정 분자의 특성을 예측할 수 있습니다.

2) 소셜 네트워크 분석

 * 위의 예시에서 보았던 것처럼 각 사람들의 관계를 통해 어떤 관계가 있는지를 판단할 수 있습니다.

 

3) 지식 그래프와 자연어 처리

 * 지식 그래프는 개념과 개념 간의 관계를 표현하는 그래프입니다. 실제 사용의 예로는 Google Knowledge Graph가 있으며, 검색 쿼리와 관련된 정보 추출에 GNN을 활용합니다. 또한, 문장 내 단어간 관계를 분석하거나, 질의 응답 시스템에서 정확한 답을 찾기위해 사용됩니다.

 

 

4) 추천 시스템

* 사용자와 아이템 간의 관계를 학습하고 표현할 수 있기에 이를 활용해서 OTT 서비스, 전자 상거래 등에서 GNN을 활용해 고객에게 상품을 추천해주는 시스템을 개발하고 있습니다.

 

이외에도 자율주행, Fraud Detection(사기 탐지) 등 관계가 형성되는 다양한 분야에서 활발히 활용되고 있습니다.

 

 

다음 시간에는 GNN을 모델적으로 접근해보겠습니다.

반응형

댓글