GNN、GCN与GAT详解

前言

图神经网络(GNN)专门处理图结构数据,能够捕捉节点之间的关系。本文介绍GNN的基本概念、GCN和GAT等经典架构。


图数据基础

图的表示

import numpy as np
import matplotlib.pyplot as plt

np.random.seed(42)

# 图的基本表示
class Graph:
    """简单图数据结构"""
    
    def __init__(self, num_nodes):
        self.num_nodes = num_nodes
        self.adj_matrix = np.zeros((num_nodes, num_nodes))
        self.node_features = None
    
    def add_edge(self, i, j, bidirectional=True):
        self.adj_matrix[i, j] = 1
        if bidirectional:
            self.adj_matrix[j, i] = 1
    
    def set_features(self, features):
        self.node_features = features
    
    def get_neighbors(self, node):
        return np.where(self.adj_matrix[node] == 1)[0]
    
    def get_degree(self, node):
        return int(np.sum(self.adj_matrix[node]))

# 创建示例图
graph = Graph(6)
edges = [(0, 1), (0, 2), (1, 2), (1, 3), (2, 4), (3, 4), (3, 5), (4, 5)]
for i, j in edges:
    graph.add_edge(i, j)

# 设置节点特征
features = np.random.randn(6, 4)  # 6个节点,每个4维特征
graph.set_features(features)

print("邻接矩阵:")
print(graph.adj_matrix)
print(f"\n节点特征形状: {graph.node_features.shape}")
print(f"节点0的邻居: {graph.get_neighbors(0)}")

可视化图

def visualize_graph(adj_matrix, node_labels=None):
    """可视化图结构"""
    
    num_nodes = adj_matrix.shape[0]
    
    # 节点位置(圆形布局)
    angles = np.linspace(0, 2*np.pi, num_nodes, endpoint=False)
    positions = np.column_stack([np.cos(angles), np.sin(angles)])
    
    fig, ax = plt.subplots(figsize=(8, 8))
    
    # 绘制边
    for i in range(num_nodes):
        for j in range(i+1, num_nodes):
            if adj_matrix[i, j] == 1:
                ax.plot([positions[i, 0], positions[j, 0]],
                       [positions[i, 1], positions[j, 1]],
                       'gray', linewidth=2, alpha=0.5)
    
    # 绘制节点
    ax.scatter(positions[:, 0], positions[:, 1], s=500, c='lightblue',
               edgecolors='black', linewidths=2, zorder=5)
    
    # 节点标签
    for i in range(num_nodes):
        label = node_labels[i] if node_labels else str(i)
        ax.text(positions[i, 0], positions[i, 1], label,
               ha='center', va='center', fontsize=12, fontweight='bold')
    
    ax.set_aspect('equal')
    ax.axis('off')
    ax.set_title('图结构可视化')
    
    plt.tight_layout()
    plt.show()

visualize_graph(graph.adj_matrix)

消息传递机制

核心思想

GNN的核心是消息传递(Message Passing):

  1. 聚合(Aggregate):收集邻居信息
  2. 更新(Update):结合自身特征更新
def simple_message_passing(adj_matrix, features, iterations=2):
    """简单消息传递"""
    
    num_nodes = features.shape[0]
    h = features.copy()
    
    for it in range(iterations):
        h_new = np.zeros_like(h)
        
        for i in range(num_nodes):
            # 聚合邻居特征(求和)
            neighbors = np.where(adj_matrix[i] == 1)[0]
            if len(neighbors) > 0:
                neighbor_sum = np.sum(h[neighbors], axis=0)
            else:
                neighbor_sum = np.zeros(h.shape[1])
            
            # 更新:自身 + 邻居平均
            degree = max(len(neighbors), 1)
            h_new[i] = h[i] + neighbor_sum / degree
        
        h = h_new
        print(f"迭代 {it+1}: 特征范围 [{h.min():.3f}, {h.max():.3f}]")
    
    return h

# 测试
updated_features = simple_message_passing(graph.adj_matrix, graph.node_features)

图卷积网络(GCN)

GCN层

GCN的传播规则:

\[H^{(l+1)} = \sigma(\tilde{D}^{-\frac{1}{2}} \tilde{A} \tilde{D}^{-\frac{1}{2}} H^{(l)} W^{(l)})\]

其中 $\tilde{A} = A + I$ 是加入自环的邻接矩阵。

