GNN

GNN(Graph Neural Network) - 10.MPNN(Message Passing Neural Network)

jsmak 2025. 2. 16. 17:23

그래프 구조를 학습하기 위해 메시지 전달 방식(Message Passing)을 사용하는 방식입니다.

GNN은 나와 이웃의 특징을 집계(Aggregation) 및 그래프 구조를 결합(Conbine)하여 특징을 다시 업데이트 하는데, 이 과정을 Message Passing이라고 합니다.

GCN, GAT, GraphSAGE등의 그래프 데이터를 집계하고 업데이트 하는 과정을 일반화 한것으로 이해할 수 있습니다.

 

MPNN은 2017년 발표되었습니다.

논문: "Neural Message Passing for Quantum Chemistry"
저자: Justin Gilmer, Samuel S. Schoenholz, Patrick F. Riley, Oriol Vinyals, George E. Dahl
발표 연도: 2017년
논문 링크: https://arxiv.org/abs/1704.01212

 

Neural Message Passing for Quantum Chemistry

Supervised learning on molecules has incredible potential to be useful in chemistry, drug discovery, and materials science. Luckily, several promising and closely related neural network models invariant to molecular symmetries have already been described i

arxiv.org

 

MPNN 논문에서 메시지 구성을 3단계로 구성됩니다.

1. 메시지 생성(Message Function, M)

2. 업데이트(Update Function, U)

3. 읽기(Readout Function, R)

 

내용이 어려운데, 이해된 내용으로 정리하면, 

GNN은 여러 Layer를 쌓는데, Layer가 커질 수록 네이버가 1hop, 2hop, 3hop등으로 커집니다. message란 용어가 어려운데 그래프 데이터를 다음 단계로 전달하기 위한 함수 정도로 이해할수 있을것 같습니다. 현재 노드와 이웃 노드의 정보를 조합하여, 다음 단계에서 사용할 데이터를 생성하는 과정으로 이해됩니다.

 

MPNN의 동작과정

1-1 Message Function(메시지 생성)

그래프의 노드들이 이웃 노드들과 정보를 교환합니다.

이웃노드들로 부터 받을 메시지를 생성하는 과정입니다.

메시지는 자기자신의 노드의 특성만으로만 구성할 수도 있고, 이웃 노드들의 특성 및 간선의 특성값들도 사용합니다.

  • $M_t$ : 학습 가능한 메시지 함수
  • $h_v^t$ : 현재 노드 v의 t번째 상태
  • $h_w^t$ : 이웃 노드 w의 t번째 상태
  • $e_{vw}$ : 엣지 (v,w)의 속성

1-2 Update Function( 메시지 집계 및 업데이트 U)

나와 이웃의 메시지를 집계(Aggregate)하고 노드의 상태를 업데이트 합니다.

 

  • $U_t$: 노드 상태 업데이트 함수
  • $m_v^{(t+1)}$ : 집계된 메시지

1-3 Read Function (출력함수)

노드의 정보를 모아 최종적인 그래프 표현을 생성합니다.

GCN, GAT, GraphSAGE에서는 노드의 임베딩을 목표로해서 별도로 Read Function은 사용하지 않습니다.

MPNN이 분자 속성 예측등을 위해 설계되어, 최종 그래프 표현을 얻기위해 사용됩니다. 즉, MPNN은 전체 그래프는 하나의 벡터로 표현이 필요해서 별도로 정의됩니다. 

 

 

torch_geometric.nn.MessagePassing

MessagePassing는 torch_geometric에서 GNN을 구현하는 핵심 클래스로 

MPNN의 메시지 전달(M)과 업데이트(U)를 통합하여 제공합니다.
propagate()를 호출하면 message() → aggregate() → update() 순서로 동작합니다.

  • message() :  이웃 노드 정보를 기반으로 메시지를 생성.
  • aggregate() : 이웃 노드의 메시지를 집계.
  • update(): 집계된 메시지를 이용하여 노드 상태를 업데이트.
메서드 설명
propagate(edge_index, x=x, edge_attr=...) forward가 호출되면 propagate()가 호출되어 메시지 전달 과정을 시작합니다.
message(x_j, edge_attr, ...) 이웃 노드의 정보를 받아 메시지를 생성
aggregate(message, index, ...) 메시지를 집계 (기본적으로 sum, mean, max 중 선택)
update(aggr_out, x, ...) 집계된 메시지를 이용해 노드 상태를 업데이트

 

