이번에는 지난번에 알아본 GNN 중 GraphSAGE 방법을 활용해서 노드 분류(Node Classification)을 진행해보겠습니다.
실습에 활용할 데이터는 Cora 입니다.
1. Cora Dataset 설명
Cora 데이터셋은 그래프 데이터 분석에서 널리 사용되는 표준 데이터 중 하나입니다. 특히, 논문 간의 인용 관계를 나타내는 정보와 함께 그래프 신경망(GNN)을 학습하고 평가하는데 자주 사용됩니다.
Cora 데이터셋의 구성은 아래와 같습니다.
1) 노드 : Cora 데이터셋에서의 각 노드는 개별 논문을 의미합니다.
2) 엣지 : 노드 간의 엣지는 논문 간의 인용 관계를 나타냅니다. 예를 들어, 논문 A가 논문 B를 인용했다면 A와 B 사이에 엣지가 존재합니다.
3) 노드 특징(Node Features) : 각 노드는 Bag-of-Words 표현으로 논문에 사용된 단어 정보를 표현하고 있습니다.
총 1,433개의 고유 단어가 특징으로 사용되며, 각 노드의 특징 벡터는 해당 논문에서 사용된 단어의 등장 여부를 나타냅니다.
[머신러닝 with Python] Bag of Words란? (BoW)
4) 클래스 : 총 7개의 클래스가 존재하며, 각 논문은 특정 주제(예 : 머신러닝, 컴퓨터비전 등)에 속합니다.
총 노드의 수는 2,708개이고, 엣지수는 5,429개이며 클래스는 7개입니다.
2. Graph SAGE를 활용한 Node Classification
먼저, 실습에 활용할 torch_geometric이라는 라이브러리를 다운로드 받아줍니다.
!pip install torch_geometric
그리고 다음으로 데이터 셋을 다운로드 받아줍니다.
import torch
import torch.nn as nn
import torch.optim as optim
from torch_geometric.nn import SAGEConv
from torch_geometric.datasets import Planetoid
from torch_geometric.loader import DataLoader
# 데이터셋 로드 (Cora 데이터셋 사용)
# Cora 데이터셋은 논문 간의 인용 관계를 그래프로 표현한 데이터셋입니다.
# 노드는 논문을 나타내며, 각 노드의 특징은 논문에서 사용된 단어의 Bag-of-Words 표현입니다.
# 엣지는 논문 간의 인용 관계를 나타냅니다.
# 총 7개의 클래스로 논문 주제를 분류하는 작업을 포함합니다.
dataset = Planetoid(root="./data", name="Cora")
data = dataset[0]
이후 GraphSAGE 모델을 활용해서 노드 prediction을 진행해줍니다.
[개념정리] Graph SAGE란? Graph SAmple & aggreGatE)
이를 바탕으로 학습을 진행해줍니다. 정해진 train test 데이터 세팅을 그대로 사용해주며, 모델은 간단한 2층의 GraphSAGE입니다.
# Graph SAGE 모델 정의
class GraphSAGE(nn.Module):
def __init__(self, in_channels, hidden_channels, out_channels):
super(GraphSAGE, self).__init__()
self.conv1 = SAGEConv(in_channels, hidden_channels)
self.conv2 = SAGEConv(hidden_channels, out_channels)
def forward(self, x, edge_index):
x = self.conv1(x, edge_index)
x = torch.relu(x)
x = self.conv2(x, edge_index)
return x
# 모델 초기화
in_channels = dataset.num_node_features
hidden_channels = 16
out_channels = dataset.num_classes
model = GraphSAGE(in_channels, hidden_channels, out_channels)
# 손실 함수와 옵티마이저 정의
criterion = nn.CrossEntropyLoss()
optimizer = optim.Adam(model.parameters(), lr=0.01)
# 학습 및 평가 함수 정의
def train():
model.train()
optimizer.zero_grad()
out = model(data.x, data.edge_index)
loss = criterion(out[data.train_mask], data.y[data.train_mask])
loss.backward()
optimizer.step()
return loss.item()
def test():
model.eval()
with torch.no_grad():
out = model(data.x, data.edge_index)
pred = out.argmax(dim=1)
train_correct = (pred[data.train_mask] == data.y[data.train_mask]).sum()
train_acc = int(train_correct) / int(data.train_mask.sum())
val_correct = (pred[data.val_mask] == data.y[data.val_mask]).sum()
val_acc = int(val_correct) / int(data.val_mask.sum())
test_correct = (pred[data.test_mask] == data.y[data.test_mask]).sum()
test_acc = int(test_correct) / int(data.test_mask.sum())
return train_acc, val_acc, test_acc
# 학습 루프
num_epochs = 100
for epoch in range(1, num_epochs + 1):
loss = train()
train_acc, val_acc, test_acc = test()
print(f"Epoch: {epoch:03d}, Loss: {loss:.4f}, Train Acc: {train_acc:.4f}, Val Acc: {val_acc:.4f}, Test Acc: {test_acc:.4f}")
100회의 학습결과는 아래와같이 76.8%의 노드가 잘 분류된 것을 알 수있습니다.
학습된 모델을 바탕으로 테스트 데이터에 적용된 결과를 시각화 하면 아래와 같습니다.
import matplotlib.pyplot as plt
import networkx as nx
from torch_geometric.utils import to_networkx
# 모델 학습 완료 후, 결과 시각화 준비
model.eval()
with torch.no_grad():
out = model(data.x, data.edge_index)
pred = out.argmax(dim=1)
# 그래프 데이터 PyTorch Geometric -> NetworkX 변환
graph = to_networkx(data, to_undirected=True)
# 시각화를 위한 노드 색상 지정 (모델의 예측값 기반)
node_colors = [pred[node].item() for node in range(data.num_nodes)]
# 그래프 시각화
plt.figure(figsize=(12, 8))
pos = nx.spring_layout(graph, seed=42) # 노드 배치 (Spring Layout 사용)
nx.draw(graph, pos, node_color=node_colors, cmap="Set3", node_size=50, with_labels=False)
plt.title("Graph SAGE Classification Results", fontsize=16)
plt.show()
댓글