class GCNLayer:
    """GCN层(NumPy实现)"""
    
    def __init__(self, input_dim, output_dim):
        self.W = np.random.randn(input_dim, output_dim) * 0.01
    
    def forward(self, adj_matrix, features):
        # 加入自环
        A_hat = adj_matrix + np.eye(adj_matrix.shape[0])
        
        # 计算度矩阵
        D_hat = np.diag(np.sum(A_hat, axis=1))
        
        # 归一化
        D_inv_sqrt = np.linalg.inv(np.sqrt(D_hat))
        A_norm = D_inv_sqrt @ A_hat @ D_inv_sqrt
        
        # 传播
        out = A_norm @ features @ self.W
        
        # ReLU激活
        return np.maximum(0, out)


class SimpleGCN:
    """简单GCN模型"""
    
    def __init__(self, input_dim, hidden_dim, output_dim):
        self.layer1 = GCNLayer(input_dim, hidden_dim)
        self.layer2 = GCNLayer(hidden_dim, output_dim)
    
    def forward(self, adj_matrix, features):
        h = self.layer1.forward(adj_matrix, features)
        h = self.layer2.forward(adj_matrix, h)
        return h

# 测试GCN
gcn = SimpleGCN(input_dim=4, hidden_dim=16, output_dim=2)
output = gcn.forward(graph.adj_matrix, graph.node_features)

print(f"输入特征: {graph.node_features.shape}")
print(f"输出特征: {output.shape}")

PyTorch实现

GCN完整实现

try:
    import torch
    import torch.nn as nn
    import torch.nn.functional as F
    
    class GCNConv(nn.Module):
        """GCN卷积层"""
        
        def __init__(self, in_features, out_features):
            super().__init__()
            self.linear = nn.Linear(in_features, out_features, bias=False)
            self.bias = nn.Parameter(torch.zeros(out_features))
        
        def forward(self, x, adj):
            # 加入自环
            adj = adj + torch.eye(adj.size(0), device=adj.device)
            
            # 度矩阵
            deg = adj.sum(dim=1)
            deg_inv_sqrt = deg.pow(-0.5)
            deg_inv_sqrt[deg_inv_sqrt == float('inf')] = 0
            
            # 归一化邻接矩阵
            adj_norm = deg_inv_sqrt.unsqueeze(1) * adj * deg_inv_sqrt.unsqueeze(0)
            
            # 消息传递
            support = self.linear(x)
            output = torch.mm(adj_norm, support) + self.bias
            
            return output
    
    
    class GCN(nn.Module):
        """GCN模型"""
        
        def __init__(self, in_features, hidden_features, out_features, dropout=0.5):
            super().__init__()
            self.conv1 = GCNConv(in_features, hidden_features)
            self.conv2 = GCNConv(hidden_features, out_features)
            self.dropout = dropout
        
        def forward(self, x, adj):
            x = self.conv1(x, adj)
            x = F.relu(x)
            x = F.dropout(x, p=self.dropout, training=self.training)
            x = self.conv2(x, adj)
            return F.log_softmax(x, dim=1)
    
    # 测试
    x = torch.randn(6, 4)
    adj = torch.tensor(graph.adj_matrix, dtype=torch.float32)
    
    model = GCN(in_features=4, hidden_features=16, out_features=3)
    out = model(x, adj)
    
    print(f"GCN输出: {out.shape}")
    
except ImportError:
    print("PyTorch未安装")

图注意力网络(GAT)

注意力机制

GAT使用注意力机制学习邻居的重要性:

\[\alpha_{ij} = \frac{\exp(\text{LeakyReLU}(a^T [Wh_i || Wh_j]))}{\sum_{k \in \mathcal{N}_i} \exp(\text{LeakyReLU}(a^T [Wh_i || Wh_k]))}\]
try:
    class GATLayer(nn.Module):
        """GAT层"""
        
        def __init__(self, in_features, out_features, num_heads=1, concat=True):
            super().__init__()
            self.num_heads = num_heads
            self.out_features = out_features
            self.concat = concat
            
            # 每个头的变换
            self.W = nn.Linear(in_features, out_features * num_heads, bias=False)
            
            # 注意力参数
            self.a = nn.Parameter(torch.zeros(num_heads, 2 * out_features))
            nn.init.xavier_uniform_(self.a)
            
            self.leaky_relu = nn.LeakyReLU(0.2)
        
        def forward(self, x, adj):
            N = x.size(0)
            
            # 线性变换
            h = self.W(x).view(N, self.num_heads, self.out_features)
            
            # 计算注意力系数
            # [N, heads, out] -> [N, N, heads, 2*out]
            a_input = torch.cat([
                h.unsqueeze(1).repeat(1, N, 1, 1),
                h.unsqueeze(0).repeat(N, 1, 1, 1)
            ], dim=-1)
            
            # [N, N, heads]
            e = self.leaky_relu(torch.einsum('ijhd,hd->ijh', a_input, self.a))
            
            # 只对邻居计算注意力
            adj_3d = adj.unsqueeze(-1).repeat(1, 1, self.num_heads)
            e = e.masked_fill(adj_3d == 0, float('-inf'))
            
            attention = F.softmax(e, dim=1)
            attention = attention.masked_fill(torch.isnan(attention), 0)
            
            # 聚合
            # [N, heads, out]
            h_prime = torch.einsum('ijh,jhd->ihd', attention, h)
            
            if self.concat:
                return h_prime.view(N, -1)
            else:
                return h_prime.mean(dim=1)
    
    
    class GAT(nn.Module):
        """GAT模型"""
        
        def __init__(self, in_features, hidden_features, out_features, 
                     num_heads=8, dropout=0.6):
            super().__init__()
            self.dropout = dropout
            
            self.gat1 = GATLayer(in_features, hidden_features, num_heads, concat=True)
            self.gat2 = GATLayer(hidden_features * num_heads, out_features, 1, concat=False)
        
        def forward(self, x, adj):
            x = F.dropout(x, p=self.dropout, training=self.training)
            x = F.elu(self.gat1(x, adj))
            x = F.dropout(x, p=self.dropout, training=self.training)
            x = self.gat2(x, adj)
            return F.log_softmax(x, dim=1)
    
    # 测试GAT
    gat_model = GAT(in_features=4, hidden_features=8, out_features=3, num_heads=4)
    gat_out = gat_model(x, adj)
    
    print(f"GAT输出: {gat_out.shape}")
    