GCN & GAT & GraphSAGE message passing

 

  • Message Function (M)
    • GCN: Degree-normalized 메시지 전달 ($\frac{1}{\sqrt{d_i d_j}} W h_j$)
    • GAT: Attention 가중치 ($\alpha_{ij} W h_j$)를 곱해 중요도 반영
    • GraphSAGE: 단순 선형 변환 ($W h_j$) 후 집계
  • Aggregation Function ()
    • GCN: 단순 합산 또는 평균 사용 가능
    • GAT: Attention 가중치가 적용된 가중 합계 방식
    • GraphSAGE: Mean, LSTM, Pooling 등 다양한 집계 방식 지원
  • Update Function (U)
    • GCN & GAT: 단순 비선형 변환 $h_v^{t+1} = \sigma(m_v)$
    • GraphSAGE: 현재 노드 정보 $h_v$ 와 이웃 정보 $m_v$concat 후 가중치 적용 ($h_v^{t+1} = \sigma(W [h_v, m_v])$)

3개 단계를 다음가 같은 식으로 요약할 수 있습니다.

$h^{'}_i = \gamma \left( h_i \oplus _{j \in \mathcal{N}_i} \phi(h_i, h_j, e_{j,i}) \right)$

$h_i$ : 노드임베딩

$e_{j,i}$ : 간선임베딩

$ \phi$ : 메시지 함수

$\oplus$ : 집계 함수

$\gamma$ : 업데이트 함수

 

GCN Message Passing

GCN을 MessagePassing 클래스를 상속하여 구현합니다.

add_self_loop는 인접행렬 A가 자기자신은 포함하지 않기 때문에 단위행렬 I를 더하는 과정입니다. (A+I)

$$ \frac{1}{\sqrt{\deg(i)} \cdot \sqrt{\deg(j)}} $$

는 아래과 같이 구현됩니다.

row, col = edge_index
deg = degree(col, num_nodes=x.size(0)) # D : 차수
deg_inv_sqrt = deg.pow(-0.5)  # d(-1/2)
deg_inv_sqrt[deg_inv_sqrt == float('inf')] = 0
norm = deg_inv_sqrt[row] * deg_inv_sqrt[col]

전체 코드입니다.

import torch
import numpy as np

from torch.nn import Linear
from torch_geometric.nn import MessagePassing
from torch_geometric.utils import add_self_loops, degree
import torch.nn.functional as F

class CustomGCNConv(MessagePassing):
    def __init__(self, dim_in, dim_h):
        super().__init__(aggr='add')
        self.linear = Linear(dim_in, dim_h, bias=False)

    def forward(self, x, edge_index):
        # Self-loop 추가
        edge_index, _ = add_self_loops(edge_index, num_nodes=x.size(0)) 

        # 정규화
        row, col = edge_index
        deg = degree(col, num_nodes=x.size(0)) # D : 차수
        deg_inv_sqrt = deg.pow(-0.5)  # d(-1/2)
        deg_inv_sqrt[deg_inv_sqrt == float('inf')] = 0
        norm = deg_inv_sqrt[row] * deg_inv_sqrt[col]

        return self.propagate(edge_index, x=x, norm=norm)

    def message(self, x_j, norm):
        # norm : 정규화 계수
        return norm.view(-1,1)*self.linear(x_j)

    # init에 aggr을 add로 설정해서 생략해도 무방
    def update(self, aggr_out):
        return aggr_out

CustomConV를 사용해서 Cora를 작성해보겠습니다.

import torch
import torch.nn as nn
import torch.nn.functional as F
from torch_geometric.datasets import Planetoid

# 데이터셋 로드
dataset = Planetoid(root='.', name='Cora')
data = dataset[0]

# PyTorch Geometric GCN 모델 정의
class MyGGCN(nn.Module):
    def __init__(self, dim_in, dim_h, dim_out):
        super().__init__()
        self.conv1 = CustomGCNConv(dim_in, dim_h)
        self.conv2 = CustomGCNConv(dim_h, dim_out)

    def forward(self, x, edge_index):
        h = self.conv1(x, edge_index)
        h = F.relu(h)
        h = self.conv2(h, edge_index)
        return F.log_softmax(h, dim=1)

    def fit(self, data, epochs=100):
        criterion = nn.CrossEntropyLoss()
        optimizer = torch.optim.Adam(self.parameters(), lr=0.01, weight_decay=5e-4)
        self.train()

        for epoch in range(epochs+1):
            optimizer.zero_grad()
            out = self(data.x, data.edge_index)
            loss = criterion(out[data.train_mask], data.y[data.train_mask])
            acc = self.accuracy(out[data.train_mask].argmax(dim=1), data.y[data.train_mask])
            loss.backward()
            optimizer.step()

            if epoch % 20 == 0:
                val_loss = criterion(out[data.val_mask], data.y[data.val_mask])
                val_acc = self.accuracy(out[data.val_mask].argmax(dim=1), data.y[data.val_mask])
                print(f'Epoch {epoch:>3} | Train Loss: {loss:.3f} | Train Acc: {acc*100:>5.2f}% | Val Loss: {val_loss:.3f} | Val Acc: {val_acc*100:>5.2f}%')

    @torch.no_grad()
    def test(self, data):
        self.eval()
        out = self(data.x, data.edge_index)
        acc = self.accuracy(out.argmax(dim=1)[data.test_mask], data.y[data.test_mask])
        return acc

    @staticmethod
    def accuracy(y_pred, y_true):
        return (y_pred == y_true).sum().item() / len(y_true)

# GCN 모델 학습 및 테스트
model = MyGGCN(dataset.num_features, 16, dataset.num_classes)
model.fit(data, epochs=100)
test_acc_pyg = model.test(data)
print(f"Test Accuracy (PyTorch Geometric): {test_acc_pyg*100:.2f}%")
Epoch   0 | Train Loss: 1.946 | Train Acc: 16.43% | Val Loss: 1.944 | Val Acc: 17.80%
Epoch  20 | Train Loss: 0.197 | Train Acc: 100.00% | Val Loss: 0.839 | Val Acc: 77.60%
Epoch  40 | Train Loss: 0.021 | Train Acc: 100.00% | Val Loss: 0.735 | Val Acc: 77.60%
Epoch  60 | Train Loss: 0.020 | Train Acc: 100.00% | Val Loss: 0.708 | Val Acc: 77.60%
Epoch  80 | Train Loss: 0.023 | Train Acc: 100.00% | Val Loss: 0.703 | Val Acc: 76.80%
Epoch 100 | Train Loss: 0.020 | Train Acc: 100.00% | Val Loss: 0.707 | Val Acc: 76.40%
Test Accuracy (PyTorch Geometric): 79.70%

 

 

GAT message Passing

import torch
import torch.nn.functional as F
from torch.nn import Parameter
from torch_geometric.nn import MessagePassing
from torch_geometric.utils import softmax

class CustomGAT(MessagePassing):
    def __init__(self, in_channels, out_channels, heads=1, concat=True):
        super(CustomGAT, self).__init__(aggr='add')  # Attention 기반 집계
        self.heads = heads
        self.concat = concat
        self.W = torch.nn.Linear(in_channels, out_channels * heads, bias=False)  # W 변환
        self.att = Parameter(torch.Tensor(1, heads, 2 * out_channels))  # Attention 가중치

        torch.nn.init.xavier_uniform_(self.att.data)  # 가중치 초기화

    def forward(self, x, edge_index):
        """
        x: (N, in_channels) - 노드 특성
        edge_index: (2, E) - 엣지 정보
        """
        x = self.W(x).view(-1, self.heads, x.shape[-1])  # 다중 헤드 변환
        return self.propagate(edge_index, x=x)

    def message(self, x_i, x_j, index, ptr, size_i):
        """
        x_i: 현재 노드 특징
        x_j: 이웃 노드 특징
        index: 메시지를 받는 노드 인덱스
        """
        alpha = (torch.cat([x_i, x_j], dim=-1) * self.att).sum(dim=-1)  # Attention 점수 계산
        alpha = F.leaky_relu(alpha, negative_slope=0.2)  # LeakyReLU 적용
        alpha = softmax(alpha, index, ptr, size_i)  # Softmax 정규화

        return x_j * alpha.view(-1, self.heads, 1)  # 가중치를 곱한 메시지 반환

    def update(self, aggr_out):
        if self.concat:
            return aggr_out.view(aggr_out.shape[0], -1)  # 다중 헤드 결합
        else:
            return aggr_out.mean(dim=1)  # 평균 풀링

# ✅ 테스트 실행
gat = CustomGAT(in_channels=5, out_channels=2, heads=2)
x = torch.randn(4, 5)  # 4개의 노드, 5차원 특징
edge_index = torch.tensor([[0, 1, 2, 3], [1, 2, 3, 0]])  # 그래프 연결 정보

output = gat(x, edge_index)
print(output)  # (4, 4) 크기의 새로운 노드 임베딩 출력 (heads=2 이므로 2*2=4)