본문 바로가기
딥러닝 with Python

[딥러닝 with Python] GraphSAGE를 활용한 논문 분류(Node Classification)

by CodeCrafter 2024. 12. 9.
반응형

 

이번에는 지난번에 알아본 GNN 중 GraphSAGE 방법을 활용해서 노드 분류(Node Classification)을 진행해보겠습니다.

 

실습에 활용할 데이터는 Cora 입니다.

 

1. Cora Dataset 설명

 

Cora 데이터셋은 그래프 데이터 분석에서 널리 사용되는 표준 데이터 중 하나입니다. 특히, 논문 간의 인용 관계를 나타내는 정보와 함께 그래프 신경망(GNN)을 학습하고 평가하는데 자주 사용됩니다.

 

Cora 데이터셋의 구성은 아래와 같습니다.

 

1) 노드 : Cora 데이터셋에서의 각 노드는 개별 논문을 의미합니다.

 

2) 엣지 : 노드 간의 엣지는 논문 간의 인용 관계를 나타냅니다. 예를 들어, 논문 A가 논문 B를 인용했다면 A와 B 사이에 엣지가 존재합니다.

 

3) 노드 특징(Node Features) : 각 노드는 Bag-of-Words 표현으로 논문에 사용된 단어 정보를 표현하고 있습니다.

 총 1,433개의 고유 단어가 특징으로 사용되며, 각 노드의 특징 벡터는 해당 논문에서 사용된 단어의 등장 여부를 나타냅니다.

[머신러닝 with Python] Bag of Words란? (BoW)

 

[머신러닝 with Python] Bag of Words란? (BoW)

Bag of Words는 텍스트 데이터를 벡터 형태로 변호나하여 머신러닝과 자연어 처리 모델에 사용할 수 있도록 하는 기본적인 텍스트 표현 기법입니다.  간단하면서도 다양한 텍스트 처리 작업에 유

jaylala.tistory.com

 

4) 클래스 : 총 7개의 클래스가 존재하며, 각 논문은 특정 주제(예 : 머신러닝, 컴퓨터비전 등)에 속합니다. 

 

총 노드의 수는 2,708개이고, 엣지수는 5,429개이며 클래스는 7개입니다. 

 

 

2. Graph SAGE를 활용한 Node Classification

먼저, 실습에 활용할 torch_geometric이라는 라이브러리를 다운로드 받아줍니다.

!pip install torch_geometric

 

 

그리고 다음으로 데이터 셋을 다운로드 받아줍니다.

import torch
import torch.nn as nn
import torch.optim as optim
from torch_geometric.nn import SAGEConv
from torch_geometric.datasets import Planetoid
from torch_geometric.loader import DataLoader

# 데이터셋 로드 (Cora 데이터셋 사용)
# Cora 데이터셋은 논문 간의 인용 관계를 그래프로 표현한 데이터셋입니다.
# 노드는 논문을 나타내며, 각 노드의 특징은 논문에서 사용된 단어의 Bag-of-Words 표현입니다.
# 엣지는 논문 간의 인용 관계를 나타냅니다.
# 총 7개의 클래스로 논문 주제를 분류하는 작업을 포함합니다.
dataset = Planetoid(root="./data", name="Cora")
data = dataset[0]


 

이후 GraphSAGE 모델을 활용해서 노드 prediction을 진행해줍니다.

 

[개념정리] Graph SAGE란? Graph SAmple & aggreGatE)

 

[개념정리] Graph SAGE란? Graph SAmple & aggreGatE)

GraphSAGE(Graph Sample and aggreGatE)는 "Inductive Representation Learning on Large Graphs"(NIPS 17)라는 논문에 소개된 모델로, GNN의 한 종류이며, 대규모 그래프 데이터에서 효율적으로 노드의 임베딩을 학습하기 위

jaylala.tistory.com

 

이를 바탕으로 학습을 진행해줍니다. 정해진 train test 데이터 세팅을 그대로 사용해주며, 모델은 간단한 2층의 GraphSAGE입니다. 

 

# Graph SAGE 모델 정의
class GraphSAGE(nn.Module):
    def __init__(self, in_channels, hidden_channels, out_channels):
        super(GraphSAGE, self).__init__()
        self.conv1 = SAGEConv(in_channels, hidden_channels)
        self.conv2 = SAGEConv(hidden_channels, out_channels)

    def forward(self, x, edge_index):
        x = self.conv1(x, edge_index)
        x = torch.relu(x)
        x = self.conv2(x, edge_index)
        return x

# 모델 초기화
in_channels = dataset.num_node_features
hidden_channels = 16
out_channels = dataset.num_classes

model = GraphSAGE(in_channels, hidden_channels, out_channels)

# 손실 함수와 옵티마이저 정의
criterion = nn.CrossEntropyLoss()
optimizer = optim.Adam(model.parameters(), lr=0.01)

# 학습 및 평가 함수 정의
def train():
    model.train()
    optimizer.zero_grad()
    out = model(data.x, data.edge_index)
    loss = criterion(out[data.train_mask], data.y[data.train_mask])
    loss.backward()
    optimizer.step()
    return loss.item()

def test():
    model.eval()
    with torch.no_grad():
        out = model(data.x, data.edge_index)
        pred = out.argmax(dim=1)

        train_correct = (pred[data.train_mask] == data.y[data.train_mask]).sum()
        train_acc = int(train_correct) / int(data.train_mask.sum())

        val_correct = (pred[data.val_mask] == data.y[data.val_mask]).sum()
        val_acc = int(val_correct) / int(data.val_mask.sum())

        test_correct = (pred[data.test_mask] == data.y[data.test_mask]).sum()
        test_acc = int(test_correct) / int(data.test_mask.sum())

    return train_acc, val_acc, test_acc

# 학습 루프
num_epochs = 100
for epoch in range(1, num_epochs + 1):
    loss = train()
    train_acc, val_acc, test_acc = test()
    print(f"Epoch: {epoch:03d}, Loss: {loss:.4f}, Train Acc: {train_acc:.4f}, Val Acc: {val_acc:.4f}, Test Acc: {test_acc:.4f}")

 

100회의 학습결과는 아래와같이 76.8%의 노드가 잘 분류된 것을 알 수있습니다.

 

 

학습된 모델을 바탕으로 테스트 데이터에 적용된 결과를 시각화 하면 아래와 같습니다.

import matplotlib.pyplot as plt
import networkx as nx
from torch_geometric.utils import to_networkx

# 모델 학습 완료 후, 결과 시각화 준비
model.eval()
with torch.no_grad():
    out = model(data.x, data.edge_index)
    pred = out.argmax(dim=1)

# 그래프 데이터 PyTorch Geometric -> NetworkX 변환
graph = to_networkx(data, to_undirected=True)

# 시각화를 위한 노드 색상 지정 (모델의 예측값 기반)
node_colors = [pred[node].item() for node in range(data.num_nodes)]

# 그래프 시각화
plt.figure(figsize=(12, 8))
pos = nx.spring_layout(graph, seed=42)  # 노드 배치 (Spring Layout 사용)
nx.draw(graph, pos, node_color=node_colors, cmap="Set3", node_size=50, with_labels=False)
plt.title("Graph SAGE Classification Results", fontsize=16)
plt.show()

 

 

반응형

댓글