except NameError:
    print("需要先导入PyTorch")

使用PyG库

try:
    import torch_geometric
    from torch_geometric.nn import GCNConv, GATConv
    from torch_geometric.data import Data
    
    # 创建PyG数据
    edge_index = torch.tensor([[0, 0, 1, 1, 2, 3, 3, 4],
                               [1, 2, 2, 3, 4, 4, 5, 5]], dtype=torch.long)
    x = torch.randn(6, 4)
    
    data = Data(x=x, edge_index=edge_index)
    
    print(f"节点数: {data.num_nodes}")
    print(f"边数: {data.num_edges}")
    print(f"特征维度: {data.num_node_features}")
    
    # 使用PyG的GCN
    class PyGGCN(nn.Module):
        def __init__(self):
            super().__init__()
            self.conv1 = GCNConv(4, 16)
            self.conv2 = GCNConv(16, 3)
        
        def forward(self, data):
            x, edge_index = data.x, data.edge_index
            x = self.conv1(x, edge_index)
            x = F.relu(x)
            x = F.dropout(x, training=self.training)
            x = self.conv2(x, edge_index)
            return F.log_softmax(x, dim=1)
    
    pyg_gcn = PyGGCN()
    out = pyg_gcn(data)
    print(f"PyG GCN输出: {out.shape}")
    
except ImportError:
    print("PyTorch Geometric未安装")
    print("安装命令: pip install torch-geometric")

GNN应用场景

应用领域

应用 图结构 任务
社交网络 用户-好友 节点分类、链接预测
分子预测 原子-键 属性预测
推荐系统 用户-物品 推荐
知识图谱 实体-关系 链接预测
交通预测 路口-道路 流量预测

常见问题

Q1: GCN和GAT的区别?

特性 GCN GAT
邻居权重 固定(度归一化) 学习(注意力)
表达能力 较弱 较强
计算复杂度

Q2: 如何处理大规模图?

  • 采样方法(GraphSAGE)
  • 小批量训练
  • 分布式计算

Q3: GNN的局限性?

  • 过平滑问题(层数增加)
  • 计算复杂度
  • 难以处理异构图

Q4: 如何选择GNN架构?

  • 节点分类:GCN、GAT
  • 图分类:加入池化层
  • 链接预测:编码节点对

总结

模型 特点 适用场景
GCN 简单高效 一般图任务
GAT 注意力加权 需要区分邻居重要性
GraphSAGE 采样聚合 大规模图

参考资料

  • Kipf, T. & Welling, M. (2017). “Semi-Supervised Classification with Graph Convolutional Networks”
  • Veličković, P. et al. (2018). “Graph Attention Networks”
  • Hamilton, W. et al. (2017). “Inductive Representation Learning on Large Graphs”
  • PyTorch Geometric官方文档

版权声明: 如无特别声明,本文版权归 sshipanoo 所有,转载请注明本文链接。

(采用 CC BY-NC-SA 4.0 许可协议进行授权)

本文标题:《 机器学习基础系列——图神经网络 》

本文链接:http://localhost:3015/ai/%E5%9B%BE%E7%A5%9E%E7%BB%8F%E7%BD%91%E7%BB%9C.html

本文最后一次更新为 天前,文章中的某些内容可能已过时!